diff --git a/ddsp/colab/demos/pitch_detection.ipynb b/ddsp/colab/demos/pitch_detection.ipynb index ef915335..9994f409 100644 --- a/ddsp/colab/demos/pitch_detection.ipynb +++ b/ddsp/colab/demos/pitch_detection.ipynb @@ -240,7 +240,7 @@ "# DDSP-INV,\n", "start_time = time.time()\n", "print('\\nExtracting f0 with DDSP-INV...')\n", - "controls = model.get_controls({'audio': audio}, training=False)\n", + "controls = model({'audio': audio}, training=False)\n", "print('Prediction took %.1f seconds' % (time.time() - start_time))\n", "\n", "# CREPE.\n", diff --git a/ddsp/colab/demos/timbre_transfer.ipynb b/ddsp/colab/demos/timbre_transfer.ipynb index a91f523b..9297bb94 100644 --- a/ddsp/colab/demos/timbre_transfer.ipynb +++ b/ddsp/colab/demos/timbre_transfer.ipynb @@ -471,7 +471,8 @@ "\n", "# Run a batch of predictions.\n", "start_time = time.time()\n", - "audio_gen = model(af, training=False)\n", + "outputs = model(af, training=False)\n", + "audio_gen = model.get_audio_from_outputs(outputs)\n", "print('Prediction took %.1f seconds' % (time.time() - start_time))\n", "\n", "# Plot\n", diff --git a/ddsp/colab/demos/train_autoencoder.ipynb b/ddsp/colab/demos/train_autoencoder.ipynb index d6814ad4..28e61f0d 100644 --- a/ddsp/colab/demos/train_autoencoder.ipynb +++ b/ddsp/colab/demos/train_autoencoder.ipynb @@ -528,7 +528,8 @@ "model.restore(SAVE_DIR)\n", "\n", "# Resynthesize audio.\n", - "audio_gen = model(batch, training=False)\n", + "outputs = model(batch, training=False)\n", + "audio_gen = model.get_audio_from_outputs(outputs)\n", "audio = batch['audio']\n", "\n", "print('Original Audio')\n", diff --git a/ddsp/colab/tutorials/3_training.ipynb b/ddsp/colab/tutorials/3_training.ipynb index 94c6936b..a5e113e1 100644 --- a/ddsp/colab/tutorials/3_training.ipynb +++ b/ddsp/colab/tutorials/3_training.ipynb @@ -408,8 +408,8 @@ "source": [ "# Run a batch of predictions.\n", "start_time = time.time()\n", - "controls = model.get_controls(next(dataset_iter))\n", - "audio_gen = controls['processor_group']['signal']\n", + "controls = model(next(dataset_iter))\n", + "audio_gen = model.get_audio_from_outputs(controls)\n", "print('Prediction took %.1f seconds' % (time.time() - start_time))" ] }, diff --git a/ddsp/training/eval_util.py b/ddsp/training/eval_util.py index fc31c981..480026f6 100644 --- a/ddsp/training/eval_util.py +++ b/ddsp/training/eval_util.py @@ -110,9 +110,8 @@ def evaluate_or_sample(data_provider, # TODO(jesseengel): Find a way to add losses with training=False. audio = batch['audio'] - audio_gen, losses = model(batch, return_losses=True, training=True) - - outputs = model.get_controls(batch, training=True) + outputs, losses = model(batch, return_losses=True, training=True) + audio_gen = model.get_audio_from_outputs(outputs) # Create metrics on first batch. if mode == 'eval' and batch_idx == 1: diff --git a/ddsp/training/inference.py b/ddsp/training/inference.py index 88b38a3e..15b1d821 100644 --- a/ddsp/training/inference.py +++ b/ddsp/training/inference.py @@ -83,7 +83,7 @@ def build_network(self): @tf.function def call(self, input_dict): """Convert f0 and loudness to synthesizer parameters.""" - controls = super().get_controls(input_dict, training=False) + controls = super().__call__(input_dict, training=False) amps = controls['additive']['controls']['amplitudes'] hd = controls['additive']['controls']['harmonic_distribution'] return amps, hd diff --git a/ddsp/training/models/autoencoder.py b/ddsp/training/models/autoencoder.py index c4066396..dadba02a 100644 --- a/ddsp/training/models/autoencoder.py +++ b/ddsp/training/models/autoencoder.py @@ -28,17 +28,14 @@ def __init__(self, decoder=None, processor_group=None, losses=None, - name='autoencoder'): - super().__init__(name=name) + **kwargs): + super().__init__(**kwargs) self.preprocessor = preprocessor self.encoder = encoder self.decoder = decoder self.processor_group = processor_group self.loss_objs = ddsp.core.make_iterable(losses) - def controls_to_audio(self, controls): - return controls[self.processor_group.name]['signal'] - def encode(self, features, training=True): """Get conditioning by preprocessing then encoding.""" if self.preprocessor is not None: @@ -52,20 +49,18 @@ def decode(self, conditioning, training=True): processor_inputs = self.decoder(conditioning, training=training) return self.processor_group(processor_inputs) + def get_audio_from_outputs(self, outputs): + """Extract audio output tensor from outputs dict of call().""" + return self.processor_group.get_signal(outputs) + def call(self, features, training=True): """Run the core of the network, get predictions and loss.""" conditioning = self.encode(features, training=training) - audio_gen = self.decode(conditioning, training=training) + processor_inputs = self.decoder(conditioning, training=training) + outputs = self.processor_group.get_controls(processor_inputs) + outputs['audio_synth'] = self.processor_group.get_signal(outputs) if training: - self.update_losses_dict(self.loss_objs, features['audio'], audio_gen) - return audio_gen + self._update_losses_dict( + self.loss_objs, features['audio'], outputs['audio_synth']) + return outputs - def get_controls(self, features, keys=None, training=False): - """Returns specific processor_group controls.""" - conditioning = self.encode(features, training=training) - processor_inputs = self.decoder(conditioning) - controls = self.processor_group.get_controls(processor_inputs) - # Also build on get_controls(), instead of just __call__(). - self.built = True - # If wrapped in tf.function, only calculates keys of interest. - return controls if keys is None else {k: controls[k] for k in keys} diff --git a/ddsp/training/models/autoencoder_test.py b/ddsp/training/models/autoencoder_test.py index 8629fafc..1be297ee 100644 --- a/ddsp/training/models/autoencoder_test.py +++ b/ddsp/training/models/autoencoder_test.py @@ -59,10 +59,11 @@ def test_build_model(self, gin_file): gin.parse_config_file(gin_file) model = models.Autoencoder() - controls = model.get_controls(self.inputs) - self.assertIsInstance(controls, dict) + outputs = model(self.inputs) + self.assertIsInstance(outputs, dict) # Confirm that model generates correctly sized audio. - audio_gen_shape = controls['processor_group']['signal'].shape.as_list() + audio_gen = model.get_audio_from_outputs(outputs) + audio_gen_shape = audio_gen.shape.as_list() self.assertEqual(audio_gen_shape, list(self.inputs['audio'].shape)) if __name__ == '__main__': diff --git a/ddsp/training/models/model.py b/ddsp/training/models/model.py index 57d99378..31ca175e 100644 --- a/ddsp/training/models/model.py +++ b/ddsp/training/models/model.py @@ -40,17 +40,26 @@ def __call__(self, *args, return_losses=False, **kwargs): **kwargs: Keyword arguments passed on to call(). Returns: - Function results if return_losses=False, else the function results - and a dictionary of losses, {loss_name: loss_value}. + outputs: A dictionary of model outputs generated in call(). + {output_name: output_tensor or dict}. + losses: If return_losses=True, also returns a dictionary of losses, + {loss_name: loss_value}. """ self._losses_dict = {} - results = super().__call__(*args, **kwargs) + outputs = super().__call__(*args, **kwargs) if not return_losses: - return results + return outputs else: self._losses_dict['total_loss'] = tf.reduce_sum( list(self._losses_dict.values())) - return results, self._losses_dict + return outputs, self._losses_dict + + def _update_losses_dict(self, loss_objs, *args, **kwargs): + """Helper function to run loss objects on args and add to model losses.""" + for loss_obj in ddsp.core.make_iterable(loss_objs): + if hasattr(loss_obj, 'get_losses_dict'): + losses_dict = loss_obj.get_losses_dict(*args, **kwargs) + self._losses_dict.update(losses_dict) def restore(self, checkpoint_path): """Restore model and optimizer from a checkpoint.""" @@ -65,13 +74,22 @@ def restore(self, checkpoint_path): logging.info('Could not find checkpoint to load at %s, skipping.', checkpoint_path) - def get_controls(self, features, keys=None, training=False): - """Base method for getting controls. Not implemented.""" - raise NotImplementedError('`get_controls` not implemented in base class!') + def get_audio_from_outputs(self, outputs): + """Extract audio output tensor from outputs dict of call().""" + raise NotImplementedError('Must implement `self.get_audio_from_outputs()`.') - def update_losses_dict(self, loss_objs, *args, **kwargs): - """Run loss objects on inputs and adds to model losses.""" - for loss_obj in ddsp.core.make_iterable(loss_objs): - if hasattr(loss_obj, 'get_losses_dict'): - losses_dict = loss_obj.get_losses_dict(*args, **kwargs) - self._losses_dict.update(losses_dict) + def call(self, *args, training=False, **kwargs): + """Run the forward pass, add losses, and create a dictionary of outputs. + + This function must run the forward pass, add losses to self._losses_dict and + return a dictionary of all the relevant output tensors. + + Args: + *args: Args for forward pass. + training: Required `training` kwarg passed in by keras. + **kwargs: kwargs for forward pass. + + Returns: + Dictionary of all relevant tensors. + """ + raise NotImplementedError('Must implement a `self.call()` method.') diff --git a/ddsp/training/models/transcribing_autoencoder.py b/ddsp/training/models/transcribing_autoencoder.py index 53bb73c7..b07e0c15 100644 --- a/ddsp/training/models/transcribing_autoencoder.py +++ b/ddsp/training/models/transcribing_autoencoder.py @@ -117,29 +117,6 @@ def __init__(self, self.processor_group = ddsp.processors.ProcessorGroup(dag=dag) - def get_controls(self, features, keys=None, training=False): - """Returns specific processor_group controls.""" - # For now just use the real data. - if isinstance(features, (list, tuple)): - features, unused_ss_features = self.parse_zipped_features(features) - - # Encode the data from audio to sinusoids. - pg_in = self.sinusoidal_encoder(features, training=training) - - # Manually apply the scaling nonlinearities. - pg_in['frequencies'] = self.freq_scale_fn(pg_in['frequencies']) - pg_in['amplitudes'] = self.amps_scale_fn(pg_in['amplitudes']) - pg_in['noise_magnitudes'] = self.amps_scale_fn(pg_in['noise_magnitudes']) - controls = self.processor_group.get_controls(pg_in) - - # Append normal training procedure outputs. - outputs = self.forward(features, training) - controls.update(outputs) - - self.built = True - # If wrapped in tf.function, only calculates keys of interest. - return controls if keys is None else {k: controls[k] for k in keys} - def generate_synthetic_audio(self, features): """Convert synthetic controls into audio.""" return self.processor_group({ @@ -163,6 +140,12 @@ def parse_zipped_features(self, features): s_idx = int(not ss_idx) return features[s_idx], features[ss_idx] + def get_audio_from_outputs(self, outputs): + """Extract audio output tensor from outputs dict of call().""" + audio_out = (outputs['sin_audio'] if self.harmonic_encoder is None else + outputs['harm_audio']) + return audio_out + def call(self, features, training=True): """Run the core of the network, get predictions and loss.""" if isinstance(features, (list, tuple)): @@ -182,8 +165,10 @@ def call(self, features, training=True): all_outputs = self.forward(inputs, training) # Split outputs. - outputs = {k: v[:batch_size] for k, v in all_outputs.items()} - ss_outputs = {k: v[batch_size:] for k, v in all_outputs.items()} + outputs = {k: v[:batch_size] for k, v in all_outputs.items() + if not isinstance(v, dict)} + ss_outputs = {k: v[batch_size:] for k, v in all_outputs.items() + if not isinstance(v, dict)} # Compute losses. self.append_losses(outputs) @@ -202,73 +187,6 @@ def call(self, features, training=True): outputs = self.forward(features, training) self.append_losses(outputs) - if self.harmonic_encoder is not None: - return outputs['harm_audio'] - else: - return outputs['sin_audio'] - - def forward(self, features, training=True): - """Run forward pass of model (no losses) on a dictionary of features.""" - # Audio -> Sinusoids ------------------------------------------------------- - audio = features['audio'] - - # Encode the data from audio to sinusoids. - pg_in = self.sinusoidal_encoder(features, training=training) - - # Manually apply the scaling nonlinearities. - sin_freqs = self.freq_scale_fn(pg_in['frequencies']) - sin_amps = self.amps_scale_fn(pg_in['amplitudes']) - noise_magnitudes = self.amps_scale_fn(pg_in['noise_magnitudes']) - pg_in['frequencies'] = sin_freqs - pg_in['amplitudes'] = sin_amps - pg_in['noise_magnitudes'] = noise_magnitudes - - # Reconstruct sinusoidal audio. - sin_audio = self.processor_group(pg_in) - - outputs = { - # Input signal. - 'audio': audio, - # Filtered noise signal. - 'noise_magnitudes': noise_magnitudes, - # Sinusoidal signal. - 'sin_audio': sin_audio, - 'sin_amps': sin_amps, - 'sin_freqs': sin_freqs, - } - - # Sinusoids -> Harmonics --------------------------------------------------- - # Encode the sinusoids into a harmonics. - if self.stop_gradient: - sin_freqs = tf.stop_gradient(sin_freqs) - sin_amps = tf.stop_gradient(sin_amps) - noise_magnitudes = tf.stop_gradient(noise_magnitudes) - - if self.harmonic_encoder is not None: - harm_amp, harm_dist, f0_hz = self.harmonic_encoder(sin_freqs, sin_amps) - - # Decode harmonics back to sinusoids. - n_harmonics = int(harm_dist.shape[-1]) - harm_freqs = ddsp.core.get_harmonic_frequencies(f0_hz, n_harmonics) - harm_amps = harm_amp * harm_dist - - # Reconstruct harmonic audio. - pg_in['frequencies'] = harm_freqs - pg_in['amplitudes'] = harm_amps - pg_in['noise_magnitudes'] = noise_magnitudes - harm_audio = self.processor_group(pg_in) - - outputs.update({ - # Harmonic signal. - 'harm_audio': harm_audio, - 'harm_amp': harm_amp, - 'harm_dist': harm_dist, - 'f0_hz': f0_hz, - # Harmonic Sinusoids. - 'harm_freqs': harm_freqs, - 'harm_amps': harm_amps, - }) - return outputs def append_losses(self, outputs, self_supervised_features=None): @@ -286,9 +204,8 @@ def append_losses(self, outputs, self_supervised_features=None): if self.harmonic_encoder is not None: # Add prior regularization on harmonic distribution. - hdp = self.harmonic_distribution_prior - if hdp is not None: - self._losses_dict.update({hdp.name: hdp(o['harm_dist'])}) + self._update_losses_dict( + self.harmonic_distribution_prior, o['harm_dist']) # Harmonic autoencoder loss. for loss_obj in self.audio_loss_objs: @@ -303,17 +220,15 @@ def append_losses(self, outputs, self_supervised_features=None): # Don't propagate harmonic errors to sinusoidal predictions. sin_amps = tf.stop_gradient(sin_amps) sin_freqs = tf.stop_gradient(sin_freqs) - for loss_obj in self.sinusoidal_consistency_losses: - self._losses_dict[loss_obj.name] = loss_obj( - sin_amps, sin_freqs, o['harm_amps'], o['harm_freqs']) + self._update_losses_dict( + self.sinusoidal_consistency_losses, + sin_amps, sin_freqs, o['harm_amps'], o['harm_freqs']) # Two-way mismatch loss between sinusoids and harmonics. if self.twm_loss is not None: - if self.harmonic_encoder is not None: - loss = self.twm_loss(o['f0_hz'], o['sin_freqs'], o['sin_amps']) - else: - loss = self.twm_loss(o['sin_freqs'], o['sin_freqs'], o['sin_amps']) - self._losses_dict[self.twm_loss.name] = loss + f0_c = o['sin_freqs'] if self.harmonic_encoder is None else o['f0_hz'] + self._update_losses_dict(self.twm_loss, + f0_c, o['sin_freqs'], o['sin_amps']) # Self-supervised Losses. else: @@ -347,4 +262,72 @@ def append_losses(self, outputs, self_supervised_features=None): self._losses_dict[name] = loss_obj( o['harm_amp'], o['f0_hz'], f['harm_amp'], f['f0_hz']) + def forward(self, features, training=True): + """Run forward pass of model (no losses) on a dictionary of features.""" + # Audio -> Sinusoids ------------------------------------------------------- + audio = features['audio'] + + # Encode the data from audio to sinusoids. + pg_in = self.sinusoidal_encoder(features, training=training) + + # Manually apply the scaling nonlinearities. + sin_freqs = self.freq_scale_fn(pg_in['frequencies']) + sin_amps = self.amps_scale_fn(pg_in['amplitudes']) + noise_magnitudes = self.amps_scale_fn(pg_in['noise_magnitudes']) + pg_in['frequencies'] = sin_freqs + pg_in['amplitudes'] = sin_amps + pg_in['noise_magnitudes'] = noise_magnitudes + + # Reconstruct sinusoidal audio. + controls = self.processor_group.get_controls(pg_in) + sin_audio = self.processor_group.get_signal(controls) + + outputs = { + # Input signal. + 'audio': audio, + # Filtered noise signal. + 'noise_magnitudes': noise_magnitudes, + # Sinusoidal signal. + 'sin_audio': sin_audio, + 'sin_amps': sin_amps, + 'sin_freqs': sin_freqs, + } + outputs.update(controls) + + # Sinusoids -> Harmonics --------------------------------------------------- + # Encode the sinusoids into a harmonics. + if self.stop_gradient: + sin_freqs = tf.stop_gradient(sin_freqs) + sin_amps = tf.stop_gradient(sin_amps) + noise_magnitudes = tf.stop_gradient(noise_magnitudes) + + if self.harmonic_encoder is not None: + harm_amp, harm_dist, f0_hz = self.harmonic_encoder(sin_freqs, sin_amps) + + # Decode harmonics back to sinusoids. + n_harmonics = int(harm_dist.shape[-1]) + harm_freqs = ddsp.core.get_harmonic_frequencies(f0_hz, n_harmonics) + harm_amps = harm_amp * harm_dist + + # Reconstruct harmonic audio. + pg_in['frequencies'] = harm_freqs + pg_in['amplitudes'] = harm_amps + pg_in['noise_magnitudes'] = noise_magnitudes + harm_audio = self.processor_group(pg_in) + + outputs.update({ + # Harmonic signal. + 'harm_audio': harm_audio, + 'harm_amp': harm_amp, + 'harm_dist': harm_dist, + 'f0_hz': f0_hz, + # Harmonic Sinusoids. + 'harm_freqs': harm_freqs, + 'harm_amps': harm_amps, + }) + + return outputs + + +