Skip to content

Commit

Permalink
Merge pull request #397 from lukewys:remove-global-injection-cumsum
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 408452225
  • Loading branch information
Magenta Team committed Nov 8, 2021
2 parents 551b38a + 41a031e commit 401ede8
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 20 deletions.
28 changes: 23 additions & 5 deletions ddsp/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,24 @@ def gradient_reversal(x):


# Unit Conversions -------------------------------------------------------------
def midi_to_hz(notes: Number) -> Number:
"""TF-compatible midi_to_hz function."""
def midi_to_hz(notes: Number, midi_zero_silence: bool = False) -> Number:
"""TF-compatible midi_to_hz function.
Args:
notes: Tensor containing encoded pitch in MIDI scale.
midi_zero_silence: Whether to output 0 hz for midi 0, which would be
convenient when midi 0 represents silence. By defualt (False), midi 0.0
corresponds to 8.18 Hz.
Returns:
hz: Frequency of MIDI in hz, same shape as input.
"""
notes = tf_float32(notes)
return 440.0 * (2.0**((notes - 69.0) / 12.0))
hz = 440.0 * (2.0 ** ((notes - 69.0) / 12.0))
# Map MIDI 0 as 0 hz when MIDI 0 is silence.
if midi_zero_silence:
hz = tf.where(tf.equal(notes, 0.0), 0.0, hz)
return hz


def hz_to_midi(frequencies: Number) -> Number:
Expand Down Expand Up @@ -909,7 +923,8 @@ def harmonic_synthesis(frequencies: tf.Tensor,
harmonic_distribution: Optional[tf.Tensor] = None,
n_samples: int = 64000,
sample_rate: int = 16000,
amp_resample_method: Text = 'window') -> tf.Tensor:
amp_resample_method: Text = 'window',
use_angular_cumsum: bool = False) -> tf.Tensor:
"""Generate audio from frame-wise monophonic harmonic oscillator bank.
Args:
Expand All @@ -926,6 +941,8 @@ def harmonic_synthesis(frequencies: tf.Tensor,
n_samples: Total length of output audio. Interpolates and crops to this.
sample_rate: Sample rate.
amp_resample_method: Mode with which to resample amplitude envelopes.
use_angular_cumsum: Use angular cumulative sum on accumulating phase
instead of tf.cumsum. More accurate for inference.
Returns:
audio: Output audio. Shape [batch_size, n_samples, 1]
Expand Down Expand Up @@ -961,7 +978,8 @@ def harmonic_synthesis(frequencies: tf.Tensor,
# Synthesize from harmonics [batch_size, n_samples].
audio = oscillator_bank(frequency_envelopes,
amplitude_envelopes,
sample_rate=sample_rate)
sample_rate=sample_rate,
use_angular_cumsum=use_angular_cumsum)
return audio


Expand Down
10 changes: 9 additions & 1 deletion ddsp/synths.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(self,
scale_fn=core.exp_sigmoid,
normalize_below_nyquist=True,
amp_resample_method='window',
use_angular_cumsum=False,
name='harmonic'):
"""Constructor.
Expand All @@ -76,6 +77,11 @@ def __init__(self,
Must be in ['nearest', 'linear', 'cubic', 'window']. 'window' uses
overlapping windows (only for upsampling) which is smoother
for amplitude envelopes with large frame sizes.
use_angular_cumsum: Use angular cumulative sum on accumulating phase
instead of tf.cumsum. If synthesized examples are longer than ~100k
audio samples, consider use_angular_cumsum to avoid accumulating
noticible phase errors due to the limited precision of tf.cumsum.
However, using angular cumulative sum is slower on accelerators.
name: Synth name.
"""
super().__init__(name=name)
Expand All @@ -84,6 +90,7 @@ def __init__(self,
self.scale_fn = scale_fn
self.normalize_below_nyquist = normalize_below_nyquist
self.amp_resample_method = amp_resample_method
self.use_angular_cumsum = use_angular_cumsum

def get_controls(self,
amplitudes,
Expand Down Expand Up @@ -145,7 +152,8 @@ def get_signal(self, amplitudes, harmonic_distribution, f0_hz):
harmonic_distribution=harmonic_distribution,
n_samples=self.n_samples,
sample_rate=self.sample_rate,
amp_resample_method=self.amp_resample_method)
amp_resample_method=self.amp_resample_method,
use_angular_cumsum=self.use_angular_cumsum)
return signal


Expand Down
35 changes: 31 additions & 4 deletions ddsp/training/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,20 +53,25 @@ def get_dataset(self, shuffle):
"""A method that returns a tf.data.Dataset."""
raise NotImplementedError

def get_batch(self, batch_size, shuffle=True, repeats=-1):
def get_batch(self,
batch_size,
shuffle=True,
repeats=-1,
drop_remainder=True):
"""Read dataset.
Args:
batch_size: Size of batch.
shuffle: Whether to shuffle the examples.
repeats: Number of times to repeat dataset. -1 for endless repeats.
drop_remainder: Whether the last batch should be dropped.
Returns:
A batched tf.data.Dataset.
"""
dataset = self.get_dataset(shuffle)
dataset = dataset.repeat(repeats)
dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
dataset = dataset.prefetch(buffer_size=_AUTOTUNE)
return dataset

Expand Down Expand Up @@ -306,13 +311,18 @@ def get_dataset(self, shuffle=True):
datasets = tuple(dp.get_dataset(shuffle) for dp in self._data_providers)
return tf.data.Dataset.zip(datasets)

def get_batch(self, batch_size, shuffle=True, repeats=-1):
def get_batch(self,
batch_size,
shuffle=True,
repeats=-1,
drop_remainder=False):
"""Read dataset.
Args:
batch_size: Size of batches, can be a list to have varying batch_sizes.
shuffle: Whether to shuffle the examples.
repeats: Number of times to repeat dataset. -1 for endless repeats.
drop_remainder: Whether the last batch should be dropped.
Returns:
A batched tf.data.Dataset.
Expand All @@ -325,7 +335,7 @@ def get_batch(self, batch_size, shuffle=True, repeats=-1):
# Varying batch sizes (Integer batch shape for each).
batch_sizes = [int(batch_size * bsr) for bsr in self._batch_size_ratios]
datasets = tuple(
dp.get_dataset(shuffle).batch(bs, drop_remainder=True)
dp.get_dataset(shuffle).batch(bs, drop_remainder=drop_remainder)
for bs, dp in zip(batch_sizes, self._data_providers))
dataset = tf.data.Dataset.zip(datasets)
dataset = dataset.repeat(repeats)
Expand Down Expand Up @@ -443,8 +453,15 @@ def features_dict(self):
'note_active_frame_indices':
tf.io.FixedLenFeature([self._feature_length * 128], tf.float32),
'instrument_id': tf.io.FixedLenFeature([], tf.string),
'recording_id': tf.io.FixedLenFeature([], tf.string),
'power_db':
tf.io.FixedLenFeature([self._feature_length], dtype=tf.float32),
'note_onsets':
tf.io.FixedLenFeature([self._feature_length * 128],
dtype=tf.float32),
'note_offsets':
tf.io.FixedLenFeature([self._feature_length * 128],
dtype=tf.float32),
})
return base_features

