Skip to content

Commit

Permalink
✨ webui in tts_pipeline
Browse files Browse the repository at this point in the history
- add infer_config.sync_gen 以支持 track_tqdm
- remove 老代码

refs #91 #90
  • Loading branch information
zhzLuke96 committed Jul 13, 2024
1 parent 112a90f commit ea39d6b
Show file tree
Hide file tree
Showing 19 changed files with 79 additions and 431 deletions.
2 changes: 0 additions & 2 deletions modules/core/handler/AudioHandler.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import base64
import io
import struct
import wave
from typing import AsyncGenerator, Generator

import numpy as np
from fastapi import Request
from pydub import AudioSegment

from modules.core.handler.encoder.StreamEncoder import StreamEncoder
from modules.core.handler.encoder.WavFile import WAVFileBytes
Expand Down
3 changes: 3 additions & 0 deletions modules/core/handler/datacls/chattts_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,6 @@ class InferConfig(BaseModel):
stream_chunk_size: int = 96

no_cache: bool = False

# 开启同步生成 (主要是给gradio用)
sync_gen: bool = False
14 changes: 9 additions & 5 deletions modules/core/models/tts/ChatTtsModel.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Any, Generator
from typing import Any, Generator, Union

import numpy as np

from modules.core.models.TTSModel import TTSModel
from modules.core.models.zoo.ChatTTS import ChatTTS, load_chat_tts, unload_chat_tts
from modules.core.models.zoo.ChatTTSInfer import ChatTTSInfer
from modules.core.models.zoo.InerCache import InferCache
from modules.core.models.tts.InerCache import InferCache
from modules.core.pipeline.dcls import TTSPipelineContext
from modules.core.pipeline.pipeline import TTSSegment
from modules.core.pipeline.processor import NP_AUDIO
Expand All @@ -16,7 +16,7 @@ class ChatTTSModel(TTSModel):
model_id = "chat-tts"

def __init__(self) -> None:
super().__init__("chat-tts-4w")
super().__init__("chat-tts")
self.chat: ChatTTS = None

def load(self, context: TTSPipelineContext) -> ChatTTS:
Expand Down Expand Up @@ -75,11 +75,15 @@ def get_cache_kwargs(self, segments: list[TTSSegment], context: TTSPipelineConte

def get_cache(
self, segments: list[TTSSegment], context: TTSPipelineContext
) -> list[NP_AUDIO]:
) -> Union[list[NP_AUDIO], None]:
no_cache = context.infer_config.no_cache
if no_cache:
return None

is_random_generate = context.infer_config.seed == -1
if is_random_generate:
return None

kwargs = self.get_cache_kwargs(segments=segments, context=context)

if InferCache.get_cache_val(model_id=self.model_id, **kwargs):
Expand All @@ -102,7 +106,7 @@ def set_cache(

def generate_batch_base(
self, segments: list[TTSSegment], context: TTSPipelineContext, stream=False
) -> list[NP_AUDIO] | Generator[list[NP_AUDIO], Any, None]:
) -> Union[list[NP_AUDIO], Generator[list[NP_AUDIO], Any, None]]:
cached = self.get_cache(segments=segments, context=context)
if cached is not None:
if not stream:
Expand Down
File renamed without changes.
13 changes: 13 additions & 0 deletions modules/core/models/zoo/ChatTTSInfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from modules import config
from modules.ChatTTS.ChatTTS.core import Chat
from modules.ChatTTS.ChatTTS.model import GPT
from modules.core.models import zoo
from modules.utils.monkey_tqdm import disable_tqdm


Expand Down Expand Up @@ -48,6 +49,9 @@ def __init__(self, instance: Chat) -> None:
self.instance = instance
ChatTTSInfer.current_infer = self

if zoo.zoo_config.debug_generate:
self.logger.setLevel(logging.DEBUG)

def get_tokenizer(self) -> LlamaTokenizer:
return self.instance.pretrain_models["tokenizer"]

Expand Down Expand Up @@ -102,6 +106,14 @@ def _infer(
# smooth_decoding = stream
smooth_decoding = False

self.logger.debug(
f"Start infer: stream={stream}, skip_refine_text={skip_refine_text}, refine_text_only={refine_text_only}, use_decoder={use_decoder}, smooth_decoding={smooth_decoding}"
)
self.logger.debug(
f"params_refine_text={params_refine_text}, params_infer_code={params_infer_code}"
)
self.logger.debug(f"Text: {text}")

with torch.no_grad():

if not skip_refine_text:
Expand Down Expand Up @@ -131,6 +143,7 @@ def _infer(
wavs = self._decode_to_wavs(result, length, use_decoder)
yield wavs
else:
# NOTE: 貌似没什么用...?
# smooth_decoding 即使用了滑动窗口的解码,每次都保留上一段的隐藏状态一起解码,并且保留上一段的音频长度用于截取
@dataclass(repr=False, eq=False)
class WavWindow:
Expand Down
1 change: 1 addition & 0 deletions modules/core/models/zoo/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from . import ChatTTS
from . import zoo_config
1 change: 1 addition & 0 deletions modules/core/models/zoo/zoo_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
debug_generate = False
2 changes: 1 addition & 1 deletion modules/core/pipeline/dcls.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class TTSPipelineContext:
text: Optional[str] = None
ssml: Optional[str] = None

spk: Speaker = None
spk: Optional[Speaker] = None
tts_config: ChatTTSConfig = ChatTTSConfig()
infer_config: InferConfig = InferConfig()
adjust_config: AdjustConfig = AdjustConfig()
Expand Down
11 changes: 11 additions & 0 deletions modules/core/pipeline/generate/BatchSynth.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(
self.generator = BatchGenerate(
buckets=self.buckets, context=context, model=model
)
self.context = context

self.thread1 = None

Expand All @@ -40,9 +41,19 @@ def read(self):
return self.streamer.read()

def start_generate(self):
sync_gen = self.context.infer_config.sync_gen
if sync_gen:
self.start_generate_sync()
else:
self.start_generate_async()

def start_generate_async(self):
if self.thread1 is not None:
return
gen_t1 = threading.Thread(target=self.generator.generate, args=(), daemon=True)
gen_t1.start()
self.thread1 = gen_t1
return gen_t1

def start_generate_sync(self):
self.generator.generate()
227 changes: 0 additions & 227 deletions modules/generate_audio.py

This file was deleted.

Loading

0 comments on commit ea39d6b

Please sign in to comment.