Skip to content

Commit ce680a2

Browse files
committed
enhance 3d-party devices in mix-precision
1 parent 8ad3e29 commit ce680a2

File tree

16 files changed

+90
-17
lines changed

16 files changed

+90
-17
lines changed

docs/source-pytorch/extensions/accelerator.rst

+37-2
Original file line numberDiff line numberDiff line change
@@ -36,45 +36,79 @@ Let's pretend we want to integrate the fictional XPU accelerator and we have acc
3636

3737
.. code-block:: python
3838
39+
import torch
3940
import xpulib
4041
42+
from functools import lru_cache
43+
from typing import Any, Dict, Union
44+
from lightning.pytorch.accelerators.accelerator import Accelerator
45+
46+
from typing_extensions import override
47+
4148
4249
class XPUAccelerator(Accelerator):
4350
"""Support for a hypothetical XPU, optimized for large-scale machine learning."""
4451
52+
@override
53+
def setup_device(self, device: torch.device) -> None:
54+
"""
55+
Raises:
56+
ValueError:
57+
If the selected device is not of type hypothetical XPU.
58+
"""
59+
if device.type != "xpu":
60+
raise ValueError(f"Device should be of type 'xpu', got '{device.type}' instead.")
61+
if device.index is None:
62+
device = torch.device("xpu", 0)
63+
xpulib.set_device(device.index)
64+
65+
@override
66+
def teardown(self) -> None:
67+
xpulib.empty_cache()
68+
4569
@staticmethod
70+
@override
4671
def parse_devices(devices: Any) -> Any:
4772
# Put parsing logic here how devices can be passed into the Trainer
4873
# via the `devices` argument
4974
return devices
5075
5176
@staticmethod
77+
@override
5278
def get_parallel_devices(devices: Any) -> Any:
5379
# Here, convert the device indices to actual device objects
5480
return [torch.device("xpu", idx) for idx in devices]
5581
5682
@staticmethod
83+
@override
5784
def auto_device_count() -> int:
5885
# Return a value for auto-device selection when `Trainer(devices="auto")`
5986
return xpulib.available_devices()
6087
6188
@staticmethod
89+
@override
6290
def is_available() -> bool:
6391
return xpulib.is_available()
6492
6593
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
6694
# Return optional device statistics for loggers
6795
return {}
6896
97+
@staticmethod
98+
@override
99+
def get_device() -> str:
100+
return "xpu"
101+
69102
70103
Finally, add the XPUAccelerator to the Trainer:
71104

