Skip to content

Commit

Permalink
Refactor model. Optionally return dict from __call__(), remove get_co…
Browse files Browse the repository at this point in the history
…ntrols().

PiperOrigin-RevId: 339775023
  • Loading branch information
jesseengel authored and Magenta Team committed Oct 29, 2020
1 parent 090b624 commit edf2b80
Show file tree
Hide file tree
Showing 10 changed files with 144 additions and 146 deletions.
2 changes: 1 addition & 1 deletion ddsp/colab/demos/pitch_detection.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@
"# DDSP-INV,\n",
"start_time = time.time()\n",
"print('\\nExtracting f0 with DDSP-INV...')\n",
"controls = model.get_controls({'audio': audio}, training=False)\n",
"controls = model({'audio': audio}, training=False)\n",
"print('Prediction took %.1f seconds' % (time.time() - start_time))\n",
"\n",
"# CREPE.\n",
Expand Down
3 changes: 2 additions & 1 deletion ddsp/colab/demos/timbre_transfer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,8 @@
"\n",
"# Run a batch of predictions.\n",
"start_time = time.time()\n",
"audio_gen = model(af, training=False)\n",
"outputs = model(af, training=False)\n",
"audio_gen = model.get_audio_from_outputs(outputs)\n",
"print('Prediction took %.1f seconds' % (time.time() - start_time))\n",
"\n",
"# Plot\n",
Expand Down
3 changes: 2 additions & 1 deletion ddsp/colab/demos/train_autoencoder.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,8 @@
"model.restore(SAVE_DIR)\n",
"\n",
"# Resynthesize audio.\n",
"audio_gen = model(batch, training=False)\n",
"outputs = model(batch, training=False)\n",
"audio_gen = model.get_audio_from_outputs(outputs)\n",
"audio = batch['audio']\n",
"\n",
"print('Original Audio')\n",
Expand Down
4 changes: 2 additions & 2 deletions ddsp/colab/tutorials/3_training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -408,8 +408,8 @@
"source": [
"# Run a batch of predictions.\n",
"start_time = time.time()\n",
"controls = model.get_controls(next(dataset_iter))\n",
"audio_gen = controls['processor_group']['signal']\n",
"controls = model(next(dataset_iter))\n",
"audio_gen = model.get_audio_from_outputs(controls)\n",
"print('Prediction took %.1f seconds' % (time.time() - start_time))"
]
},
Expand Down
5 changes: 2 additions & 3 deletions ddsp/training/eval_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,8 @@ def evaluate_or_sample(data_provider,

# TODO(jesseengel): Find a way to add losses with training=False.
audio = batch['audio']
audio_gen, losses = model(batch, return_losses=True, training=True)

outputs = model.get_controls(batch, training=True)
outputs, losses = model(batch, return_losses=True, training=True)
audio_gen = model.get_audio_from_outputs(outputs)

# Create metrics on first batch.
if mode == 'eval' and batch_idx == 1:
Expand Down
2 changes: 1 addition & 1 deletion ddsp/training/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def build_network(self):
@tf.function
def call(self, input_dict):
"""Convert f0 and loudness to synthesizer parameters."""
controls = super().get_controls(input_dict, training=False)
controls = super().__call__(input_dict, training=False)
amps = controls['additive']['controls']['amplitudes']
hd = controls['additive']['controls']['harmonic_distribution']
return amps, hd
Expand Down
29 changes: 12 additions & 17 deletions ddsp/training/models/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,14 @@ def __init__(self,
decoder=None,
processor_group=None,
losses=None,
name='autoencoder'):
super().__init__(name=name)
**kwargs):
super().__init__(**kwargs)
self.preprocessor = preprocessor
self.encoder = encoder
self.decoder = decoder
self.processor_group = processor_group
self.loss_objs = ddsp.core.make_iterable(losses)

def controls_to_audio(self, controls):
return controls[self.processor_group.name]['signal']

def encode(self, features, training=True):
"""Get conditioning by preprocessing then encoding."""
if self.preprocessor is not None:
Expand All @@ -52,20 +49,18 @@ def decode(self, conditioning, training=True):
processor_inputs = self.decoder(conditioning, training=training)
return self.processor_group(processor_inputs)

def get_audio_from_outputs(self, outputs):
"""Extract audio output tensor from outputs dict of call()."""
return self.processor_group.get_signal(outputs)

