diff --git a/ddsp/core.py b/ddsp/core.py index 2cef68f9..3e4efb89 100644 --- a/ddsp/core.py +++ b/ddsp/core.py @@ -146,6 +146,11 @@ def leaf_key(nested_key: Text, return keys[-1] +def map_shape(x: Dict[Text, tf.Tensor]) -> Dict[Text, Sequence[int]]: + """Recursively infer tensor shapes for a dictionary of tensors.""" + return tf.nest.map_structure(lambda t: list(tf.shape(t).numpy()), x) + + def pad_axis(x, padding=(0, 0), axis=0, **pad_kwargs): """Pads only one axis of a tensor. diff --git a/ddsp/training/ddsp_export.py b/ddsp/training/ddsp_export.py new file mode 100644 index 00000000..d9ef5ba4 --- /dev/null +++ b/ddsp/training/ddsp_export.py @@ -0,0 +1,195 @@ +# Copyright 2021 The DDSP Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Convert checkpoint to SavedModel, and SavedModel to TFJS / TFLite format. + +Example Usage (Defaults): +ddsp_export --model_path=/path/to/model + +Example Usage (TFJS model): +ddsp_export --model_path=/path/to/model --inference_model=autoencoder \ +--tflite=false --tfjs + +Example Usage (TFLite model): +ddsp_export --model_path=/path/to/model --inference_model=streaming_f0_pw \ +--tflite --tfjs=false + +Example Usage (SavedModel Only): +ddsp_export --model_path=/path/to/model --inference_model=[model_type] \ +--tflite=false --tfjs=false +""" + +import os + +from absl import app +from absl import flags + +from ddsp.training import inference +from ddsp.training import train_util +import gin +import tensorflow as tf +from tensorflowjs.converters import converter + + +flags.DEFINE_string('model_path', '', + 'Path to checkpoint or SavedModel directory. If no ' + 'SavedModel is found, will search for latest checkpoint ' + 'use it to create a SavedModel. Can also provide direct ' + 'path to desired checkpoint. E.g. `/path/to/ckpt-[iter]`.') +flags.DEFINE_string('save_dir', '', + 'Optional directory in which to save converted checkpoint.' + 'If none is provided, it will be FLAGS.model_path if it ' + 'contains a SavedModel, otherwise FLAGS.model_path/export.') + +# Specify model class. +flags.DEFINE_enum('inference_model', 'streaming_f0_pw', + ['autoencoder', + 'streaming_f0_pw', + ], + 'Specify the ddsp.training.inference model to use for ' + 'converting a checkpoint to a SavedModel. Names are ' + 'snake_case versions of class names.') + +# Optional flags. +flags.DEFINE_multi_string('gin_param', [], + 'Gin parameters for custom inference model kwargs.') +flags.DEFINE_boolean('debug', False, 'DEBUG: Do not save the model') + +# Conversion formats. +flags.DEFINE_boolean('tfjs', True, + 'Convert SavedModel to TFJS for deploying on the web.') +flags.DEFINE_boolean('tflite', True, + 'Convert SavedModel to TFLite for embedded C++ apps.') + + +FLAGS = flags.FLAGS + + +def get_inference_model(ckpt): + """Restore model from checkpoint using global FLAGS. + + Use --gin_param for any custom kwargs for model constructors. + Args: + ckpt: Path to the checkpoint. + + Returns: + Inference model, built and restored from checkpoint. + """ + # Parse model kwargs from --gin_param. + with gin.unlock_config(): + gin.parse_config_files_and_bindings(None, FLAGS.gin_param) + + models = { + 'autoencoder': inference.AutoencoderInference, + 'streaming_f0_pw': inference.StreamingF0PwInference, + } + return models[FLAGS.inference_model](ckpt) + + +def ckpt_to_saved_model(ckpt, save_dir): + """Convert Checkpoint to SavedModel.""" + print(f'\nConverting to SavedModel:'f'\nInput: {ckpt}\nOutput: {save_dir}\n') + model = get_inference_model(ckpt) + print('Finshed Loading Model!') + if not FLAGS.debug: + model.save_model(save_dir) + print('SavedModel Conversion Success!') + + +def saved_model_to_tfjs(input_dir, save_dir): + """Convert SavedModel to TFJS model.""" + print(f'\nConverting to TFJS:\nInput:{input_dir}\nOutput:{save_dir}\n') + converter.convert(['--input_format=tf_saved_model', + '--signature_name=serving_default', + '--control_flow_v2=True', + '--skip_op_check', + '--quantize_float16=True', + '--experiments=True', + input_dir, + save_dir]) + print('TFJS Conversion Success!') + + +def saved_model_to_tflite(input_dir, save_dir): + """Convert SavedModel to TFLite model.""" + print(f'\nConverting to TFLite:\nInput:{input_dir}\nOutput:{save_dir}\n') + # Convert the model. + tflite_converter = tf.lite.TFLiteConverter.from_saved_model(input_dir) + tflite_converter.target_spec.supported_ops = [ + tf.lite.OpsSet.TFLITE_BUILTINS, # Enable TensorFlow Lite ops. + tf.lite.OpsSet.SELECT_TF_OPS, # Enable extended TensorFlow ops. + ] + tflite_model = tflite_converter.convert() # Byte string. + # Save the model. + save_path = os.path.join(save_dir, 'model.tflite') + with tf.io.gfile.GFile(save_path, 'wb') as f: + f.write(tflite_model) + print('TFLite Conversion Success!') + + +def ensure_exits(dir_path): + """Make directory if none exists.""" + if not tf.io.gfile.exists(dir_path): + tf.io.gfile.makedirs(dir_path) + + +def main(unused_argv): + model_path = train_util.expand_path(FLAGS.model_path) + + # Figure out what type the model path is. + is_saved_model = tf.io.gfile.exists( + os.path.join(model_path, 'saved_model.pb')) + is_ckpt = not tf.io.gfile.isdir(model_path) + + # Infer save directory path. + if FLAGS.save_dir: + save_dir = FLAGS.save_dir + else: + if is_saved_model: + # If model_path is a SavedModel, use that directory. + save_dir = model_path + elif is_ckpt: + # If model_path is a checkpoint file, use the directory of the file. + save_dir = os.path.join(os.path.dirname(model_path), 'export') + else: + # If model_path is a checkpoint directory, use child export directory. + save_dir = os.path.join(model_path, 'export') + + # Make a new save directory. + save_dir = train_util.expand_path(save_dir) + ensure_exits(save_dir) + + # Create SavedModel if none already exists. + if not is_saved_model: + ckpt_to_saved_model(model_path, save_dir) + + # Convert SavedModel. + if FLAGS.tfjs: + tfjs_dir = os.path.join(save_dir, 'tfjs') + ensure_exits(tfjs_dir) + saved_model_to_tfjs(save_dir, tfjs_dir) + + if FLAGS.tflite: + tflite_dir = os.path.join(save_dir, 'tflite') + ensure_exits(tflite_dir) + saved_model_to_tflite(save_dir, tflite_dir) + + +def console_entry_point(): + """From pip installed script.""" + app.run(main) + + +if __name__ == '__main__': + console_entry_point() diff --git a/ddsp/training/ddsp_run.py b/ddsp/training/ddsp_run.py index ce513ce4..b6ccd290 100644 --- a/ddsp/training/ddsp_run.py +++ b/ddsp/training/ddsp_run.py @@ -142,11 +142,13 @@ def parse_gin(restore_dir): gin.parse_config_file(eval_default) # Load operative_config if it exists (model has already trained). - operative_config = train_util.get_latest_operative_config(restore_dir) - if tf.io.gfile.exists(operative_config): + try: + operative_config = train_util.get_latest_operative_config(restore_dir) logging.info('Using operative config: %s', operative_config) operative_config = cloud.make_file_paths_local(operative_config, GIN_PATH) gin.parse_config_file(operative_config, skip_unknown=True) + except FileNotFoundError: + logging.info('Operative config not found in %s', restore_dir) # User gin config and user hyperparameters from flags. gin_file = cloud.make_file_paths_local(FLAGS.gin_file, GIN_PATH) diff --git a/ddsp/training/eval_util.py b/ddsp/training/eval_util.py index ef31f029..d05bc302 100644 --- a/ddsp/training/eval_util.py +++ b/ddsp/training/eval_util.py @@ -94,7 +94,11 @@ def evaluate_or_sample(data_provider, dataset_iter = iter(dataset) # Load model. - model.restore(checkpoint_path) + try: + model.restore(checkpoint_path) + except FileNotFoundError: + logging.warn('No existing checkpoint found in %s, skipping ' + 'checkpoint loading.', restore_dir) # Iterate through dataset and make predictions checkpoint_start_time = time.time() diff --git a/ddsp/training/inference.py b/ddsp/training/inference.py index 28587a37..3ffa40be 100644 --- a/ddsp/training/inference.py +++ b/ddsp/training/inference.py @@ -15,11 +15,20 @@ # Lint as: python3 """Constructs inference version of the models. -These models can be stored as SavedModels by calling model.save() and used -just like other SavedModels. -""" +N.B. (jesseengel): I tried to make a nice base class. I tried both with multiple +inheritance, and encapsulation, but restoring model parameters seems a bit +fragile given that TF implicitly uses the Python object model for checkpoints, +so I decided to opt for code duplication to make things more robust and preserve +the python object model structure of the original ddsp.training models. + +That said, inference models should satisfy the following interface. -import os +Interface: + Initialize from checkpoint: `model = InferenceModel(ckpt_path)` + Create SavedModel: `model.save_model(save_dir)` + +Need to use model.save_model() as can't override keras model.save(). +""" import ddsp from ddsp.training import models @@ -28,67 +37,62 @@ import tensorflow as tf -class InferenceModel(object): - """Base class for inference models.""" - - def __init__(self, ckpt, model_class, **kwargs): - self.parse_gin_config(ckpt) - model_class.__init__(self, **kwargs) - self.restore(ckpt) - self.build_network() - - def parse_gin_config(self, ckpt): - with gin.unlock_config(): - ckpt_dir = os.path.dirname(ckpt) - operative_config = train_util.get_latest_operative_config(ckpt_dir) - print(f'Parsing from operative_config {operative_config}') - gin.parse_config_file(operative_config, skip_unknown=True) - - def build_network(self): - """Run a fake batch through the network.""" - raise NotImplementedError('Need to specify build_network() method.') - - def save_model(self, save_dir): - """Save model to save_dir, override for custom function signatures.""" - self.save(save_dir) +def parse_operative_config(ckpt_dir): + with gin.unlock_config(): + operative_config = train_util.get_latest_operative_config(ckpt_dir) + print(f'Parsing from operative_config {operative_config}') + gin.parse_config_file(operative_config, skip_unknown=True) @gin.configurable -class AutoencoderInference(models.Autoencoder, InferenceModel): +class AutoencoderInference(models.Autoencoder): """Create an inference-only version of the model.""" def __init__(self, ckpt, length_seconds=4, - sample_rate=16000, - frame_rate=250, + remove_reverb=True, + verbose=True, **kwargs): - # pylint: disable=super-init-not-called self.length_seconds = length_seconds - self.sample_rate = sample_rate - self.frame_rate = frame_rate - self.hop_size = int(sample_rate / frame_rate) - self.time_steps = int(length_seconds * sample_rate / self.hop_size) - self.n_samples = self.time_steps * self.hop_size - self.n_frames = int(frame_rate * length_seconds) - InferenceModel.__init__(self, ckpt, models.Autoencoder, **kwargs) - - @tf.function - def call(self, input_dict): - """Run the core of the network, get predictions.""" - input_dict = ddsp.core.copy_if_tf_function(input_dict) - return super().call(input_dict, training=False) + self.remove_reverb = remove_reverb + self.configure_gin(ckpt) + super().__init__(**kwargs) + self.restore(ckpt, verbose=verbose) + self.build_network() - def parse_gin_config(self, ckpt): - """Parse the model operative config with new length parameters.""" - with gin.unlock_config(): - ckpt_dir = os.path.dirname(ckpt) - operative_config = train_util.get_latest_operative_config(ckpt_dir) - print(f'Parsing from operative_config {operative_config}') - gin.parse_config_file(operative_config, skip_unknown=True) - # Set gin params to new length. + def configure_gin(self, ckpt): + """Parse the model operative config to infer new length parameters.""" + parse_operative_config(ckpt) + + # Get preprocessor_type, + ref = gin.query_parameter('Autoencoder.preprocessor') + self.preprocessor_type = ref.config_key[-1].split('.')[-1] + + # Get hop_size, and sample_rate from gin config. + self.sample_rate = gin.query_parameter('Harmonic.sample_rate') + n_samples_train = gin.query_parameter('Harmonic.n_samples') + time_steps_train = gin.query_parameter( + f'{self.preprocessor_type}.time_steps') + self.hop_size = n_samples_train // time_steps_train + + # Get new lengths for inference. + self.n_frames = int(self.length_seconds * self.sample_rate / self.hop_size) + self.n_samples = self.n_frames * self.hop_size + print('N_Samples:', self.n_samples) + print('Hop Size:', self.hop_size) + print('N_Frames:', self.n_frames) + + # Set gin config to new lengths from model properties. + config = [ + f'Harmonic.n_samples = {self.n_samples}', + f'FilteredNoise.n_samples = {self.n_samples}', + f'{self.preprocessor_type}.time_steps = {self.n_frames}', + 'oscillator_bank.use_angular_cumsum = True', + ] + if self.remove_reverb: # Remove reverb processor. - pg_string = """ProcessorGroup.dag = [ + processor_group_string = """ProcessorGroup.dag = [ (@synths.Harmonic(), ['amps', 'harmonic_distribution', 'f0_hz']), (@synths.FilteredNoise(), @@ -96,59 +100,75 @@ def parse_gin_config(self, ckpt): (@processors.Add(), ['filtered_noise/signal', 'harmonic/signal']), ]""" - gin.parse_config([ - 'Harmonic.n_samples=%d' % self.n_samples, - 'FilteredNoise.n_samples=%d' % self.n_samples, - 'F0LoudnessPreprocessor.time_steps=%d' % self.time_steps, - 'oscillator_bank.use_angular_cumsum=True', - pg_string, - ]) + config.append(processor_group_string) + + with gin.unlock_config(): + gin.parse_config(config) + + def save_model(self, save_dir): + """Saves a SavedModel after initialization.""" + self.save(save_dir) def build_network(self): """Run a fake batch through the network.""" + db_key = 'power_db' if 'Power' in self.preprocessor_type else 'loudness_db' input_dict = { - 'loudness_db': tf.zeros([self.n_frames]), + db_key: tf.zeros([self.n_frames]), 'f0_hz': tf.zeros([self.n_frames]), } - print('Inputs to Model:', input_dict) + # Recursive print of shape. + print('Inputs to Model:', ddsp.core.map_shape(input_dict)) unused_outputs = self(input_dict) - print('Outputs', unused_outputs) + print('Outputs from Model:', ddsp.core.map_shape(unused_outputs)) + + @tf.function + def call(self, inputs, **unused_kwargs): + """Run the core of the network, get predictions.""" + inputs = ddsp.core.copy_if_tf_function(inputs) + return super().call(inputs, training=False) @gin.configurable -class StreamingF0PwInference(models.Autoencoder, InferenceModel): +class StreamingF0PwInference(models.Autoencoder): """Create an inference-only version of the model.""" - def __init__(self, ckpt, **kwargs): - # pylint: disable=super-init-not-called - InferenceModel.__init__(self, ckpt, models.Autoencoder, **kwargs) + def __init__(self, ckpt, verbose=True, **kwargs): + self.configure_gin(ckpt) + super().__init__(**kwargs) + self.restore(ckpt, verbose=verbose) + self.build_network() - def parse_gin_config(self, ckpt): + def configure_gin(self, ckpt): """Parse the model operative config with special streaming parameters.""" + parse_operative_config(ckpt) + + # Set streaming specific params. + time_steps = gin.query_parameter('F0PowerPreprocessor.time_steps') + n_samples = gin.query_parameter('Harmonic.n_samples') + samples_per_frame = int(n_samples / time_steps) + config = [ + 'F0PowerPreprocessor.time_steps = 1', + f'Harmonic.n_samples = {samples_per_frame}', + f'FilteredNoise.n_samples = {samples_per_frame}', + ] + + # Remove reverb processor. + processor_group_string = """ProcessorGroup.dag = [ + (@synths.Harmonic(), + ['amps', 'harmonic_distribution', 'f0_hz']), + (@synths.FilteredNoise(), + ['noise_magnitudes']), + (@processors.Add(), + ['filtered_noise/signal', 'harmonic/signal']), + ]""" + config.append(processor_group_string) + with gin.unlock_config(): - ckpt_dir = os.path.dirname(ckpt) - operative_config = train_util.get_latest_operative_config(ckpt_dir) - print(f'Parsing from operative_config {operative_config}') - gin.parse_config_file(operative_config, skip_unknown=True) - # Set streaming specific params. - # Remove reverb processor. - pg_string = """ProcessorGroup.dag = [ - (@synths.Harmonic(), - ['amps', 'harmonic_distribution', 'f0_hz']), - (@synths.FilteredNoise(), - ['noise_magnitudes']), - (@processors.Add(), - ['filtered_noise/signal', 'harmonic/signal']), - ]""" - time_steps = gin.query_parameter('F0PowerPreprocessor.time_steps') - n_samples = gin.query_parameter('Harmonic.n_samples') - samples_per_frame = int(n_samples / time_steps) - gin.parse_config([ - 'F0PowerPreprocessor.time_steps=1', - f'Harmonic.n_samples={samples_per_frame}', - f'FilteredNoise.n_samples={samples_per_frame}', - pg_string, - ]) + gin.parse_config(config) + + def save_model(self, save_dir): + """Saves a SavedModel after initialization.""" + self.save(save_dir) def build_network(self): """Run a fake batch through the network.""" @@ -156,15 +176,15 @@ def build_network(self): 'f0_hz': tf.zeros([1]), 'power_db': tf.zeros([1]), } - print('Inputs to Model:', input_dict) + print('Inputs to Model:', ddsp.core.map_shape(input_dict)) unused_outputs = self(input_dict) - print('Outputs', unused_outputs) + print('Outputs from Model:', ddsp.core.map_shape(unused_outputs)) @tf.function - def call(self, input_dict): + def call(self, inputs, **unused_kwargs): """Convert f0 and loudness to synthesizer parameters.""" - input_dict = ddsp.core.copy_if_tf_function(input_dict) - controls = super().call(input_dict, training=False) + inputs = ddsp.core.copy_if_tf_function(inputs) + controls = super().call(inputs, training=False) amps = controls['harmonic']['controls']['amplitudes'] hd = controls['harmonic']['controls']['harmonic_distribution'] noise = controls['filtered_noise']['controls']['magnitudes'] diff --git a/ddsp/training/models/model.py b/ddsp/training/models/model.py index 74ccd380..3961df26 100644 --- a/ddsp/training/models/model.py +++ b/ddsp/training/models/model.py @@ -71,18 +71,25 @@ def _update_losses_dict(self, loss_objs, *args, **kwargs): losses_dict = loss_obj.get_losses_dict(*args, **kwargs) self._losses_dict.update(losses_dict) - def restore(self, checkpoint_path): - """Restore model and optimizer from a checkpoint.""" + def restore(self, checkpoint_path, verbose=True): + """Restore model and optimizer from a checkpoint. + + Args: + checkpoint_path: Path to checkpoint file or directory. + verbose: Warn about missing variables. + + Raises: + FileNotFoundError: If no checkpoint is found. + """ start_time = time.time() - latest_checkpoint = train_util.get_latest_chekpoint(checkpoint_path) - if latest_checkpoint is not None: - checkpoint = tf.train.Checkpoint(model=self) - checkpoint.restore(latest_checkpoint).expect_partial() - logging.info('Loaded checkpoint %s', latest_checkpoint) - logging.info('Loading model took %.1f seconds', time.time() - start_time) + latest_checkpoint = train_util.get_latest_checkpoint(checkpoint_path) + checkpoint = tf.train.Checkpoint(model=self) + if verbose: + checkpoint.restore(latest_checkpoint) else: - logging.info('Could not find checkpoint to load at %s, skipping.', - checkpoint_path) + checkpoint.restore(latest_checkpoint).expect_partial() + logging.info('Loaded checkpoint %s', latest_checkpoint) + logging.info('Loading model took %.1f seconds', time.time() - start_time) def get_audio_from_outputs(self, outputs): """Extract audio output tensor from outputs dict of call().""" diff --git a/ddsp/training/train_util.py b/ddsp/training/train_util.py index af8517b5..32d8ca81 100644 --- a/ddsp/training/train_util.py +++ b/ddsp/training/train_util.py @@ -80,31 +80,94 @@ def get_strategy(tpu='', cluster_config=''): return strategy -def get_latest_chekpoint(checkpoint_path): +def expand_path(file_path): + return os.path.expanduser(os.path.expandvars(file_path)) + + +def get_latest_file(dir_path, prefix='operative_config-', suffix='.gin'): + """Returns latest file with pattern '/dir_path/prefix[iteration]suffix'. + + Args: + dir_path: Path to the directory. + prefix: Filename prefix, not including directory. + suffix: Filename suffix, including extension. + + Returns: + Path to the latest file + + Raises: + FileNotFoundError: If no files match the pattern + '/dir_path/prefix[int]suffix'. + """ + dir_path = expand_path(dir_path) + dir_prefix = os.path.join(dir_path, prefix) + search_pattern = dir_prefix + '*' + suffix + file_paths = tf.io.gfile.glob(search_pattern) + if not file_paths: + raise FileNotFoundError( + f'No files found matching the pattern \'{search_pattern}\'.') + try: + # Filter to get highest iteration, no negative iterations. + get_iter = lambda fp: abs(int(fp.split(dir_prefix)[-1].split(suffix)[0])) + latest_file = max(file_paths, key=get_iter) + return latest_file + except ValueError: + raise FileNotFoundError( + f'Files found with pattern \'{search_pattern}\' do not match ' + f'the pattern \'{dir_prefix}[iteration_number]{suffix}\'.\n\n' + f'Files found:\n{file_paths}') + + +def get_latest_checkpoint(checkpoint_path): """Helper function to get path to latest checkpoint. Args: checkpoint_path: Path to the directory containing model checkpoints, or - to a specific checkpoint (e.g. `path/to/model.ckpt-iteration`). + to a specific checkpoint (e.g. `/path/to/model.ckpt-iteration`). Returns: - Path to latest checkpoint, or None if none exist. + Path to latest checkpoint. + + Raises: + FileNotFoundError: If no checkpoint is found. """ - checkpoint_path = os.path.expanduser(os.path.expandvars(checkpoint_path)) + checkpoint_path = expand_path(checkpoint_path) is_checkpoint = tf.io.gfile.exists(checkpoint_path + '.index') if is_checkpoint: + # Return the path if it points to a checkpoint. return checkpoint_path else: - # None if no checkpoints, or directory doesn't exist. - return tf.train.latest_checkpoint(checkpoint_path) + # Search using 'checkpoints' file. + # Returns None if no 'checkpoints' file, or directory doesn't exist. + ckpt = tf.train.latest_checkpoint(checkpoint_path) + if ckpt: + return ckpt + else: + # Last resort, look for '/path/ckpt-[iter].index' files. + ckpt_f = get_latest_file(checkpoint_path, prefix='ckpt-', suffix='.index') + return ckpt_f.split('.index')[0] # ---------------------------------- Gin --------------------------------------- def get_latest_operative_config(restore_dir): - """Finds the most recently saved operative_config in a directory.""" - file_paths = tf.io.gfile.glob(os.path.join(restore_dir, 'operative_config*')) - get_iter = lambda file_path: int(file_path.split('-')[-1].split('.gin')[0]) - return max(file_paths, key=get_iter) if file_paths else '' + """Finds the most recently saved operative_config in a directory. + + Args: + restore_dir: Path to directory with gin operative_configs. Will also work + if passing a path to a file in that directory such as a checkpoint. + + Returns: + Filepath to most recent operative config. + + Raises: + FileNotFoundError: If no config is found. + """ + try: + return get_latest_file( + restore_dir, prefix='operative_config-', suffix='.gin') + except FileNotFoundError: + return get_latest_file( + os.path.dirname(restore_dir), prefix='operative_config-', suffix='.gin') def write_gin_config(summary_writer, save_dir, step): @@ -198,7 +261,11 @@ def train(data_provider, trainer.build(next(dataset_iter)) # Load latest checkpoint if one exists in load directory. - trainer.restore(restore_dir) + try: + trainer.restore(restore_dir) + except FileNotFoundError: + logging.info('No existing checkpoint found in %s, skipping ' + 'checkpoint loading.', restore_dir) if save_dir: # Set up the summary writer and metrics. diff --git a/ddsp/training/trainers.py b/ddsp/training/trainers.py index 376001ee..c59ecfff 100644 --- a/ddsp/training/trainers.py +++ b/ddsp/training/trainers.py @@ -83,7 +83,15 @@ def save(self, save_dir): logging.info('Saving model took %.1f seconds', time.time() - start_time) def restore(self, checkpoint_path, restore_keys=None): - """Restore model and optimizer from a checkpoint if it exists.""" + """Restore model and optimizer from a checkpoint if it exists. + + Args: + checkpoint_path: Path to checkpoint file or directory. + restore_keys: Optional list of strings for submodules to restore. + + Raises: + FileNotFoundError: If no checkpoint is found. + """ logging.info('Restoring from checkpoint...') start_time = time.time() @@ -105,19 +113,16 @@ def restore(self, checkpoint_path, restore_keys=None): # Restore from latest checkpoint. checkpoint = self.get_checkpoint(model) - latest_checkpoint = train_util.get_latest_chekpoint(checkpoint_path) - if latest_checkpoint is not None: - # checkpoint.restore must be within a strategy.scope() so that optimizer - # slot variables are mirrored. - with self.strategy.scope(): - if restore_keys is None: - checkpoint.restore(latest_checkpoint) - else: - checkpoint.restore(latest_checkpoint).expect_partial() - logging.info('Loaded checkpoint %s', latest_checkpoint) - logging.info('Loading model took %.1f seconds', time.time() - start_time) - else: - logging.info('No checkpoint, skipping.') + latest_checkpoint = train_util.get_latest_checkpoint(checkpoint_path) + # checkpoint.restore must be within a strategy.scope() so that optimizer + # slot variables are mirrored. + with self.strategy.scope(): + if restore_keys is None: + checkpoint.restore(latest_checkpoint) + else: + checkpoint.restore(latest_checkpoint).expect_partial() + logging.info('Loaded checkpoint %s', latest_checkpoint) + logging.info('Loading model took %.1f seconds', time.time() - start_time) @property def step(self): diff --git a/ddsp/version.py b/ddsp/version.py index 83f02bc2..0741e89a 100644 --- a/ddsp/version.py +++ b/ddsp/version.py @@ -19,4 +19,4 @@ pulling in all the dependencies in __init__.py. """ -__version__ = '1.2.0' +__version__ = '1.3.0' diff --git a/setup.py b/setup.py index ab5bb467..61b3b195 100644 --- a/setup.py +++ b/setup.py @@ -53,7 +53,10 @@ 'scipy', 'six', 'tensorflow', + # TODO(jesseengel): Update to v2.5 when CREPE is no longer blocking. + # 'tensorflow==2.5.0rc0', 'tensorflow-addons', + 'tensorflowjs', 'tensorflow-probability', # TODO(adarob): Switch to tensorflow_datasets once includes nsynth 2.3. 'tfds-nightly', @@ -70,6 +73,7 @@ }, entry_points={ 'console_scripts': [ + 'ddsp_export = ddsp.training.ddsp_export:console_entry_point', 'ddsp_run = ddsp.training.ddsp_run:console_entry_point', 'ddsp_prepare_tfrecord = ddsp.training.data_preparation.ddsp_prepare_tfrecord:console_entry_point', 'ddsp_generate_synthetic_dataset = ddsp.training.data_preparation.ddsp_generate_synthetic_dataset:console_entry_point',