From d9124662d4e836a7665039f98e7421b9dd9a4178 Mon Sep 17 00:00:00 2001 From: Jesse Engel Date: Thu, 8 Apr 2021 15:54:01 -0700 Subject: [PATCH] * Scripts for converting Checkpoints to SavedModels, and SavedModels to TFJS and TFLite Models. * Revert Inference models to inheriting from ddsp.training models, rather than a base class (more reliable). * More robust search for latest checkpoint and latest operative config. Throw error if not found instead of failing silently, put initial train and eval calls in a try statement as the files might not exist yet. * Helper function for printing shapes of a nested dictionary of tensors. * Verbose messages on model loading to let you know if you're not actually loading weights. Not fully functional yet, relies on TF2.5, but that breaks CREPE atm. Should be able to get it working by running: `pip install -U tensorflow==2.5.0rc0` after pip installing ddsp PiperOrigin-RevId: 367523871 --- ddsp/core.py | 5 + ddsp/training/ddsp_export.py | 195 +++++++++++++++++++++++++++++++ ddsp/training/ddsp_run.py | 6 +- ddsp/training/eval_util.py | 6 +- ddsp/training/inference.py | 212 +++++++++++++++++++--------------- ddsp/training/models/model.py | 27 +++-- ddsp/training/train_util.py | 89 ++++++++++++-- ddsp/training/trainers.py | 33 +++--- ddsp/version.py | 2 +- setup.py | 4 + 10 files changed, 444 insertions(+), 135 deletions(-) create mode 100644 ddsp/training/ddsp_export.py diff --git a/ddsp/core.py b/ddsp/core.py index 2cef68f9..3e4efb89 100644 --- a/ddsp/core.py +++ b/ddsp/core.py @@ -146,6 +146,11 @@ def leaf_key(nested_key: Text, return keys[-1] +def map_shape(x: Dict[Text, tf.Tensor]) -> Dict[Text, Sequence[int]]: + """Recursively infer tensor shapes for a dictionary of tensors.""" + return tf.nest.map_structure(lambda t: list(tf.shape(t).numpy()), x) + + def pad_axis(x, padding=(0, 0), axis=0, **pad_kwargs): """Pads only one axis of a tensor. diff --git a/ddsp/training/ddsp_export.py b/ddsp/training/ddsp_export.py new file mode 100644 index 00000000..d9ef5ba4 --- /dev/null +++ b/ddsp/training/ddsp_export.py @@ -0,0 +1,195 @@ +# Copyright 2021 The DDSP Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Convert checkpoint to SavedModel, and SavedModel to TFJS / TFLite format. + +Example Usage (Defaults): +ddsp_export --model_path=/path/to/model + +Example Usage (TFJS model): +ddsp_export --model_path=/path/to/model --inference_model=autoencoder \ +--tflite=false --tfjs + +Example Usage (TFLite model): +ddsp_export --model_path=/path/to/model --inference_model=streaming_f0_pw \ +--tflite --tfjs=false + +Example Usage (SavedModel Only): +ddsp_export --model_path=/path/to/model --inference_model=[model_type] \ +--tflite=false --tfjs=false +""" + +import os + +from absl import app +from absl import flags + +from ddsp.training import inference +from ddsp.training import train_util +import gin +import tensorflow as tf +from tensorflowjs.converters import converter + + +flags.DEFINE_string('model_path', '', + 'Path to checkpoint or SavedModel directory. If no ' + 'SavedModel is found, will search for latest checkpoint ' + 'use it to create a SavedModel. Can also provide direct ' + 'path to desired checkpoint. E.g. `/path/to/ckpt-[iter]`.') +flags.DEFINE_string('save_dir', '', + 'Optional directory in which to save converted checkpoint.' + 'If none is provided, it will be FLAGS.model_path if it ' + 'contains a SavedModel, otherwise FLAGS.model_path/export.') + +# Specify model class. +flags.DEFINE_enum('inference_model', 'streaming_f0_pw', + ['autoencoder', + 'streaming_f0_pw', + ], + 'Specify the ddsp.training.inference model to use for ' + 'converting a checkpoint to a SavedModel. Names are ' + 'snake_case versions of class names.') + +# Optional flags. +flags.DEFINE_multi_string('gin_param', [], + 'Gin parameters for custom inference model kwargs.') +flags.DEFINE_boolean('debug', False, 'DEBUG: Do not save the model') + +# Conversion formats. +flags.DEFINE_boolean('tfjs', True, + 'Convert SavedModel to TFJS for deploying on the web.') +flags.DEFINE_boolean('tflite', True, + 'Convert SavedModel to TFLite for embedded C++ apps.') + + +FLAGS = flags.FLAGS + + +def get_inference_model(ckpt): + """Restore model from checkpoint using global FLAGS. + + Use --gin_param for any custom kwargs for model constructors. + Args: + ckpt: Path to the checkpoint. + + Returns: + Inference model, built and restored from checkpoint. + """ + # Parse model kwargs from --gin_param. + with gin.unlock_config(): + gin.parse_config_files_and_bindings(None, FLAGS.gin_param) + + models = { + 'autoencoder': inference.AutoencoderInference, + 'streaming_f0_pw': inference.StreamingF0PwInference, + } + return models[FLAGS.inference_model](ckpt) + + +def ckpt_to_saved_model(ckpt, save_dir): + """Convert Checkpoint to SavedModel.""" + print(f'\nConverting to SavedModel:'f'\nInput: {ckpt}\nOutput: {save_dir}\n') + model = get_inference_model(ckpt) + print('Finshed Loading Model!') + if not FLAGS.debug: + model.save_model(save_dir) + print('SavedModel Conversion Success!') + + +def saved_model_to_tfjs(input_dir, save_dir): + """Convert SavedModel to TFJS model.""" + print(f'\nConverting to TFJS:\nInput:{input_dir}\nOutput:{save_dir}\n') + converter.convert(['--input_format=tf_saved_model', + '--signature_name=serving_default', + '--control_flow_v2=True', + '--skip_op_check', + '--quantize_float16=True', + '--experiments=True', + input_dir, + save_dir]) + print('TFJS Conversion Success!') + + +def saved_model_to_tflite(input_dir, save_dir): + """Convert SavedModel to TFLite model.""" + print(f'\nConverting to TFLite:\nInput:{input_dir}\nOutput:{save_dir}\n') + # Convert the model. + tflite_converter = tf.lite.TFLiteConverter.from_saved_model(input_dir) + tflite_converter.target_spec.supported_ops = [ + tf.lite.OpsSet.TFLITE_BUILTINS, # Enable TensorFlow Lite ops. + tf.lite.OpsSet.SELECT_TF_OPS, # Enable extended TensorFlow ops. + ] + tflite_model = tflite_converter.convert() # Byte string. + # Save the model. + save_path = os.path.join(save_dir, 'model.tflite') + with tf.io.gfile.GFile(save_path, 'wb') as f: + f.write(tflite_model) + print('TFLite Conversion Success!') + + +def ensure_exits(dir_path): + """Make directory if none exists.""" + if not tf.io.gfile.exists(dir_path): + tf.io.gfile.makedirs(dir_path) + + +def main(unused_argv): + model_path = train_util.expand_path(FLAGS.model_path) + + # Figure out what type the model path is. + is_saved_model = tf.io.gfile.exists( + os.path.join(model_path, 'saved_model.pb')) + is_ckpt = not tf.io.gfile.isdir(model_path) + + # Infer save directory path. + if FLAGS.save_dir: + save_dir = FLAGS.save_dir + else: + if is_saved_model: + # If model_path is a SavedModel, use that directory. + save_dir = model_path + elif is_ckpt: + # If model_path is a checkpoint file, use the directory of the file. + save_dir = os.path.join(os.path.dirname(model_path), 'export') + else: + # If model_path is a checkpoint directory, use child export directory. + save_dir = os.path.join(model_path, 'export') + + # Make a new save directory. + save_dir = train_util.expand_path(save_dir) + ensure_exits(save_dir) + + # Create SavedModel if none already exists. + if not is_saved_model: + ckpt_to_saved_model(model_path, save_dir) + + # Convert SavedModel. + if FLAGS.tfjs: + tfjs_dir = os.path.join(save_dir, 'tfjs') + ensure_exits(tfjs_dir) + saved_model_to_tfjs(save_dir, tfjs_dir) + + if FLAGS.tflite: + tflite_dir = os.path.join(save_dir, 'tflite') + ensure_exits(tflite_dir) + saved_model_to_tflite(save_dir, tflite_dir) + + +def console_entry_point(): + """From pip installed script.""" + app.run(main) + + +if __name__ == '__main__': + console_entry_point() diff --git a/ddsp/training/ddsp_run.py b/ddsp/training/ddsp_run.py index ce513ce4..b6ccd290 100644 --- a/ddsp/training/ddsp_run.py +++ b/ddsp/training/ddsp_run.py @@ -142,11 +142,13 @@ def parse_gin(restore_dir): gin.parse_config_file(eval_default) # Load operative_config if it exists (model has already trained). - operative_config = train_util.get_latest_operative_config(restore_dir) - if tf.io.gfile.exists(operative_config): + try: + operative_config = train_util.get_latest_operative_config(restore_dir) logging.info('Using operative config: %s', operative_config) operative_config = cloud.make_file_paths_local(operative_config, GIN_PATH) gin.parse_config_file(operative_config, skip_unknown=True) + except FileNotFoundError: + logging.info('Operative config not found in %s', restore_dir) # User gin config and user hyperparameters from flags. gin_file = cloud.make_file_paths_local(FLAGS.gin_file, GIN_PATH) diff --git a/ddsp/training/eval_util.py b/ddsp/training/eval_util.py index ef31f029..d05bc302 100644 --- a/ddsp/training/eval_util.py +++ b/ddsp/training/eval_util.py @@ -94,7 +94,11 @@ def evaluate_or_sample(data_provider, dataset_iter = iter(dataset) # Load model. - model.restore(checkpoint_path) + try: + model.restore(checkpoint_path) + except FileNotFoundError: + logging.warn('No existing checkpoint found in %s, skipping ' + 'checkpoint loading.', restore_dir) # Iterate through dataset and make predictions checkpoint_start_time = time.time() diff --git a/ddsp/training/inference.py b/ddsp/training/inference.py index 28587a37..3ffa40be 100644 --- a/ddsp/training/inference.py +++ b/ddsp/training/inference.py @@ -15,11 +15,20 @@ # Lint as: python3 """Constructs inference version of the models. -These models can be stored as SavedModels by calling model.save() and used -just like other SavedModels. -""" +N.B. (jesseengel): I tried to make a nice base class. I tried both with multiple +inheritance, and encapsulation, but restoring model parameters seems a bit +fragile given that TF implicitly uses the Python object model for checkpoints, +so I decided to opt for code duplication to make things more robust and preserve +the python object model structure of the original ddsp.training models. + +That said, inference models should satisfy the following interface. -import os +Interface: + Initialize from checkpoint: `model = InferenceModel(ckpt_path)` + Create SavedModel: `model.save_model(save_dir)` + +Need to use model.save_model() as can't override keras model.save(). +""" import ddsp from ddsp.training import models @@ -28,67 +37,62 @@ import tensorflow as tf -class InferenceModel(object): - """Base class for inference models.""" - - def __init__(self, ckpt, model_class, **kwargs): - self.parse_gin_config(ckpt) - model_class.__init__(self, **kwargs) - self.restore(ckpt) - self.build_network() - - def parse_gin_config(self, ckpt): - with gin.unlock_config(): - ckpt_dir = os.path.dirname(ckpt) - operative_config = train_util.get_latest_operative_config(ckpt_dir) - print(f'Parsing from operative_config {operative_config}') - gin.parse_config_file(operative_config, skip_unknown=True) - - def build_network(self): - """Run a fake batch through the network.""" - raise NotImplementedError('Need to specify build_network() method.') - - def save_model(self, save_dir): - """Save model to save_dir, override for custom function signatures.""" - self.save(save_dir) +def parse_operative_config(ckpt_dir): + with gin.unlock_config(): + operative_config = train_util.get_latest_operative_config(ckpt_dir) + print(f'Parsing from operative_config {operative_config}') + gin.parse_config_file(operative_config, skip_unknown=True) @gin.configurable -class AutoencoderInference(models.Autoencoder, InferenceModel): +class AutoencoderInference(models.Autoencoder): """Create an inference-only version of the model.""" def __init__(self, ckpt, length_seconds=4, - sample_rate=16000, - frame_rate=250, + remove_reverb=True, + verbose=True, **kwargs): - # pylint: disable=super-init-not-called self.length_seconds = length_seconds - self.sample_rate = sample_rate - self.frame_rate = frame_rate - self.hop_size = int(sample_rate / frame_rate) - self.time_steps = int(length_seconds * sample_rate / self.hop_size) - self.n_samples = self.time_steps * self.hop_size - self.n_frames = int(frame_rate * length_seconds) - InferenceModel.__init__(self, ckpt, models.Autoencoder, **kwargs) - - @tf.function - def call(self, input_dict): - """Run the core of the network, get predictions.""" - input_dict = ddsp.core.copy_if_tf_function(input_dict) - return super().call(input_dict, training=False) + self.remove_reverb = remove_reverb + self.configure_gin(ckpt) + super().__init__(**kwargs) + self.restore(ckpt, verbose=verbose) + self.build_network() - def parse_gin_config(self, ckpt): - """Parse the model operative config with new length parameters.""" - with gin.unlock_config(): - ckpt_dir = os.path.dirname(ckpt) - operative_config = train_util.get_latest_operative_config(ckpt_dir) - print(f'Parsing from operative_config {operative_config}') - gin.parse_config_file(operative_config, skip_unknown=True) - # Set gin params to new length. + def configure_gin(self, ckpt): + """Parse the model operative config to infer new length parameters.""" + parse_operative_config(ckpt) + + # Get preprocessor_type, + ref = gin.query_parameter('Autoencoder.preprocessor') + self.preprocessor_type = ref.config_key[-1].split('.')[-1] + + # Get hop_size, and sample_rate from gin config. + self.sample_rate = gin.query_parameter('Harmonic.sample_rate') + n_samples_train = gin.query_parameter('Harmonic.n_samples') + time_steps_train = gin.query_parameter( + f'{self.preprocessor_type}.time_steps') + self.hop_size = n_samples_train // time_steps_train + + # Get new lengths for inference. + self.n_frames = int(self.length_seconds * self.sample_rate / self.hop_size) + self.n_samples = self.n_frames * self.hop_size + print('N_Samples:', self.n_samples) + print('Hop Size:', self.hop_size) + print('N_Frames:', self.n_frames) + + # Set gin config to new lengths from model properties. + config = [ + f'Harmonic.n_samples = {self.n_samples}', + f'FilteredNoise.n_samples = {self.n_samples}', + f'{self.preprocessor_type}.time_steps = {self.n_frames}', + 'oscillator_bank.use_angular_cumsum = True', + ] + if self.remove_reverb: # Remove reverb processor. - pg_string = """ProcessorGroup.dag = [ + processor_group_string = """ProcessorGroup.dag = [ (@synths.Harmonic(), ['amps', 'harmonic_distribution', 'f0_hz']), (@synths.FilteredNoise(), @@ -96,59 +100,75 @@ def parse_gin_config(self, ckpt): (@processors.Add(), ['filtered_noise/signal', 'harmonic/signal']), ]""" - gin.parse_config([ - 'Harmonic.n_samples=%d' % self.n_samples, - 'FilteredNoise.n_samples=%d' % self.n_samples, - 'F0LoudnessPreprocessor.time_steps=%d' % self.time_steps, - 'oscillator_bank.use_angular_cumsum=True', - pg_string, - ]) + config.append(processor_group_string) + + with gin.unlock_config(): + gin.parse_config(config) + + def save_model(self, save_dir): + """Saves a SavedModel after initialization.""" + self.save(save_dir) def build_network(self): """Run a fake batch through the network.""" + db_key = 'power_db' if 'Power' in self.preprocessor_type else 'loudness_db' input_dict = { - 'loudness_db': tf.zeros([self.n_frames]), + db_key: tf.zeros([self.n_frames]), 'f0_hz': tf.zeros([self.n_frames]), } - print('Inputs to Model:', input_dict) + # Recursive print of shape. + print('Inputs to Model:', ddsp.core.map_shape(input_dict)) unused_outputs = self(input_dict) - print('Outputs', unused_outputs) + print('Outputs from Model:', ddsp.core.map_shape(unused_outputs)) + + @tf.function + def call(self, inputs, **unused_kwargs): + """Run the core of the network, get predictions.""" + inputs = ddsp.core.copy_if_tf_function(inputs) + return super().call(inputs, training=False) @gin.configurable -class StreamingF0PwInference(models.Autoencoder, InferenceModel): +class StreamingF0PwInference(models.Autoencoder): """Create an inference-only version of the model.""" - def __init__(self, ckpt, **kwargs): - # pylint: disable=super-init-not-called - InferenceModel.__init__(self, ckpt, models.Autoencoder, **kwargs) + def __init__(self, ckpt, verbose=True, **kwargs): + self.configure_gin(ckpt) + super().__init__(**kwargs) + self.restore(ckpt, verbose=verbose) + self.build_network() - def parse_gin_config(self, ckpt): + def configure_gin(self, ckpt): """Parse the model operative config with special streaming parameters.""" + parse_operative_config(ckpt) + + # Set streaming specific params. + time_steps = gin.query_parameter('F0PowerPreprocessor.time_steps') + n_samples = gin.query_parameter('Harmonic.n_samples') + samples_per_frame = int(n_samples / time_steps) + config = [ + 'F0PowerPreprocessor.time_steps = 1', + f'Harmonic.n_samples = {samples_per_frame}', + f'FilteredNoise.n_samples = {samples_per_frame}', + ] + + # Remove reverb processor. + processor_group_string = """ProcessorGroup.dag = [ + (@synths.Harmonic(), + ['amps', 'harmonic_distribution', 'f0_hz']), + (@synths.FilteredNoise(), + ['noise_magnitudes']), + (@processors.Add(), + ['filtered_noise/signal', 'harmonic/signal']), + ]""" + config.append(processor_group_string) + with gin.unlock_config(): - ckpt_dir = os.path.dirname(ckpt) - operative_config = train_util.get_latest_operative_config(ckpt_dir) - print(f'Parsing from operative_config {operative_config}') - gin.parse_config_file(operative_config, skip_unknown=True) - # Set streaming specific params. - # Remove reverb processor. - pg_string = """ProcessorGroup.dag = [ - (@synths.Harmonic(), - ['amps', 'harmonic_distribution', 'f0_hz']), - (@synths.FilteredNoise(), - ['noise_magnitudes']), - (@processors.Add(), - ['filtered_noise/signal', 'harmonic/signal']), - ]""" - time_steps = gin.query_parameter('F0PowerPreprocessor.time_steps') - n_samples = gin.query_parameter('Harmonic.n_samples') - samples_per_frame = int(n_samples / time_steps) - gin.parse_config([ - 'F0PowerPreprocessor.time_steps=1', - f'Harmonic.n_samples={samples_per_frame}', - f'FilteredNoise.n_samples={samples_per_frame}', - pg_string, - ]) + gin.parse_config(config) + + def save_model(self, save_dir): + """Saves a SavedModel after initialization.""" + self.save(save_dir) def build_network(self): """Run a fake batch through the network.""" @@ -156,15 +176,15 @@ def build_network(self): 'f0_hz': tf.zeros([1]), 'power_db': tf.zeros([1]), } - print('Inputs to Model:', input_dict) + print('Inputs to Model:', ddsp.core.map_shape(input_dict)) unused_outputs = self(input_dict) - print('Outputs', unused_outputs) + print('Outputs from Model:', ddsp.core.map_shape(unused_outputs)) @tf.function - def call(self, input_dict): + def call(self, inputs, **unused_kwargs): """Convert f0 and loudness to synthesizer parameters.""" - input_dict = ddsp.core.copy_if_tf_function(input_dict) - controls = super().call(input_dict, training=False) + inputs = ddsp.core.copy_if_tf_function(inputs) + controls = super().call(inputs, training=False) amps = controls['harmonic']['controls']['amplitudes'] hd = controls['harmonic']['controls']['harmonic_distribution'] noise = controls['filtered_noise']['controls']['magnitudes'] diff --git a/ddsp/training/models/model.py b/ddsp/training/models/model.py index 74ccd380..3961df26 100644 --- a/ddsp/training/models/model.py +++ b/ddsp/training/models/model.py @@ -71,18 +71,25 @@ def _update_losses_dict(self, loss_objs, *args, **kwargs): losses_dict = loss_obj.get_losses_dict(*args, **kwargs) self._losses_dict.update(losses_dict) - def restore(self, checkpoint_path): - """Restore model and optimizer from a checkpoint.""" + def restore(self, checkpoint_path, verbose=True): + """Restore model and optimizer from a checkpoint. + + Args: + checkpoint_path: Path to checkpoint file or directory. + verbose: Warn about missing variables. + + Raises: + FileNotFoundError: If no checkpoint is found. + """ start_time = time.time() - latest_checkpoint = train_util.get_latest_chekpoint(checkpoint_path) - if latest_checkpoint is not None: - checkpoint = tf.train.Checkpoint(model=self) - checkpoint.restore(latest_checkpoint).expect_partial() - logging.info('Loaded checkpoint %s', latest_checkpoint) - logging.info('Loading model took %.1f seconds', time.time() - start_time) + latest_checkpoint = train_util.get_latest_checkpoint(checkpoint_path) + checkpoint = tf.train.Checkpoint(model=self) + if verbose: + checkpoint.restore(latest_checkpoint) else: - logging.info('Could not find checkpoint to load at %s, skipping.', - checkpoint_path) + checkpoint.restore(latest_checkpoint).expect_partial() + logging.info('Loaded checkpoint %s', latest_checkpoint) + logging.info('Loading model took %.1f seconds', time.time() - start_time) def get_audio_from_outputs(self, outputs): """Extract audio output tensor from outputs dict of call().""" diff --git a/ddsp/training/train_util.py b/ddsp/training/train_util.py index af8517b5..32d8ca81 100644 --- a/ddsp/training/train_util.py +++ b/ddsp/training/train_util.py @@ -80,31 +80,94 @@ def get_strategy(tpu='', cluster_config=''): return strategy -def get_latest_chekpoint(checkpoint_path): +def expand_path(file_path): + return os.path.expanduser(os.path.expandvars(file_path)) + + +def get_latest_file(dir_path, prefix='operative_config-', suffix='.gin'): + """Returns latest file with pattern '/dir_path/prefix[iteration]suffix'. + + Args: + dir_path: Path to the directory. + prefix: Filename prefix, not including directory. + suffix: Filename suffix, including extension. + + Returns: + Path to the latest file + + Raises: + FileNotFoundError: If no files match the pattern + '/dir_path/prefix[int]suffix'. + """ + dir_path = expand_path(dir_path) + dir_prefix = os.path.join(dir_path, prefix) + search_pattern = dir_prefix + '*' + suffix + file_paths = tf.io.gfile.glob(search_pattern) + if not file_paths: + raise FileNotFoundError( + f'No files found matching the pattern \'{search_pattern}\'.') + try: + # Filter to get highest iteration, no negative iterations. + get_iter = lambda fp: abs(int(fp.split(dir_prefix)[-1].split(suffix)[0])) + latest_file = max(file_paths, key=get_iter) + return latest_file + except ValueError: + raise FileNotFoundError( + f'Files found with pattern \'{search_pattern}\' do not match ' + f'the pattern \'{dir_prefix}[iteration_number]{suffix}\'.\n\n' + f'Files found:\n{file_paths}') + + +def get_latest_checkpoint(checkpoint_path): """Helper function to get path to latest checkpoint. Args: checkpoint_path: Path to the directory containing model checkpoints, or - to a specific checkpoint (e.g. `path/to/model.ckpt-iteration`). + to a specific checkpoint (e.g. `/path/to/model.ckpt-iteration`). Returns: - Path to latest checkpoint, or None if none exist. + Path to latest checkpoint. + + Raises: + FileNotFoundError: If no checkpoint is found. """ - checkpoint_path = os.path.expanduser(os.path.expandvars(checkpoint_path)) + checkpoint_path = expand_path(checkpoint_path) is_checkpoint = tf.io.gfile.exists(checkpoint_path + '.index') if is_checkpoint: + # Return the path if it points to a checkpoint. return checkpoint_path else: - # None if no checkpoints, or directory doesn't exist. - return tf.train.latest_checkpoint(checkpoint_path) + # Search using 'checkpoints' file. + # Returns None if no 'checkpoints' file, or directory doesn't exist. + ckpt = tf.train.latest_checkpoint(checkpoint_path) + if ckpt: + return ckpt + else: + # Last resort, look for '/path/ckpt-[iter].index' files. + ckpt_f = get_latest_file(checkpoint_path, prefix='ckpt-', suffix='.index') + return ckpt_f.split('.index')[0] # ---------------------------------- Gin --------------------------------------- def get_latest_operative_config(restore_dir): - """Finds the most recently saved operative_config in a directory.""" - file_paths = tf.io.gfile.glob(os.path.join(restore_dir, 'operative_config*')) - get_iter = lambda file_path: int(file_path.split('-')[-1].split('.gin')[0]) - return max(file_paths, key=get_iter) if file_paths else '' + """Finds the most recently saved operative_config in a directory. + + Args: + restore_dir: Path to directory with gin operative_configs. Will also work + if passing a path to a file in that directory such as a checkpoint. + + Returns: + Filepath to most recent operative config. + + Raises: + FileNotFoundError: If no config is found. + """ + try: + return get_latest_file( + restore_dir, prefix='operative_config-', suffix='.gin') + except FileNotFoundError: + return get_latest_file( + os.path.dirname(restore_dir), prefix='operative_config-', suffix='.gin') def write_gin_config(summary_writer, save_dir, step): @@ -198,7 +261,11 @@ def train(data_provider, trainer.build(next(dataset_iter)) # Load latest checkpoint if one exists in load directory. - trainer.restore(restore_dir) + try: + trainer.restore(restore_dir) + except FileNotFoundError: + logging.info('No existing checkpoint found in %s, skipping ' + 'checkpoint loading.', restore_dir) if save_dir: # Set up the summary writer and metrics. diff --git a/ddsp/training/trainers.py b/ddsp/training/trainers.py index 376001ee..c59ecfff 100644 --- a/ddsp/training/trainers.py +++ b/ddsp/training/trainers.py @@ -83,7 +83,15 @@ def save(self, save_dir): logging.info('Saving model took %.1f seconds', time.time() - start_time) def restore(self, checkpoint_path, restore_keys=None): - """Restore model and optimizer from a checkpoint if it exists.""" + """Restore model and optimizer from a checkpoint if it exists. + + Args: + checkpoint_path: Path to checkpoint file or directory. + restore_keys: Optional list of strings for submodules to restore. + + Raises: + FileNotFoundError: If no checkpoint is found. + """ logging.info('Restoring from checkpoint...') start_time = time.time() @@ -105,19 +113,16 @@ def restore(self, checkpoint_path, restore_keys=None): # Restore from latest checkpoint. checkpoint = self.get_checkpoint(model) - latest_checkpoint = train_util.get_latest_chekpoint(checkpoint_path) - if latest_checkpoint is not None: - # checkpoint.restore must be within a strategy.scope() so that optimizer - # slot variables are mirrored. - with self.strategy.scope(): - if restore_keys is None: - checkpoint.restore(latest_checkpoint) - else: - checkpoint.restore(latest_checkpoint).expect_partial() - logging.info('Loaded checkpoint %s', latest_checkpoint) - logging.info('Loading model took %.1f seconds', time.time() - start_time) - else: - logging.info('No checkpoint, skipping.') + latest_checkpoint = train_util.get_latest_checkpoint(checkpoint_path) + # checkpoint.restore must be within a strategy.scope() so that optimizer + # slot variables are mirrored. + with self.strategy.scope(): + if restore_keys is None: + checkpoint.restore(latest_checkpoint) + else: + checkpoint.restore(latest_checkpoint).expect_partial() + logging.info('Loaded checkpoint %s', latest_checkpoint) + logging.info('Loading model took %.1f seconds', time.time() - start_time) @property def step(self): diff --git a/ddsp/version.py b/ddsp/version.py index 83f02bc2..0741e89a 100644 --- a/ddsp/version.py +++ b/ddsp/version.py @@ -19,4 +19,4 @@ pulling in all the dependencies in __init__.py. """ -__version__ = '1.2.0' +__version__ = '1.3.0' diff --git a/setup.py b/setup.py index ab5bb467..61b3b195 100644 --- a/setup.py +++ b/setup.py @@ -53,7 +53,10 @@ 'scipy', 'six', 'tensorflow', + # TODO(jesseengel): Update to v2.5 when CREPE is no longer blocking. + # 'tensorflow==2.5.0rc0', 'tensorflow-addons', + 'tensorflowjs', 'tensorflow-probability', # TODO(adarob): Switch to tensorflow_datasets once includes nsynth 2.3. 'tfds-nightly', @@ -70,6 +73,7 @@ }, entry_points={ 'console_scripts': [ + 'ddsp_export = ddsp.training.ddsp_export:console_entry_point', 'ddsp_run = ddsp.training.ddsp_run:console_entry_point', 'ddsp_prepare_tfrecord = ddsp.training.data_preparation.ddsp_prepare_tfrecord:console_entry_point', 'ddsp_generate_synthetic_dataset = ddsp.training.data_preparation.ddsp_generate_synthetic_dataset:console_entry_point',