From 628eef84f3de5309c4bb33c88fd59778c29e7fd1 Mon Sep 17 00:00:00 2001 From: Rigel Swavely Date: Tue, 6 Apr 2021 17:41:04 -0700 Subject: [PATCH] Add train and eval splitting logic and coarse chunking to training data pipeline. Fixes #270. PiperOrigin-RevId: 367121062 --- .../data_preparation/ddsp_prepare_tfrecord.py | 20 ++- .../data_preparation/prepare_tfrecord_lib.py | 118 ++++++++++++------ .../prepare_tfrecord_lib_test.py | 36 +++++- 3 files changed, 126 insertions(+), 48 deletions(-) diff --git a/ddsp/training/data_preparation/ddsp_prepare_tfrecord.py b/ddsp/training/data_preparation/ddsp_prepare_tfrecord.py index dff388d6..f7a4c0f3 100644 --- a/ddsp/training/data_preparation/ddsp_prepare_tfrecord.py +++ b/ddsp/training/data_preparation/ddsp_prepare_tfrecord.py @@ -32,9 +32,8 @@ FLAGS = flags.FLAGS -flags.DEFINE_list( - 'input_audio_filepatterns', [], - 'List of filepatterns to glob for input audio files.') +flags.DEFINE_list('input_audio_filepatterns', [], + 'List of filepatterns to glob for input audio files.') flags.DEFINE_string( 'output_tfrecord_path', None, 'The prefix path to the output TFRecord. Shard numbers will be added to ' @@ -43,9 +42,8 @@ 'num_shards', None, 'The number of shards to use for the TFRecord. If None, this number will ' 'be determined automatically.') -flags.DEFINE_integer( - 'sample_rate', 16000, - 'The sample rate to use for the audio.') +flags.DEFINE_integer('sample_rate', 16000, + 'The sample rate to use for the audio.') flags.DEFINE_integer( 'frame_rate', 250, 'The frame rate to use for f0 and loudness features. If set to 0, ' @@ -59,6 +57,14 @@ 'sliding_window_hop_secs', 1, 'The hop size in seconds to use when splitting audio into constant-length ' 'examples.') +flags.DEFINE_float( + 'eval_split_fraction', 0.0, + 'Fraction of the dataset to reserve for eval split. If set to 0, no eval ' + 'split is created.' +) +flags.DEFINE_float( + 'coarse_chunk_secs', 20.0, + 'Chunk size in seconds used to split the input audio files.') flags.DEFINE_list( 'pipeline_options', '--runner=DirectRunner', 'A comma-separated list of command line arguments to be used as options ' @@ -78,6 +84,8 @@ def run(): frame_rate=FLAGS.frame_rate, window_secs=FLAGS.example_secs, hop_secs=FLAGS.sliding_window_hop_secs, + eval_split_fraction=FLAGS.eval_split_fraction, + coarse_chunk_secs=FLAGS.coarse_chunk_secs, pipeline_options=FLAGS.pipeline_options) diff --git a/ddsp/training/data_preparation/prepare_tfrecord_lib.py b/ddsp/training/data_preparation/prepare_tfrecord_lib.py index 23266602..da18b8c6 100644 --- a/ddsp/training/data_preparation/prepare_tfrecord_lib.py +++ b/ddsp/training/data_preparation/prepare_tfrecord_lib.py @@ -24,8 +24,7 @@ -def _load_audio_as_array(audio_path: str, - sample_rate: int) -> np.array: +def _load_audio_as_array(audio_path: str, sample_rate: int) -> np.array: """Load audio file at specified sample rate and return an array. When `sample_rate` > original SR of audio file, Pydub may miss samples when @@ -86,19 +85,18 @@ def _add_f0_estimate(ex, sample_rate, frame_rate): return ex -def split_example( - ex, sample_rate, frame_rate, window_secs, hop_secs): +def split_example(ex, sample_rate, frame_rate, window_secs, hop_secs): """Splits example into windows, padding final window if needed.""" def get_windows(sequence, rate): window_size = int(window_secs * rate) hop_size = int(hop_secs * rate) - n_windows = int(np.ceil((len(sequence) - window_size) / hop_size)) + 1 + n_windows = int(np.ceil((len(sequence) - window_size) / hop_size)) + 1 n_samples_padded = (n_windows - 1) * hop_size + window_size n_padding = n_samples_padded - len(sequence) sequence = np.pad(sequence, (0, n_padding), mode='constant') for window_end in range(window_size, len(sequence) + 1, hop_size): - yield sequence[window_end-window_size:window_end] + yield sequence[window_end - window_size:window_end] for audio, loudness_db, f0_hz, f0_confidence in zip( get_windows(ex['audio'], sample_rate), @@ -121,19 +119,35 @@ def float_dict_to_tfexample(float_dict): feature={ k: tf.train.Feature(float_list=tf.train.FloatList(value=v)) for k, v in float_dict.items() - } - )) - - -def prepare_tfrecord( - input_audio_paths, - output_tfrecord_path, - num_shards=None, - sample_rate=16000, - frame_rate=250, - window_secs=4, - hop_secs=1, - pipeline_options=''): + })) + + +def add_key(example): + """Add a key to this example by taking the hash of the values.""" + return hash(example['audio'].tobytes()), example + + +def eval_split_partition_fn(example, num_partitions, eval_fraction, all_ids): + """Partition function to split into train/eval based on the hash ids.""" + del num_partitions + example_id = example[0] + eval_range = int(len(all_ids) * eval_fraction) + for i in range(eval_range): + if all_ids[i] == example_id: + return 0 + return 1 + + +def prepare_tfrecord(input_audio_paths, + output_tfrecord_path, + num_shards=None, + sample_rate=16000, + frame_rate=250, + window_secs=4, + hop_secs=1, + eval_split_fraction=0.0, + coarse_chunk_secs=20.0, + pipeline_options=''): """Prepares a TFRecord for use in training, evaluation, and prediction. Args: @@ -144,12 +158,17 @@ def prepare_tfrecord( num_shards: The number of shards to use for the TFRecord. If None, this number will be determined automatically. sample_rate: The sample rate to use for the audio. - frame_rate: The frame rate to use for f0 and loudness features. - If set to None, these features will not be computed. - window_secs: The size of the sliding window (in seconds) to use to - split the audio and features. If 0, they will not be split. - hop_secs: The number of seconds to hop when computing the sliding - windows. + frame_rate: The frame rate to use for f0 and loudness features. If set to + None, these features will not be computed. + window_secs: The size of the sliding window (in seconds) to use to split the + audio and features. If 0, they will not be split. + hop_secs: The number of seconds to hop when computing the sliding windows. + eval_split_fraction: Fraction of the dataset to reserve for eval split. If + set to 0, no eval split is created. + coarse_chunk_secs: Chunk size in seconds used to split the input audio + files. This is used to split large audio files into manageable chunks + for better parallelization and to enable non-overlapping train/eval + splits. pipeline_options: An iterable of command line arguments to be used as options for the Beam Pipeline. """ @@ -167,16 +186,39 @@ def prepare_tfrecord( | beam.Map(_add_f0_estimate, sample_rate, frame_rate) | beam.Map(add_loudness, sample_rate, frame_rate)) - if window_secs: - examples |= beam.FlatMap( - split_example, sample_rate, frame_rate, window_secs, hop_secs) - - _ = ( - examples - | beam.Reshuffle() - | beam.Map(float_dict_to_tfexample) - | beam.io.tfrecordio.WriteToTFRecord( - output_tfrecord_path, - num_shards=num_shards, - coder=beam.coders.ProtoCoder(tf.train.Example)) - ) + if coarse_chunk_secs: + examples |= beam.FlatMap(split_example, sample_rate, frame_rate, + coarse_chunk_secs, coarse_chunk_secs) + + def postprocess_pipeline(examples, output_path, stage_name=''): + if stage_name: + stage_name = f'_{stage_name}' + + if window_secs: + examples |= f'create_batches{stage_name}' >> beam.FlatMap( + split_example, sample_rate, frame_rate, window_secs, hop_secs) + _ = ( + examples + | f'reshuffle{stage_name}' >> beam.Reshuffle() + | f'make_tfexample{stage_name}' >> beam.Map(float_dict_to_tfexample) + | f'write{stage_name}' >> beam.io.tfrecordio.WriteToTFRecord( + output_path, + num_shards=num_shards, + coder=beam.coders.ProtoCoder(tf.train.Example))) + + if eval_split_fraction: + examples |= beam.Map(add_key) + keys = examples | beam.Keys() + splits = examples | beam.Partition(eval_split_partition_fn, 2, + eval_split_fraction, + beam.pvalue.AsList(keys)) + + # Remove ids. + eval_split = splits[0] | 'remove_id_eval' >> beam.Map(lambda x: x[1]) + train_split = splits[1] | 'remove_id_train' >> beam.Map(lambda x: x[1]) + + postprocess_pipeline(eval_split, f'{output_tfrecord_path}-eval', 'eval') + postprocess_pipeline(train_split, f'{output_tfrecord_path}-train', + 'train') + else: + postprocess_pipeline(examples, output_tfrecord_path) diff --git a/ddsp/training/data_preparation/prepare_tfrecord_lib_test.py b/ddsp/training/data_preparation/prepare_tfrecord_lib_test.py index 77ab6ab7..bbbb1fbf 100644 --- a/ddsp/training/data_preparation/prepare_tfrecord_lib_test.py +++ b/ddsp/training/data_preparation/prepare_tfrecord_lib_test.py @@ -86,7 +86,8 @@ def test_prepare_tfrecord(self, sample_rate): sample_rate=sample_rate, frame_rate=frame_rate, window_secs=window_secs, - hop_secs=hop_secs) + hop_secs=hop_secs, + coarse_chunk_secs=None) expected_f0_and_loudness_length = int(window_secs * frame_rate) self.validate_outputs( @@ -107,7 +108,8 @@ def test_prepare_tfrecord_no_split(self, sample_rate): num_shards=2, sample_rate=sample_rate, frame_rate=frame_rate, - window_secs=None) + window_secs=None, + coarse_chunk_secs=None) expected_f0_and_loudness_length = int(self.wav_secs * frame_rate) self.validate_outputs( @@ -118,6 +120,30 @@ def test_prepare_tfrecord_no_split(self, sample_rate): 'loudness_db': expected_f0_and_loudness_length, }) + @parameterized.named_parameters(('16k', 16000), ('24k', 24000), + ('48k', 48000)) + def test_prepare_tfrecord_chunk(self, sample_rate): + frame_rate = 250 + chunk_secs = 1.5 + prepare_tfrecord_lib.prepare_tfrecord( + [self.wav_path], + os.path.join(self.test_dir, 'output.tfrecord'), + num_shards=2, + sample_rate=sample_rate, + frame_rate=frame_rate, + window_secs=None, + coarse_chunk_secs=chunk_secs) + + expected_f0_and_loudness_length = int(chunk_secs * frame_rate) + + self.validate_outputs( + 2, { + 'audio': int(chunk_secs * sample_rate), + 'f0_hz': expected_f0_and_loudness_length, + 'f0_confidence': expected_f0_and_loudness_length, + 'loudness_db': expected_f0_and_loudness_length, + }) + @parameterized.named_parameters(('16k', 16000), ('24k', 24000), ('48k', 48000)) def test_prepare_tfrecord_no_f0_and_loudness(self, sample_rate): @@ -127,7 +153,8 @@ def test_prepare_tfrecord_no_f0_and_loudness(self, sample_rate): num_shards=2, sample_rate=sample_rate, frame_rate=None, - window_secs=None) + window_secs=None, + coarse_chunk_secs=None) self.validate_outputs( 1, { @@ -147,7 +174,8 @@ def test_prepare_tfrecord_at_indivisible_sample_rate_throws_error( num_shards=2, sample_rate=sample_rate, frame_rate=frame_rate, - window_secs=None) + window_secs=None, + coarse_chunk_secs=None) if __name__ == '__main__':