Skip to content

Commit

Permalink
redo medley_db, ikala, and guitarset create_input_data functions to h…
Browse files Browse the repository at this point in the history
…ave more stable dataset division.
  • Loading branch information
bgenchel-avail committed Jul 26, 2024
1 parent 7d78385 commit 1403b5f
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 28 deletions.
14 changes: 8 additions & 6 deletions basic_pitch/data/datasets/guitarset.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,17 +136,19 @@ def create_input_data(
if seed:
random.seed(seed)

def determine_split() -> str:
partition = random.uniform(0, 1)
if partition < validation_bound:
def determine_split(index: int) -> str:
if index < len(track_ids) * validation_bound:
return "train"
if partition < test_bound:
elif index < len(track_ids) * test_bound:
return "validation"
return "test"
else:
return "test"

guitarset = mirdata.initialize("guitarset")
track_ids = guitarset.track_ids
random.shuffle(track_ids)

return [(track_id, determine_split()) for track_id in guitarset.track_ids]
return [(track_id, determine_split(i)) for i, track_id in enumerate(track_ids)]


def main(known_args: argparse.Namespace, pipeline_args: List[str]) -> None:
Expand Down
16 changes: 6 additions & 10 deletions basic_pitch/data/datasets/ikala.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,21 +138,17 @@ def process(self, element: List[str], *args: Tuple[Any, Any], **kwargs: Dict[str
def create_input_data(train_percent: float, seed: Optional[int] = None) -> List[Tuple[str, str]]:
assert train_percent < 1.0, "Don't over allocate the data!"

# Test percent is 1 - train - validation
validation_bound = train_percent

if seed:
random.seed(seed)

def determine_split() -> str:
partition = random.uniform(0, 1)
if partition < validation_bound:
return "train"
return "validation"

ikala = mirdata.initialize("ikala")
track_ids = ikala.track_ids
random.shuffle(track_ids)

def determine_split(index: int) -> str:
return "train" if index < len(track_ids) * train_percent else "validation"

return [(track_id, determine_split()) for track_id in ikala.track_ids]
return [(track_id, determine_split(i)) for i, track_id in enumerate(track_ids)]


def main(known_args: argparse.Namespace, pipeline_args: List[str]) -> None:
Expand Down
17 changes: 6 additions & 11 deletions basic_pitch/data/datasets/medleydb_pitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,22 +136,17 @@ def process(self, element: List[str], *args: Tuple[Any, Any], **kwargs: Dict[str
def create_input_data(train_percent: float, seed: Optional[int] = None) -> List[Tuple[str, str]]:
assert train_percent < 1.0, "Don't over allocate the data!"

# Test percent is 1 - train - validation
validation_bound = train_percent

if seed:
random.seed(seed)

def determine_split() -> str:
partition = random.uniform(0, 1)
if partition < validation_bound:
return "train"
return "validation"

medleydb_pitch = mirdata.initialize("medleydb_pitch")
medleydb_pitch.download()
track_ids = medleydb_pitch.track_ids
random.shuffle(track_ids)

def determine_split(index: int) -> str:
return "train" if index < len(track_ids) * train_percent else "validation"

return [(track_id, determine_split()) for track_id in medleydb_pitch.track_ids]
return [(track_id, determine_split(i)) for i, track_id in enumerate(track_ids)]


def main(known_args: argparse.Namespace, pipeline_args: List[str]) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/data/test_medleydb_pitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_medleydb_pitch_invalid_tracks(tmpdir: str) -> None:
def test_medleydb_create_input_data() -> None:
data = create_input_data(train_percent=0.5)
data.sort(key=lambda el: el[1]) # sort by split
tolerance = 0.05
tolerance = 0.01
for _, group in itertools.groupby(data, lambda el: el[1]):
assert (0.5 - tolerance) * len(data) <= len(list(group)) <= (0.5 + tolerance) * len(data)

Expand Down

0 comments on commit 1403b5f

Please sign in to comment.