@@ -141,6 +141,8 @@ def __init__(
141
141
self ._accelerator_flag = self ._choose_auto_accelerator ()
142
142
elif self ._accelerator_flag == "gpu" :
143
143
self ._accelerator_flag = self ._choose_gpu_accelerator_backend ()
144
+ elif isinstance (self ._accelerator_flag , Accelerator ):
145
+ pass # for 3rd party accelerator, just do nothing
144
146
145
147
self ._check_device_config_and_set_final_flags (devices = devices , num_nodes = num_nodes )
146
148
self ._set_parallel_devices_and_init_accelerator ()
@@ -301,15 +303,18 @@ def _check_config_and_set_final_flags(
301
303
f" but accelerator set to { self ._accelerator_flag } , please choose one device type"
302
304
)
303
305
self ._accelerator_flag = "cpu"
304
- if self ._strategy_flag .parallel_devices [0 ].type == "cuda" :
306
+ elif self ._strategy_flag .parallel_devices [0 ].type == "cuda" :
305
307
if self ._accelerator_flag and self ._accelerator_flag not in ("auto" , "cuda" , "gpu" ):
306
308
raise MisconfigurationException (
307
309
f"GPU parallel_devices set through { self ._strategy_flag .__class__ .__name__ } class,"
308
310
f" but accelerator set to { self ._accelerator_flag } , please choose one device type"
309
311
)
310
312
self ._accelerator_flag = "cuda"
313
+ else :
314
+ pass # 3rd party accelerator
311
315
self ._parallel_devices = self ._strategy_flag .parallel_devices
312
316
317
+
313
318
def _check_device_config_and_set_final_flags (self , devices : Union [List [int ], str , int ], num_nodes : int ) -> None :
314
319
if not isinstance (num_nodes , int ) or num_nodes < 1 :
315
320
raise ValueError (f"`num_nodes` must be a positive integer, but got { num_nodes } ." )
@@ -458,11 +463,16 @@ def _check_strategy_and_fallback(self) -> None:
458
463
459
464
if (
460
465
strategy_flag in FSDPStrategy .get_registered_strategies () or type (self ._strategy_flag ) is FSDPStrategy
461
- ) and self ._accelerator_flag not in ("cuda" , "gpu" ):
466
+ ) and self ._accelerator_flag not in ("cuda" , "gpu" ) and isinstance ( self . _accelerator_flag , str ) :
462
467
raise ValueError (
463
468
f"The strategy `{ FSDPStrategy .strategy_name } ` requires a GPU accelerator, but got:"
464
469
f" { self ._accelerator_flag } "
465
470
)
471
+ elif isinstance (self ._accelerator_flag , Accelerator ):
472
+ Warning (
473
+ f"Using a custom accelerator `{ self ._accelerator_flag .__class__ .__name__ } `."
474
+ f" Please ensure it is compatible with the selected strategy `{ strategy_flag } `."
475
+ )
466
476
if strategy_flag in _DDP_FORK_ALIASES and "fork" not in torch .multiprocessing .get_all_start_methods ():
467
477
raise ValueError (
468
478
f"You selected `Trainer(strategy='{ strategy_flag } ')` but process forking is not supported on this"
@@ -496,7 +506,7 @@ def _check_and_init_precision(self) -> Precision:
496
506
if isinstance (self .strategy , DeepSpeedStrategy ):
497
507
return DeepSpeedPrecision (self ._precision_flag ) # type: ignore[arg-type]
498
508
if isinstance (self .strategy , FSDPStrategy ):
499
- return FSDPPrecision (self ._precision_flag ) # type: ignore[arg-type]
509
+ return FSDPPrecision (precision = self ._precision_flag , device = self . _accelerator_flag . get_device () if isinstance ( self . _accelerator_flag , Accelerator ) else None ) # type: ignore[arg-type]
500
510
if self ._precision_flag in ("16-true" , "bf16-true" ):
501
511
return HalfPrecision (self ._precision_flag ) # type: ignore
502
512
if self ._precision_flag == "32-true" :
@@ -520,6 +530,8 @@ def _check_and_init_precision(self) -> Precision:
520
530
f"Using { '16bit' if self ._precision_flag == '16-mixed' else 'bfloat16' } Automatic Mixed Precision (AMP)"
521
531
)
522
532
device = "cpu" if self ._accelerator_flag == "cpu" else "cuda"
533
+ if isinstance (self ._accelerator_flag , Accelerator ):
534
+ device = self ._accelerator_flag .get_device ()
523
535
return MixedPrecision (self ._precision_flag , device ) # type: ignore[arg-type]
524
536
525
537
raise RuntimeError ("No precision set" )
0 commit comments