diff --git a/stable_whisper/whisper_word_level/hf_whisper.py b/stable_whisper/whisper_word_level/hf_whisper.py index 6784881..77f768e 100644 --- a/stable_whisper/whisper_word_level/hf_whisper.py +++ b/stable_whisper/whisper_word_level/hf_whisper.py @@ -29,8 +29,8 @@ def get_device(device: str = None) -> str: return device if torch.cuda.is_available(): return 'cuda:0' - if (mps := getattr(torch.backends, 'mps', None)) is not None: - return mps.is_available() + if (mps := getattr(torch.backends, 'mps', None)) is not None and mps.is_available(): + return 'mps' return 'cpu'