Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
CNChTu authored Jan 31, 2024
1 parent 57e91bf commit 03fe7f5
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions diffusion/unit2mel.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ def load_model_vocoder_from_combo(combo_model_path, device='cpu', loaded_vocoder
if '_version_' in read_dict.keys():
raise ValueError(" [X] 这是新版本的模型, 请在新仓库中使用")
# 检查是否有键名“comb_diff_model”, 如果有为带级联训练的模型
if 'comb_diff_model' in read_dict.keys():
is_combo_diff_model = True
if '_is_comb_diff_model' in read_dict.keys():
if read_dict['_is_comb_diff_model']:
is_combo_diff_model = True
else:
is_combo_diff_model = False
# 如果打包了声码器, 则从权重中加载声码器
Expand Down Expand Up @@ -254,12 +255,12 @@ def __init__(
if not isinstance(naive_fn, DotDict):
assert isinstance(naive_fn, dict)
naive_fn = DotDict(naive_fn)
self.naive_stack = Unit2MelNaiveV2ForDiff(
input_channel=n_hidden,
out_dims=out_dims,
net_fn=naive_fn
)
self.naive_proj = nn.Linear(out_dims, n_hidden)
self.naive_stack = Unit2MelNaiveV2ForDiff(
input_channel=n_hidden,
out_dims=out_dims,
net_fn=naive_fn
)
self.naive_proj = nn.Linear(out_dims, n_hidden)

if denoise_fn.type == 'WaveNet':
# catch None
Expand Down Expand Up @@ -295,7 +296,7 @@ def __init__(
gradient_checkpointing=self.gradient_checkpointing
)

elif denoise_fn.type == 'NaiveV2Diff':
elif (denoise_fn.type == 'NaiveV2Diff') or (denoise_fn.type == 'LYNXNetDiff'):
# catch None
self.cn_layers = denoise_fn.cn_layers if (denoise_fn.cn_layers is not None) else 20
self.cn_chans = denoise_fn.cn_chans if (denoise_fn.cn_chans is not None) else 384
Expand Down

0 comments on commit 03fe7f5

Please sign in to comment.