diff --git a/ddsp/spectral_ops.py b/ddsp/spectral_ops.py index 14ce1110..4699effa 100644 --- a/ddsp/spectral_ops.py +++ b/ddsp/spectral_ops.py @@ -22,6 +22,7 @@ import librosa import numpy as np import tensorflow.compat.v2 as tf +import tensorflow_probability as tfp CREPE_SAMPLE_RATE = 16000 _CREPE_FRAME_SIZE = 1024 @@ -417,12 +418,12 @@ class PretrainedCREPE(tf.keras.Model): """A wrapper around a pretrained CREPE model, for pitch prediction. Enables predicting pitch and confidence entirely in TF for running in batch - on accelerators. Constructor requires path to a SavedModel of the base CREPE - models. Available on GCS at gs://crepe-models/saved_models/[full,large,small]. + on accelerators. For [full,large,small,tiny] crepe models, reads h5 models + from installed pip package. Other saved models """ def __init__(self, - saved_model_path, + model_size_or_path, hop_size=160, **kwargs): super().__init__(**kwargs) @@ -430,21 +431,27 @@ def __init__(self, self.frame_size = 1024 self.sample_rate = 16000 # Load the crepe model. - self.saved_model_path = saved_model_path - self.core_model = tf.keras.models.load_model(self.saved_model_path) + if model_size_or_path in ['full', 'large', 'small', 'tiny']: + self.core_model = crepe.core.build_and_load_model(model_size_or_path) + else: + self.core_model = tf.keras.models.load_model(model_size_or_path) + + self.model_size_or_path = model_size_or_path @classmethod - def activations_to_f0_and_confidence(cls, activations): + def activations_to_f0_and_confidence(cls, activations, centers=None): """Convert network outputs (activations) to f0 predictions.""" cent_mapping = tf.cast( tf.linspace(0, 7180, 360) + 1997.3794084376191, tf.float32) # The confidence of voicing activity and the argmax bin. confidence = tf.reduce_max(activations, axis=-1, keepdims=True) - center = tf.cast(tf.math.argmax(activations, axis=-1), tf.int32) + if centers is None: + centers = tf.math.argmax(activations, axis=-1) + centers = tf.cast(centers, tf.int32) # Slice the local neighborhood around the argmax bin. - start = center - 4 + start = centers - 4 idx_list = tf.range(0, 10) idx_list = start[:, None] + idx_list[None, :] @@ -478,16 +485,65 @@ def normalize_frames(self, frames): frames /= std[:, None] return frames - def predict_f0_and_confidence(self, audio): + def predict_f0_and_confidence(self, audio, viterbi=False): audio = audio[None, :] if len(audio.shape) == 1 else audio batch_size = audio.shape[0] frames = self.batch_frames(audio) frames = self.normalize_frames(frames) acts = self.core_model(frames, training=False) - f0_hz, confidence = self.activations_to_f0_and_confidence(acts) + + if viterbi: + acts_viterbi = tf.reshape(acts, [batch_size, -1, 360]) + centers = self.viterbi_decode(acts_viterbi) + centers = tf.reshape(centers, [-1]) + else: + centers = None + + f0_hz, confidence = self.activations_to_f0_and_confidence(acts, centers) f0_hz = tf.reshape(f0_hz, [batch_size, -1]) confidence = tf.reshape(confidence, [batch_size, -1]) return f0_hz, confidence + def create_hmm(self, num_steps): + """Same as the original CREPE viterbi decdoding, but in TF.""" + # Initial distribution is uniform. + initial_distribution = tfp.distributions.Categorical( + probs=tf.ones([360]) / 360) + + # Transition probabilities inducing continuous pitch. + bins = tf.range(360, dtype=tf.float32) + xx, yy = tf.meshgrid(bins, bins) + min_transition = 1e-5 # For training stabiity. + transition = tf.maximum(12 - abs(xx - yy), min_transition) + transition = transition / tf.reduce_sum(transition, axis=1)[:, None] + transition = tf.cast(transition, tf.float32) + transition_distribution = tfp.distributions.Categorical( + probs=transition) + + # Emission probability = fixed probability for self, evenly distribute the + # others. + self_emission = 0.1 + emission = ( + tf.eye(360) * self_emission + tf.ones(shape=(360, 360)) * + ((1 - self_emission) / 360.) + ) + emission = tf.cast(emission, tf.float32)[None, ...] + observation_distribution = tfp.distributions.Multinomial( + total_count=1, probs=emission) + + return tfp.distributions.HiddenMarkovModel( + initial_distribution=initial_distribution, + transition_distribution=transition_distribution, + observation_distribution=observation_distribution, + num_steps=num_steps, + ) + + def viterbi_decode(self, acts): + """Adapted from original CREPE viterbi decdoding, but in TF.""" + num_steps = acts.shape[1] + hmm = self.create_hmm(num_steps) + centers = hmm.posterior_mode(acts) + return centers + diff --git a/ddsp/version.py b/ddsp/version.py index abf78027..072023d7 100644 --- a/ddsp/version.py +++ b/ddsp/version.py @@ -19,4 +19,4 @@ pulling in all the dependencies in __init__.py. """ -__version__ = '1.8.0' +__version__ = '1.9.0'