Skip to content

Commit af0e361

Browse files
authored
feat: 1) imporve api return value; 2) rename FTAllReduceStatus to FTCollectiveStatus (kleveross#73)
1 parent 78a62de commit af0e361

File tree

5 files changed

+11
-12
lines changed

5 files changed

+11
-12
lines changed

ftlib/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
__version__ = "0.0.1"
22

3-
from ftlib.ftlib_status import FTAllReduceStatus # noqa: F401
3+
from ftlib.ftlib_status import FTCollectiveStatus # noqa: F401
44
from ftlib.impl import BasicFTLib # noqa: F401

ftlib/ftlib_status.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from enum import Enum
22

33

4-
class FTAllReduceStatus(Enum):
4+
class FTCollectiveStatus(Enum):
55
NO_NEED = -1
66
ABORT = 2
77
FAIL = 1

ftlib/impl.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import time
44

55
from ftlib.consensus.consensus_status import ConsensusMode, ConsensusStatus
6-
from ftlib.ftlib_status import FTAllReduceStatus, FTRebuildStatus
6+
from ftlib.ftlib_status import FTCollectiveStatus, FTRebuildStatus
77
from ftlib.rank_assign_scheme import get_rank_size
88

99
# in the ftlib package, user is able to initialize the package
@@ -161,7 +161,7 @@ def execute(self, func, *args, **kwargs):
161161
# TODO: we should consider retrying rebuild process
162162
# for multiple times
163163
rebuild_result = self._rebuild()
164-
if rebuild_result != FTAllReduceStatus.SUCCESS:
164+
if rebuild_result != FTCollectiveStatus.SUCCESS:
165165
if rebuild_result == FTRebuildStatus.FAIL:
166166
raise Exception("rebuild process returns failed")
167167
elif rebuild_result == FTRebuildStatus.ABORT:
@@ -317,8 +317,7 @@ def func(*argc, **kwargs):
317317
# **kwargs, kwargs passed to the function
318318
# Returns:
319319
# if collective ops returns successfully, returns the value
320-
# otherwise, returns FTAllReduceStatus
321-
# TODO: rename FTAllReduceStatus to FTCollectiveStatus
320+
# otherwise, returns FTCollectiveStatus
322321

323322
# the api here is
324323
ops = (
@@ -330,7 +329,7 @@ def func(*argc, **kwargs):
330329
# if skil_allreduce == True, then any collective ops
331330
# shouldn't be called
332331
if self.skip_allreduce():
333-
return FTAllReduceStatus.NO_NEED
332+
return FTCollectiveStatus.NO_NEED, None
334333

335334
# if the instance is not initialized, then start rebuild
336335
# TODO: put rebuild into a try loop?
@@ -341,13 +340,13 @@ def func(*argc, **kwargs):
341340
rebuild_result = self._rebuild()
342341
if rebuild_result == FTRebuildStatus.ABORT:
343342
logging.warning("rebuild process returns abort")
344-
return FTAllReduceStatus.ABORT
343+
return FTCollectiveStatus.ABORT, None
345344
if rebuild_result == FTRebuildStatus.FAIL:
346345
logging.warning("rebuild process returns fail")
347-
return FTAllReduceStatus.ABORT
346+
return FTCollectiveStatus.ABORT, None
348347
if rebuild_result == FTRebuildStatus.SKIP_ALLREDUCE:
349348
logging.warning("rebuild process returns skip allreduce")
350-
return FTAllReduceStatus.NO_NEED
349+
return FTCollectiveStatus.NO_NEED, None
351350

352351
try:
353352
self.lock()
@@ -366,10 +365,10 @@ def func(*argc, **kwargs):
366365
except Exception as e:
367366
logging.exception(str(e))
368367
self.set_initialized(False)
369-
return FTAllReduceStatus.FAIL
368+
return FTCollectiveStatus.FAIL, None
370369
else:
371370
self.consensus.average_success()
372-
return result
371+
return FTCollectiveStatus.SUCCESS, result
373372
finally:
374373
self.unlock()
375374

test/tricky-data/pytorch-gossip-tricky-data.py renamed to test/deprecated-tests/tricky-data/pytorch-gossip-tricky-data.py

File renamed without changes.

0 commit comments

Comments
 (0)