72105
.. code-block:: python
73106
74107
from lightning.pytorch import Trainer
75-
108+
from lightning.pytorch.strategies import DDPStrategy
76109
accelerator = XPUAccelerator()
77-
trainer = Trainer(accelerator=accelerator, devices=2)
110+
strategy = DDPStrategy(parallel_devices=accelerator.get_parallel_devices(2))
111+
trainer = Trainer(accelerator=accelerator, strategy=strategy, devices=2)
78112
79113
80114
:doc:`Learn more about Strategies <../extensions/strategy>` and how they interact with the Accelerator.
@@ -93,6 +127,7 @@ If you wish to switch to a custom accelerator from the CLI without code changes,
93127
...
94128
95129
@classmethod
130+
@override
96131
def register_accelerators(cls, accelerator_registry):
97132
accelerator_registry.register(
98133
"xpu",

src/lightning/fabric/connector.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,8 @@ def __init__(
141141
self._accelerator_flag = self._choose_auto_accelerator()
142142
elif self._accelerator_flag == "gpu":
143143
self._accelerator_flag = self._choose_gpu_accelerator_backend()
144+
elif isinstance(self._accelerator_flag, Accelerator):
145+
pass # for 3rd party accelerator, just do nothing
144146

145147
self._set_parallel_devices_and_init_accelerator()
146148

@@ -461,7 +463,7 @@ def _check_and_init_precision(self) -> Precision:
461463
if isinstance(self.strategy, DeepSpeedStrategy):
462464
return DeepSpeedPrecision(self._precision_input) # type: ignore
463465
if isinstance(self.strategy, FSDPStrategy):
464-
return FSDPPrecision(precision=self._precision_input) # type: ignore[arg-type]
466+
return FSDPPrecision(precision=self._precision_input, device=self._accelerator_flag.get_device() if isinstance(self._accelerator_flag, Accelerator) else None) # type: ignore[arg-type]
465467
mp_precision_supported = ("32-true", "bf16-mixed", "bf16-true", "16-true")
466468
if isinstance(self.strategy, ModelParallelStrategy) and self._precision_input not in mp_precision_supported:
467469
raise ValueError(
@@ -493,6 +495,8 @@ def _check_and_init_precision(self) -> Precision:
493495
else "Using bfloat16 Automatic Mixed Precision (AMP)"
494496
)
495497
device = "cpu" if self._accelerator_flag == "cpu" else "cuda"
498+
if isinstance(self._accelerator_flag, Accelerator):
499+
device = self._accelerator_flag.get_device()
496500
return MixedPrecision(precision=self._precision_input, device=device) # type: ignore[arg-type]
497501

498502
raise RuntimeError("No precision set")

src/lightning/fabric/plugins/precision/amp.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __init__(
5050

5151
self.precision = precision
5252
if scaler is None and self.precision == "16-mixed":
53-
scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler()
53+
scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else getattr(torch, f"{device.split(':')[0]}").amp.GradScaler()
5454
if scaler is not None and self.precision == "bf16-mixed":
5555
raise ValueError(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.")
5656
self.device = device

src/lightning/fabric/plugins/precision/fsdp.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,14 @@ class FSDPPrecision(Precision):
4848
4949
"""
5050

51-
def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None) -> None:
51+
def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None, device: str = None) -> None:
5252
supported_precision = get_args(_PRECISION_INPUT)
5353
if precision not in supported_precision:
5454
raise ValueError(
5555
f"`precision={precision!r})` is not supported in FSDP."
5656
f" `precision` must be one of: {supported_precision}."
5757
)
58+
self.device = device if device is not None else "cuda"
5859

5960
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
6061

@@ -110,7 +111,7 @@ def module_init_context(self) -> ContextManager:
110111
@override
111112
def forward_context(self) -> ContextManager:
112113
if "mixed" in self.precision:
113-
return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16))
114+
return torch.autocast(self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16))
114115
return self.tensor_init_context()
115116

116117
@override

src/lightning/fabric/strategies/ddp.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def setup_module(self, module: Module) -> DistributedDataParallel:
124124
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
125125
device_ids = self._determine_ddp_device_ids()
126126
# https://pytorch.org/docs/stable/notes/cuda.html#id5
127-
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
127+
ctx = getattr(torch, f"{self.root_device.type.split(':')[0]}").stream(getattr(torch, f"{self.root_device.type.split(':')[0]}").Stream()) if device_ids is not None else nullcontext()
128128
with ctx:
129129
return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs)
130130

src/lightning/fabric/strategies/deepspeed.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ def load_checkpoint(
506506

507507
optimzer_state_requested = any(isinstance(item, (Optimizer, DeepSpeedOptimizer)) for item in state.values())
508508

509-
torch.cuda.empty_cache()
509+
getattr(torch, f"{self.root_device.type.split(':')[0]}").empty_cache() if self.root_device.type != "cpu" else None
510510
_, client_state = engine.load_checkpoint(
511511
path,
512512
tag="checkpoint",

src/lightning/fabric/strategies/strategy.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def load_checkpoint(
325325
given, the full checkpoint will be returned.
326326
327327
"""
328-
torch.cuda.empty_cache()
328+
getattr(torch, f"{self.root_device.type.split(':')[0]}").empty_cache() if self.root_device.type != "cpu" else None
329329
checkpoint = self.checkpoint_io.load_checkpoint(path)
330330
if not state:
331331
return checkpoint

src/lightning/pytorch/accelerators/accelerator.py

+5
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,8 @@ def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]:
4545
4646
"""
4747
raise NotImplementedError
48+
49+
@staticmethod
50+
def get_device() -> str:
51+
"""Get the device for the current process."""
52+
raise NotImplementedError

src/lightning/pytorch/accelerators/cpu.py

+5
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
8080
description=cls.__name__,
8181
)
8282

83+
@staticmethod
84+
@override
85+
def get_device() -> str:
86+
return "cpu"
87+
8388

8489
# CPU device metrics
8590
_CPU_VM_PERCENT = "cpu_vm_percent"

src/lightning/pytorch/accelerators/cuda.py

+5
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,11 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
113113
description=cls.__name__,
114114
)
115115

116+
@staticmethod
117+
@override
118+
def get_device() -> str:
119+
return "cuda"
120+
116121

117122
def get_nvidia_gpu_stats(device: _DEVICE) -> Dict[str, float]: # pragma: no-cover
118123
"""Get GPU stats including memory, fan speed, and temperature from nvidia-smi.

src/lightning/pytorch/accelerators/mps.py

+5
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,11 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
8787
description=cls.__name__,
8888
)
8989

90+
@staticmethod
91+
@override
92+
def get_device() -> str:
93+
return "mps"
94+
9095

9196
# device metrics
9297
_VM_PERCENT = "M1_vm_percent"

src/lightning/pytorch/plugins/precision/amp.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __init__(
5050

5151
self.precision = precision
5252
if scaler is None and self.precision == "16-mixed":
53-
scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler()
53+
scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else getattr(torch, f"{device.split(':')[0]}").amp.GradScaler()
5454
if scaler is not None and self.precision == "bf16-mixed":
5555
raise MisconfigurationException(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.")
5656
self.device = device

src/lightning/pytorch/plugins/precision/fsdp.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,14 @@ class FSDPPrecision(Precision):
4747
4848
"""
4949

50-
def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None) -> None:
50+
def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None, device: str = None) -> None:
5151
supported_precision = get_args(_PRECISION_INPUT)
5252
if precision not in supported_precision:
5353
raise ValueError(
5454
f"`precision={precision!r})` is not supported in FSDP."
5555
f" `precision` must be one of: {supported_precision}."
5656
)
57+
self.device = device if device is not None else "cuda"
5758

5859
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
5960

@@ -119,7 +120,7 @@ def module_init_context(self) -> ContextManager:
119120
@override
120121
def forward_context(self) -> ContextManager:
121122
if "mixed" in self.precision:
122-
return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16))
123+
return torch.autocast(self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16))
123124
return _DtypeContextManager(self._desired_input_dtype)
124125

125126
@override

src/lightning/pytorch/strategies/ddp.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def _setup_model(self, model: Module) -> DistributedDataParallel:
190190
device_ids = self.determine_ddp_device_ids()
191191
log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}")
192192
# https://pytorch.org/docs/stable/notes/cuda.html#id5
193-
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
193+
ctx = getattr(torch, f"{self.root_device.type.split(':')[0]}").stream(getattr(torch, f"{self.root_device.type.split(':')[0]}").Stream()) if device_ids is not None else nullcontext()
194194
with ctx:
195195
return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)
196196

src/lightning/pytorch/strategies/strategy.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def lightning_module(self) -> Optional["pl.LightningModule"]:
363363
return self._lightning_module
364364

365365
def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
366-
torch.cuda.empty_cache()
366+
getattr(torch, f"{self.root_device.type.split(':')[0]}").empty_cache() if self.root_device.type != "cpu" else None
367367
return self.checkpoint_io.load_checkpoint(checkpoint_path)
368368

369369
def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = True) -> None:

src/lightning/pytorch/trainer/connectors/accelerator_connector.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,8 @@ def __init__(
141141
self._accelerator_flag = self._choose_auto_accelerator()
142142
elif self._accelerator_flag == "gpu":
143143
self._accelerator_flag = self._choose_gpu_accelerator_backend()
144+
elif isinstance(self._accelerator_flag, Accelerator):
145+
pass # for 3rd party accelerator, just do nothing
144146

145147
self._check_device_config_and_set_final_flags(devices=devices, num_nodes=num_nodes)
146148
self._set_parallel_devices_and_init_accelerator()
@@ -301,15 +303,18 @@ def _check_config_and_set_final_flags(
301303
f" but accelerator set to {self._accelerator_flag}, please choose one device type"
302304
)
303305
self._accelerator_flag = "cpu"
304-
if self._strategy_flag.parallel_devices[0].type == "cuda":
306+
elif self._strategy_flag.parallel_devices[0].type == "cuda":
305307
if self._accelerator_flag and self._accelerator_flag not in ("auto", "cuda", "gpu"):
306308
raise MisconfigurationException(
307309
f"GPU parallel_devices set through {self._strategy_flag.__class__.__name__} class,"
308310
f" but accelerator set to {self._accelerator_flag}, please choose one device type"
309311
)
310312
self._accelerator_flag = "cuda"
313+
else:
314+
pass # 3rd party accelerator
311315
self._parallel_devices = self._strategy_flag.parallel_devices
312316

317+
313318
def _check_device_config_and_set_final_flags(self, devices: Union[List[int], str, int], num_nodes: int) -> None:
314319
if not isinstance(num_nodes, int) or num_nodes < 1:
315320
raise ValueError(f"`num_nodes` must be a positive integer, but got {num_nodes}.")
@@ -458,11 +463,16 @@ def _check_strategy_and_fallback(self) -> None:
458463

459464
if (
460465
strategy_flag in FSDPStrategy.get_registered_strategies() or type(self._strategy_flag) is FSDPStrategy
461-
) and self._accelerator_flag not in ("cuda", "gpu"):
466+
) and self._accelerator_flag not in ("cuda", "gpu") and isinstance(self._accelerator_flag, str):
462467
raise ValueError(
463468
f"The strategy `{FSDPStrategy.strategy_name}` requires a GPU accelerator, but got:"
464469
f" {self._accelerator_flag}"
465470
)
471+
elif isinstance(self._accelerator_flag, Accelerator):
472+
Warning(
473+
f"Using a custom accelerator `{self._accelerator_flag.__class__.__name__}`."
474+
f" Please ensure it is compatible with the selected strategy `{strategy_flag}`."
475+
)
466476
if strategy_flag in _DDP_FORK_ALIASES and "fork" not in torch.multiprocessing.get_all_start_methods():
467477
raise ValueError(
468478
f"You selected `Trainer(strategy='{strategy_flag}')` but process forking is not supported on this"
@@ -496,7 +506,7 @@ def _check_and_init_precision(self) -> Precision:
496506
if isinstance(self.strategy, DeepSpeedStrategy):
497507
return DeepSpeedPrecision(self._precision_flag) # type: ignore[arg-type]
498508
if isinstance(self.strategy, FSDPStrategy):
499-
return FSDPPrecision(self._precision_flag) # type: ignore[arg-type]
509+
return FSDPPrecision(precision=self._precision_flag, device=self._accelerator_flag.get_device() if isinstance(self._accelerator_flag, Accelerator) else None) # type: ignore[arg-type]
500510
if self._precision_flag in ("16-true", "bf16-true"):
501511
return HalfPrecision(self._precision_flag) # type: ignore
502512
if self._precision_flag == "32-true":
@@ -520,6 +530,8 @@ def _check_and_init_precision(self) -> Precision:
520530
f"Using {'16bit' if self._precision_flag == '16-mixed' else 'bfloat16'} Automatic Mixed Precision (AMP)"
521531
)
522532
device = "cpu" if self._accelerator_flag == "cpu" else "cuda"
533+
if isinstance(self._accelerator_flag, Accelerator):
534+
device = self._accelerator_flag.get_device()
523535
return MixedPrecision(self._precision_flag, device) # type: ignore[arg-type]
524536

525537
raise RuntimeError("No precision set")

0 commit comments

Comments
 (0)