Skip to content

Commit

Permalink
added generation parameters to WhisperHF.transcribe()
Browse files Browse the repository at this point in the history
-`WhisperHF.transcribe()` can now take generation parameters supported by `Transformers` (e.g. `temperature`, `num_beams`); note: some have same functionalities but different names from the normal `transcribe()`; for full list of parameters, see https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationConfig or `transformers.generation.configuration_utils.GenerationConfig.__doc__`
  • Loading branch information
jianfch committed Feb 8, 2024
1 parent a684fb4 commit 133f323
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
2 changes: 1 addition & 1 deletion stable_whisper/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.15.4"
__version__ = "2.15.5"
28 changes: 17 additions & 11 deletions stable_whisper/whisper_word_level/hf_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from ..audio import convert_demucs_kwargs, prep_audio
from ..non_whisper import transcribe_any
from ..utils import isolate_useful_options


HF_MODELS = {
Expand Down Expand Up @@ -86,9 +87,11 @@ def _inner_transcribe(
task: str = None,
batch_size: int = 24,
word_timestamps=True,
verbose: Optional[bool] = False
verbose: Optional[bool] = False,
**kwargs
):
generate_kwargs = {'task': task or 'transcribe', 'language': language}
generate_kwargs.update(kwargs)
if verbose is not None:
print(f'Transcribing with Hugging Face Whisper ({self._model_name})...')
result = self._pipe(
Expand Down Expand Up @@ -186,31 +189,34 @@ def transcribe(
check_sorted: bool = True,
**options
):
transcribe_any_options = isolate_useful_options(options, transcribe_any, pop=True)
denoiser, denoiser_options = convert_demucs_kwargs(
denoiser, denoiser_options,
demucs=options.pop('demucs', None), demucs_options=options.pop('demucs_options', None)
demucs=transcribe_any_options.pop('demucs', None),
demucs_options=transcribe_any_options.pop('demucs_options', None)
)

if isinstance(audio, (str, bytes)):
audio = prep_audio(audio, sr=self.sampling_rate).numpy()
options['input_sr'] = self.sampling_rate
transcribe_any_options['input_sr'] = self.sampling_rate

if 'input_sr' not in options:
options['input_sr'] = self.sampling_rate
if 'input_sr' not in transcribe_any_options:
transcribe_any_options['input_sr'] = self.sampling_rate

if denoiser or only_voice_freq:
if 'audio_type' not in options:
options['audio_type'] = 'numpy'
if 'model_sr' not in options:
options['model_sr'] = self.sampling_rate
if 'audio_type' not in transcribe_any_options:
transcribe_any_options['audio_type'] = 'numpy'
if 'model_sr' not in transcribe_any_options:
transcribe_any_options['model_sr'] = self.sampling_rate

inference_kwargs = dict(
audio=audio,
language=language,
task=task,
batch_size=batch_size,
word_timestamps=word_timestamps,
verbose=verbose
verbose=verbose,
**options
)
return transcribe_any(
inference_func=self._inner_transcribe,
Expand All @@ -234,7 +240,7 @@ def transcribe(
only_ffmpeg=only_ffmpeg,
force_order=True,
check_sorted=check_sorted,
**options
**transcribe_any_options
)


Expand Down

0 comments on commit 133f323

Please sign in to comment.