diff --git a/ddsp/vocoder.py b/ddsp/vocoder.py index f934807..64bd153 100644 --- a/ddsp/vocoder.py +++ b/ddsp/vocoder.py @@ -666,7 +666,11 @@ def forward(self, units_frames, f0_frames, volume_frames, spk_id=None, spk_mix_d # combtooth exciter signal combtooth = torch.sinc(self.sampling_rate * x / (f0 + 1e-3)) - combtooth = combtooth.squeeze(-1) + combtooth = combtooth.squeeze(-1) + if combtooth.shape[-1] > self.win_length // 2: + pad_mode = 'reflect' + else: + pad_mode = 'constant' combtooth_stft = torch.stft( combtooth, n_fft = self.win_length, @@ -674,7 +678,8 @@ def forward(self, units_frames, f0_frames, volume_frames, spk_id=None, spk_mix_d hop_length = self.block_size, window = self.window, center = True, - return_complex = True) + return_complex = True, + pad_mode = pad_mode) # noise exciter signal noise = torch.randn_like(combtooth) @@ -685,7 +690,8 @@ def forward(self, units_frames, f0_frames, volume_frames, spk_id=None, spk_mix_d hop_length = self.block_size, window = self.window, center = True, - return_complex = True) + return_complex = True, + pad_mode = pad_mode) # apply the filters signal_stft = combtooth_stft * src_filter.permute(0, 2, 1) + noise_stft * noise_filter.permute(0, 2, 1)