def call(self, features, training=True):
"""Run the core of the network, get predictions and loss."""
conditioning = self.encode(features, training=training)
audio_gen = self.decode(conditioning, training=training)
processor_inputs = self.decoder(conditioning, training=training)
outputs = self.processor_group.get_controls(processor_inputs)
outputs['audio_synth'] = self.processor_group.get_signal(outputs)
if training:
self.update_losses_dict(self.loss_objs, features['audio'], audio_gen)
return audio_gen
self._update_losses_dict(
self.loss_objs, features['audio'], outputs['audio_synth'])
return outputs

def get_controls(self, features, keys=None, training=False):
"""Returns specific processor_group controls."""
conditioning = self.encode(features, training=training)
processor_inputs = self.decoder(conditioning)
controls = self.processor_group.get_controls(processor_inputs)
# Also build on get_controls(), instead of just __call__().
self.built = True
# If wrapped in tf.function, only calculates keys of interest.
return controls if keys is None else {k: controls[k] for k in keys}
7 changes: 4 additions & 3 deletions ddsp/training/models/autoencoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ def test_build_model(self, gin_file):
gin.parse_config_file(gin_file)

model = models.Autoencoder()
controls = model.get_controls(self.inputs)
self.assertIsInstance(controls, dict)
outputs = model(self.inputs)
self.assertIsInstance(outputs, dict)
# Confirm that model generates correctly sized audio.
audio_gen_shape = controls['processor_group']['signal'].shape.as_list()
audio_gen = model.get_audio_from_outputs(outputs)
audio_gen_shape = audio_gen.shape.as_list()
self.assertEqual(audio_gen_shape, list(self.inputs['audio'].shape))

if __name__ == '__main__':
Expand Down
46 changes: 32 additions & 14 deletions ddsp/training/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,26 @@ def __call__(self, *args, return_losses=False, **kwargs):
**kwargs: Keyword arguments passed on to call().
Returns:
Function results if return_losses=False, else the function results
and a dictionary of losses, {loss_name: loss_value}.
outputs: A dictionary of model outputs generated in call().
{output_name: output_tensor or dict}.
losses: If return_losses=True, also returns a dictionary of losses,
{loss_name: loss_value}.
"""
self._losses_dict = {}
results = super().__call__(*args, **kwargs)
outputs = super().__call__(*args, **kwargs)
if not return_losses:
return results
return outputs
else:
self._losses_dict['total_loss'] = tf.reduce_sum(
list(self._losses_dict.values()))
return results, self._losses_dict
return outputs, self._losses_dict

def _update_losses_dict(self, loss_objs, *args, **kwargs):
"""Helper function to run loss objects on args and add to model losses."""
for loss_obj in ddsp.core.make_iterable(loss_objs):
if hasattr(loss_obj, 'get_losses_dict'):
losses_dict = loss_obj.get_losses_dict(*args, **kwargs)
self._losses_dict.update(losses_dict)

def restore(self, checkpoint_path):
"""Restore model and optimizer from a checkpoint."""
Expand All @@ -65,13 +74,22 @@ def restore(self, checkpoint_path):
logging.info('Could not find checkpoint to load at %s, skipping.',
checkpoint_path)

def get_controls(self, features, keys=None, training=False):
"""Base method for getting controls. Not implemented."""
raise NotImplementedError('`get_controls` not implemented in base class!')
def get_audio_from_outputs(self, outputs):
"""Extract audio output tensor from outputs dict of call()."""
raise NotImplementedError('Must implement `self.get_audio_from_outputs()`.')

def update_losses_dict(self, loss_objs, *args, **kwargs):
"""Run loss objects on inputs and adds to model losses."""
for loss_obj in ddsp.core.make_iterable(loss_objs):
if hasattr(loss_obj, 'get_losses_dict'):
losses_dict = loss_obj.get_losses_dict(*args, **kwargs)
self._losses_dict.update(losses_dict)
def call(self, *args, training=False, **kwargs):
"""Run the forward pass, add losses, and create a dictionary of outputs.
This function must run the forward pass, add losses to self._losses_dict and
return a dictionary of all the relevant output tensors.
Args:
*args: Args for forward pass.
training: Required `training` kwarg passed in by keras.
**kwargs: kwargs for forward pass.
Returns:
Dictionary of all relevant tensors.
"""
raise NotImplementedError('Must implement a `self.call()` method.')
Loading

0 comments on commit edf2b80

Please sign in to comment.