diff --git a/basic_pitch/inference.py b/basic_pitch/inference.py index 2544cd9..25b062c 100644 --- a/basic_pitch/inference.py +++ b/basic_pitch/inference.py @@ -129,7 +129,10 @@ def __init__(self, model_path: Union[pathlib.Path, str]): present.append("ONNX") try: self.model_type = Model.MODEL_TYPES.ONNX - self.model = ort.InferenceSession(str(model_path), providers=["CPUExecutionProvider"]) + providers = ["CPUExecutionProvider"] + if "CUDAExecutionProvider" in ort.get_available_providers(): + providers.insert(0, "CUDAExecutionProvider") + self.model = ort.InferenceSession(str(model_path), providers=providers) return except Exception as e: if str(model_path).endswith(".onnx"):