Skip to content

Commit

Permalink
added new datatypes to Manifest.ini, addressed formatting, linting, a…
Browse files Browse the repository at this point in the history
…nd mypy errors.
  • Loading branch information
bgenchel-avail committed Jul 16, 2024
1 parent 5cfa7fe commit 76d5c9a
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 22 deletions.
2 changes: 1 addition & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
include *.txt tox.ini *.rst *.md LICENSE
include catalog-info.yaml
include Dockerfile .dockerignore
recursive-include tests *.py *.wav *.npz *.jams *.zip
recursive-include tests *.py *.wav *.npz *.jams *.zip *.mid *.flac *.yaml
recursive-include basic_pitch *.py *.md
recursive-include basic_pitch/saved_models *.index *.pb variables.data* *.mlmodel *.json *.onnx *.tflite *.bin
13 changes: 6 additions & 7 deletions basic_pitch/data/datasets/slakh.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import os
import time

from typing import List, Tuple
from typing import List, Tuple, Any

import apache_beam as beam
import mirdata
Expand All @@ -34,13 +34,13 @@ class SlakhFilterInvalidTracks(beam.DoFn):
def __init__(self, source: str):
self.source = source

def setup(self):
def setup(self) -> None:
import mirdata

self.slakh_remote = mirdata.initialize("slakh", data_home=self.source)
self.filesystem = beam.io.filesystems.FileSystems()

def process(self, element: Tuple[str, str]):
def process(self, element: Tuple[str, str]) -> Any:
import tempfile

import apache_beam as beam
Expand Down Expand Up @@ -100,17 +100,16 @@ def __init__(self, source: str, download: bool) -> None:
self.source = source
self.download = download

def setup(self):
def setup(self) -> None:
import apache_beam as beam
import os
import mirdata

self.slakh_remote = mirdata.initialize("slakh", data_home=self.source)
self.filesystem = beam.io.filesystems.FileSystems() # TODO: replace with fsspec
if self.download:
self.slakh_remote.download()

def process(self, element: List[str]):
def process(self, element: List[str]) -> List[Any]:
import tempfile

import numpy as np
Expand Down Expand Up @@ -188,7 +187,7 @@ def create_input_data() -> List[Tuple[str, str]]:
return [(track_id, track.data_split) for track_id, track in slakh.load_tracks().items()]


def main(known_args, pipeline_args):
def main(known_args: argparse.Namespace, pipeline_args: List[str]) -> None:
time_created = int(time.time())
destination = commandline.resolve_destination(known_args, time_created)
input_data = create_input_data()
Expand Down
23 changes: 9 additions & 14 deletions tests/data/test_slakh.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,14 @@ def test_slakh_to_tf_example(tmpdir: str) -> None:

def test_slakh_invalid_tracks(tmpdir: str) -> None:
split_labels = ["train", "validation", "test"]
input_data = [(TRAIN_PIANO_TRACK_ID, "train"),
(VALID_PIANO_TRACK_ID, "validation"),
(TEST_PIANO_TRACK_ID, "test")]
input_data = [(TRAIN_PIANO_TRACK_ID, "train"), (VALID_PIANO_TRACK_ID, "validation"), (TEST_PIANO_TRACK_ID, "test")]

with TestPipeline() as p:
splits = (
p
| "Create PCollection" >> beam.Create(input_data)
| "Tag it" >> beam.ParDo(
SlakhFilterInvalidTracks(str(RESOURCES_PATH / "data" / "slakh"))).with_outputs(*split_labels)
| "Tag it"
>> beam.ParDo(SlakhFilterInvalidTracks(str(RESOURCES_PATH / "data" / "slakh"))).with_outputs(*split_labels)
)

for split in split_labels:
Expand All @@ -87,15 +85,14 @@ def test_slakh_invalid_tracks(tmpdir: str) -> None:

def test_slakh_invalid_tracks_omitted(tmpdir: str) -> None:
split_labels = ["train", "omitted"]
input_data = [(TRAIN_PIANO_TRACK_ID, "train"),
(OMITTED_PIANO_TRACK_ID, "omitted")]
input_data = [(TRAIN_PIANO_TRACK_ID, "train"), (OMITTED_PIANO_TRACK_ID, "omitted")]

with TestPipeline() as p:
splits = (
p
| "Create PCollection" >> beam.Create(input_data)
| "Tag it" >> beam.ParDo(
SlakhFilterInvalidTracks(str(RESOURCES_PATH / "data" / "slakh"))).with_outputs(*split_labels)
| "Tag it"
>> beam.ParDo(SlakhFilterInvalidTracks(str(RESOURCES_PATH / "data" / "slakh"))).with_outputs(*split_labels)
)

for split in split_labels:
Expand All @@ -114,16 +111,14 @@ def test_slakh_invalid_tracks_omitted(tmpdir: str) -> None:

def test_slakh_invalid_tracks_drums(tmpdir: str) -> None:
split_labels = ["train", "validation", "test"]
input_data = [(TRAIN_DRUMS_TRACK_ID, "train"),
(VALID_DRUMS_TRACK_ID, "validation"),
(TEST_DRUMS_TRACK_ID, "test")]
input_data = [(TRAIN_DRUMS_TRACK_ID, "train"), (VALID_DRUMS_TRACK_ID, "validation"), (TEST_DRUMS_TRACK_ID, "test")]

with TestPipeline() as p:
splits = (
p
| "Create PCollection" >> beam.Create(input_data)
| "Tag it" >> beam.ParDo(
SlakhFilterInvalidTracks(str(RESOURCES_PATH / "data" / "slakh"))).with_outputs(*split_labels)
| "Tag it"
>> beam.ParDo(SlakhFilterInvalidTracks(str(RESOURCES_PATH / "data" / "slakh"))).with_outputs(*split_labels)
)

for split in split_labels:
Expand Down

0 comments on commit 76d5c9a

Please sign in to comment.