Skip to content

Commit

Permalink
convert tf.debugging.assert_none_equal to standard assert for single …
Browse files Browse the repository at this point in the history
…value checks, remove model train callback test bc basically duplicate with more work.
  • Loading branch information
bgenchel committed Aug 15, 2024
1 parent f00dde0 commit e2f68fc
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 34 deletions.
10 changes: 2 additions & 8 deletions basic_pitch/data/tf_example_deserialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,7 @@ def prepare_datasets(

# check that the base dataset returned by ds_function is FINITE
for ds in [ds_train, ds_validation]:
tf.debugging.assert_none_equal(
tf.cast(tf.data.experimental.cardinality(ds), tf.int32),
tf.data.experimental.INFINITE_CARDINALITY,
)
assert tf.cast(tf.data.experimental.cardinality(ds), tf.int32) != tf.data.experimental.INFINITE_CARDINALITY

# training dataset
if training_shuffle_buffer_size > 0:
Expand Down Expand Up @@ -137,10 +134,7 @@ def prepare_visualization_datasets(

# check that the base dataset returned by ds_function is FINITE
for ds in [ds_train, ds_validation]:
tf.debugging.assert_none_equal(
tf.cast(tf.data.experimental.cardinality(ds), tf.int32),
tf.data.experimental.INFINITE_CARDINALITY,
)
assert tf.cast(tf.data.experimental.cardinality(ds), tf.int32) != tf.data.experimental.INFINITE_CARDINALITY

# training dataset
ds_train = ds_train.repeat().batch(batch_size).prefetch(tf.data.AUTOTUNE)
Expand Down
26 changes: 0 additions & 26 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
from basic_pitch.callbacks import VisualizeCallback
from basic_pitch.constants import AUDIO_N_SAMPLES, ANNOTATIONS_N_SEMITONES, ANNOT_N_FRAMES

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"


class MockModel(tf.keras.Model):
def __init__(self) -> None:
Expand Down Expand Up @@ -62,27 +60,3 @@ def test_visualize_callback_on_epoch_end(tmpdir: str) -> None:
vc.model = MockModel()

vc.on_epoch_end(1, {"loss": np.random.random(), "val_loss": np.random.random()})


def test_visualize_callback_on_epoch_end_with_model(tmpdir: str) -> None:
model = MockModel()
model.compile(optimizer="adam", loss="mse")

batch_size = 2 # needs to be at least 2 bc validation_split required

x_train = np.random.random((batch_size, AUDIO_N_SAMPLES, 1))
y_train = {
key: np.random.random((batch_size, ANNOTATIONS_N_SEMITONES, ANNOT_N_FRAMES))
for key in ["onset", "contour", "note"]
}

vc = VisualizeCallback(
train_ds=create_mock_dataset(),
validation_ds=create_mock_dataset(),
tensorboard_dir=str(tmpdir),
original_validation_ds=create_mock_dataset(),
contours=True,
)

history = model.fit(x_train, y_train, epochs=1, validation_split=0.5, callbacks=[vc], verbose=0)
assert history

0 comments on commit e2f68fc

Please sign in to comment.