Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using refine() or align() with transcribe_any() #419

Open
tmoroney opened this issue Nov 26, 2024 · 3 comments
Open

Using refine() or align() with transcribe_any() #419

tmoroney opened this issue Nov 26, 2024 · 3 comments

Comments

@tmoroney
Copy link

tmoroney commented Nov 26, 2024

I am using the transcribe_any() method with MLX-Whisper on Mac and I was wondering if it is possible to use the refine() or align() when using transcribe_any with a custom model? If not, are there are any other methods of improving the timestamp accuracy that I am not currently using? It doesn't seem like there is but I thought i'd ask anyway.

My current code:

def inference(audio, **kwargs) -> dict:
        output = mlx_whisper.transcribe(
            audio,
            path_or_hf_repo=kwargs["model"],
            word_timestamps=True,
            verbose=True,
            task=kwargs["task"]
        )
    return output

def transcribe_audio(audio_file, kwargs, max_words, max_chars):
    print("Starting transcription...")
    whisperResult = stable_whisper.transcribe_any(inference, audio_file, inference_kwargs = kwargs, vad=True)
    whisperResult.split_by_length(max_words=max_words, max_chars=max_chars)
    return whisperResult.to_dict()
@tmoroney tmoroney changed the title Use refine() with transcribe_any() Use refine() or align() with transcribe_any() Nov 26, 2024
@tmoroney tmoroney changed the title Use refine() or align() with transcribe_any() Using refine() or align() with transcribe_any() Nov 26, 2024
@jianfch
Copy link
Owner

jianfch commented Nov 27, 2024

refine() or align() are supported for custom models.
Gap Adjustment was added recently. It might help, if you're looking for accurate segment timestamps. Note that the default settings may be far from optimal depending on the audio you're working with.

@tmoroney
Copy link
Author

tmoroney commented Nov 27, 2024

Do I need to change the way I am loading the model in order to use refine() or align()? I do not have a model object to be able to usemodel.align().

@jianfch
Copy link
Owner

jianfch commented Nov 27, 2024

It is less straight-forward than transcribe_any() because custom models have different signatures and use different frameworks. There are portions of codes in refine() and align() that is written for model with a specific class signature. So you will need to implement those portions to handle the custom model in order to use it.
For refine():

def get_prob():
tokens = torch.tensor(
[
*tokenizer.sot_sequence,
tokenizer.no_timestamps,
*text_tokens,
tokenizer.eot,
]
).to(model.device)
with torch.no_grad():
curr_mel_segment = mel_segment if prob_indices else orig_mel_segment
if single_batch:
logits = torch.cat(
[model(_mel.unsqueeze(0), tokens.unsqueeze(0)) for _mel in curr_mel_segment]
)
else:
logits = model(curr_mel_segment, tokens.unsqueeze(0))
sampled_logits = logits[:, len(tokenizer.sot_sequence):, : tokenizer.eot]
token_probs = sampled_logits.softmax(dim=-1)
text_token_probs = token_probs[:, np.arange(len(text_tokens)), text_tokens]
token_positions = token_probs[:, np.arange(len(text_tokens))]
if logits.shape[0] != 1 and prob_indices is not None:
indices1 = np.arange(len(prob_indices))
text_token_probs = text_token_probs[prob_indices, indices1]
token_positions = token_positions[prob_indices, indices1]
else:
text_token_probs.squeeze_(0)
text_token_probs = text_token_probs.tolist()
token_positions = \
(
token_positions.sort().indices == tokens[len(tokenizer.sot_sequence) + 1:-1][:, None]
).nonzero()[:, -1].tolist()
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens]), (1, 0))
word_probabilities = np.array([
text_token_probs[j-1] if is_end_ts else text_token_probs[i]
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
])
token_positions = [
token_positions[j-1] if is_end_ts else token_positions[i]
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
]
return word_probabilities, token_positions

For align():
def timestamp_words():
if curr_split_indices:
temp_split_indices = [0] + curr_split_indices
if temp_split_indices[-1] < len(curr_words):
temp_split_indices.append(len(curr_words))
temp_segments = [
dict(
seek=time_offset,
tokens=(curr_words[i:j], curr_word_tokens[i:j])
)
for i, j in zip(temp_split_indices[:-1], temp_split_indices[1:])
]
else:
temp_segments = [dict(seek=time_offset, tokens=(curr_words, curr_word_tokens))]
sample_padding = max(N_SAMPLES - audio_segment.shape[-1], 0)
mel_segment = log_mel_spectrogram(audio_segment, model.dims.n_mels, padding=sample_padding)
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(device=model.device)
add_word_timestamps_stable(
segments=temp_segments,
model=model,
tokenizer=tokenizer,
mel=mel_segment,
num_samples=segment_samples,
split_callback=(lambda x, _: x),
prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations,
gap_padding=gap_padding if presplit else None,
extra_models=extra_models,
pad_first_seg=pad_first_seg,
dynamic_heads=dynamic_heads
)
if len(temp_segments) == 1:
return temp_segments[0]
return dict(words=[w for seg in temp_segments for w in seg['words']])

There are may be a few other lines in refine() and align() that will also need be changed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants