diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 8eb6a7185..4010a1635 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -42,6 +42,7 @@ def transcribe( verbose: Optional[bool] = None, temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), compression_ratio_threshold: Optional[float] = 2.4, + compression_ratio_hallucination_threshold: Optional[float] = 3, logprob_threshold: Optional[float] = -1.0, no_speech_threshold: Optional[float] = 0.6, condition_on_previous_text: bool = True, @@ -76,6 +77,9 @@ def transcribe( compression_ratio_threshold: float If the gzip compression ratio is above this value, treat as failed + compression_ratio_hallucination_threshold: float + If the gzip compression ratio is above this value after all attempts to decode, treat as a hallucination and skip + logprob_threshold: float If the average log probability over sampled tokens is below this value, treat as failed @@ -216,6 +220,13 @@ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: and decode_result.no_speech_prob > no_speech_threshold ): needs_fallback = False # silence + if ( + compression_ratio_hallucination_threshold is not None + and decode_result.compression_ratio > compression_ratio_hallucination_threshold + and t == temperatures[-1] + ): + # Discard the segment + return None # Skip to the next segment if not needs_fallback: break @@ -291,6 +302,14 @@ def new_segment( decode_options["prompt"] = all_tokens[prompt_reset_since:] result: DecodingResult = decode_with_fallback(mel_segment) + if result is None: + if verbose: + print( + f"Discarding segment {format_timestamp(time_offset)} - {format_timestamp(time_offset + segment_duration)} " + "due to high compression ratio." + ) + seek += segment_size # Move to the next segment + continue # Skip processing this segment tokens = torch.tensor(result.tokens) if no_speech_threshold is not None: