Skip to content

Commit

Permalink
Fix bug in power preprocessor that was actually feeding in resampled …
Browse files Browse the repository at this point in the history
…audio instead of power.

* Allow Dict Layers to handle multiple optional kwarg tensors.
* Extract list of inputs from args, kwargs, dict args, and defaults to DictLayer.
* Ensure that number of args is as expected.
* Add new tests for all cases.
* Add tests for Preprocessor to make sure it handles input correctly either when given audio or power (or both) in the input dictionary.

PiperOrigin-RevId: 363471898
  • Loading branch information
jesseengel authored and Magenta Team committed Mar 17, 2021
1 parent 829e490 commit 4c1344c
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 23 deletions.
16 changes: 16 additions & 0 deletions ddsp/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,22 @@ def nested_lookup(nested_key: Text,
return value


def leaf_key(nested_key: Text,
delimiter: Text = '/') -> tf.Tensor:
"""Returns the leaf node key name.
Args:
nested_key: String of the form "key/key/key...".
delimiter: String that splits the nested keys.
Returns:
value: Final leaf node key name.
"""
# Parse the input string.
keys = nested_key.split(delimiter)
return keys[-1]


def pad_axis(x, padding=(0, 0), axis=0, **pad_kwargs):
"""Pads only one axis of a tensor.
Expand Down
84 changes: 71 additions & 13 deletions ddsp/training/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,28 @@ def __init__(self, input_keys=None, output_keys=None, **kwargs):
**kwargs: Other keras layer kwargs such as name.
"""
super().__init__(**kwargs)
input_keys = input_keys or self.get_argument_names('call')
if not input_keys:
input_keys = self.get_argument_names('call')
self.default_input_keys = list(self.get_default_argument_names('call'))
self.default_input_values = list(self.get_default_argument_values('call'))
else:
# Manually specifying input keys overwrites default arguments.
self.default_input_keys = []
self.default_input_values = []
output_keys = output_keys or self.get_return_annotations('call')

self.input_keys = list(input_keys)
self.output_keys = list(output_keys)
self.default_input_keys = self.get_default_argument_names('call')

@property
def all_input_keys(self):
"""Full list of inputs and outputs."""
return self.input_keys + self.default_input_keys

@property
def n_inputs(self):
"""Dynamically computed in case input_keys is changed in subclass init."""
return len(self.all_input_keys)

def __call__(self, *inputs, **kwargs):
"""Wrap the layer's __call__() with dictionary inputs and outputs.
Expand Down Expand Up @@ -121,24 +137,58 @@ def call(self, f0_hz, loudness, power=None) -> ['amps', 'frequencies']:
returns a dictionary it will be returned directly, otherwise the output
tensors will be wrapped in a dictionary {output_key: output_tensor}.
"""
# Merge all dictionaries provided in inputs.
# Construct a list of input tensors equal in length and order to the `call`
# input signature.
# -- Start first with any tensor arguments.
# -- Then lookup tensors from input dictionaries.
# -- Use default values if not found.

# Start by merging all dictionaries of tensors from the input.
input_dict = {}
for v in inputs:
if isinstance(v, dict):
input_dict.update(v)

# If any dicts provided, lookup input tensors from those dicts.
# Otherwise, just use inputs list as input tensors.
if input_dict:
inputs = [core.nested_lookup(key, input_dict) for key in self.input_keys]
# Optionally add for default arguments if key is present in input_dict.
for key in self.default_input_keys:
try:
inputs.append(core.nested_lookup(key, input_dict))
except KeyError:
pass
# And then strip all dictionaries from the input.
inputs = [v for v in inputs if not isinstance(v, dict)]

# Add any tensors from kwargs.
for key in self.all_input_keys:
if key in kwargs:
input_dict[key] = kwargs[key]

# And strip from kwargs.
kwargs = {k: v for k, v in kwargs.items() if k not in self.all_input_keys}

# Look up further inputs from the dictionaries.
for key in self.input_keys:
try:
# If key is present use the input_dict value.
inputs.append(core.nested_lookup(key, input_dict))
except KeyError:
# Skip if not present.
pass

# Add default arguments.
for key, value in zip(self.default_input_keys, self.default_input_values):
try:
# If key is present, use the input_dict value.
inputs.append(core.nested_lookup(key, input_dict))
except KeyError:
# Otherwise use the default value if not supplied as non-dict input.
if len(inputs) < self.n_inputs:
inputs.append(value)

# Run input tensors through the model.
if len(inputs) != self.n_inputs:
raise TypeError(f'{len(inputs)} input tensors extracted from inputs'
'(including default args) but the layer expects '
f'{self.n_inputs} tensors.\n'
f'Input keys: {self.input_keys}\n'
f'Default keys: {self.default_input_keys}\n'
f'Default values: {self.default_input_values}\n'
f'Input dictionaries: {input_dict}\n'
f'Input Tensors (Args, Dicts, and Defaults): {inputs}\n')
outputs = super().__call__(*inputs, **kwargs)

# Return dict if call() returns it.
Expand Down Expand Up @@ -170,6 +220,14 @@ def get_default_argument_names(self, method):
else:
return []

def get_default_argument_values(self, method):
"""Get list of strings for names of default arguments to method."""
spec = inspect.getfullargspec(getattr(self, method))
if spec.defaults:
return spec.defaults
else:
return []

def get_return_annotations(self, method):
"""Get list of strings of return annotations of method."""
spec = inspect.getfullargspec(getattr(self, method))
Expand Down
49 changes: 39 additions & 10 deletions ddsp/training/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,20 @@ def call(self, x1, x2=None) -> ['y1', 'y2', 'y3']:

self.TestLayerDefaults = TestLayerDefaults # pylint: disable=invalid-name

class TestLayerOptionals(nn.DictLayer):
"""Uses args and return annotations, and default arg inputs."""

def call(self, x1, x2=None, x3=None) -> ['y']:
"""Optionally add multiples of each input."""
y = x1
if x2 is not None:
y += 2.0 * x2
if x3 is not None:
y += 3.0 * x3
return y

self.TestLayerOptionals = TestLayerOptionals # pylint: disable=invalid-name

def assert_output_shapes_are_correct(self, dict_layer, outputs):
"""Check that the output is correct for a input."""
self.assertListEqual(list(dict_layer.output_keys), list(outputs.keys()))
Expand Down Expand Up @@ -104,22 +118,24 @@ def test_output_is_correct(self, layer_class):
outputs = test_layer(x1=self.x, x2=self.x)
self.assert_output_shapes_are_correct(test_layer, outputs)

# Merge multiple dict inputs, ignore other args, ignore extra keys.
# Merge multiple dict inputs, ignore extra keys.
outputs = test_layer({'x1': self.x, 'ignore': 0},
{'x2': self.x, 'ignore2': 0},
0, 0, 0)
{'x2': self.x, 'ignore2': 0})
self.assert_output_shapes_are_correct(test_layer, outputs)

# Raises errors for bad inputs.
# Missing key, wrong key name.
with self.assertRaises(KeyError):
test_layer({'asdf': self.x, 'x2': self.x})
# Missing keys.
with self.assertRaises(KeyError):
test_layer({'x': self.x})
# Wrong number of args.
with self.assertRaises(TypeError):
test_layer(self.x)
# Missing key --> Wrong number of args.
with self.assertRaises(TypeError):
test_layer({'x1': self.x})
# Missing key, wrong key name --> Wrong number of args.
with self.assertRaises(TypeError):
test_layer({'asdf': self.x, 'x2': self.x})
# Duplicate input argument.
with self.assertRaises(TypeError):
test_layer(self.x, {'x1': self.x, 'x2': self.x})

def test_input_output_keys_are_correct(self):
"""Ensure input output keys are the same for different class definitions."""
Expand Down Expand Up @@ -152,7 +168,7 @@ def test_renaming_input_output_keys(self, layer_class):
self.assert_output_shapes_are_correct(test_layer, outputs)

# Make sure original input_keys no longer are correct.
with self.assertRaises(KeyError):
with self.assertRaises(TypeError):
test_layer({'x1': self.x, 'x2': self.x})

def assertOutputsEqual(self, outputs_1, outputs_2):
Expand All @@ -174,6 +190,19 @@ def test_default_args_have_correct_outputs(self):
outputs_2 = test_layer({'x1': self.x})
self.assertOutputsEqual(outputs_1, outputs_2)

def test_optional_args_read_in_correct_order(self):
"""Check correct result regardless of whether default args are provided."""
test_layer = self.TestLayerOptionals()
# Dict inputs.
outputs_1 = test_layer({'x1': self.x})['y']
outputs_1_2 = test_layer({'x1': self.x, 'x2': self.x})['y']
outputs_1_3 = test_layer({'x1': self.x, 'x3': self.x})['y']
outputs_1_2_3 = test_layer({'x1': self.x, 'x2': self.x, 'x3': self.x})['y']
self.assertAllClose(outputs_1, self.x)
self.assertAllClose(outputs_1_2, 3.0 * self.x)
self.assertAllClose(outputs_1_3, 4.0 * self.x)
self.assertAllClose(outputs_1_2_3, 6.0 * self.x)


class SplitToDictTest(tf.test.TestCase):

Expand Down
70 changes: 70 additions & 0 deletions ddsp/training/preprocessing_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2021 The DDSP Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Tests for ddsp.training.preprocessing."""

from absl.testing import parameterized
from ddsp.core import resample
from ddsp.spectral_ops import compute_power
from ddsp.training import preprocessing
import tensorflow as tf

tfkl = tf.keras.layers


class F0PowerPreprocessorTest(parameterized.TestCase, tf.test.TestCase):

def setUp(self):
"""Create input dictionary and preprocessor."""
super().setUp()
sr = 16000
frame_rate = 250
frame_size = 256
n_samples = 16000
n_t = 250
# Replicate preprocessor computations.
audio = 0.5 * tf.sin(tf.range(0, n_samples, dtype=tf.float32))[None, :]
power_db = compute_power(audio,
sample_rate=sr,
frame_rate=frame_rate,
frame_size=frame_size)
power_db = preprocessing.at_least_3d(power_db)
power_db = resample(power_db, n_t)
self.input_dict = {
'f0_hz': tf.ones([1, n_t]),
'audio': audio,
'power_db': power_db,
}
self.preprocessor = preprocessing.F0PowerPreprocessor(
time_steps=n_t,
frame_rate=frame_rate,
sample_rate=sr)

@parameterized.named_parameters(
('audio_only', ['audio']),
('power_only', ['power_db']),
('audio_and_power', ['audio', 'power_db']),
)
def test_audio_only(self, input_keys):
input_keys += ['f0_hz']
inputs = {k: v for k, v in self.input_dict.items() if k in input_keys}
outputs = self.preprocessor(inputs)
self.assertAllClose(self.input_dict['power_db'],
outputs['pw_db'],
rtol=0.5,
atol=30)

if __name__ == '__main__':
tf.test.main()

0 comments on commit 4c1344c

Please sign in to comment.