diff --git a/src/seamless_communication/cli/expressivity/evaluate/run_asr_bleu.py b/src/seamless_communication/cli/expressivity/evaluate/run_asr_bleu.py index 6d705619..ee789877 100644 --- a/src/seamless_communication/cli/expressivity/evaluate/run_asr_bleu.py +++ b/src/seamless_communication/cli/expressivity/evaluate/run_asr_bleu.py @@ -4,12 +4,158 @@ # This source code is licensed under the license found in the # MIT_LICENSE file in the root directory of this source tree. +from pathlib import Path +from typing import Optional, Tuple, Union +import logging +import pandas as pd +import whisper +from whisper import Whisper +from sacrebleu.metrics.base import Score, Signature from fire import Fire from seamless_communication.cli.eval_utils.compute_metrics import ( - compute_quality_metrics, + init_whisper_model, + compute_corpus_metric_score, + transcribe_series, ) from fairseq2.typing import Device -from pathlib import Path + + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s -- %(name)s: %(message)s", +) + +logger = logging.getLogger(__name__) + + +def compute_asr_bleu( + audio_paths_series: pd.Series, + ref_text_series: pd.Series, + lang: str, + asr_model: Whisper, + whisper_normalize_text: bool = True, + beam_size: int = 1, + temperature: float = 0.0, + return_transcriptions: bool = True, +) -> Tuple[Score, Signature, pd.DataFrame]: + if lang == "cmn": + try: + import chinese_converter + except ImportError: + raise ImportError( + "Please install chinese_converter: pip install chinese_converter" + ) + audio_transcriptions = transcribe_series( + audio_paths_series, + asr_model, + audio_lang=lang, + beam_size=beam_size, + temperature=temperature, + ) + if lang == "cmn": + audio_transcriptions = pd.Series([chinese_converter.to_simplified(text) for text in audio_transcriptions]) + ref_text_series = pd.Series([chinese_converter.to_simplified(text) for text in ref_text_series]) + asr_bleu, asr_bleu_signature = compute_corpus_metric_score( + audio_transcriptions, ref_text_series, lang, whisper_normalize_text + ) + asr_bleu_signature.info["whisper_asr_beam_size"] = beam_size + asr_bleu_signature.info["whisper_asr_temperature"] = temperature + asr_bleu_signature.info["whisper_asr_language"] = lang + + transcript_df = None + if return_transcriptions: + transcript_df = pd.concat( + [ + audio_paths_series, + audio_transcriptions, + ref_text_series, + ], + axis=1, + keys=["audio", "transcript", "reference"], + ) + return asr_bleu, asr_bleu_signature, transcript_df + + +def compute_s2st_quality_metrics( + output_manifest_tsv_path: Path, + output_path: Path, + tgt_lang: str, + device: Device, + whisper_model_name: str = "large", + whisper_normalize_text_output: bool = False, + ref_text_col_name: str = "ref_tgt_text", + pred_text_col_name: Optional[str] = "pred_tgt_text", + pred_audio_col_name: str = "pred_tgt_audio", +) -> str: + + df = pd.read_csv( + output_manifest_tsv_path, sep="\t", quoting=3, encoding="utf-8", escapechar="\\" + ) + + if not output_path.exists(): + output_path.mkdir(parents=True, exist_ok=True) + + metric = "bleu" + + text_metric, text_metric_signature = compute_corpus_metric_score( + hyp_text_series=df[pred_text_col_name], + ref_text_series=df[ref_text_col_name], + lang=tgt_lang, + whisper_normalize_text=whisper_normalize_text_output, + metric=metric, + ) + text_metric_json = text_metric.format( + signature=text_metric_signature.format(), is_json=True + ) + + filename = ( + "s2tt_bleu_normalized.json" + if whisper_normalize_text_output + else "s2tt_bleu.json" + ) + cur_task = "S2TT" + + with open(output_path / filename, "w") as f: + f.write(text_metric_json) + + logger.info(f"{cur_task} {metric}:\n{text_metric_json}") + + whisper_model = init_whisper_model(device, whisper_model_name) + ( + asr_bleu_normalized, + asr_bleu_normalized_signature, + transcripts_df, + ) = compute_asr_bleu( + audio_paths_series=df[pred_audio_col_name], + ref_text_series=df[ref_text_col_name], + lang=tgt_lang, + asr_model=whisper_model, + whisper_normalize_text=True, + ) + transcripts_df.to_csv( + (output_path / "whisper_audio_transcriptions.tsv"), + sep="\t", + index=False, + encoding="utf-8", + escapechar="\\", + ) + + asr_bleu_normalized_signature.info["whisper_asr_model"] = whisper_model_name + + asr_bleu_normalized_json = asr_bleu_normalized.format( + signature=asr_bleu_normalized_signature.format(), is_json=True + ) + filename = "s2st_asr_bleu_normalized.json" + + with open( + output_path / filename, + "w", + ) as f: + f.write(asr_bleu_normalized_json) + + logger.info(f"S2ST ASR Normalized BLEU:\n{asr_bleu_normalized_json}") + + return filename def run_asr_bleu_expressive_model( @@ -17,11 +163,10 @@ def run_asr_bleu_expressive_model( generate_tsv_filename: str, tgt_lang: str, ): - compute_quality_metrics( + compute_s2st_quality_metrics( f"{generation_dir_path}/{generate_tsv_filename}", Path(generation_dir_path), tgt_lang, - "S2ST", device=Device("cuda"), ref_text_col_name="tgt_text", pred_text_col_name="s2t_out",