diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..d895d32 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,8 @@ +.git +.tox +**/.mypy_cache +**/.pytest_cache +**/__pycache__ +**/*.wav +**/saved_models +**/basic-pitch/saved_models diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..815864e --- /dev/null +++ b/Dockerfile @@ -0,0 +1,19 @@ +FROM apache/beam_python3.10_sdk:2.51.0 + +RUN --mount=type=cache,target=/var/cache/apt \ + apt-get update \ + && apt-get install --no-install-recommends -y --fix-missing \ + sox \ + libsndfile1 \ + libsox-fmt-all \ + ffmpeg \ + libhdf5-dev \ + && rm -rf /var/lib/apt/lists/* + +COPY . /basic-pitch +WORKDIR basic-pitch +RUN --mount=type=cache,target=/root/.cache \ + pip3 install --upgrade pip && \ + pip3 install --upgrade setuptools wheel && \ + pip3 install -e '.[train]' + diff --git a/MANIFEST.in b/MANIFEST.in index f6503d9..500d4e6 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,5 +1,6 @@ include *.txt tox.ini *.rst *.md LICENSE include catalog-info.yaml +include Dockerfile .dockerignore recursive-include tests *.py *.wav *.npz *.jams *.zip recursive-include basic_pitch *.py recursive-include basic_pitch/saved_models *.index *.pb variables.data* *.mlmodel *.json *.onnx *.tflite *.bin diff --git a/basic_pitch/data/commandline.py b/basic_pitch/data/commandline.py index 233ac34..1dd9103 100644 --- a/basic_pitch/data/commandline.py +++ b/basic_pitch/data/commandline.py @@ -23,23 +23,23 @@ def add_default(parser: argparse.ArgumentParser, dataset_name: str = "") -> None: - default_source = Path.home() / "mir_datasets" / dataset_name - default_destination = Path.home() / "data" / "basic_pitch" / dataset_name + default_source = str(Path.home() / "mir_datasets" / dataset_name) + default_destination = str(Path.home() / "data" / "basic_pitch" / dataset_name) parser.add_argument( "--source", default=default_source, - type=Path, + type=str, help=f"Source directory for mir data. Defaults to {default_source}", ) parser.add_argument( "--destination", default=default_destination, - type=Path, + type=str, help=f"Output directory to write results to. Defaults to {default_destination}", ) parser.add_argument( "--runner", - choices=["DataflowRunner", "DirectRunner"], + choices=["DataflowRunner", "DirectRunner", "PortableRunner"], default="DirectRunner", help="Whether to run the download and process locally or on GCP Dataflow", ) @@ -51,11 +51,12 @@ def add_default(parser: argparse.ArgumentParser, dataset_name: str = "") -> None ) parser.add_argument("--batch-size", default=5, type=int, help="Number of examples per tfrecord") parser.add_argument( - "--worker-harness-container-image", + "--sdk_container_image", default="", help="Container image to run dataset generation job with. \ Required due to non-python dependencies.", ) + parser.add_argument("--job_endpoint", default="embed", help="") def resolve_destination(namespace: argparse.Namespace, time_created: int) -> str: diff --git a/basic_pitch/data/datasets/__init__.py b/basic_pitch/data/datasets/__init__.py index 7b57781..e69de29 100644 --- a/basic_pitch/data/datasets/__init__.py +++ b/basic_pitch/data/datasets/__init__.py @@ -1 +0,0 @@ -DOWNLOAD = True diff --git a/basic_pitch/data/datasets/guitarset.py b/basic_pitch/data/datasets/guitarset.py index 6169771..a4142be 100644 --- a/basic_pitch/data/datasets/guitarset.py +++ b/basic_pitch/data/datasets/guitarset.py @@ -27,7 +27,6 @@ import mirdata from basic_pitch.data import commandline, pipeline -from basic_pitch.data.datasets import DOWNLOAD class GuitarSetInvalidTracks(beam.DoFn): @@ -163,12 +162,15 @@ def main(known_args: argparse.Namespace, pipeline_args: List[str]) -> None: "disk_size_gb": 128, "experiments": ["use_runner_v2"], "save_main_session": True, - "worker_harness_container_image": known_args.worker_harness_container_image, + "sdk_container_image": known_args.sdk_container_image, + "job_endpoint": known_args.job_endpoint, + "environment_type": "DOCKER", + "environment_config": known_args.sdk_container_image, } pipeline.run( pipeline_options, input_data, - GuitarSetToTfExample(known_args.source, DOWNLOAD), + GuitarSetToTfExample(known_args.source, download=True), GuitarSetInvalidTracks(), destination, known_args.batch_size, @@ -180,5 +182,6 @@ def main(known_args: argparse.Namespace, pipeline_args: List[str]) -> None: commandline.add_default(parser, os.path.basename(os.path.splitext(__file__)[0])) commandline.add_split(parser) known_args, pipeline_args = parser.parse_known_args() + print(pipeline_args) main(known_args, pipeline_args) diff --git a/basic_pitch/data/download.py b/basic_pitch/data/download.py index 9066d8c..9018685 100644 --- a/basic_pitch/data/download.py +++ b/basic_pitch/data/download.py @@ -16,17 +16,12 @@ import argparse import logging -import sys from basic_pitch.data import commandline from basic_pitch.data.datasets.guitarset import main as guitarset_main logger = logging.getLogger() logger.setLevel(logging.INFO) -handler = logging.StreamHandler(sys.stdout) -formatter = logging.Formatter("%(levelname)s:: %(message)s") -handler.setFormatter(formatter) -logger.addHandler(handler) DATASET_DICT = { "guitarset": guitarset_main, @@ -49,7 +44,8 @@ def main() -> None: commandline.add_split(cl_parser) known_args, pipeline_args = cl_parser.parse_known_args(remaining_args) for arg in vars(known_args): - logger.info(f"{arg} = {getattr(known_args, arg)}") + logger.info(f"known_args:: {arg} = {getattr(known_args, arg)}") + logger.info(f"pipeline_args = {pipeline_args}") DATASET_DICT[dataset](known_args, pipeline_args) diff --git a/basic_pitch/data/pipeline.py b/basic_pitch/data/pipeline.py index 8520075..11385e6 100644 --- a/basic_pitch/data/pipeline.py +++ b/basic_pitch/data/pipeline.py @@ -68,13 +68,10 @@ def transcription_dataset_writer( "validation", ) ) - for split in ["train", "test", "validation"]: ( getattr(valid_track_ids, split) - | f"Combine {split} into giant list" >> beam.transforms.combiners.ToList() - # | f"Batch {split}" >> beam.ParDo(Batch(batch_size)) - | f"Batch {split}" >> beam.BatchElements(max_batch_size=batch_size) + | f"Batch {split}" >> beam.BatchElements(min_batch_size=batch_size, max_batch_size=batch_size) | f"Reshuffle {split}" >> beam.Reshuffle() # To prevent fuses | f"Create tf.Example {split} batch" >> beam.ParDo(to_tf_example) | f"Write {split} batch to tfrecord" >> beam.ParDo(WriteBatchToTfRecord(os.path.join(destination, split))) @@ -95,5 +92,6 @@ def run( destination: str, batch_size: int, ) -> None: - with beam.Pipeline(options=PipelineOptions(**pipeline_options)) as p: + logging.info(f"pipeline_options = {pipeline_options}") + with beam.Pipeline(options=PipelineOptions.from_dictionary(pipeline_options)) as p: transcription_dataset_writer(p, input_data, to_tf_example, filter_invalid_tracks, destination, batch_size) diff --git a/tests/data/test_guitarset.py b/tests/data/test_guitarset.py index 9212cb5..2ff8ebd 100644 --- a/tests/data/test_guitarset.py +++ b/tests/data/test_guitarset.py @@ -34,14 +34,13 @@ def test_guitar_set_to_tf_example(tmpdir: str) -> None: - DOWNLOAD = False input_data: List[str] = [TRACK_ID] with TestPipeline() as p: ( p | "Create PCollection of track IDs" >> beam.Create([input_data]) | "Create tf.Example" - >> beam.ParDo(GuitarSetToTfExample(str(RESOURCES_PATH / "data" / "guitarset"), DOWNLOAD)) + >> beam.ParDo(GuitarSetToTfExample(str(RESOURCES_PATH / "data" / "guitarset"), download=False)) | "Write to tfrecord" >> beam.ParDo(WriteBatchToTfRecord(tmpdir)) )