diff --git a/ddsp/core.py b/ddsp/core.py index 9153b6a..e0f720b 100644 --- a/ddsp/core.py +++ b/ddsp/core.py @@ -24,7 +24,7 @@ def get_fft_size(frame_size: int, ir_size: int, power_of_2: bool = True): return fft_size -def meanfilter(signal, kernel_size): +def mean_filter(signal, kernel_size): signal = signal.permute(0, 2, 1) signal = F.pad(signal, ((kernel_size - 1) // 2, kernel_size // 2), mode="reflect") ones_kernel = torch.ones(signal.size(1), 1, kernel_size, device=signal.device) @@ -40,13 +40,6 @@ def upsample(signal, factor): return signal.permute(0, 2, 1) -def remove_above_fmax(amplitudes, pitch, fmax, level_start=1): - n_harm = amplitudes.shape[-1] - pitches = pitch * torch.arange(level_start, n_harm + level_start).to(pitch) - aa = (pitches < fmax).float() + 1e-7 - return amplitudes * aa - - def crop_and_compensate_delay(audio, audio_size, ir_size, padding = 'same', delay_compensation = -1): @@ -123,7 +116,7 @@ def fft_convolve(audio, audio_frames = F.pad(audio, (hop_size, hop_size)).unfold(1, frame_size, hop_size) # B, n_frames+1, 2*hop_size # Apply Bartlett (triangular) window - window = torch.bartlett_window(frame_size).to(audio_frames) + window = torch.bartlett_window(frame_size, device=audio_frames.device) audio_frames = audio_frames * window # Pad and FFT the audio and impulse responses. @@ -145,97 +138,26 @@ def fft_convolve(audio, # Crop and shift the output audio. output_signal = crop_and_compensate_delay(output_signal[:,hop_size:], audio_size, ir_size) return output_signal - -cache_win={} -def apply_window_to_impulse_response(impulse_response, # B, n_frames, 2*(n_mag-1) - window_size: int = 0, - causal: bool = False): - """Apply a window to an impulse response and put in causal form. - Args: - impulse_response: A series of impulse responses frames to window, of shape - [batch, n_frames, ir_size]. ---------> ir_size means size of filter_bank ?????? - - window_size: Size of the window to apply in the time domain. If window_size - is less than 1, it defaults to the impulse_response size. - causal: Impulse response input is in causal form (peak in the middle). - Returns: - impulse_response: Windowed impulse response in causal form, with last - dimension cropped to window_size if window_size is greater than 0 and less - than ir_size. - """ - - # If IR is in causal form, put it in zero-phase form. - if causal: - impulse_response = torch.fftshift(impulse_response, axes=-1) - - # Get a window for better time/frequency resolution than rectangular. - # Window defaults to IR size, cannot be bigger. - ir_size = impulse_response.size(-1) - if (window_size <= 0) or (window_size > ir_size): - window_size = ir_size - crw = cache_win.get(window_size) - if crw is not None: - window = crw - else: - window= nn.Parameter(torch.hann_window(window_size), requires_grad = False).to(impulse_response) - cache_win[window_size] = window - - # Zero pad the window and put in in zero-phase form. - padding = ir_size - window_size - if padding > 0: - half_idx = (window_size + 1) // 2 - window = torch.cat([window[half_idx:], - torch.zeros([padding]), - window[:half_idx]], axis=0) - else: - window = window.roll(int(window.size(-1)//2), -1) - - # Apply the window, to get new IR (both in zero-phase form). - window = window.unsqueeze(0) - impulse_response = impulse_response * window - - # Put IR in causal form and trim zero padding. - if padding > 0: - first_half_start = (ir_size - (half_idx - 1)) + 1 - second_half_end = half_idx + 1 - impulse_response = torch.cat([impulse_response[..., first_half_start:], - impulse_response[..., :second_half_end]], - dim=-1) - else: - impulse_response = impulse_response.roll(int(impulse_response.size(-1)//2), -1) - return impulse_response - - -def apply_dynamic_window_to_impulse_response(impulse_response, # B, n_frames, 2*(n_mag-1) or 2*n_mag-1 - half_width_frames): # B,n_frames, 1 - ir_size = impulse_response.size(-1) # 2*(n_mag -1) or 2*n_mag-1 - - window = torch.arange(-(ir_size // 2), (ir_size + 1) // 2).to(impulse_response) / half_width_frames - window = torch.clamp(window, min=-1, max=1) - window = (1 + torch.cos(np.pi * window)) / 2 # B, n_frames, 2*(n_mag -1) or 2*n_mag-1 - - impulse_response = impulse_response.roll(int(ir_size // 2), -1) - impulse_response = impulse_response * window - - return impulse_response - - + def frequency_impulse_response(magnitudes, hann_window = True, half_width_frames = None): # Get the IR impulse_response = torch.fft.irfft(magnitudes) # B, n_frames, 2*(n_mags-1) + ir_size = impulse_response.size(-1) + impulse_response = impulse_response.roll(int(ir_size // 2), -1) # Window and put in causal form. if hann_window: if half_width_frames is None: - impulse_response = apply_window_to_impulse_response(impulse_response) + window = torch.hann_window(ir_size, device=impulse_response.device) else: - impulse_response = apply_dynamic_window_to_impulse_response(impulse_response, half_width_frames) - else: - impulse_response = impulse_response.roll(int(impulse_response.size(-1) // 2), -1) + window = torch.arange(-(ir_size // 2), (ir_size + 1) // 2, device=impulse_response.device) / half_width_frames + window = torch.clamp(window, min=-1, max=1) + window = (1 + torch.cos(np.pi * window)) / 2 # B, n_frames, 2*(n_mag -1) or 2*n_mag-1 + impulse_response *= window return impulse_response diff --git a/ddsp/vocoder.py b/ddsp/vocoder.py index 32c8ef3..caa3a2d 100644 --- a/ddsp/vocoder.py +++ b/ddsp/vocoder.py @@ -5,7 +5,7 @@ import torch.nn.functional as F from librosa.filters import mel as librosa_mel_fn from .mel2control import Mel2Control -from .core import frequency_filter, meanfilter, upsample, remove_above_fmax +from .core import frequency_filter, mean_filter, upsample class DotDict(dict): def __getattr__(*args): @@ -14,7 +14,8 @@ def __getattr__(*args): __setattr__ = dict.__setitem__ __delattr__ = dict.__delitem__ - + + def load_model( model_path, device='cpu'): @@ -55,7 +56,8 @@ def load_model( model.load_state_dict(ckpt['model']) model.eval() return model, args - + + class Audio2Mel(torch.nn.Module): def __init__( self, @@ -130,7 +132,8 @@ def forward(self, audio, keyshift=0, speed=1): # print('og_mel_spec:', log_mel_spec.shape) log_mel_spec = log_mel_spec.squeeze(2) # mono return log_mel_spec - + + class Sins(torch.nn.Module): def __init__(self, sampling_rate, @@ -162,7 +165,18 @@ def __init__(self, self.mean_kernel_size = win_length // block_size else: self.mean_kernel_size = 1 - + + def fast_phase_gen(self, f0_frames): + n = torch.arange(self.block_size, device=f0_frames.device) + s0 = f0_frames / self.sampling_rate + ds0 = F.pad(s0[:, 1:, :] - s0[:, :-1, :], (0, 0, 0, 1)) + rad = s0 * (n + 1) + 0.5 * ds0 * n * (n + 1) / self.block_size + rad2 = torch.fmod(rad[..., -1:].float() + 0.5, 1.0) - 0.5 + rad_acc = rad2.cumsum(dim=1).fmod(1.0).to(f0_frames) + rad += F.pad(rad_acc[:, :-1, :], (0, 0, 1, 0)) + phase = 2 * np.pi * rad.reshape(f0_frames.shape[0], -1, 1) + return phase + def forward(self, mel_frames, f0_frames, @@ -174,14 +188,7 @@ def forward(self, f0_frames: B x n_frames x 1 ''' # exciter phase - f0 = upsample(f0_frames, self.block_size) - if infer: - x = torch.cumsum(f0.double() / self.sampling_rate, axis=1) - else: - x = torch.cumsum(f0 / self.sampling_rate, axis=1) - x = x - torch.round(x) - x = x.to(f0) - phase = 2 * np.pi * x + phase = self.fast_phase_gen(f0_frames) # sinusoid exciter signal sinusoid = torch.sin(phase).squeeze(-1) @@ -194,25 +201,22 @@ def forward(self, # parameter prediction ctrls = self.mel2ctrl(mel_frames, sinusoid_frames, noise_frames) if self.mean_kernel_size > 1: - ctrls['amplitudes'] = meanfilter(ctrls['amplitudes'], self.mean_kernel_size) - ctrls['harmonic_phase'] = meanfilter(ctrls['harmonic_phase'], self.mean_kernel_size) + ctrls['amplitudes'] = mean_filter(ctrls['amplitudes'], self.mean_kernel_size) + ctrls['harmonic_phase'] = mean_filter(ctrls['harmonic_phase'], self.mean_kernel_size) src_allpass = torch.exp(1.j * np.pi * ctrls['harmonic_phase']) src_allpass = torch.cat((src_allpass, src_allpass[:,-1:,:]), 1) amplitudes_frames = torch.exp(ctrls['amplitudes'])/ 128 noise_param = torch.exp(ctrls['noise_magnitude'] + 1.j * np.pi * ctrls['noise_phase']) / 128 - # sinusoids exciter signal + # harmonic additive synthesis if infer and output_f0_frames is not None: f0_frames = output_f0_frames - f0 = upsample(f0_frames, self.block_size) - x = torch.cumsum(f0.double() / self.sampling_rate, axis=1) - x = x - torch.round(x) - x = x.to(f0) - phase = 2 * np.pi * x - amplitudes_frames = remove_above_fmax(amplitudes_frames, f0_frames, self.sampling_rate / 2, level_start = 1) + phase = self.fast_phase_gen(output_f0_frames) n_harmonic = amplitudes_frames.shape[-1] - level_harmonic = torch.arange(1, n_harmonic + 1).to(phase) + level_harmonic = torch.arange(1, n_harmonic + 1, device=phase.device) + mask = (f0_frames * level_harmonic < self.sampling_rate / 2).float() + 1e-7 + amplitudes_frames *= mask sinusoids = 0. for n in range(( n_harmonic - 1) // max_upsample_dim + 1): start = n * max_upsample_dim @@ -248,6 +252,7 @@ def forward(self, return signal, sinusoids, (harmonic, noise) + class CombSub(torch.nn.Module): def __init__(self, sampling_rate, @@ -278,7 +283,20 @@ def __init__(self, self.mean_kernel_size = win_length // block_size else: self.mean_kernel_size = 1 - + + def fast_source_gen(self, f0_frames): + n = torch.arange(self.block_size, device=f0_frames.device) + s0 = f0_frames / self.sampling_rate + ds0 = F.pad(s0[:, 1:, :] - s0[:, :-1, :], (0, 0, 0, 1)) + rad = s0 * (n + 1) + 0.5 * ds0 * n * (n + 1) / self.block_size + s0 = s0 + ds0 * n / self.block_size + rad2 = torch.fmod(rad[..., -1:].float() + 0.5, 1.0) - 0.5 + rad_acc = rad2.cumsum(dim=1).fmod(1.0).to(f0_frames) + rad += F.pad(rad_acc[:, :-1, :], (0, 0, 1, 0)) + rad -= torch.round(rad) + combtooth = torch.sinc(rad / (s0 + 1e-5)).reshape(f0_frames.shape[0], -1) + return combtooth + def forward(self, mel_frames, f0_frames, @@ -289,17 +307,9 @@ def forward(self, mel_frames: B x n_frames x n_mels f0_frames: B x n_frames x 1 ''' - # exciter phase - f0 = upsample(f0_frames, self.block_size) - if infer: - x = torch.cumsum(f0.double() / self.sampling_rate, axis=1) - else: - x = torch.cumsum(f0 / self.sampling_rate, axis=1) - x = x - torch.round(x) - x = x.to(f0) # combtooth exciter signal - combtooth = torch.sinc(self.sampling_rate * x / (f0 + 1e-3)).squeeze(-1) + combtooth = self.fast_source_gen(f0_frames) combtooth_frames = combtooth.unfold(1, self.block_size, self.block_size) # noise exciter signal @@ -309,8 +319,8 @@ def forward(self, # parameter prediction ctrls = self.mel2ctrl(mel_frames, combtooth_frames, noise_frames) if self.mean_kernel_size > 1: - ctrls['harmonic_magnitude'] = meanfilter(ctrls['harmonic_magnitude'], self.mean_kernel_size) - ctrls['harmonic_phase'] = meanfilter(ctrls['harmonic_phase'], self.mean_kernel_size) + ctrls['harmonic_magnitude'] = mean_filter(ctrls['harmonic_magnitude'], self.mean_kernel_size) + ctrls['harmonic_phase'] = mean_filter(ctrls['harmonic_phase'], self.mean_kernel_size) src_allpass = torch.exp(1.j * np.pi * ctrls['harmonic_phase']) src_allpass = torch.cat((src_allpass, src_allpass[:,-1:,:]), 1) @@ -320,11 +330,7 @@ def forward(self, # harmonic part filter (using dynamic-windowed LTV-FIR) if infer and output_f0_frames is not None: f0_frames = output_f0_frames - f0 = upsample(f0_frames, self.block_size) - x = torch.cumsum(f0.double() / self.sampling_rate, axis=1) - x = x - torch.round(x) - x = x.to(f0) - combtooth = torch.sinc(self.sampling_rate * x / (f0 + 1e-3)).squeeze(-1) + combtooth = self.fast_source_gen(output_f0_frames) harmonic = frequency_filter( combtooth, torch.complex(src_param, torch.zeros_like(src_param)),