Skip to content

Commit

Permalink
use cuda with onnx if available
Browse files Browse the repository at this point in the history
  • Loading branch information
nateraw committed Oct 23, 2024
1 parent 9991303 commit b7e349d
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion basic_pitch/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,10 @@ def __init__(self, model_path: Union[pathlib.Path, str]):
present.append("ONNX")
try:
self.model_type = Model.MODEL_TYPES.ONNX
self.model = ort.InferenceSession(str(model_path), providers=["CPUExecutionProvider"])
providers = ["CPUExecutionProvider"]
if "CUDAExecutionProvider" in ort.get_available_providers():
providers.insert(0, "CUDAExecutionProvider")
self.model = ort.InferenceSession(str(model_path), providers=providers)
return
except Exception as e:
if str(model_path).endswith(".onnx"):
Expand Down

0 comments on commit b7e349d

Please sign in to comment.