From bb8c47519d5badf4bbd8a12e482fa596e90a4e7b Mon Sep 17 00:00:00 2001 From: Alexander Kuznetsov Date: Fri, 1 Nov 2024 20:16:57 +0300 Subject: [PATCH 1/6] Add compression_ratio_hallucination_threshold Add compression_ratio_hallucination_threshold to Discard High Compression Ratio Segments in transcribe() https://github.com/openai/whisper/discussions/2420 --- whisper/transcribe.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 8eb6a7185..3c510e71c 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -42,11 +42,11 @@ def transcribe( verbose: Optional[bool] = None, temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), compression_ratio_threshold: Optional[float] = 2.4, + compression_ratio_halucination_threshold: Optional[float] = 3, logprob_threshold: Optional[float] = -1.0, no_speech_threshold: Optional[float] = 0.6, condition_on_previous_text: bool = True, initial_prompt: Optional[str] = None, - carry_initial_prompt: bool = False, word_timestamps: bool = False, prepend_punctuations: str = "\"'“¿([{-", append_punctuations: str = "\"'.。,,!!??::”)]}、", @@ -76,6 +76,9 @@ def transcribe( compression_ratio_threshold: float If the gzip compression ratio is above this value, treat as failed + compression_ratio_halcination_threshold: float + If the gzip compression ratio is above this value after all attempts to decode, treat as a halucination and skip + logprob_threshold: float If the average log probability over sampled tokens is below this value, treat as failed @@ -205,7 +208,7 @@ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold ): - needs_fallback = True # too repetitive + needs_fallback = True # too repetitive <-- We can inprove it... if ( logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold @@ -216,6 +219,13 @@ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: and decode_result.no_speech_prob > no_speech_threshold ): needs_fallback = False # silence + if ( + compression_ratio_halucination_threshold is not None + and decode_result.compression_ratio > compression_ratio_halucination_threshold + and t == temperatures[-1] + ): + # Discard the segment + continue # Skip to the next segment if not needs_fallback: break From 80ddd07c2808b3109c2fd034d67cebdfc91df5cd Mon Sep 17 00:00:00 2001 From: Alexander Kuznetsov Date: Fri, 1 Nov 2024 20:27:00 +0300 Subject: [PATCH 2/6] Update transcribe.py Typos --- whisper/transcribe.py | 27 +++++++-------------------- 1 file changed, 7 insertions(+), 20 deletions(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 3c510e71c..f25dd3501 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -42,7 +42,7 @@ def transcribe( verbose: Optional[bool] = None, temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), compression_ratio_threshold: Optional[float] = 2.4, - compression_ratio_halucination_threshold: Optional[float] = 3, + compression_ratio_hallucination_threshold: Optional[float] = 3, logprob_threshold: Optional[float] = -1.0, no_speech_threshold: Optional[float] = 0.6, condition_on_previous_text: bool = True, @@ -76,7 +76,7 @@ def transcribe( compression_ratio_threshold: float If the gzip compression ratio is above this value, treat as failed - compression_ratio_halcination_threshold: float + compression_ratio_hallucination_threshold: float If the gzip compression ratio is above this value after all attempts to decode, treat as a halucination and skip logprob_threshold: float @@ -106,11 +106,6 @@ def transcribe( "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those word correctly. - carry_initial_prompt: bool - If carry_initial_prompt is True, `initial_prompt` is prepended to the prompt of each internal - `decode()` call. If there is not enough context space at the start of the prompt, it is - left-sliced to make space. - decode_options: dict Keyword arguments to construct `DecodingOptions` instances @@ -220,12 +215,14 @@ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: ): needs_fallback = False # silence if ( - compression_ratio_halucination_threshold is not None - and decode_result.compression_ratio > compression_ratio_halucination_threshold + compression_ratio_hallucination_threshold is not None + and decode_result.compression_ratio > compression_ratio_hallucination_threshold and t == temperatures[-1] ): # Discard the segment continue # Skip to the next segment + + if not needs_fallback: break @@ -243,11 +240,9 @@ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: all_segments = [] prompt_reset_since = 0 - remaining_prompt_length = model.dims.n_text_ctx // 2 - 1 if initial_prompt is not None: initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip()) all_tokens.extend(initial_prompt_tokens) - remaining_prompt_length -= len(initial_prompt_tokens) else: initial_prompt_tokens = [] @@ -293,13 +288,7 @@ def new_segment( segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype) - if carry_initial_prompt: - nignored = max(len(initial_prompt_tokens), prompt_reset_since) - remaining_prompt = all_tokens[nignored:][-remaining_prompt_length:] - decode_options["prompt"] = initial_prompt_tokens + remaining_prompt - else: - decode_options["prompt"] = all_tokens[prompt_reset_since:] - + decode_options["prompt"] = all_tokens[prompt_reset_since:] result: DecodingResult = decode_with_fallback(mel_segment) tokens = torch.tensor(result.tokens) @@ -553,8 +542,6 @@ def valid_model_name(name): parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations") parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.") - parser.add_argument("--carry_initial_prompt", type=str2bool, default=False, help="if True, prepend initial_prompt to every internal decode() call. May reduce the effectiveness of condition_on_previous_text") - parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop") parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default") From 482a5b89d8db1984d05a6d35d4875e65350a54af Mon Sep 17 00:00:00 2001 From: Alexander Kuznetsov Date: Fri, 1 Nov 2024 20:33:52 +0300 Subject: [PATCH 3/6] Update transcribe.py --- whisper/transcribe.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index f25dd3501..06d61c96e 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -76,8 +76,8 @@ def transcribe( compression_ratio_threshold: float If the gzip compression ratio is above this value, treat as failed - compression_ratio_hallucination_threshold: float - If the gzip compression ratio is above this value after all attempts to decode, treat as a halucination and skip + compression_ratio_halcination_threshold: float + If the gzip compression ratio is above this value after all attempts to decode, treat as a hallucination and skip logprob_threshold: float If the average log probability over sampled tokens is below this value, treat as failed @@ -106,6 +106,11 @@ def transcribe( "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those word correctly. + carry_initial_prompt: bool + If carry_initial_prompt is True, `initial_prompt` is prepended to the prompt of each internal + `decode()` call. If there is not enough context space at the start of the prompt, it is + left-sliced to make space. + decode_options: dict Keyword arguments to construct `DecodingOptions` instances @@ -221,8 +226,6 @@ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: ): # Discard the segment continue # Skip to the next segment - - if not needs_fallback: break @@ -240,9 +243,11 @@ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: all_segments = [] prompt_reset_since = 0 + remaining_prompt_length = model.dims.n_text_ctx // 2 - 1 if initial_prompt is not None: initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip()) all_tokens.extend(initial_prompt_tokens) + remaining_prompt_length -= len(initial_prompt_tokens) else: initial_prompt_tokens = [] @@ -288,7 +293,13 @@ def new_segment( segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype) - decode_options["prompt"] = all_tokens[prompt_reset_since:] + if carry_initial_prompt: + nignored = max(len(initial_prompt_tokens), prompt_reset_since) + remaining_prompt = all_tokens[nignored:][-remaining_prompt_length:] + decode_options["prompt"] = initial_prompt_tokens + remaining_prompt + else: + decode_options["prompt"] = all_tokens[prompt_reset_since:] + result: DecodingResult = decode_with_fallback(mel_segment) tokens = torch.tensor(result.tokens) @@ -542,6 +553,8 @@ def valid_model_name(name): parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations") parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.") + parser.add_argument("--carry_initial_prompt", type=str2bool, default=False, help="if True, prepend initial_prompt to every internal decode() call. May reduce the effectiveness of condition_on_previous_text") + parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop") parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default") From df635793f15a6e9b5e2c6ff76717078ef47725b0 Mon Sep 17 00:00:00 2001 From: Alexander Kuznetsov Date: Fri, 1 Nov 2024 20:35:50 +0300 Subject: [PATCH 4/6] Update transcribe.py Typo --- whisper/transcribe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 06d61c96e..ef956371c 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -76,7 +76,7 @@ def transcribe( compression_ratio_threshold: float If the gzip compression ratio is above this value, treat as failed - compression_ratio_halcination_threshold: float + compression_ratio_hallucination_threshold: float If the gzip compression ratio is above this value after all attempts to decode, treat as a hallucination and skip logprob_threshold: float From f3184f38297f6d1ac1ba9e1ed7e9106ac282e0fe Mon Sep 17 00:00:00 2001 From: Alexander Kuznetsov Date: Fri, 1 Nov 2024 20:44:24 +0300 Subject: [PATCH 5/6] Update transcribe.py Added support for None when halucinating --- whisper/transcribe.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index ef956371c..ac62a4524 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -225,7 +225,7 @@ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: and t == temperatures[-1] ): # Discard the segment - continue # Skip to the next segment + return None # Skip to the next segment if not needs_fallback: break @@ -301,6 +301,14 @@ def new_segment( decode_options["prompt"] = all_tokens[prompt_reset_since:] result: DecodingResult = decode_with_fallback(mel_segment) + if result is None: + if verbose: + print( + f"Discarding segment {format_timestamp(time_offset)} - {format_timestamp(time_offset + segment_duration)} " + "due to high compression ratio." + ) + seek += segment_size # Move to the next segment + continue # Skip processing this segment tokens = torch.tensor(result.tokens) if no_speech_threshold is not None: From ce1b65e386b0038a13d3d3d29b083fce4e153081 Mon Sep 17 00:00:00 2001 From: Alexander Kuznetsov Date: Sun, 3 Nov 2024 13:02:15 +0300 Subject: [PATCH 6/6] Update transcribe.py Fixed comments --- whisper/transcribe.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index ac62a4524..4010a1635 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -47,6 +47,7 @@ def transcribe( no_speech_threshold: Optional[float] = 0.6, condition_on_previous_text: bool = True, initial_prompt: Optional[str] = None, + carry_initial_prompt: bool = False, word_timestamps: bool = False, prepend_punctuations: str = "\"'“¿([{-", append_punctuations: str = "\"'.。,,!!??::”)]}、", @@ -208,7 +209,7 @@ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold ): - needs_fallback = True # too repetitive <-- We can inprove it... + needs_fallback = True # too repetitive if ( logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold @@ -220,9 +221,9 @@ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: ): needs_fallback = False # silence if ( - compression_ratio_hallucination_threshold is not None - and decode_result.compression_ratio > compression_ratio_hallucination_threshold - and t == temperatures[-1] + compression_ratio_hallucination_threshold is not None + and decode_result.compression_ratio > compression_ratio_hallucination_threshold + and t == temperatures[-1] ): # Discard the segment return None # Skip to the next segment