@@ -108,7 +108,7 @@ def test_different_option_combinations(self) -> None:
108108            # Only async staging 
109109            test_cases .append (
110110                CheckpointStagerConfig (
111-                     use_pinned_memory = torch .cuda .is_available (),
111+                     use_pinned_memory = torch .accelerator .is_available (),
112112                    use_shared_memory = False ,
113113                    use_async_staging = True ,
114114                    use_non_blocking_copy = False ,
@@ -117,7 +117,7 @@ def test_different_option_combinations(self) -> None:
117117            # Only CUDA non-blocking copy 
118118            test_cases .append (
119119                CheckpointStagerConfig (
120-                     use_pinned_memory = torch .cuda .is_available (),
120+                     use_pinned_memory = torch .accelerator .is_available (),
121121                    use_shared_memory = False ,
122122                    use_async_staging = False ,
123123                    use_non_blocking_copy = torch .accelerator .is_available (),
@@ -129,7 +129,7 @@ def test_different_option_combinations(self) -> None:
129129                stager  =  DefaultStager (options )
130130
131131                # Test staging works with these options 
132-                 if  options .use_async_staging  and  torch .cuda .is_available ():
132+                 if  options .use_async_staging  and  torch .accelerator .is_available ():
133133                    result  =  stager .stage (self .state_dict )
134134                    self .assertIsInstance (result , Future )
135135                    staged_dict  =  result .result ()
@@ -183,7 +183,7 @@ def test_multiple_staging_operations(self) -> None:
183183        """Test multiple staging operations with the same stager.""" 
184184        options  =  CheckpointStagerConfig (
185185            use_async_staging = False ,
186-             use_pinned_memory = torch .cuda .is_available (),
186+             use_pinned_memory = torch .accelerator .is_available (),
187187            use_shared_memory = False ,
188188            use_non_blocking_copy = torch .accelerator .is_available (),
189189        )
0 commit comments