@@ -110,7 +110,7 @@ class StagingOptions:
110110        use_async_staging (bool): Enable asynchronous staging using a 
111111            background thread pool. Allows overlapping computation with 
112112            staging operations. Requires CUDA. Default: True 
113-         use_cuda_non_blocking_copy  (bool): Use non-blocking CUDA memory 
113+         use_non_blocking_copy  (bool): Use non-blocking CUDA memory 
114114            copies with stream synchronization. Improves performance by 
115115            allowing CPU work to continue during GPU transfers. Default: True 
116116
@@ -121,7 +121,7 @@ class StagingOptions:
121121    use_pinned_memory : bool  =  True 
122122    use_shared_memory : bool  =  True 
123123    use_async_staging : bool  =  True 
124-     use_cuda_non_blocking_copy : bool  =  True 
124+     use_non_blocking_copy : bool  =  True 
125125
126126
127127class  DefaultStager (AsyncStager ):
@@ -177,15 +177,17 @@ def __init__(
177177        self ._staging_stream  =  None 
178178        if  self ._config .use_async_staging :
179179            self ._staging_executor  =  ThreadPoolExecutor (max_workers = 1 )
180-             if  torch .cuda .is_available ():
180+             if  torch .accelerator .is_available ():
181181                # Note: stream needs to be initialized on the main thread after default cuda 
182182                # stream is setup/used to avoid the risk of accidentally reusing the main 
183183                # compute stream or in other cases kernels actually launching from the 
184184                # main thread. 
185-                 self ._staging_stream  =  torch .cuda . Stream ()
185+                 self ._staging_stream  =  torch .Stream ()
186186
187-         if  self ._config .use_cuda_non_blocking_copy :
188-             assert  torch .cuda .is_available (), "Non-blocking copy requires CUDA" 
187+         if  self ._config .use_non_blocking_copy :
188+             assert  torch .accelerator .is_available (), (
189+                 "Non-blocking copy requires CUDA/XPU" 
190+             )
189191
190192        self ._staging_future : Optional [Future [STATE_DICT_TYPE ]] =  None 
191193
@@ -216,20 +218,20 @@ def stage(
216218            return  self ._stage (state_dict , ** kwargs )
217219
218220    def  _stage (self , state_dict : STATE_DICT_TYPE , ** kwargs : Any ) ->  STATE_DICT_TYPE :
219-         if  self ._config .use_cuda_non_blocking_copy :
221+         if  self ._config .use_non_blocking_copy :
220222            assert  self ._staging_stream  or  not  self ._config .use_async_staging , (
221-                 "Non-blocking cuda  copy in a background thread for async staging needs staging_stream to be initialized." 
223+                 "Non-blocking copy in a background thread for async staging needs staging_stream to be initialized." 
222224            )
223225            with  (
224226                self ._staging_stream 
225227                if  self ._staging_stream  is  not None 
226228                else  nullcontext ()
227229            ):
228230                state_dict  =  self ._state_dict_stager .stage (
229-                     state_dict , non_blocking = self ._config .use_cuda_non_blocking_copy 
231+                     state_dict , non_blocking = self ._config .use_non_blocking_copy 
230232                )
231233            # waits for the enqued copy operations to finish. 
232-             self ._staging_stream .synchronize () if  self ._staging_stream  else  torch .cuda .synchronize ()
234+             self ._staging_stream .synchronize () if  self ._staging_stream  else  torch .accelerator .synchronize ()
233235        else :
234236            state_dict  =  self ._state_dict_stager .stage (state_dict , non_blocking = False )
235237        return  state_dict 
0 commit comments