Skip to content

Commit

Permalink
Switch over arguments and defaults to use PortableRunner in emulation of
Browse files Browse the repository at this point in the history
Dataflow runner, using our Docker image. Added Dockerfile and .dockerfileignore to Manifest.in
  • Loading branch information
bgenchel-avail committed Jun 5, 2024
1 parent 2c3d4fd commit 4225ade
Show file tree
Hide file tree
Showing 9 changed files with 47 additions and 23 deletions.
8 changes: 8 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
.git
.tox
**/.mypy_cache
**/.pytest_cache
**/__pycache__
**/*.wav
**/saved_models
**/basic-pitch/saved_models
19 changes: 19 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
FROM apache/beam_python3.10_sdk:2.51.0

RUN --mount=type=cache,target=/var/cache/apt \
apt-get update \
&& apt-get install --no-install-recommends -y --fix-missing \
sox \
libsndfile1 \
libsox-fmt-all \
ffmpeg \
libhdf5-dev \
&& rm -rf /var/lib/apt/lists/*

COPY . /basic-pitch
WORKDIR basic-pitch
RUN --mount=type=cache,target=/root/.cache \
pip3 install --upgrade pip && \
pip3 install --upgrade setuptools wheel && \
pip3 install -e '.[train]'

1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
include *.txt tox.ini *.rst *.md LICENSE
include catalog-info.yaml
include Dockerfile .dockerignore
recursive-include tests *.py *.wav *.npz *.jams *.zip
recursive-include basic_pitch *.py
recursive-include basic_pitch/saved_models *.index *.pb variables.data* *.mlmodel *.json *.onnx *.tflite *.bin
13 changes: 7 additions & 6 deletions basic_pitch/data/commandline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,23 @@


def add_default(parser: argparse.ArgumentParser, dataset_name: str = "") -> None:
default_source = Path.home() / "mir_datasets" / dataset_name
default_destination = Path.home() / "data" / "basic_pitch" / dataset_name
default_source = str(Path.home() / "mir_datasets" / dataset_name)
default_destination = str(Path.home() / "data" / "basic_pitch" / dataset_name)
parser.add_argument(
"--source",
default=default_source,
type=Path,
type=str,
help=f"Source directory for mir data. Defaults to {default_source}",
)
parser.add_argument(
"--destination",
default=default_destination,
type=Path,
type=str,
help=f"Output directory to write results to. Defaults to {default_destination}",
)
parser.add_argument(
"--runner",
choices=["DataflowRunner", "DirectRunner"],
choices=["DataflowRunner", "DirectRunner", "PortableRunner"],
default="DirectRunner",
help="Whether to run the download and process locally or on GCP Dataflow",
)
Expand All @@ -51,11 +51,12 @@ def add_default(parser: argparse.ArgumentParser, dataset_name: str = "") -> None
)
parser.add_argument("--batch-size", default=5, type=int, help="Number of examples per tfrecord")
parser.add_argument(
"--worker-harness-container-image",
"--sdk_container_image",
default="",
help="Container image to run dataset generation job with. \
Required due to non-python dependencies.",
)
parser.add_argument("--job_endpoint", default="embed", help="")


def resolve_destination(namespace: argparse.Namespace, time_created: int) -> str:
Expand Down
1 change: 0 additions & 1 deletion basic_pitch/data/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
DOWNLOAD = True
9 changes: 6 additions & 3 deletions basic_pitch/data/datasets/guitarset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import mirdata

from basic_pitch.data import commandline, pipeline
from basic_pitch.data.datasets import DOWNLOAD


class GuitarSetInvalidTracks(beam.DoFn):
Expand Down Expand Up @@ -163,12 +162,15 @@ def main(known_args: argparse.Namespace, pipeline_args: List[str]) -> None:
"disk_size_gb": 128,
"experiments": ["use_runner_v2"],
"save_main_session": True,
"worker_harness_container_image": known_args.worker_harness_container_image,
"sdk_container_image": known_args.sdk_container_image,
"job_endpoint": known_args.job_endpoint,
"environment_type": "DOCKER",
"environment_config": known_args.sdk_container_image,
}
pipeline.run(
pipeline_options,
input_data,
GuitarSetToTfExample(known_args.source, DOWNLOAD),
GuitarSetToTfExample(known_args.source, download=True),
GuitarSetInvalidTracks(),
destination,
known_args.batch_size,
Expand All @@ -180,5 +182,6 @@ def main(known_args: argparse.Namespace, pipeline_args: List[str]) -> None:
commandline.add_default(parser, os.path.basename(os.path.splitext(__file__)[0]))
commandline.add_split(parser)
known_args, pipeline_args = parser.parse_known_args()
print(pipeline_args)

main(known_args, pipeline_args)
8 changes: 2 additions & 6 deletions basic_pitch/data/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,12 @@

import argparse
import logging
import sys

from basic_pitch.data import commandline
from basic_pitch.data.datasets.guitarset import main as guitarset_main

logger = logging.getLogger()
logger.setLevel(logging.INFO)
handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter("%(levelname)s:: %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)

DATASET_DICT = {
"guitarset": guitarset_main,
Expand All @@ -49,7 +44,8 @@ def main() -> None:
commandline.add_split(cl_parser)
known_args, pipeline_args = cl_parser.parse_known_args(remaining_args)
for arg in vars(known_args):
logger.info(f"{arg} = {getattr(known_args, arg)}")
logger.info(f"known_args:: {arg} = {getattr(known_args, arg)}")
logger.info(f"pipeline_args = {pipeline_args}")
DATASET_DICT[dataset](known_args, pipeline_args)


Expand Down
8 changes: 3 additions & 5 deletions basic_pitch/data/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,10 @@ def transcription_dataset_writer(
"validation",
)
)

for split in ["train", "test", "validation"]:
(
getattr(valid_track_ids, split)
| f"Combine {split} into giant list" >> beam.transforms.combiners.ToList()
# | f"Batch {split}" >> beam.ParDo(Batch(batch_size))
| f"Batch {split}" >> beam.BatchElements(max_batch_size=batch_size)
| f"Batch {split}" >> beam.BatchElements(min_batch_size=batch_size, max_batch_size=batch_size)
| f"Reshuffle {split}" >> beam.Reshuffle() # To prevent fuses
| f"Create tf.Example {split} batch" >> beam.ParDo(to_tf_example)
| f"Write {split} batch to tfrecord" >> beam.ParDo(WriteBatchToTfRecord(os.path.join(destination, split)))
Expand All @@ -95,5 +92,6 @@ def run(
destination: str,
batch_size: int,
) -> None:
with beam.Pipeline(options=PipelineOptions(**pipeline_options)) as p:
logging.info(f"pipeline_options = {pipeline_options}")
with beam.Pipeline(options=PipelineOptions.from_dictionary(pipeline_options)) as p:
transcription_dataset_writer(p, input_data, to_tf_example, filter_invalid_tracks, destination, batch_size)
3 changes: 1 addition & 2 deletions tests/data/test_guitarset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,13 @@


def test_guitar_set_to_tf_example(tmpdir: str) -> None:
DOWNLOAD = False
input_data: List[str] = [TRACK_ID]
with TestPipeline() as p:
(
p
| "Create PCollection of track IDs" >> beam.Create([input_data])
| "Create tf.Example"
>> beam.ParDo(GuitarSetToTfExample(str(RESOURCES_PATH / "data" / "guitarset"), DOWNLOAD))
>> beam.ParDo(GuitarSetToTfExample(str(RESOURCES_PATH / "data" / "guitarset"), download=False))
| "Write to tfrecord" >> beam.ParDo(WriteBatchToTfRecord(tmpdir))
)

Expand Down

0 comments on commit 4225ade

Please sign in to comment.