Skip to content

Commit

Permalink
add docstrings for methods in train files that were missing them. Add…
Browse files Browse the repository at this point in the history
… a no-sonify argument to main training, and remove seemingly outdated / mismatching original_validation_ds arg from visualization callback and method.
  • Loading branch information
bgenchel committed Aug 15, 2024
1 parent e2f68fc commit 06c4181
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 8 deletions.
17 changes: 14 additions & 3 deletions basic_pitch/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,31 @@

class VisualizeCallback(tf.keras.callbacks.Callback):
# TODO RACHEL make this WAY faster
"""
Callback to run during training to create tensorboard visualizations per epoch.
Attributes:
train_ds: training dataset to use for prediction / visualization / sonification / summarization
valid_ds: validation dataset to use for "" ""
tensorboard_dir: directory to output "" ""
sonify: whether to include sonifications in tensorboard
contours: whether to plot note contours in tensorboard
"""

def __init__(
self,
train_ds: tf.data.Dataset,
validation_ds: tf.data.Dataset,
tensorboard_dir: str,
original_validation_ds: tf.data.Dataset,
sonify: bool,
contours: bool,
):
super().__init__()
self.train_iter = iter(train_ds)
self.validation_iter = iter(validation_ds)
self.validation_ds = original_validation_ds
self.tensorboard_dir = os.path.join(tensorboard_dir, "tensorboard_logs")
self.file_writer = tf.summary.create_file_writer(tensorboard_dir)
self.sonify = sonify
self.contours = contours

def on_epoch_end(self, epoch: int, logs: Dict[Any, Any]) -> None:
Expand All @@ -59,6 +70,6 @@ def on_epoch_end(self, epoch: int, logs: Dict[Any, Any]) -> None:
outputs,
loss,
epoch,
self.validation_ds,
sonify=self.sonify,
contours=self.contours,
)
31 changes: 29 additions & 2 deletions basic_pitch/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,31 @@ def main(
size_evaluation_callback_datasets: int,
datasets_to_use: List[str],
dataset_sampling_frequency: np.ndarray,
no_sonify: bool,
no_contours: bool,
weighted_onset_loss: bool,
positive_onset_weight: float,
) -> None:
"""Parse config and run training or evaluation."""
"""Parse config and run training or evaluation.
Args:
source: source directory for data
output: output directory for trained model / checkpoints / tensorboard
batch_size: batch size for data.
shuffle_size: size of shuffle buffer (only for training set) for the data shuffling mechanism
learning_rate: learning_rate for training
epochs: number of epochs to train for
steps_per_epoch: the number of batches to process per epoch during training
validation_steps: the number of validation batches to evaluate on per epoch
size_evaluation_callback_datasets: the batch size to use for visualization / logging
datasets_to_use: which datasets to train / evaluate on e.g. guitarset, medleydb_pitch, slakh
dataset_sampling_frequency: distribution weighting vector corresponding to datasets determining how they
are sampled from during training / validation dataset creation.
no_sonify: Whether or not to include sonifications in tensorboard.
no_contours: Whether or not to include contours in the output.
weighted_onset_loss: whether or not to use a weighted cross entropy loss.
positive_onset_weight: weighting factor for the positive labels.
"""
# configuration.add_externals()
logging.info(f"source directory: {source}")
logging.info(f"output directory: {output}")
Expand Down Expand Up @@ -115,7 +135,7 @@ def main(
train_visualization_ds,
validation_visualization_ds,
tensorboard_log_dir,
validation_ds.take(validation_steps),
not no_sonify,
not no_contours,
),
]
Expand Down Expand Up @@ -202,6 +222,12 @@ def console_entry_point() -> None:
default=False,
help=f"Use {dataset} dataset in training",
)
parser.add_argument(
"--no-sonify",
action="store_true",
default=False,
help="if given, exclude sonifications from the tensorboard / data visualization",
)
parser.add_argument(
"--no-contours",
action="store_true",
Expand Down Expand Up @@ -251,6 +277,7 @@ def console_entry_point() -> None:
args.size_evaluation_callback_datasets,
datasets_to_use,
dataset_sampling_frequency,
args.dont_sonify,
args.no_contours,
args.weighted_onset_loss,
args.positive_onset_weight,
Expand Down
20 changes: 19 additions & 1 deletion basic_pitch/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@


def get_input_model() -> tf.keras.Model:
"""define a model that generates the CQT (Constant-Q Transform) of input audio"""
inputs = tf.keras.Input(shape=(AUDIO_N_SAMPLES, 1)) # (batch, time, ch)
x = models.get_cqt(inputs, 1, False)
model = tf.keras.Model(inputs=inputs, outputs=x)
Expand Down Expand Up @@ -83,7 +84,8 @@ def visualize_transcription(
outputs: batch of output data (dictionary)
loss: loss value for epoch
step: which epoch this is
sonify: sonify outputs
contours: plot note contours
"""
with file_writer.as_default():
# create audio player
Expand Down Expand Up @@ -169,11 +171,19 @@ def visualize_transcription(
# plot max
if contours:
tf.summary.scalar(f"{stage}/contour-max", np.max(outputs["contour"]), step=step)

tf.summary.scalar(f"{stage}/note-max", np.max(outputs["note"]), step=step)
tf.summary.scalar(f"{stage}/onset-max", np.max(outputs["onset"]), step=step)


def _array_to_sonification(array: tf.Tensor, max_outputs: int, clip: float = 0.3) -> tf.Tensor:
"""sonify time frequency representation of audio
Args:
array: time-frequency representation of audio
max_outputs: the number of grams / batches to process / append to the resulting output
clip: value below which signal is 0'd out.
"""
gram_batch = tf.transpose(array, perm=[0, 2, 1]).numpy()
audio_list = []

Expand All @@ -194,6 +204,14 @@ def _array_to_sonification(array: tf.Tensor, max_outputs: int, clip: float = 0.3


def _audio_input(audio: tf.Tensor) -> tf.Tensor:
"""Gets the Constant-Q transform of audio input using the input model defined above.
Args:
audio: the audio signal to process
Returns:
constant-q transform of the audio (3 bins per semitone, ~11ms hop size.)
"""
audio_in = INPUT_MODEL(audio)
return tf.transpose(audio_in, perm=[0, 2, 1, 3])

Expand Down
3 changes: 1 addition & 2 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# limitations under the License.

import numpy as np
import os
import tensorflow as tf

from typing import Dict
Expand Down Expand Up @@ -53,7 +52,7 @@ def test_visualize_callback_on_epoch_end(tmpdir: str) -> None:
train_ds=create_mock_dataset(),
validation_ds=create_mock_dataset(),
tensorboard_dir=str(tmpdir),
original_validation_ds=create_mock_dataset(),
sonify=True,
contours=True,
)

Expand Down

0 comments on commit 06c4181

Please sign in to comment.