Skip to content

Commit dab36d2

Browse files
Fix MPS get_device (#2486)
1 parent d5d12fe commit dab36d2

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

torchtune/utils/_device.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def _setup_device(device: torch.device) -> torch.device:
6868
6969
Raises:
7070
RuntimeError: If device index is not available.
71+
AttributeError: If ``set_device`` is not supported for the device type (e.g. on MPS).
7172
7273
Returns:
7374
device
@@ -86,6 +87,10 @@ def _setup_device(device: torch.device) -> torch.device:
8687
raise RuntimeError(
8788
f"The local rank is larger than the number of available {device_name}s."
8889
)
90+
if not hasattr(torch_device, "set_device"):
91+
raise AttributeError(
92+
f"The device type {device_type} does not support the `set_device` method."
93+
)
8994
torch_device.set_device(device)
9095
return device
9196

@@ -166,7 +171,7 @@ def get_device(device: Optional[str] = None) -> torch.device:
166171
if device is None:
167172
device = _get_device_type_from_env()
168173
device = torch.device(device)
169-
if device.type in ["cuda", "npu", "xpu", "mps"]:
174+
if device.type in ["cuda", "npu", "xpu"]:
170175
device = _setup_device(device)
171176
_validate_device_from_env(device)
172177
return device

0 commit comments

Comments
 (0)