diff --git a/ddsp/training/eval_util.py b/ddsp/training/eval_util.py index 5782a0b7..bd060740 100644 --- a/ddsp/training/eval_util.py +++ b/ddsp/training/eval_util.py @@ -379,14 +379,9 @@ def evaluate_or_sample(data_provider, # Load model. model.restore(checkpoint_path) - # Create metrics. - if mode == 'eval': - f0_loudness_metrics = F0LoudnessMetrics() - avg_losses = {name: tf.keras.metrics.Mean(name=name, dtype=tf.float32) - for name in model.loss_names} - # Iterate through dataset and make predictions checkpoint_start_time = time.time() + for batch_idx in range(1, num_batches + 1): try: start_time = time.time() @@ -396,9 +391,16 @@ def evaluate_or_sample(data_provider, batch = next(dataset_iter) audio = batch['audio'] # TODO(jesseengel): Find a way to add losses with training=False. - audio_gen = model(batch, training=True) # Adds losses. + audio_gen, losses = model(batch, return_losses=True, training=True) outputs = model.get_controls(batch, training=True) + # Create metrics on first batch. + if mode == 'eval' and batch_idx == 1: + f0_loudness_metrics = F0LoudnessMetrics() + avg_losses = { + name: tf.keras.metrics.Mean(name=name, dtype=tf.float32) + for name in list(losses.keys())} + # Resample f0_hz outputs to match batch if they don't already. has_f0 = ('f0_hz' in outputs and 'f0_hz' in batch) @@ -439,7 +441,6 @@ def evaluate_or_sample(data_provider, outputs['f0_hz']) # Loss. - losses = model.losses_dict for k, v in losses.items(): avg_losses[k].update_state(v) diff --git a/ddsp/training/models.py b/ddsp/training/models.py index ff50b072..5eb9d0f6 100644 --- a/ddsp/training/models.py +++ b/ddsp/training/models.py @@ -43,18 +43,31 @@ def get_model(model=gin.REQUIRED): class Model(tf.keras.Model): """Wrap the model function for dependency injection with gin.""" - def __init__(self, losses=None, name='model'): + def __init__(self, name='model'): super().__init__(name=name) - self.loss_objs = ddsp.core.make_iterable(losses) - self.loss_names = [loss_obj.name - for loss_obj in self.loss_objs] + ['total_loss'] - - @property - def losses_dict(self): - """For metrics, returns dict {loss_name: loss_value}.""" - losses_dict = dict(zip(self.loss_names, self.losses)) - losses_dict['total_loss'] = tf.reduce_sum(self.losses) - return losses_dict + self._losses_dict = {} + + def __call__(self, *args, return_losses=False, **kwargs): + """Reset the losses dict on each call. + + Args: + *args: Arguments passed on to call(). + return_losses: Return a dictionary of losses in addition to the call() + function returns. + **kwargs: Keyword arguments passed on to call(). + + Returns: + Function results if return_losses=False, else the function results + and a dictionary of losses, {loss_name: loss_value}. + """ + self._losses_dict = {} + results = super().__call__(*args, **kwargs) + if not return_losses: + return results + else: + self._losses_dict['total_loss'] = tf.reduce_sum( + list(self._losses_dict.values())) + return results, self._losses_dict def restore(self, checkpoint_path): """Restore model and optimizer from a checkpoint.""" @@ -81,11 +94,12 @@ def __init__(self, processor_group=None, losses=None, name='autoencoder'): - super().__init__(name=name, losses=losses) + super().__init__(name=name) self.preprocessor = preprocessor self.encoder = encoder self.decoder = decoder self.processor_group = processor_group + self.loss_objs = ddsp.core.make_iterable(losses) def controls_to_audio(self, controls): return controls[self.processor_group.name]['signal'] @@ -106,7 +120,8 @@ def call(self, features, training=True): audio_gen = self.decode(conditioning, training=training) if training: for loss_obj in self.loss_objs: - self.add_loss(loss_obj(features['audio'], audio_gen)) + loss = loss_obj(features['audio'], audio_gen) + self._losses_dict[loss_obj.name] = loss return audio_gen def get_controls(self, features, keys=None, training=False): diff --git a/ddsp/training/train_util.py b/ddsp/training/train_util.py index 8f20b826..5d7fff89 100644 --- a/ddsp/training/train_util.py +++ b/ddsp/training/train_util.py @@ -197,7 +197,7 @@ def run(self, fn, *args, **kwargs): return self.strategy.experimental_run_v2(fn, args=args, kwargs=kwargs) def build(self, batch): - """Build the model by running a batch through it.""" + """Build the model by running a distributed batch through it.""" logging.info('Building the model...') _ = self.run(tf.function(self.model.__call__), batch) self.model.summary() @@ -223,13 +223,12 @@ def train_step(self, dataset_iter): def step_fn(self, batch): """Per-Replica training step.""" with tf.GradientTape() as tape: - _ = self.model(batch, training=True) - total_loss = tf.reduce_sum(self.model.losses) + _, losses = self.model(batch, return_losses=True, training=True) # Clip and apply gradients. - grads = tape.gradient(total_loss, self.model.trainable_variables) + grads = tape.gradient(losses['total_loss'], self.model.trainable_variables) grads, _ = tf.clip_by_global_norm(grads, self.grad_clip_norm) self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables)) - return self.model.losses_dict + return losses @gin.configurable @@ -252,11 +251,6 @@ def train(data_provider, # Load latest checkpoint if one exists in model_dir. trainer.restore(model_dir) - # Create training loss metrics. - logging.info('Creating metrics for %s', list(trainer.model.loss_names)) - avg_losses = {name: tf.keras.metrics.Mean(name=name, dtype=tf.float32) - for name in trainer.model.loss_names} - # Set up the summary writer and metrics. summary_dir = os.path.join(model_dir, 'summaries', 'train') summary_writer = tf.summary.create_file_writer(summary_dir) @@ -268,12 +262,19 @@ def train(data_provider, with summary_writer.as_default(): tick = time.time() - for _ in range(num_steps): - step = trainer.step + for iteration in range(num_steps): + step = trainer.step # Step is not iteration if restarting a model. # Take a step. losses = trainer.train_step(dataset_iter) + # Create training loss metrics when starting/restarting training. + if iteration == 0: + loss_names = list(losses.keys()) + logging.info('Creating metrics for %s', loss_names) + avg_losses = {name: tf.keras.metrics.Mean(name=name, dtype=tf.float32) + for name in loss_names} + # Update metrics. for k, v in losses.items(): avg_losses[k].update_state(v) diff --git a/ddsp/version.py b/ddsp/version.py index cc015c7a..77088e9a 100644 --- a/ddsp/version.py +++ b/ddsp/version.py @@ -19,4 +19,4 @@ pulling in all the dependencies in __init__.py. """ -__version__ = '0.0.10' +__version__ = '0.1.0'