Skip to content

Commit

Permalink
Merge pull request #124 from spotify/bgenchel/data-infra-for-training
Browse files Browse the repository at this point in the history
Data Infra for Training
  • Loading branch information
drubinstein authored Jun 14, 2024
2 parents d28cf49 + a6634f3 commit 1127978
Show file tree
Hide file tree
Showing 24 changed files with 80,369 additions and 19 deletions.
8 changes: 8 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
.git
.tox
**/.mypy_cache
**/.pytest_cache
**/__pycache__
**/*.wav
**/saved_models
**/basic-pitch/saved_models
12 changes: 6 additions & 6 deletions .github/workflows/tox.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ jobs:
python-version: ${{ matrix.py }}
- uses: actions/checkout@v3
- name: Install soundlibs Ubuntu
run: sudo apt-get update && sudo apt-get install --no-install-recommends -y --fix-missing pkg-config libsndfile1
if: matrix.os == 'Ubuntu'
run: sudo apt-get update && sudo apt-get install --no-install-recommends -y --fix-missing pkg-config libsndfile1 sox
if: matrix.os == 'ubuntu-latest'
- name: Install soundlibs MacOs
run: brew install libsndfile llvm libomp
if: matrix.os == 'MacOs'
run: brew install libsndfile llvm libomp sox
if: matrix.os == 'macos-latest-xlarge'
- name: Install soundlibs Windows
run: choco install libsndfile
if: matrix.os == 'Windows'
run: choco install libsndfile sox.portable
if: matrix.os == 'windows-latest'
- name: Upgrade pip
run: python -m pip install -U pip
- name: Install tox
Expand Down
3 changes: 2 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ We recommend first installing the following non-python dependencies:
- To install on Windows, run `choco install libsndfile` using [Chocolatey](https://chocolatey.org/)
- To install on Ubuntu, run `sudo apt-get update && sudo apt-get install --no-install-recommends -y --fix-missing pkg-config libsndfile1`
- [ffmpeg](https://ffmpeg.org/) is a complete, cross-platform solution to record, convert and stream audio in all `basic-pitch` supported formats
- [sox](https://sourceforge.net/projects/sox/) is a general purpose sound processing utility library used to process and transform training data used for training the `basic-pitch` model.

To compile a debug build of `basic-pitch` that allows using a debugger (like gdb or lldb), use the following command to build the package locally and install a symbolic link for debugging:
```shell
Expand Down Expand Up @@ -87,4 +88,4 @@ terms of the [LICENSE](https://github.com/spotify/basic-pitch/blob/main/LICENSE)

# Code of Conduct

Read our [Code of Conduct](CODE_OF_CONDUCT.md) for the project.
Read our [Code of Conduct](CODE_OF_CONDUCT.md) for the project.
19 changes: 19 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
FROM apache/beam_python3.10_sdk:2.51.0

RUN --mount=type=cache,target=/var/cache/apt \
apt-get update \
&& apt-get install --no-install-recommends -y --fix-missing \
sox \
libsndfile1 \
libsox-fmt-all \
ffmpeg \
libhdf5-dev \
&& rm -rf /var/lib/apt/lists/*

COPY . /basic-pitch
WORKDIR basic-pitch
RUN --mount=type=cache,target=/root/.cache \
pip3 install --upgrade pip && \
pip3 install --upgrade setuptools wheel && \
pip3 install -e '.[train]'

5 changes: 3 additions & 2 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
include *.txt tox.ini *.rst *.md LICENSE
include catalog-info.yaml
recursive-include tests *.py *.wav *.npz
recursive-include basic_pitch *.py
include Dockerfile .dockerignore
recursive-include tests *.py *.wav *.npz *.jams *.zip
recursive-include basic_pitch *.py *.md
recursive-include basic_pitch/saved_models *.index *.pb variables.data* *.mlmodel *.json *.onnx *.tflite *.bin
11 changes: 11 additions & 0 deletions basic_pitch/data/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Data / Training
The code and scripts in this section deal with training basic pitch on your own. Scripts in the `datasets` folder allow one to download and process a selection of the datasets used to train the original model. Each of these download scripts has the following keyword arguments:
* **--source**: Source directory to download raw data to. It defaults to `$HOME/mir_datasets/{dataset_name}`
* **--destination**: Directory to write processed data to. It defaults to `$HOME/data/basic_pitch/{dataset_name}`.
* **--runner**: The method used to run the Beam Pipeline for processing the dataset. Options include `DirectRunner`, running directly in the code process running the pipeline, `PortableRunner`, which can be used to run the pipeline in a docker container locally, and `DataflowRunner`, which can be used to run the pipeline in a docker container on Dataflow.
* **--timestamped**: If passed, the dataset will be put into a timestamp directory instead of 'splits'.
* **--batch-size**: Number of examples per tfrecord when partitioning the dataset.
* **--sdk_container_image**: The Docker container image used to process the data if using `PortableRunner` or `DirectRunner` .
* **--job_endpoint**: the endpoint where the job is running. It defaults to `embed` which works for `PortableRunner`.

Additional arguments that work with Beam in general can be used as well, and will be passed along and used by the pipeline. If using `DataflowRunner`, you will be required to pass `--temp_location={Path to GCS Bucket}`, `--staging_location={Path to GCS Bucket}`, `--project={Name of GCS Project}` and `--region={GCS region}`.
Empty file added basic_pitch/data/__init__.py
Empty file.
89 changes: 89 additions & 0 deletions basic_pitch/data/commandline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#!/usr/bin/env python
# encoding: utf-8
#
# Cos.pathyright 2024 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a cos.pathy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os

from pathlib import Path
from typing import Optional


def add_default(parser: argparse.ArgumentParser, dataset_name: str = "") -> None:
default_source = str(Path.home() / "mir_datasets" / dataset_name)
default_destination = str(Path.home() / "data" / "basic_pitch" / dataset_name)
parser.add_argument(
"--source",
default=default_source,
type=str,
help=f"Source directory for mir data. Defaults to {default_source}",
)
parser.add_argument(
"--destination",
default=default_destination,
type=str,
help=f"Output directory to write results to. Defaults to {default_destination}",
)
parser.add_argument(
"--runner",
choices=["DataflowRunner", "DirectRunner", "PortableRunner"],
default="DirectRunner",
help="Whether to run the download and process locally or on GCP Dataflow",
)
parser.add_argument(
"--timestamped",
default=False,
action="store_true",
help="If passed, the dataset will be put into a timestamp directory instead of 'splits'",
)
parser.add_argument("--batch-size", default=5, type=int, help="Number of examples per tfrecord")
parser.add_argument(
"--sdk_container_image",
default="",
help="Container image to run dataset generation job with. \
Required due to non-python dependencies.",
)
parser.add_argument("--job_endpoint", default="embed", help="")


def resolve_destination(namespace: argparse.Namespace, time_created: int) -> str:
return os.path.join(namespace.destination, str(time_created) if namespace.timestamped else "splits")


def add_split(
parser: argparse.ArgumentParser,
train_percent: float = 0.8,
validation_percent: float = 0.1,
split_seed: Optional[int] = None,
) -> None:
parser.add_argument(
"--train-percent",
type=float,
default=train_percent,
help="Percentage of tracks to mark as train",
)
parser.add_argument(
"--validation-percent",
type=float,
default=validation_percent,
help="Percentage of tracks to mark as validation",
)
parser.add_argument(
"--split-seed",
type=int,
default=split_seed,
help="Seed for random number generator used in split generation",
)
Empty file.
187 changes: 187 additions & 0 deletions basic_pitch/data/datasets/guitarset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
#!/usr/bin/env python
# encoding: utf-8
#
# Copyright 2024 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import logging
import os
import random
import time

from typing import Any, List, Dict, Tuple, Optional

import apache_beam as beam
import mirdata

from basic_pitch.data import commandline, pipeline


class GuitarSetInvalidTracks(beam.DoFn):
def process(self, element: Tuple[str, str], *args: Tuple[Any, Any], **kwargs: Dict[str, Any]) -> Any:
track_id, split = element
yield beam.pvalue.TaggedOutput(split, track_id)


class GuitarSetToTfExample(beam.DoFn):
DOWNLOAD_ATTRIBUTES = ["audio_mic_path", "jams_path"]

def __init__(self, source: str, download: bool) -> None:
self.source = source
self.download = download

def setup(self) -> None:
import apache_beam as beam
import mirdata

self.guitarset_remote = mirdata.initialize("guitarset", data_home=self.source)
self.filesystem = beam.io.filesystems.FileSystems() # TODO: replace with fsspec
if self.download:
self.guitarset_remote.download()

def process(self, element: List[str], *args: Tuple[Any, Any], **kwargs: Dict[str, Any]) -> List[Any]:
import tempfile

import mirdata
import numpy as np
import sox

from basic_pitch.constants import (
AUDIO_N_CHANNELS,
AUDIO_SAMPLE_RATE,
FREQ_BINS_CONTOURS,
FREQ_BINS_NOTES,
ANNOTATION_HOP,
N_FREQ_BINS_NOTES,
N_FREQ_BINS_CONTOURS,
)
from basic_pitch.data import tf_example_serialization

logging.info(f"Processing {element}")
batch = []

for track_id in element:
track_remote = self.guitarset_remote.track(track_id)
with tempfile.TemporaryDirectory() as local_tmp_dir:
guitarset_local = mirdata.initialize("guitarset", local_tmp_dir)
track_local = guitarset_local.track(track_id)

for attribute in self.DOWNLOAD_ATTRIBUTES:
source = getattr(track_remote, attribute)
destination = getattr(track_local, attribute)
os.makedirs(os.path.dirname(destination), exist_ok=True)
with self.filesystem.open(source) as s, open(destination, "wb") as d:
d.write(s.read())

local_wav_path = f"{track_local.audio_mic_path}_tmp.wav"

tfm = sox.Transformer()
tfm.rate(AUDIO_SAMPLE_RATE)
tfm.channels(AUDIO_N_CHANNELS)
tfm.build(track_local.audio_mic_path, local_wav_path)

duration = sox.file_info.duration(local_wav_path)
time_scale = np.arange(0, duration + ANNOTATION_HOP, ANNOTATION_HOP)
n_time_frames = len(time_scale)

note_indices, note_values = track_local.notes_all.to_sparse_index(
time_scale, "s", FREQ_BINS_NOTES, "hz"
)
onset_indices, onset_values = track_local.notes_all.to_sparse_index(
time_scale, "s", FREQ_BINS_NOTES, "hz", onsets_only=True
)
contour_indices, contour_values = track_local.multif0.to_sparse_index(
time_scale, "s", FREQ_BINS_CONTOURS, "hz"
)

batch.append(
tf_example_serialization.to_transcription_tfexample(
track_local.track_id,
"guitarset",
local_wav_path,
note_indices,
note_values,
onset_indices,
onset_values,
contour_indices,
contour_values,
(n_time_frames, N_FREQ_BINS_NOTES),
(n_time_frames, N_FREQ_BINS_CONTOURS),
)
)
return [batch]


def create_input_data(
train_percent: float, validation_percent: float, seed: Optional[int] = None
) -> List[Tuple[str, str]]:
assert train_percent + validation_percent < 1.0, "Don't over allocate the data!"

# Test percent is 1 - train - validation
validation_bound = train_percent
test_bound = validation_bound + validation_percent

if seed:
random.seed(seed)

def determine_split() -> str:
partition = random.uniform(0, 1)
if partition < validation_bound:
return "train"
if partition < test_bound:
return "validation"
return "test"

guitarset = mirdata.initialize("guitarset")

return [(track_id, determine_split()) for track_id in guitarset.track_ids]


def main(known_args: argparse.Namespace, pipeline_args: List[str]) -> None:
time_created = int(time.time())
destination = commandline.resolve_destination(known_args, time_created)
input_data = create_input_data(known_args.train_percent, known_args.validation_percent, known_args.split_seed)

pipeline_options = {
"runner": known_args.runner,
"job_name": f"guitarset-tfrecords-{time_created}",
"machine_type": "e2-standard-4",
"num_workers": 25,
"disk_size_gb": 128,
"experiments": ["use_runner_v2"],
"save_main_session": True,
"sdk_container_image": known_args.sdk_container_image,
"job_endpoint": known_args.job_endpoint,
"environment_type": "DOCKER",
"environment_config": known_args.sdk_container_image,
}
pipeline.run(
pipeline_options,
pipeline_args,
input_data,
GuitarSetToTfExample(known_args.source, download=True),
GuitarSetInvalidTracks(),
destination,
known_args.batch_size,
)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
commandline.add_default(parser, os.path.basename(os.path.splitext(__file__)[0]))
commandline.add_split(parser)
known_args, pipeline_args = parser.parse_known_args()

main(known_args, pipeline_args)
Loading

1 comment on commit 1127978

@Bituvo
Copy link

@Bituvo Bituvo commented on 1127978 Jun 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one of the commits ever... JAM

Please sign in to comment.