Skip to content

Commit

Permalink
updated Faster-Whisper and HF compatibility
Browse files Browse the repository at this point in the history
-updated `align()` and `transcribe_stable()` to be compatible with models on the latest faster-whisper commit (#403)
-added `pipeline_kwargs` to `load_hf_whisper()` for passing specific arguments to `transformers.AutoModelForSpeechSeq2Seq.pipeline()`
-added `"large-v3-turbo"` and `"turbo"` to `HF_MODELS` for loading `"openai/whisper-large-v3-turbo"` on Hugging Face
  • Loading branch information
jianfch committed Oct 11, 2024
1 parent df8dace commit 024d7dc
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 10 deletions.
9 changes: 7 additions & 2 deletions stable_whisper/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def align(
remaining_len = sum(len(w) for w in words)

if is_faster_model:
from .whisper_compatibility import is_faster_whisper_on_pt

def timestamp_words():
temp_segment = dict(
Expand All @@ -306,11 +307,15 @@ def timestamp_words():
end=round(segment_samples / model.feature_extractor.sampling_rate, 3),
tokens=[t for wt in curr_word_tokens for t in wt],
)
features = model.feature_extractor(audio_segment.cpu().numpy())
is_on_pt = is_faster_whisper_on_pt()
if is_on_pt:
features = model.feature_extractor(audio_segment)
else:
features = model.feature_extractor(audio_segment.cpu().numpy())
encoder_output = model.encode(features[:, : model.feature_extractor.nb_max_frames])

model.add_word_timestamps(
segments=[temp_segment],
segments=[[temp_segment]] if is_on_pt else [temp_segment],
tokenizer=tokenizer,
encoder_output=encoder_output,
num_frames=round(segment_samples / model.feature_extractor.hop_length),
Expand Down
13 changes: 13 additions & 0 deletions stable_whisper/whisper_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,26 @@
)
_required_whisper_ver = _COMPATIBLE_WHISPER_VERSIONS[-1]

_faster_compatibility = {}

if IS_WHISPER_AVAILABLE:
import whisper.tokenizer
_TOKENIZER_PARAMS = get_func_parameters(whisper.tokenizer.get_tokenizer)
else:
_TOKENIZER_PARAMS = ()


def is_faster_whisper_on_pt() -> bool:
if 'is_on_pt' not in _faster_compatibility:
try:
requirements = importlib.metadata.distribution('faster-whisper').requires
except importlib.metadata.PackageNotFoundError:
_faster_compatibility["is_on_pt"] = False
else:
_faster_compatibility["is_on_pt"] = any(r.startswith('torch') for r in requirements)
return _faster_compatibility["is_on_pt"]


def whisper_not_available(*args, **kwargs):
raise ModuleNotFoundError("Please install Whisper: "
"'pip install -U openai-whisper'. "
Expand Down
7 changes: 6 additions & 1 deletion stable_whisper/whisper_word_level/faster_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ..utils import safe_print, isolate_useful_options
from ..audio import audioloader_not_supported, convert_demucs_kwargs

from ..whisper_compatibility import LANGUAGES
from ..whisper_compatibility import LANGUAGES, is_faster_whisper_on_pt


def faster_transcribe(
Expand Down Expand Up @@ -130,6 +130,11 @@ def faster_transcribe(
denoiser, denoiser_options = convert_demucs_kwargs(
denoiser, denoiser_options, demucs=demucs, demucs_options=demucs_options
)
if is_faster_whisper_on_pt():
if 'audio_type' not in extra_options:
extra_options['audio_type'] = 'torch'
if 'model_sr' not in extra_options:
extra_options['model_sr'] = model.feature_extractor.sampling_rate
if not isinstance(audio, (str, bytes)):
if 'input_sr' not in extra_options:
extra_options['input_sr'] = model.feature_extractor.sampling_rate
Expand Down
19 changes: 12 additions & 7 deletions stable_whisper/whisper_word_level/hf_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
"large-v2": "openai/whisper-large-v2",
"large-v3": "openai/whisper-large-v3",
"large": "openai/whisper-large-v3",
"large-v3-turbo": "openai/whisper-large-v3-turbo",
"turbo": "openai/whisper-large-v3-turbo"
}


Expand All @@ -34,7 +36,7 @@ def get_device(device: str = None) -> str:
return 'cpu'


def load_hf_pipe(model_name: str, device: str = None, flash: bool = False):
def load_hf_pipe(model_name: str, device: str = None, flash: bool = False, **pipeline_kwargs):
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
device = get_device(device)
is_cpu = (device if isinstance(device, str) else getattr(device, 'type', None)) == 'cpu'
Expand All @@ -56,8 +58,8 @@ def load_hf_pipe(model_name: str, device: str = None, flash: bool = False):
except ValueError:
pass

pipe = pipeline(
"automatic-speech-recognition",
final_pipe_kwargs = dict(
task="automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
Expand All @@ -66,15 +68,18 @@ def load_hf_pipe(model_name: str, device: str = None, flash: bool = False):
torch_dtype=dtype,
device=device,
)
final_pipe_kwargs.update(**pipeline_kwargs)
pipe = pipeline(**final_pipe_kwargs)

return pipe


class WhisperHF:

def __init__(self, model_name: str, device: str = None, flash: bool = False, pipeline=None):
def __init__(self, model_name: str, device: str = None, flash: bool = False, pipeline=None, **pipeline_kwargs):
self._model_name = model_name
self._pipe = load_hf_pipe(self._model_name, device, flash=flash) if pipeline is None else pipeline
self._pipe = load_hf_pipe(self._model_name, device, flash=flash, **pipeline_kwargs) if pipeline is None \
else pipeline
self._model_name = getattr(self._pipe.model, 'name_or_path', self._model_name)

@property
Expand Down Expand Up @@ -259,5 +264,5 @@ def transcribe(
)


def load_hf_whisper(model_name: str, device: str = None, flash: bool = False, pipeline=None):
return WhisperHF(model_name, device, flash=flash, pipeline=pipeline)
def load_hf_whisper(model_name: str, device: str = None, flash: bool = False, pipeline=None, **pipeline_kwargs):
return WhisperHF(model_name, device, flash=flash, pipeline=pipeline, **pipeline_kwargs)

0 comments on commit 024d7dc

Please sign in to comment.