File tree 1 file changed +6
-1
lines changed
1 file changed +6
-1
lines changed Original file line number Diff line number Diff line change @@ -68,6 +68,7 @@ def _setup_device(device: torch.device) -> torch.device:
68
68
69
69
Raises:
70
70
RuntimeError: If device index is not available.
71
+ AttributeError: If ``set_device`` is not supported for the device type (e.g. on MPS).
71
72
72
73
Returns:
73
74
device
@@ -86,6 +87,10 @@ def _setup_device(device: torch.device) -> torch.device:
86
87
raise RuntimeError (
87
88
f"The local rank is larger than the number of available { device_name } s."
88
89
)
90
+ if not hasattr (torch_device , "set_device" ):
91
+ raise AttributeError (
92
+ f"The device type { device_type } does not support the `set_device` method."
93
+ )
89
94
torch_device .set_device (device )
90
95
return device
91
96
@@ -166,7 +171,7 @@ def get_device(device: Optional[str] = None) -> torch.device:
166
171
if device is None :
167
172
device = _get_device_type_from_env ()
168
173
device = torch .device (device )
169
- if device .type in ["cuda" , "npu" , "xpu" , "mps" ]:
174
+ if device .type in ["cuda" , "npu" , "xpu" ]:
170
175
device = _setup_device (device )
171
176
_validate_device_from_env (device )
172
177
return device
You can’t perform that action at this time.
0 commit comments