Expand All @@ -461,6 +478,16 @@ def _reshape_tensors(data):
data['note_active_velocities'] = tf.reshape(
data['note_active_velocities'], (-1, 128))
data['instrument_id'] = inst_vocab.lookup(data['instrument_id'])
data['midi'] = tf.argmax(data['note_active_frame_indices'], axis=-1)
data['f0_hz'] = data['f0_hz'][..., tf.newaxis]
data['loudness_db'] = data['loudness_db'][..., tf.newaxis]
onsets = tf.reduce_sum(
tf.reshape(data['note_onsets'], (-1, 128)), axis=-1)
data['onsets'] = tf.cast(onsets > 0, tf.int64)
offsets = tf.reduce_sum(
tf.reshape(data['note_offsets'], (-1, 128)), axis=-1)
data['offsets'] = tf.cast(offsets > 0, tf.int64)

return data

ds = super().get_dataset(shuffle)
Expand Down
9 changes: 6 additions & 3 deletions ddsp/training/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(self,
self.dense_out = tfkl.Dense(2)
self.norm = nn.Normalize('layer') if norm else None

def call(self, z_pitch, z_vel, z=None) -> ['f0_midi', 'loudness']:
def call(self, z_pitch, z_vel=None, z=None) -> ['f0_midi', 'loudness']:
"""Forward pass for the MIDI decoder.
Args:
Expand Down Expand Up @@ -121,6 +121,7 @@ def __init__(self,
('amplitudes', 1),
('harmonic_distribution', 60),
('magnitudes', 65)),
midi_zero_silence=True,
**kwargs):
"""Constructor."""
self.output_splits = output_splits
Expand All @@ -133,8 +134,9 @@ def __init__(self,
self.f0_residual = f0_residual
self.dense_out = tfkl.Dense(self.n_out)
self.norm = nn.Normalize('layer') if norm else None
self.midi_zero_silence = midi_zero_silence

def call(self, z_pitch, z_vel, z=None):
def call(self, z_pitch, z_vel=None, z=None):
"""Forward pass for the MIDI decoder.
Args:
Expand All @@ -160,7 +162,8 @@ def call(self, z_pitch, z_vel, z=None):
if self.f0_residual:
outputs['f0_midi'] += z_pitch

outputs['f0_hz'] = core.midi_to_hz(outputs['f0_midi'])
outputs['f0_hz'] = core.midi_to_hz(outputs['f0_midi'],
midi_zero_silence=self.midi_zero_silence)
return outputs


Expand Down
76 changes: 69 additions & 7 deletions ddsp/training/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,8 +363,8 @@ def straight_through_int_quantization(x):
def get_note_mask(q_pitch, max_regions=100, note_on_only=True):
"""Get a binary mask for each note from a monophonic instrument.
Each transition of the value creates a new region. Returns the mask of each
region.
Each transition of the q_pitch value creates a new region. Returns the mask of
each region.
Args:
q_pitch: A quantized value, such as pitch or velocity. Shape
[batch, n_timesteps] or [batch, n_timesteps, 1].
Expand Down Expand Up @@ -413,6 +413,57 @@ def get_note_mask(q_pitch, max_regions=100, note_on_only=True):
return note_mask


def get_note_mask_from_onset(q_pitch, onset, max_regions=100,
note_on_only=True):
"""Get a binary mask for each note from a monophonic instrument.
Each onset creates a new region. Returns the mask of each region.
Args:
q_pitch: A quantized value, such as pitch or velocity. Shape
[batch, n_timesteps] or [batch, n_timesteps, 1].
onset: Binary onset in shape [batch, n_timesteps] or
[batch, n_timesteps, 1]. 1 represents onset.
max_regions: Maximum number of note regions to consider in the sequence.
Also, the channel dimension of the output mask. Each value transition
defines a new region, e.g. each note-on and note-off count as a separate
region.
note_on_only: Return a mask that is true only for regions where the pitch
is greater than 0.
Returns:
A binary mask of each region [batch, n_timesteps, max_regions].
"""
# Only batch and time dimensions.
if len(q_pitch.shape) == 3:
q_pitch = q_pitch[:, :, 0]
if len(onset.shape) == 3:
onset = onset[:, :, 0]

edges = onset
# Count endpoints as starts/ends of regions.
edges = edges[:, 1:, ...]
edges = tf.pad(edges,
[[0, 0], [1, 0]], mode='constant', constant_values=True)
edges = tf.cast(edges, tf.int32)

# Count up onset and offsets for each timestep.
# Assumes each onset has a corresponding offset.
# The -1 ensures that the 0th index is the first note.
edge_idx = tf.cumsum(edges, axis=1) - 1

# Create masks of shape [batch, n_timesteps, max_regions].
note_mask = edge_idx[..., None] == tf.range(max_regions)[None, None, :]
note_mask = tf.cast(note_mask, tf.float32)

if note_on_only:
# [batch, time, notes]
note_on = tf.cast(q_pitch > 0.0, tf.float32)[:, :, None]
# [batch, time, notes]
note_mask *= note_on

return note_mask


def get_note_lengths(note_mask):
"""Count the lengths of each note [batch, time, notes] -> [batch, notes]."""
return tf.reduce_sum(note_mask, axis=1)
Expand Down Expand Up @@ -457,20 +508,31 @@ def get_note_moments(x, note_mask, return_std=True):
return x_mean


def pool_over_notes(x, note_mask):
def pool_over_notes(x, note_mask, return_std=True):
"""Return the time-distributed average value of x pooled over the note.
Args:
x: Value to be pooled, [batch, time, dims].
note_mask: Binary mask of notes [batch, time, notes].
return_std: Also return the standard deviation for each note.
Returns:
Values pooled over each note region, [batch, time, dims].
Returns only mean if return_std=False, else mean and std.
"""
x_notes = get_note_moments(x, note_mask, return_std=False) # [b, n, d]
x_time_notes = (x_notes[:, tf.newaxis, ...] *
note_mask[..., tf.newaxis]) # [b, t, n, d]
return tf.reduce_sum(x_time_notes, axis=2) # [b, t, d]
x_notes, x_notes_std = get_note_moments(x, note_mask,
return_std=True) # [b, n, d]
x_time_notes_mean = (x_notes[:, tf.newaxis, ...] *
note_mask[..., tf.newaxis]) # [b, t, n, d]
pooled_mean = tf.reduce_sum(x_time_notes_mean, axis=2) # [b, t, d]

if return_std:
x_time_notes_std = (x_notes_std[:, tf.newaxis, ...] *
note_mask[..., tf.newaxis]) # [b, t, n, d]
pooled_std = tf.reduce_sum(x_time_notes_std, axis=2) # [b, t, d]
return pooled_mean, pooled_std
else:
return pooled_mean


def get_short_note_loss_mask(note_mask, note_lengths,
Expand Down

0 comments on commit 401ede8

Please sign in to comment.