Skip to content

Commit

Permalink
Add Viterbi decoding to PretrainedCREPE, allows hmm decoding during t…
Browse files Browse the repository at this point in the history
…raining.

PiperOrigin-RevId: 417506797
  • Loading branch information
jesseengel authored and Magenta Team committed Dec 21, 2021
1 parent 0e38f29 commit 8536a36
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 11 deletions.
76 changes: 66 additions & 10 deletions ddsp/spectral_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import librosa
import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp

CREPE_SAMPLE_RATE = 16000
_CREPE_FRAME_SIZE = 1024
Expand Down Expand Up @@ -417,34 +418,40 @@ class PretrainedCREPE(tf.keras.Model):
"""A wrapper around a pretrained CREPE model, for pitch prediction.
Enables predicting pitch and confidence entirely in TF for running in batch
on accelerators. Constructor requires path to a SavedModel of the base CREPE
models. Available on GCS at gs://crepe-models/saved_models/[full,large,small].
on accelerators. For [full,large,small,tiny] crepe models, reads h5 models
from installed pip package. Other saved models
"""

def __init__(self,
saved_model_path,
model_size_or_path,
hop_size=160,
**kwargs):
super().__init__(**kwargs)
self.hop_size = hop_size
self.frame_size = 1024
self.sample_rate = 16000
# Load the crepe model.
self.saved_model_path = saved_model_path
self.core_model = tf.keras.models.load_model(self.saved_model_path)
if model_size_or_path in ['full', 'large', 'small', 'tiny']:
self.core_model = crepe.core.build_and_load_model(model_size_or_path)
else:
self.core_model = tf.keras.models.load_model(model_size_or_path)

self.model_size_or_path = model_size_or_path

@classmethod
def activations_to_f0_and_confidence(cls, activations):
def activations_to_f0_and_confidence(cls, activations, centers=None):
"""Convert network outputs (activations) to f0 predictions."""
cent_mapping = tf.cast(
tf.linspace(0, 7180, 360) + 1997.3794084376191, tf.float32)

# The confidence of voicing activity and the argmax bin.
confidence = tf.reduce_max(activations, axis=-1, keepdims=True)
center = tf.cast(tf.math.argmax(activations, axis=-1), tf.int32)
if centers is None:
centers = tf.math.argmax(activations, axis=-1)
centers = tf.cast(centers, tf.int32)

# Slice the local neighborhood around the argmax bin.
start = center - 4
start = centers - 4
idx_list = tf.range(0, 10)
idx_list = start[:, None] + idx_list[None, :]

Expand Down Expand Up @@ -478,16 +485,65 @@ def normalize_frames(self, frames):
frames /= std[:, None]
return frames

def predict_f0_and_confidence(self, audio):
def predict_f0_and_confidence(self, audio, viterbi=False):
audio = audio[None, :] if len(audio.shape) == 1 else audio
batch_size = audio.shape[0]

frames = self.batch_frames(audio)
frames = self.normalize_frames(frames)
acts = self.core_model(frames, training=False)
f0_hz, confidence = self.activations_to_f0_and_confidence(acts)

if viterbi:
acts_viterbi = tf.reshape(acts, [batch_size, -1, 360])
centers = self.viterbi_decode(acts_viterbi)
centers = tf.reshape(centers, [-1])
else:
centers = None

f0_hz, confidence = self.activations_to_f0_and_confidence(acts, centers)
f0_hz = tf.reshape(f0_hz, [batch_size, -1])
confidence = tf.reshape(confidence, [batch_size, -1])
return f0_hz, confidence

def create_hmm(self, num_steps):
"""Same as the original CREPE viterbi decdoding, but in TF."""
# Initial distribution is uniform.
initial_distribution = tfp.distributions.Categorical(
probs=tf.ones([360]) / 360)

# Transition probabilities inducing continuous pitch.
bins = tf.range(360, dtype=tf.float32)
xx, yy = tf.meshgrid(bins, bins)
min_transition = 1e-5 # For training stabiity.
transition = tf.maximum(12 - abs(xx - yy), min_transition)
transition = transition / tf.reduce_sum(transition, axis=1)[:, None]
transition = tf.cast(transition, tf.float32)
transition_distribution = tfp.distributions.Categorical(
probs=transition)

# Emission probability = fixed probability for self, evenly distribute the
# others.
self_emission = 0.1
emission = (
tf.eye(360) * self_emission + tf.ones(shape=(360, 360)) *
((1 - self_emission) / 360.)
)
emission = tf.cast(emission, tf.float32)[None, ...]
observation_distribution = tfp.distributions.Multinomial(
total_count=1, probs=emission)

return tfp.distributions.HiddenMarkovModel(
initial_distribution=initial_distribution,
transition_distribution=transition_distribution,
observation_distribution=observation_distribution,
num_steps=num_steps,
)

def viterbi_decode(self, acts):
"""Adapted from original CREPE viterbi decdoding, but in TF."""
num_steps = acts.shape[1]
hmm = self.create_hmm(num_steps)
centers = hmm.posterior_mode(acts)
return centers


2 changes: 1 addition & 1 deletion ddsp/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
pulling in all the dependencies in __init__.py.
"""

__version__ = '1.8.0'
__version__ = '1.9.0'

0 comments on commit 8536a36

Please sign in to comment.