Skip to content

Commit

Permalink
* Scripts for converting Checkpoints to SavedModels, and SavedModels …
Browse files Browse the repository at this point in the history
…to TFJS and TFLite Models.

* Revert Inference models to inheriting from ddsp.training models, rather than a base class (more reliable).
* More robust search for latest checkpoint and latest operative config. Throw error if not found instead of failing silently, put initial train and eval calls in a try statement as the files might not exist yet.
* Helper function for printing shapes of a nested dictionary of tensors.
* Verbose messages on model loading to let you know if you're not actually loading weights.

Not fully functional yet, relies on TF2.5, but that breaks CREPE atm. Should be able to get it working by running:
`pip install -U tensorflow==2.5.0rc0`
after pip installing ddsp

PiperOrigin-RevId: 367523871
  • Loading branch information
jesseengel authored and Magenta Team committed Apr 8, 2021
1 parent 628eef8 commit d912466
Show file tree
Hide file tree
Showing 10 changed files with 444 additions and 135 deletions.
5 changes: 5 additions & 0 deletions ddsp/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,11 @@ def leaf_key(nested_key: Text,
return keys[-1]


def map_shape(x: Dict[Text, tf.Tensor]) -> Dict[Text, Sequence[int]]:
"""Recursively infer tensor shapes for a dictionary of tensors."""
return tf.nest.map_structure(lambda t: list(tf.shape(t).numpy()), x)


def pad_axis(x, padding=(0, 0), axis=0, **pad_kwargs):
"""Pads only one axis of a tensor.
Expand Down
195 changes: 195 additions & 0 deletions ddsp/training/ddsp_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# Copyright 2021 The DDSP Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""Convert checkpoint to SavedModel, and SavedModel to TFJS / TFLite format.
Example Usage (Defaults):
ddsp_export --model_path=/path/to/model
Example Usage (TFJS model):
ddsp_export --model_path=/path/to/model --inference_model=autoencoder \
--tflite=false --tfjs
Example Usage (TFLite model):
ddsp_export --model_path=/path/to/model --inference_model=streaming_f0_pw \
--tflite --tfjs=false
Example Usage (SavedModel Only):
ddsp_export --model_path=/path/to/model --inference_model=[model_type] \
--tflite=false --tfjs=false
"""

import os

from absl import app
from absl import flags

from ddsp.training import inference
from ddsp.training import train_util
import gin
import tensorflow as tf
from tensorflowjs.converters import converter


flags.DEFINE_string('model_path', '',
'Path to checkpoint or SavedModel directory. If no '
'SavedModel is found, will search for latest checkpoint '
'use it to create a SavedModel. Can also provide direct '
'path to desired checkpoint. E.g. `/path/to/ckpt-[iter]`.')
flags.DEFINE_string('save_dir', '',
'Optional directory in which to save converted checkpoint.'
'If none is provided, it will be FLAGS.model_path if it '
'contains a SavedModel, otherwise FLAGS.model_path/export.')

# Specify model class.
flags.DEFINE_enum('inference_model', 'streaming_f0_pw',
['autoencoder',
'streaming_f0_pw',
],
'Specify the ddsp.training.inference model to use for '
'converting a checkpoint to a SavedModel. Names are '
'snake_case versions of class names.')

# Optional flags.
flags.DEFINE_multi_string('gin_param', [],
'Gin parameters for custom inference model kwargs.')
flags.DEFINE_boolean('debug', False, 'DEBUG: Do not save the model')

# Conversion formats.
flags.DEFINE_boolean('tfjs', True,
'Convert SavedModel to TFJS for deploying on the web.')
flags.DEFINE_boolean('tflite', True,
'Convert SavedModel to TFLite for embedded C++ apps.')


FLAGS = flags.FLAGS


def get_inference_model(ckpt):
"""Restore model from checkpoint using global FLAGS.
Use --gin_param for any custom kwargs for model constructors.
Args:
ckpt: Path to the checkpoint.
Returns:
Inference model, built and restored from checkpoint.
"""
# Parse model kwargs from --gin_param.
with gin.unlock_config():
gin.parse_config_files_and_bindings(None, FLAGS.gin_param)

models = {
'autoencoder': inference.AutoencoderInference,
'streaming_f0_pw': inference.StreamingF0PwInference,
}
return models[FLAGS.inference_model](ckpt)


def ckpt_to_saved_model(ckpt, save_dir):
"""Convert Checkpoint to SavedModel."""
print(f'\nConverting to SavedModel:'f'\nInput: {ckpt}\nOutput: {save_dir}\n')
model = get_inference_model(ckpt)
print('Finshed Loading Model!')
if not FLAGS.debug:
model.save_model(save_dir)
print('SavedModel Conversion Success!')


def saved_model_to_tfjs(input_dir, save_dir):
"""Convert SavedModel to TFJS model."""
print(f'\nConverting to TFJS:\nInput:{input_dir}\nOutput:{save_dir}\n')
converter.convert(['--input_format=tf_saved_model',
'--signature_name=serving_default',
'--control_flow_v2=True',
'--skip_op_check',
'--quantize_float16=True',
'--experiments=True',
input_dir,
save_dir])
print('TFJS Conversion Success!')


def saved_model_to_tflite(input_dir, save_dir):
"""Convert SavedModel to TFLite model."""
print(f'\nConverting to TFLite:\nInput:{input_dir}\nOutput:{save_dir}\n')
# Convert the model.
tflite_converter = tf.lite.TFLiteConverter.from_saved_model(input_dir)
tflite_converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # Enable TensorFlow Lite ops.
tf.lite.OpsSet.SELECT_TF_OPS, # Enable extended TensorFlow ops.
]
tflite_model = tflite_converter.convert() # Byte string.
# Save the model.
save_path = os.path.join(save_dir, 'model.tflite')
with tf.io.gfile.GFile(save_path, 'wb') as f:
f.write(tflite_model)
print('TFLite Conversion Success!')


