Skip to content

Commit

Permalink
✨ add f5 test cases #176
Browse files Browse the repository at this point in the history
  • Loading branch information
zhzLuke96 committed Oct 17, 2024
1 parent fa57430 commit 330624a
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 84 deletions.
32 changes: 6 additions & 26 deletions tests/pipeline/test_chat_tts_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from modules.core.pipeline.factory import PipelineFactory
from modules.core.spk.TTSSpeaker import TTSSpeaker
from tests.pipeline.misc import load_audio, load_audio_wav, save_audio
from tests.pipeline.voice_clone_pipe import run_voice_clone_pipeline_test


@pytest.mark.pipe_chat_tts
Expand Down Expand Up @@ -40,30 +41,9 @@ def test_chat_tts_voice_clone_pipe():
voice_target_path = "./tests/test_inputs/chattts_out1.wav"
out_audio_path = "./tests/test_outputs/pipe_chat_tts_voice_clone_out1.wav"

voice_target = load_audio_wav(voice_target_path)
voice_spk = TTSSpeaker.from_ref_wav_bytes(
ref_wav=voice_target, text="这是一个测试文本。"
run_voice_clone_pipeline_test(
pipeline_func=PipelineFactory.create_cosyvoice_pipeline,
voice_target_path=voice_target_path,
voice_target_text="这是一个测试文本。",
out_audio_path=out_audio_path,
)

pipe0 = PipelineFactory.create_chattts_pipeline(
ctx=TTSPipelineContext(
text="你好,这里是音色克隆测试~",
tts_config=TTSConfig(
mid="chat-tts",
),
spk=voice_spk,
infer_config=InferConfig(eos=" ", sync_gen=True),
),
)

audio_sr, audio_data = pipe0.generate()
assert audio_data.dtype == np.float32
assert audio_data.size != 0
save_audio(
#
file_path=out_audio_path,
audio_data=audio_data,
sample_rate=audio_sr,
)
# 检查文件不为空
assert load_audio(out_audio_path)[1].size != 0
32 changes: 6 additions & 26 deletions tests/pipeline/test_cosyvoice_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,37 +7,17 @@
from modules.core.pipeline.factory import PipelineFactory
from modules.core.spk.TTSSpeaker import TTSSpeaker
from tests.pipeline.misc import load_audio, load_audio_wav, save_audio
from tests.pipeline.voice_clone_pipe import run_voice_clone_pipeline_test


@pytest.mark.pipe_cosyvoice
def test_cosy_voice_clone_pipe():
voice_target_path = "./tests/test_inputs/chattts_out1.wav"
out_audio_path = "./tests/test_outputs/pipe_cosyvoice_voice_clone_out1.wav"

voice_target = load_audio_wav(voice_target_path)
voice_spk = TTSSpeaker.from_ref_wav_bytes(
ref_wav=voice_target, text="这是一个测试文本。"
run_voice_clone_pipeline_test(
pipeline_func=PipelineFactory.create_cosyvoice_pipeline,
voice_target_path=voice_target_path,
voice_target_text="这是一个测试文本。",
out_audio_path=out_audio_path,
)

pipe0 = PipelineFactory.create_cosyvoice_pipeline(
ctx=TTSPipelineContext(
text="你好,这里是音色克隆测试~",
tts_config=TTSConfig(
mid="cosy-voice",
),
infer_config=InferConfig(eos=" ", sync_gen=True),
spk=voice_spk,
),
)

audio_sr, audio_data = pipe0.generate()
assert audio_data.dtype == np.float32
assert audio_data.size != 0
save_audio(
#
file_path=out_audio_path,
audio_data=audio_data,
sample_rate=audio_sr,
)
# 检查文件不为空
assert load_audio(out_audio_path)[1].size != 0
17 changes: 17 additions & 0 deletions tests/pipeline/test_f5_tts_pipe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest

from modules.core.pipeline.factory import PipelineFactory
from tests.pipeline.voice_clone_pipe import run_voice_clone_pipeline_test


@pytest.mark.pipe_f5_tts
def test_f5_voice_clone_pipe():
voice_target_path = "./tests/test_inputs/chattts_out1.wav"
out_audio_path = "./tests/test_outputs/pipe_f5_voice_clone_out1.wav"

run_voice_clone_pipeline_test(
pipeline_func=PipelineFactory.create_f5_tts_pipeline,
voice_target_path=voice_target_path,
voice_target_text="这是一个测试文本。",
out_audio_path=out_audio_path,
)
38 changes: 6 additions & 32 deletions tests/pipeline/test_fire_red_tts_pipe.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,17 @@
import numpy as np
import pytest

from modules.core.handler.datacls.audio_model import AdjustConfig
from modules.core.handler.datacls.tts_model import InferConfig, TTSConfig
from modules.core.pipeline.dcls import TTSPipelineContext
from modules.core.pipeline.factory import PipelineFactory
from modules.core.spk.TTSSpeaker import TTSSpeaker
from tests.pipeline.misc import load_audio, load_audio_wav, save_audio
from tests.pipeline.voice_clone_pipe import run_voice_clone_pipeline_test


@pytest.mark.pipe_fire_red_tts
def test_fire_red_tts_clone_pipe():
voice_target_path = "./tests/test_inputs/chattts_out1.wav"
out_audio_path = "./tests/test_outputs/pipe_fire_red_tts_voice_clone_out1.wav"

voice_target = load_audio_wav(voice_target_path)
voice_spk = TTSSpeaker.from_ref_wav_bytes(
ref_wav=voice_target, text="这是一个测试文本。"
run_voice_clone_pipeline_test(
pipeline_func=PipelineFactory.create_fire_red_tts_pipeline,
voice_target_path=voice_target_path,
voice_target_text="这是一个测试文本。",
out_audio_path=out_audio_path,
)

pipe0 = PipelineFactory.create_fire_red_tts_pipeline(
ctx=TTSPipelineContext(
text="你好,这里是音色克隆测试~",
tts_config=TTSConfig(
mid="fire-red-tts",
),
infer_config=InferConfig(eos=" ", sync_gen=True),
spk=voice_spk,
),
)

audio_sr, audio_data = pipe0.generate()
assert audio_data.dtype == np.float32
assert audio_data.size != 0
save_audio(
#
file_path=out_audio_path,
audio_data=audio_data,
sample_rate=audio_sr,
)
# 检查文件不为空
assert load_audio(out_audio_path)[1].size != 0
56 changes: 56 additions & 0 deletions tests/pipeline/voice_clone_pipe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from typing import Callable
import numpy as np
import pytest

from modules.core.handler.datacls.audio_model import AdjustConfig
from modules.core.handler.datacls.tts_model import InferConfig, TTSConfig
from modules.core.pipeline.dcls import TTSPipelineContext
from modules.core.pipeline.factory import PipelineFactory
from modules.core.pipeline.pipeline import AudioPipeline
from modules.core.spk.TTSSpeaker import TTSSpeaker
from tests.pipeline.misc import load_audio, load_audio_wav, save_audio


# Pipeline function type, which returns a pipeline with a generate() method
PipelineFunc = Callable[
[TTSPipelineContext], "AudioPipeline"
] # Type hint for pipeline function


def run_voice_clone_pipeline_test(
pipeline_func: PipelineFunc,
voice_target_path: str,
voice_target_text: str,
out_audio_path: str,
text: str = "你好,这里是音色克隆测试~",
) -> None:
"""辅助函数:执行音色克隆测试并保存结果音频。
参数:
- pipeline_func: 用于创建 TTS 流水线的函数。
- voice_target_path: 目标语音的文件路径。
- voice_target_text: 目标语音的文本。
- out_audio_path: 输出音频的文件路径。
- text: 生成音频时使用的文本。
"""
voice_target: bytes = load_audio_wav(voice_target_path)
voice_spk: TTSSpeaker = TTSSpeaker.from_ref_wav_bytes(
ref_wav=voice_target, text=voice_target_text
)

pipe: AudioPipeline = pipeline_func(
ctx=TTSPipelineContext(
text=text,
tts_config=TTSConfig(mid="cosy-voice"),
infer_config=InferConfig(eos=" ", sync_gen=True),
spk=voice_spk,
),
)

audio_sr, audio_data = pipe.generate() # Tuple[int, np.ndarray]
assert audio_data.dtype == np.float32
assert audio_data.size != 0

save_audio(file_path=out_audio_path, audio_data=audio_data, sample_rate=audio_sr)
# 检查文件不为空
assert load_audio(out_audio_path)[1].size != 0

0 comments on commit 330624a

Please sign in to comment.