From a38a0b37dfe68b9e406c8a1ad1b22df3063b7431 Mon Sep 17 00:00:00 2001 From: Jesse Engel Date: Wed, 9 Feb 2022 16:14:56 -0800 Subject: [PATCH] Update RecordProvider to allow centered padding datasets. * Small helper function to core for NaNs. PiperOrigin-RevId: 427597538 --- ddsp/core.py | 29 +++++++--------------------- ddsp/training/data.py | 20 +++++++++++++++---- ddsp/training/gin/models/vst/vst.gin | 21 +++++++++----------- ddsp/training/postprocessing.py | 11 ++++++++--- ddsp/training/preprocessing.py | 6 +++--- ddsp/version.py | 2 +- 6 files changed, 44 insertions(+), 45 deletions(-) diff --git a/ddsp/core.py b/ddsp/core.py index ec2b710b..51fcedfb 100644 --- a/ddsp/core.py +++ b/ddsp/core.py @@ -200,6 +200,11 @@ def diff(x, axis=-1): # Math ------------------------------------------------------------------------- +def nan_to_num(x, value=0.0): + """Replace NaNs with value.""" + return tf.where(tf.math.is_nan(x), value * tf.ones_like(x), x) + + def safe_divide(numerator, denominator, eps=1e-7): """Avoid dividing by zero by adding a small epsilon.""" safe_denominator = tf.where(denominator == 0.0, eps, denominator) @@ -710,26 +715,6 @@ def upsample_with_windows(inputs: tf.Tensor, return x[:, hop_size:-hop_size, :] -# TODO(jesseengel): Axis param, don't assume axis=1. -def center_pad(audio, frame_size, mode='CONSTANT'): - """Pad an audio signal such that timestamps align to the center of frames. - - Without centering, timestamps align to the front of frames. - Args: - audio: Input, shape [batch, time, ...]. - frame_size: Size of each frame. - mode: Padding mode for tf.pad. One of "CONSTANT", "REFLECT", or - "SYMMETRIC" (case-insensitive). - - Returns: - audio_padded: Shape [batch, time + (frame_size // 2) * 2, ...]. - """ - pad_amount = int(frame_size // 2) # Symmetric even padding like librosa. - pads = [[0, 0] for _ in range(len(audio.shape))] - pads[1] = [pad_amount, pad_amount] - return tf.pad(audio, pads, mode=mode) - - def center_crop(audio, frame_size): """Remove padding introduced from centering frames. @@ -852,8 +837,8 @@ def angular_cumsum(angular_frequency, chunk_size=1000): # Pad if needed. remainder = n_time % chunk_size if remainder: - pad = chunk_size - remainder - angular_frequency = pad_axis(angular_frequency, [0, pad], axis=1) + pad_amount = chunk_size - remainder + angular_frequency = pad_axis(angular_frequency, [0, pad_amount], axis=1) # Split input into chunks. length = angular_frequency.shape[1] diff --git a/ddsp/training/data.py b/ddsp/training/data.py index 3bb5c626..149470d5 100644 --- a/ddsp/training/data.py +++ b/ddsp/training/data.py @@ -17,6 +17,7 @@ import os from absl import logging +from ddsp.spectral_ops import get_framed_lengths import gin import tensorflow.compat.v2 as tf import tensorflow_datasets as tfds @@ -186,14 +187,24 @@ def __init__(self, example_secs, sample_rate, frame_rate, - data_format_map_fn): + data_format_map_fn, + centered=False): """RecordProvider constructor.""" self._file_pattern = file_pattern or self.default_file_pattern self._audio_length = example_secs * sample_rate - self._feature_length = example_secs * frame_rate super().__init__(sample_rate, frame_rate) + self._feature_length = self.get_feature_length(centered) self._data_format_map_fn = data_format_map_fn + def get_feature_length(self, centered): + """Take into account center padding to get number of frames.""" + # Number of frames is independent of frame size for "center/same" padding. + frame_size = 1024 + hop_size = self.sample_rate / self.frame_rate + padding = 'center' if centered else 'same' + return get_framed_lengths( + self._audio_length, frame_size, hop_size, padding)[0] + @property def default_file_pattern(self): """Used if file_pattern is not provided to constructor.""" @@ -244,10 +255,11 @@ def __init__(self, file_pattern=None, example_secs=4, sample_rate=16000, - frame_rate=250): + frame_rate=250, + centered=False): """TFRecordProvider constructor.""" super().__init__(file_pattern, example_secs, sample_rate, - frame_rate, tf.data.TFRecordDataset) + frame_rate, tf.data.TFRecordDataset, centered=centered) # ------------------------------------------------------------------------------ diff --git a/ddsp/training/gin/models/vst/vst.gin b/ddsp/training/gin/models/vst/vst.gin index 77fc6f70..8c633613 100644 --- a/ddsp/training/gin/models/vst/vst.gin +++ b/ddsp/training/gin/models/vst/vst.gin @@ -17,21 +17,18 @@ n_samples = 64064 # Extra frame for center padding. # Preprocessor -# Constructor requires path to a SavedModel of the base CREPE -# models. Available on GCS at gs://crepe-models/saved_models/[full,large,small]. +# Use same preprocessor for creating dataset and for training / inference. +prepare_tfrecord_lib_vst.prepare_tfrecord.preprocessor = @preprocessing.OnlineF0PowerPreprocessor() Autoencoder.preprocessor = @preprocessing.OnlineF0PowerPreprocessor() OnlineF0PowerPreprocessor: - time_steps = 1001 # Extra frame added for center padding. - sample_rate = %sample_rate + frame_rate = %frame_rate + frame_size = %frame_size + padding = 'center' compute_power = True - center_power = True - power_frame_rate = %frame_rate - power_frame_size = %frame_size - compute_f0 = True - center_f0 = True - f0_frame_rate = %frame_rate - f0_frame_size = %frame_size - crepe_saved_model_path = '' + compute_f0 = False + crepe_saved_model_path = 'full' + viterbi = False + # time_steps = 1001 # Extra frame added for center padding. # Encoder diff --git a/ddsp/training/postprocessing.py b/ddsp/training/postprocessing.py index 00c5499d..2d14fa5f 100644 --- a/ddsp/training/postprocessing.py +++ b/ddsp/training/postprocessing.py @@ -27,7 +27,7 @@ def detect_notes(loudness_db, exponent=2.0, smoothing=40, f0_confidence_threshold=0.7, - min_db=-120.): + min_db=-spectral_ops.DB_RANGE): """Detect note on-off using loudness and smoothed f0_confidence.""" mean_db = np.mean(loudness_db) db = smooth(f0_confidence**exponent, smoothing) * (loudness_db - min_db) @@ -253,13 +253,15 @@ def fit_transform(self, x): def compute_dataset_statistics(data_provider, batch_size=1, - power_frame_size=256): + power_frame_size=1024, + power_frame_rate=50): """Calculate dataset stats. Args: data_provider: A DataProvider from ddsp.training.data. batch_size: Iterate over dataset with this batch size. power_frame_size: Calculate power features on the fly with this frame size. + power_frame_rate: Calculate power features on the fly with this frame rate. Returns: Dictionary of dataset statistics. This is an overcomplete set of statistics, @@ -280,7 +282,9 @@ def compute_dataset_statistics(data_provider, for batch in data_iter: loudness.append(batch['loudness_db']) power.append( - spectral_ops.compute_power(batch['audio'], frame_size=power_frame_size)) + spectral_ops.compute_power(batch['audio'], + frame_size=power_frame_size, + frame_rate=power_frame_rate)) f0.append(batch['f0_hz']) f0_conf.append(batch['f0_confidence']) audio.append(batch['audio']) @@ -304,6 +308,7 @@ def compute_dataset_statistics(data_provider, # Detect notes. mask_on, _ = detect_notes(loudness_trimmed, f0_conf_trimmed) + quantile_transform = fit_quantile_transform(loudness_trimmed, mask_on) # Pitch statistics. diff --git a/ddsp/training/preprocessing.py b/ddsp/training/preprocessing.py index e4caae3a..61ee5d36 100644 --- a/ddsp/training/preprocessing.py +++ b/ddsp/training/preprocessing.py @@ -64,18 +64,18 @@ def __init__(self, time_steps=1000, frame_rate=250, sample_rate=16000, - recompute_loudness=True, + compute_loudness=True, **kwargs): super().__init__(**kwargs) self.time_steps = time_steps self.frame_rate = frame_rate self.sample_rate = sample_rate - self.recompute_loudness = recompute_loudness + self.compute_loudness = compute_loudness def call(self, loudness_db, f0_hz, audio=None) -> [ 'f0_hz', 'loudness_db', 'f0_scaled', 'ld_scaled']: # Compute loudness fresh (it's fast). - if self.recompute_loudness: + if self.compute_loudness: loudness_db = ddsp.spectral_ops.compute_loudness( audio, sample_rate=self.sample_rate, diff --git a/ddsp/version.py b/ddsp/version.py index ae69df64..5ba8a39d 100644 --- a/ddsp/version.py +++ b/ddsp/version.py @@ -19,4 +19,4 @@ pulling in all the dependencies in __init__.py. """ -__version__ = '3.0.0' +__version__ = '3.1.0'