def ensure_exits(dir_path):
"""Make directory if none exists."""
if not tf.io.gfile.exists(dir_path):
tf.io.gfile.makedirs(dir_path)


def main(unused_argv):
model_path = train_util.expand_path(FLAGS.model_path)

# Figure out what type the model path is.
is_saved_model = tf.io.gfile.exists(
os.path.join(model_path, 'saved_model.pb'))
is_ckpt = not tf.io.gfile.isdir(model_path)

# Infer save directory path.
if FLAGS.save_dir:
save_dir = FLAGS.save_dir
else:
if is_saved_model:
# If model_path is a SavedModel, use that directory.
save_dir = model_path
elif is_ckpt:
# If model_path is a checkpoint file, use the directory of the file.
save_dir = os.path.join(os.path.dirname(model_path), 'export')
else:
# If model_path is a checkpoint directory, use child export directory.
save_dir = os.path.join(model_path, 'export')

# Make a new save directory.
save_dir = train_util.expand_path(save_dir)
ensure_exits(save_dir)

# Create SavedModel if none already exists.
if not is_saved_model:
ckpt_to_saved_model(model_path, save_dir)

# Convert SavedModel.
if FLAGS.tfjs:
tfjs_dir = os.path.join(save_dir, 'tfjs')
ensure_exits(tfjs_dir)
saved_model_to_tfjs(save_dir, tfjs_dir)

if FLAGS.tflite:
tflite_dir = os.path.join(save_dir, 'tflite')
ensure_exits(tflite_dir)
saved_model_to_tflite(save_dir, tflite_dir)


def console_entry_point():
"""From pip installed script."""
app.run(main)


if __name__ == '__main__':
console_entry_point()
6 changes: 4 additions & 2 deletions ddsp/training/ddsp_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,13 @@ def parse_gin(restore_dir):
gin.parse_config_file(eval_default)

# Load operative_config if it exists (model has already trained).
operative_config = train_util.get_latest_operative_config(restore_dir)
if tf.io.gfile.exists(operative_config):
try:
operative_config = train_util.get_latest_operative_config(restore_dir)
logging.info('Using operative config: %s', operative_config)
operative_config = cloud.make_file_paths_local(operative_config, GIN_PATH)
gin.parse_config_file(operative_config, skip_unknown=True)
except FileNotFoundError:
logging.info('Operative config not found in %s', restore_dir)

# User gin config and user hyperparameters from flags.
gin_file = cloud.make_file_paths_local(FLAGS.gin_file, GIN_PATH)
Expand Down
6 changes: 5 additions & 1 deletion ddsp/training/eval_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,11 @@ def evaluate_or_sample(data_provider,
dataset_iter = iter(dataset)

# Load model.
model.restore(checkpoint_path)
try:
model.restore(checkpoint_path)
except FileNotFoundError:
logging.warn('No existing checkpoint found in %s, skipping '
'checkpoint loading.', restore_dir)

# Iterate through dataset and make predictions
checkpoint_start_time = time.time()
Expand Down
Loading

0 comments on commit d912466

Please sign in to comment.