Skip to content

Commit

Permalink
pitch preprocessing update
Browse files Browse the repository at this point in the history
  • Loading branch information
carankt committed Dec 2, 2020
1 parent 752e516 commit a9a87e9
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 6 deletions.
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,11 @@ idea/*
dataset/audio/__pycache__/__init__.cpython-36.pyc
*.pyc
Untitled.ipynb
mel.npy
*.png
*.npy
Testing/2log_v2/no_exp_before_bins_fs2v2_2_31k_test_tts.wav
Testing/exp_log/test_tts.wav
Testing/exp_log_v2/exp_before_bins_fs2v2_2_31k_test_tts.wav
mel.png
mel.npy
4 changes: 3 additions & 1 deletion core/variance_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,9 @@ def inference(self, xs: torch.Tensor, olens = None, alpha: float = 1.0):
"""
f0_spec, f0_mean, f0_std = self.forward(xs, olens, x_masks=None) # (B, Tmax, 10)
f0_reconstructed = self.inverse(f0_spec, f0_mean, f0_std)

#print(f0_reconstructed)
f0_reconstructed = torch.exp(f0_reconstructed)
#print(f0_reconstructed)
return self.to_one_hot(f0_reconstructed)

def to_one_hot(self, x: torch.Tensor):
Expand Down
17 changes: 12 additions & 5 deletions dataset/audio/pitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,20 @@ def forward(
# F0 extraction

# input shape = [T,]
pitch = self._calculate_f0(input)
pitch, pitch_log = self._calculate_f0(input)
# (Optional): Adjust length to match with the mel-spectrogram
if feats_lengths is not None:
pitch = [
self._adjust_num_frames(p, fl).view(-1)
for p, fl in zip(pitch, feats_lengths)
]
pitch, mean, std = self._normalize(pitch, durations)
coefs = self._cwt(pitch.numpy())
pitch_log = [
self._adjust_num_frames(p, fl).view(-1)
for p, fl in zip(pitch_log, feats_lengths)
]

pitch_log_norm, mean, std = self._normalize(pitch_log, durations)
coefs = self._cwt(pitch_log_norm.numpy())
# (Optional): Average by duration to calculate token-wise f0
if self.use_token_averaged_f0:
pitch = self._average_by_duration(pitch, durations)
Expand All @@ -112,10 +117,12 @@ def _calculate_f0(self, input: torch.Tensor) -> torch.Tensor:
f0 = pyworld.stonemask(x, f0, timeaxis, self.fs)
if self.use_continuous_f0:
f0 = self._convert_to_continuous_f0(f0)

if self.use_log_f0:
nonzero_idxs = np.where(f0 != 0)[0]
f0[nonzero_idxs] = np.log(f0[nonzero_idxs])
return input.new_tensor(f0.reshape(-1), dtype=torch.float)
f0_log[nonzero_idxs] = np.log(f0[nonzero_idxs])

This comment has been minimized.

Copy link
@rishikksh20

rishikksh20 Dec 2, 2020

Owner

@karan-deepsync @carankt f0_log variable never defined, this code won't work.

f0_log = f0
f0_log[nonzero_idxs] = np.log(f0_log [nonzero_idxs])

or even better way is to create numpy zero array f0_log of shape f0 then

f0_log = np.zeros_like(f0)
f0_log[nonzero_idxs] = np.log(f0[nonzero_idxs])

return input.new_tensor(f0.reshape(-1), dtype=torch.float), input.new_tensor(f0_log.reshape(-1), dtype=torch.float)


@staticmethod
Expand Down

0 comments on commit a9a87e9

Please sign in to comment.