Skip to content

Commit

Permalink
optimize
Browse files Browse the repository at this point in the history
  • Loading branch information
yxlllc committed Aug 22, 2024
1 parent a6326cd commit f875fbf
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 128 deletions.
98 changes: 10 additions & 88 deletions ddsp/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get_fft_size(frame_size: int, ir_size: int, power_of_2: bool = True):
return fft_size


def meanfilter(signal, kernel_size):
def mean_filter(signal, kernel_size):
signal = signal.permute(0, 2, 1)
signal = F.pad(signal, ((kernel_size - 1) // 2, kernel_size // 2), mode="reflect")
ones_kernel = torch.ones(signal.size(1), 1, kernel_size, device=signal.device)
Expand All @@ -40,13 +40,6 @@ def upsample(signal, factor):
return signal.permute(0, 2, 1)


def remove_above_fmax(amplitudes, pitch, fmax, level_start=1):
n_harm = amplitudes.shape[-1]
pitches = pitch * torch.arange(level_start, n_harm + level_start).to(pitch)
aa = (pitches < fmax).float() + 1e-7
return amplitudes * aa


def crop_and_compensate_delay(audio, audio_size, ir_size,
padding = 'same',
delay_compensation = -1):
Expand Down Expand Up @@ -123,7 +116,7 @@ def fft_convolve(audio,
audio_frames = F.pad(audio, (hop_size, hop_size)).unfold(1, frame_size, hop_size) # B, n_frames+1, 2*hop_size

# Apply Bartlett (triangular) window
window = torch.bartlett_window(frame_size).to(audio_frames)
window = torch.bartlett_window(frame_size, device=audio_frames.device)
audio_frames = audio_frames * window

# Pad and FFT the audio and impulse responses.
Expand All @@ -145,97 +138,26 @@ def fft_convolve(audio,
# Crop and shift the output audio.
output_signal = crop_and_compensate_delay(output_signal[:,hop_size:], audio_size, ir_size)
return output_signal

cache_win={}
def apply_window_to_impulse_response(impulse_response, # B, n_frames, 2*(n_mag-1)
window_size: int = 0,
causal: bool = False):
"""Apply a window to an impulse response and put in causal form.
Args:
impulse_response: A series of impulse responses frames to window, of shape
[batch, n_frames, ir_size]. ---------> ir_size means size of filter_bank ??????
window_size: Size of the window to apply in the time domain. If window_size
is less than 1, it defaults to the impulse_response size.
causal: Impulse response input is in causal form (peak in the middle).
Returns:
impulse_response: Windowed impulse response in causal form, with last
dimension cropped to window_size if window_size is greater than 0 and less
than ir_size.
"""

# If IR is in causal form, put it in zero-phase form.
if causal:
impulse_response = torch.fftshift(impulse_response, axes=-1)

# Get a window for better time/frequency resolution than rectangular.
# Window defaults to IR size, cannot be bigger.
ir_size = impulse_response.size(-1)
if (window_size <= 0) or (window_size > ir_size):
window_size = ir_size
crw = cache_win.get(window_size)
if crw is not None:
window = crw
else:
window= nn.Parameter(torch.hann_window(window_size), requires_grad = False).to(impulse_response)
cache_win[window_size] = window

# Zero pad the window and put in in zero-phase form.
padding = ir_size - window_size
if padding > 0:
half_idx = (window_size + 1) // 2
window = torch.cat([window[half_idx:],
torch.zeros([padding]),
window[:half_idx]], axis=0)
else:
window = window.roll(int(window.size(-1)//2), -1)

# Apply the window, to get new IR (both in zero-phase form).
window = window.unsqueeze(0)
impulse_response = impulse_response * window

# Put IR in causal form and trim zero padding.
if padding > 0:
first_half_start = (ir_size - (half_idx - 1)) + 1
second_half_end = half_idx + 1
impulse_response = torch.cat([impulse_response[..., first_half_start:],
impulse_response[..., :second_half_end]],
dim=-1)
else:
impulse_response = impulse_response.roll(int(impulse_response.size(-1)//2), -1)

return impulse_response


def apply_dynamic_window_to_impulse_response(impulse_response, # B, n_frames, 2*(n_mag-1) or 2*n_mag-1
half_width_frames): # B,n_frames, 1
ir_size = impulse_response.size(-1) # 2*(n_mag -1) or 2*n_mag-1

window = torch.arange(-(ir_size // 2), (ir_size + 1) // 2).to(impulse_response) / half_width_frames
window = torch.clamp(window, min=-1, max=1)
window = (1 + torch.cos(np.pi * window)) / 2 # B, n_frames, 2*(n_mag -1) or 2*n_mag-1

impulse_response = impulse_response.roll(int(ir_size // 2), -1)
impulse_response = impulse_response * window

return impulse_response



def frequency_impulse_response(magnitudes,
hann_window = True,
half_width_frames = None):

# Get the IR
impulse_response = torch.fft.irfft(magnitudes) # B, n_frames, 2*(n_mags-1)
ir_size = impulse_response.size(-1)
impulse_response = impulse_response.roll(int(ir_size // 2), -1)

# Window and put in causal form.
if hann_window:
if half_width_frames is None:
impulse_response = apply_window_to_impulse_response(impulse_response)
window = torch.hann_window(ir_size, device=impulse_response.device)
else:
impulse_response = apply_dynamic_window_to_impulse_response(impulse_response, half_width_frames)
else:
impulse_response = impulse_response.roll(int(impulse_response.size(-1) // 2), -1)
window = torch.arange(-(ir_size // 2), (ir_size + 1) // 2, device=impulse_response.device) / half_width_frames
window = torch.clamp(window, min=-1, max=1)
window = (1 + torch.cos(np.pi * window)) / 2 # B, n_frames, 2*(n_mag -1) or 2*n_mag-1
impulse_response *= window

return impulse_response

Expand Down
86 changes: 46 additions & 40 deletions ddsp/vocoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.nn.functional as F
from librosa.filters import mel as librosa_mel_fn
from .mel2control import Mel2Control
from .core import frequency_filter, meanfilter, upsample, remove_above_fmax
from .core import frequency_filter, mean_filter, upsample

class DotDict(dict):
def __getattr__(*args):
Expand All @@ -14,7 +14,8 @@ def __getattr__(*args):

__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__



def load_model(
model_path,
device='cpu'):
Expand Down Expand Up @@ -55,7 +56,8 @@ def load_model(
model.load_state_dict(ckpt['model'])
model.eval()
return model, args



class Audio2Mel(torch.nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -130,7 +132,8 @@ def forward(self, audio, keyshift=0, speed=1):
# print('og_mel_spec:', log_mel_spec.shape)
log_mel_spec = log_mel_spec.squeeze(2) # mono
return log_mel_spec



class Sins(torch.nn.Module):
def __init__(self,
sampling_rate,
Expand Down Expand Up @@ -162,7 +165,18 @@ def __init__(self,
self.mean_kernel_size = win_length // block_size
else:
self.mean_kernel_size = 1


def fast_phase_gen(self, f0_frames):
n = torch.arange(self.block_size, device=f0_frames.device)
s0 = f0_frames / self.sampling_rate
ds0 = F.pad(s0[:, 1:, :] - s0[:, :-1, :], (0, 0, 0, 1))
rad = s0 * (n + 1) + 0.5 * ds0 * n * (n + 1) / self.block_size
rad2 = torch.fmod(rad[..., -1:].float() + 0.5, 1.0) - 0.5
rad_acc = rad2.cumsum(dim=1).fmod(1.0).to(f0_frames)
rad += F.pad(rad_acc[:, :-1, :], (0, 0, 1, 0))
phase = 2 * np.pi * rad.reshape(f0_frames.shape[0], -1, 1)
return phase

def forward(self,
mel_frames,
f0_frames,
Expand All @@ -174,14 +188,7 @@ def forward(self,
f0_frames: B x n_frames x 1
'''
# exciter phase
f0 = upsample(f0_frames, self.block_size)
if infer:
x = torch.cumsum(f0.double() / self.sampling_rate, axis=1)
else:
x = torch.cumsum(f0 / self.sampling_rate, axis=1)
x = x - torch.round(x)
x = x.to(f0)
phase = 2 * np.pi * x
phase = self.fast_phase_gen(f0_frames)

# sinusoid exciter signal
sinusoid = torch.sin(phase).squeeze(-1)
Expand All @@ -194,25 +201,22 @@ def forward(self,
# parameter prediction
ctrls = self.mel2ctrl(mel_frames, sinusoid_frames, noise_frames)
if self.mean_kernel_size > 1:
ctrls['amplitudes'] = meanfilter(ctrls['amplitudes'], self.mean_kernel_size)
ctrls['harmonic_phase'] = meanfilter(ctrls['harmonic_phase'], self.mean_kernel_size)
ctrls['amplitudes'] = mean_filter(ctrls['amplitudes'], self.mean_kernel_size)
ctrls['harmonic_phase'] = mean_filter(ctrls['harmonic_phase'], self.mean_kernel_size)

src_allpass = torch.exp(1.j * np.pi * ctrls['harmonic_phase'])
src_allpass = torch.cat((src_allpass, src_allpass[:,-1:,:]), 1)
amplitudes_frames = torch.exp(ctrls['amplitudes'])/ 128
noise_param = torch.exp(ctrls['noise_magnitude'] + 1.j * np.pi * ctrls['noise_phase']) / 128

# sinusoids exciter signal
# harmonic additive synthesis
if infer and output_f0_frames is not None:
f0_frames = output_f0_frames
f0 = upsample(f0_frames, self.block_size)
x = torch.cumsum(f0.double() / self.sampling_rate, axis=1)
x = x - torch.round(x)
x = x.to(f0)
phase = 2 * np.pi * x
amplitudes_frames = remove_above_fmax(amplitudes_frames, f0_frames, self.sampling_rate / 2, level_start = 1)
phase = self.fast_phase_gen(output_f0_frames)
n_harmonic = amplitudes_frames.shape[-1]
level_harmonic = torch.arange(1, n_harmonic + 1).to(phase)
level_harmonic = torch.arange(1, n_harmonic + 1, device=phase.device)
mask = (f0_frames * level_harmonic < self.sampling_rate / 2).float() + 1e-7
amplitudes_frames *= mask
sinusoids = 0.
for n in range(( n_harmonic - 1) // max_upsample_dim + 1):
start = n * max_upsample_dim
Expand Down Expand Up @@ -248,6 +252,7 @@ def forward(self,

return signal, sinusoids, (harmonic, noise)


class CombSub(torch.nn.Module):
def __init__(self,
sampling_rate,
Expand Down Expand Up @@ -278,7 +283,20 @@ def __init__(self,
self.mean_kernel_size = win_length // block_size
else:
self.mean_kernel_size = 1


def fast_source_gen(self, f0_frames):
n = torch.arange(self.block_size, device=f0_frames.device)
s0 = f0_frames / self.sampling_rate
ds0 = F.pad(s0[:, 1:, :] - s0[:, :-1, :], (0, 0, 0, 1))
rad = s0 * (n + 1) + 0.5 * ds0 * n * (n + 1) / self.block_size
s0 = s0 + ds0 * n / self.block_size
rad2 = torch.fmod(rad[..., -1:].float() + 0.5, 1.0) - 0.5
rad_acc = rad2.cumsum(dim=1).fmod(1.0).to(f0_frames)
rad += F.pad(rad_acc[:, :-1, :], (0, 0, 1, 0))
rad -= torch.round(rad)
combtooth = torch.sinc(rad / (s0 + 1e-5)).reshape(f0_frames.shape[0], -1)
return combtooth

def forward(self,
mel_frames,
f0_frames,
Expand All @@ -289,17 +307,9 @@ def forward(self,
mel_frames: B x n_frames x n_mels
f0_frames: B x n_frames x 1
'''
# exciter phase
f0 = upsample(f0_frames, self.block_size)
if infer:
x = torch.cumsum(f0.double() / self.sampling_rate, axis=1)
else:
x = torch.cumsum(f0 / self.sampling_rate, axis=1)
x = x - torch.round(x)
x = x.to(f0)

# combtooth exciter signal
combtooth = torch.sinc(self.sampling_rate * x / (f0 + 1e-3)).squeeze(-1)
combtooth = self.fast_source_gen(f0_frames)
combtooth_frames = combtooth.unfold(1, self.block_size, self.block_size)

# noise exciter signal
Expand All @@ -309,8 +319,8 @@ def forward(self,
# parameter prediction
ctrls = self.mel2ctrl(mel_frames, combtooth_frames, noise_frames)
if self.mean_kernel_size > 1:
ctrls['harmonic_magnitude'] = meanfilter(ctrls['harmonic_magnitude'], self.mean_kernel_size)
ctrls['harmonic_phase'] = meanfilter(ctrls['harmonic_phase'], self.mean_kernel_size)
ctrls['harmonic_magnitude'] = mean_filter(ctrls['harmonic_magnitude'], self.mean_kernel_size)
ctrls['harmonic_phase'] = mean_filter(ctrls['harmonic_phase'], self.mean_kernel_size)

src_allpass = torch.exp(1.j * np.pi * ctrls['harmonic_phase'])
src_allpass = torch.cat((src_allpass, src_allpass[:,-1:,:]), 1)
Expand All @@ -320,11 +330,7 @@ def forward(self,
# harmonic part filter (using dynamic-windowed LTV-FIR)
if infer and output_f0_frames is not None:
f0_frames = output_f0_frames
f0 = upsample(f0_frames, self.block_size)
x = torch.cumsum(f0.double() / self.sampling_rate, axis=1)
x = x - torch.round(x)
x = x.to(f0)
combtooth = torch.sinc(self.sampling_rate * x / (f0 + 1e-3)).squeeze(-1)
combtooth = self.fast_source_gen(output_f0_frames)
harmonic = frequency_filter(
combtooth,
torch.complex(src_param, torch.zeros_like(src_param)),
Expand Down

0 comments on commit f875fbf

Please sign in to comment.