@@ -108,7 +108,7 @@ def test_different_option_combinations(self) -> None:
108108 # Only async staging
109109 test_cases .append (
110110 CheckpointStagerConfig (
111- use_pinned_memory = torch .cuda .is_available (),
111+ use_pinned_memory = torch .accelerator .is_available (),
112112 use_shared_memory = False ,
113113 use_async_staging = True ,
114114 use_non_blocking_copy = False ,
@@ -117,7 +117,7 @@ def test_different_option_combinations(self) -> None:
117117 # Only CUDA non-blocking copy
118118 test_cases .append (
119119 CheckpointStagerConfig (
120- use_pinned_memory = torch .cuda .is_available (),
120+ use_pinned_memory = torch .accelerator .is_available (),
121121 use_shared_memory = False ,
122122 use_async_staging = False ,
123123 use_non_blocking_copy = torch .accelerator .is_available (),
@@ -129,7 +129,7 @@ def test_different_option_combinations(self) -> None:
129129 stager = DefaultStager (options )
130130
131131 # Test staging works with these options
132- if options .use_async_staging and torch .cuda .is_available ():
132+ if options .use_async_staging and torch .accelerator .is_available ():
133133 result = stager .stage (self .state_dict )
134134 self .assertIsInstance (result , Future )
135135 staged_dict = result .result ()
@@ -183,7 +183,7 @@ def test_multiple_staging_operations(self) -> None:
183183 """Test multiple staging operations with the same stager."""
184184 options = CheckpointStagerConfig (
185185 use_async_staging = False ,
186- use_pinned_memory = torch .cuda .is_available (),
186+ use_pinned_memory = torch .accelerator .is_available (),
187187 use_shared_memory = False ,
188188 use_non_blocking_copy = torch .accelerator .is_available (),
189189 )
0 commit comments