33import time
44
55from ftlib .consensus .consensus_status import ConsensusMode , ConsensusStatus
6- from ftlib .ftlib_status import FTAllReduceStatus , FTRebuildStatus
6+ from ftlib .ftlib_status import FTCollectiveStatus , FTRebuildStatus
77from ftlib .rank_assign_scheme import get_rank_size
88
99# in the ftlib package, user is able to initialize the package
@@ -161,7 +161,7 @@ def execute(self, func, *args, **kwargs):
161161 # TODO: we should consider retrying rebuild process
162162 # for multiple times
163163 rebuild_result = self ._rebuild ()
164- if rebuild_result != FTAllReduceStatus .SUCCESS :
164+ if rebuild_result != FTCollectiveStatus .SUCCESS :
165165 if rebuild_result == FTRebuildStatus .FAIL :
166166 raise Exception ("rebuild process returns failed" )
167167 elif rebuild_result == FTRebuildStatus .ABORT :
@@ -317,8 +317,7 @@ def func(*argc, **kwargs):
317317 # **kwargs, kwargs passed to the function
318318 # Returns:
319319 # if collective ops returns successfully, returns the value
320- # otherwise, returns FTAllReduceStatus
321- # TODO: rename FTAllReduceStatus to FTCollectiveStatus
320+ # otherwise, returns FTCollectiveStatus
322321
323322 # the api here is
324323 ops = (
@@ -330,7 +329,7 @@ def func(*argc, **kwargs):
330329 # if skil_allreduce == True, then any collective ops
331330 # shouldn't be called
332331 if self .skip_allreduce ():
333- return FTAllReduceStatus .NO_NEED
332+ return FTCollectiveStatus .NO_NEED , None
334333
335334 # if the instance is not initialized, then start rebuild
336335 # TODO: put rebuild into a try loop?
@@ -341,13 +340,13 @@ def func(*argc, **kwargs):
341340 rebuild_result = self ._rebuild ()
342341 if rebuild_result == FTRebuildStatus .ABORT :
343342 logging .warning ("rebuild process returns abort" )
344- return FTAllReduceStatus .ABORT
343+ return FTCollectiveStatus .ABORT , None
345344 if rebuild_result == FTRebuildStatus .FAIL :
346345 logging .warning ("rebuild process returns fail" )
347- return FTAllReduceStatus .ABORT
346+ return FTCollectiveStatus .ABORT , None
348347 if rebuild_result == FTRebuildStatus .SKIP_ALLREDUCE :
349348 logging .warning ("rebuild process returns skip allreduce" )
350- return FTAllReduceStatus .NO_NEED
349+ return FTCollectiveStatus .NO_NEED , None
351350
352351 try :
353352 self .lock ()
@@ -366,10 +365,10 @@ def func(*argc, **kwargs):
366365 except Exception as e :
367366 logging .exception (str (e ))
368367 self .set_initialized (False )
369- return FTAllReduceStatus .FAIL
368+ return FTCollectiveStatus .FAIL , None
370369 else :
371370 self .consensus .average_success ()
372- return result
371+ return FTCollectiveStatus . SUCCESS , result
373372 finally :
374373 self .unlock ()
375374
0 commit comments