Skip to content

Commit

Permalink
Merge pull request #7 from SCAuFish/main
Browse files Browse the repository at this point in the history
Support Runtime on Apple Silicon
  • Loading branch information
luke-carlson authored Oct 28, 2024
2 parents f9e87fc + ada9e77 commit 32d5b01
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
8 changes: 7 additions & 1 deletion ml_mdm/clis/generate_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@
from ml_mdm.config import get_arguments, get_model, get_pipeline
from ml_mdm.language_models import factory

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device(
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)

# Note that it is called add_arguments, not add_argument.
logging.basicConfig(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ dependencies = [
"imageio[ffmpeg]",
"matplotlib",
"mlx-data",
"numpy",
"numpy<2",
"pytorch-model-summary",
"rotary-embedding-torch",
"simple-parsing==0.1.5",
Expand Down

0 comments on commit 32d5b01

Please sign in to comment.