Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add option to provide audio samples for prediction #153

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 26 additions & 9 deletions basic_pitch/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def window_audio_file(


def get_audio_input(
audio_path: Union[pathlib.Path, str], overlap_len: int, hop_size: int
audio_path_or_array: Union[pathlib.Path, str, np.ndarray], sample_rate: int, overlap_len: int, hop_size: int
) -> Iterable[Tuple[npt.NDArray[np.float32], Dict[str, float], int]]:
"""
Read wave file (as mono), pad appropriately, and return as
Expand All @@ -229,7 +229,17 @@ def get_audio_input(
"""
assert overlap_len % 2 == 0, f"overlap_length must be even, got {overlap_len}"

audio_original, _ = librosa.load(str(audio_path), sr=AUDIO_SAMPLE_RATE, mono=True)
if isinstance(audio_path_or_array, np.ndarray):
audio_original = audio_path_or_array
if sample_rate is None:
raise ValueError("Sample rate must be provided when input is an array of audio samples.")
elif sample_rate != AUDIO_SAMPLE_RATE:
audio_original = librosa.resample(audio_original, orig_sr=sample_rate, target_sr=AUDIO_SAMPLE_RATE)
# convert to mono if necessary
if audio_original.ndim != 1:
audio_original = librosa.to_mono(audio_path_or_array)
else:
audio_original, _ = librosa.load(str(audio_path_or_array), sr=AUDIO_SAMPLE_RATE, mono=True)

original_length = audio_original.shape[0]
audio_original = np.concatenate([np.zeros((int(overlap_len / 2),), dtype=np.float32), audio_original])
Expand Down Expand Up @@ -267,14 +277,16 @@ def unwrap_output(


def run_inference(
audio_path: Union[pathlib.Path, str],
audio_path_or_array: Union[pathlib.Path, str, np.ndarray],
sample_rate: None,
model_or_model_path: Union[Model, pathlib.Path, str],
debug_file: Optional[pathlib.Path] = None,
) -> Dict[str, np.array]:
"""Run the model on the input audio path.

Args:
audio_path: The audio to run inference on.
audio_path_or_array: The audio to run inference on. Can be either the path to an audio file or a numpy array of audio samples.
sample_rate: Sample rate of the audio file. Only used if audio_path_or_array is a np array.
model_or_model_path: A loaded Model or path to a serialized model to load.
debug_file: An optional path to output debug data to. Useful for testing/verification.

Expand All @@ -292,7 +304,7 @@ def run_inference(
hop_size = AUDIO_N_SAMPLES - overlap_len

output: Dict[str, Any] = {"note": [], "onset": [], "contour": []}
for audio_windowed, _, audio_original_length in get_audio_input(audio_path, overlap_len, hop_size):
for audio_windowed, _, audio_original_length in get_audio_input(audio_path_or_array, sample_rate, overlap_len, hop_size):
for k, v in model.predict(audio_windowed).items():
output[k].append(v)

Expand Down Expand Up @@ -415,7 +427,8 @@ def save_note_events(


def predict(
audio_path: Union[pathlib.Path, str],
audio_path_or_array: Union[pathlib.Path, str, np.ndarray],
sample_rate: int = None,
model_or_model_path: Union[Model, pathlib.Path, str] = ICASSP_2022_MODEL_PATH,
onset_threshold: float = 0.5,
frame_threshold: float = 0.3,
Expand All @@ -434,7 +447,8 @@ def predict(
"""Run a single prediction.

Args:
audio_path: File path for the audio to run inference on.
audio_path_or_array: File path for the audio to run inference on or array of audio samples.
sample_rate: Sample rate of the audio file. Only used if audio_path_or_array is a np array.
model_or_model_path: A loaded Model or path to a serialized model to load.
onset_threshold: Minimum energy required for an onset to be considered present.
frame_threshold: Minimum energy requirement for a frame to be considered present.
Expand All @@ -449,9 +463,12 @@ def predict(
"""

with no_tf_warnings():
print(f"Predicting MIDI for {audio_path}...")
if isinstance(audio_path_or_array, np.ndarray):
print("Predicting MIDI ...")
else:
print(f"Predicting MIDI for {audio_path_or_array}...")

model_output = run_inference(audio_path, model_or_model_path, debug_file)
model_output = run_inference(audio_path_or_array, sample_rate, model_or_model_path, debug_file)
min_note_len = int(np.round(minimum_note_length / 1000 * (AUDIO_SAMPLE_RATE / FFT_HOP)))
midi_data, note_events = infer.model_output_to_notes(
model_output,
Expand Down