Skip to content

Commit

Permalink
added an entrypoint for downloading datasets, figured out how to pass…
Browse files Browse the repository at this point in the history
… unknown args / pipeline args to the pipeline along with keyword args, added a README.md in the data folder.
  • Loading branch information
bgenchel-avail committed Jun 5, 2024
1 parent 4225ade commit 1c58725
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 2 deletions.
11 changes: 11 additions & 0 deletions basic_pitch/data/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Data / Training
The code and scripts in this section deal with training basic pitch on your own. Scripts in the `datasets` folder allow one to download and process a selection of the datasets used to train the original model. Each of these download scripts has the following keyword arguments:
* **--source**: Source directory to download raw data to. It defaults to `$HOME/mir_datasets/{dataset_name}`
* **--destination**: Directory to write processed data to. It defaults to `$HOME/data/basic_pitch/{dataset_name}`.
* **--runner**: The method used to run the Beam Pipeline for processing the dataset. Options include `DirectRunner`, running directly in the code process running the pipeline, `PortableRunner`, which can be used to run the pipeline in a docker container locally, and `DataflowRunner`, which can be used to run the pipeline in a docker container on Dataflow.
* **--timestamped**: If passed, the dataset will be put into a timestamp directory instead of 'splits'.
* **--batch-size**: Number of examples per tfrecord when partitioning the dataset.
* **--sdk_container_image**: The Docker container image used to process the data if using `PortableRunner` or `DirectRunner` .
* **--job_endpoint**: the endpoint where the job is running. It defaults to `embed` which works for `PortableRunner`.

Additional arguments that work with Beam in general can be used as well, and will be passed along and used by the pipeline. If using `DataflowRunner`, you will be required to pass `--temp_location={Path to GCS Bucket}`, `--staging_location={Path to GCS Bucket}`, `--project={Name of GCS Project}` and `--region={GCS region}`.
2 changes: 1 addition & 1 deletion basic_pitch/data/datasets/guitarset.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def main(known_args: argparse.Namespace, pipeline_args: List[str]) -> None:
}
pipeline.run(
pipeline_options,
pipeline_args,
input_data,
GuitarSetToTfExample(known_args.source, download=True),
GuitarSetInvalidTracks(),
Expand All @@ -182,6 +183,5 @@ def main(known_args: argparse.Namespace, pipeline_args: List[str]) -> None:
commandline.add_default(parser, os.path.basename(os.path.splitext(__file__)[0]))
commandline.add_split(parser)
known_args, pipeline_args = parser.parse_known_args()
print(pipeline_args)

main(known_args, pipeline_args)
4 changes: 3 additions & 1 deletion basic_pitch/data/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,14 @@ def transcription_dataset_writer(

def run(
pipeline_options: Dict[str, str],
pipeline_args: List[str],
input_data: List[Tuple[str, str]],
to_tf_example: beam.DoFn,
filter_invalid_tracks: beam.DoFn,
destination: str,
batch_size: int,
) -> None:
logging.info(f"pipeline_options = {pipeline_options}")
with beam.Pipeline(options=PipelineOptions.from_dictionary(pipeline_options)) as p:
logging.info(f"pipeline_args = {pipeline_args}")
with beam.Pipeline(options=PipelineOptions(flags=pipeline_args, **pipeline_options)) as p:
transcription_dataset_writer(p, input_data, to_tf_example, filter_invalid_tracks, destination, batch_size)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ namespaces = false

[project.scripts]
basic-pitch = "basic_pitch.predict:main"
bp-download = "basic_pitch.data.download:main"

[project.optional-dependencies]
train = [
Expand Down

0 comments on commit 1c58725

Please sign in to comment.