Skip to content

Commit 5221fe1

Browse files
revise option to control number of resubmit fail jobs #545 (#554)
reopen PR #545 due to branch removed <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added configurable retry_count to machine instances (default: 3), persisted through serialization and visible in machine arguments. * **Bug Fixes / Improvements** * Improved machine constructors to accept flexible initialization options for various machine types. * **Tests** * Updated test suite to validate retry_count configuration and persistence. * **Chores** * Project configuration formatting refined. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 0b5dfd1 commit 5221fe1

File tree

6 files changed

+36
-37
lines changed

6 files changed

+36
-37
lines changed

dpdispatcher/machine.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def __init__(
8282
local_root=None,
8383
remote_root=None,
8484
remote_profile={},
85+
retry_count=3,
8586
*,
8687
context=None,
8788
):
@@ -96,6 +97,7 @@ def __init__(
9697
else:
9798
pass
9899
self.bind_context(context=context)
100+
self.retry_count = retry_count
99101

100102
def bind_context(self, context):
101103
self.context = context
@@ -148,7 +150,8 @@ def load_from_dict(cls, machine_dict):
148150
base.check_value(machine_dict, strict=False)
149151

150152
context = BaseContext.load_from_dict(machine_dict)
151-
machine = machine_class(context=context)
153+
retry_count = machine_dict.get("retry_count", 3)
154+
machine = machine_class(context=context, retry_count=retry_count)
152155
return machine
153156

154157
def serialize(self, if_empty_remote_profile=False):
@@ -161,6 +164,7 @@ def serialize(self, if_empty_remote_profile=False):
161164
machine_dict["remote_profile"] = self.context.remote_profile
162165
else:
163166
machine_dict["remote_profile"] = {}
167+
machine_dict["retry_count"] = self.retry_count
164168
# normalize the dict
165169
base = self.arginfo()
166170
machine_dict = base.normalize_value(machine_dict, trim_pattern="_*")
@@ -396,6 +400,7 @@ def arginfo(cls):
396400
doc_clean_asynchronously = (
397401
"Clean the remote directory asynchronously after the job finishes."
398402
)
403+
doc_retry_count = "Number of retries to resubmit failed jobs."
399404

400405
machine_args = [
401406
Argument("batch_type", str, optional=False, doc=doc_batch_type),
@@ -413,6 +418,7 @@ def arginfo(cls):
413418
default=False,
414419
doc=doc_clean_asynchronously,
415420
),
421+
Argument("retry_count", int, optional=True, default=3, doc=doc_retry_count),
416422
]
417423

418424
context_variant = Variant(

dpdispatcher/machines/dp_cloud_server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
class Bohrium(Machine):
2020
alias = ("Lebesgue", "DpCloudServer")
2121

22-
def __init__(self, context):
22+
def __init__(self, context, **kwargs):
23+
super().__init__(context=context, **kwargs)
2324
self.context = context
2425
self.input_data = context.remote_profile["input_data"].copy()
2526
self.api_version = 2
@@ -32,7 +33,6 @@ def __init__(self, context):
3233
phone = context.remote_profile.get("phone", None)
3334
username = context.remote_profile.get("username", None)
3435
password = context.remote_profile.get("password", None)
35-
self.retry_count = context.remote_profile.get("retry_count", 3)
3636
self.ignore_exit_code = context.remote_profile.get("ignore_exit_code", True)
3737

3838
ticket = os.environ.get("BOHR_TICKET", None)

dpdispatcher/machines/openapi.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ def unzip_file(zip_file, out_dir="./"):
2929

3030

3131
class OpenAPI(Machine):
32-
def __init__(self, context):
32+
def __init__(self, context, **kwargs):
33+
super().__init__(context=context, **kwargs)
3334
if not found_bohriumsdk:
3435
raise ModuleNotFoundError(
3536
"bohriumsdk not installed. Install dpdispatcher with `pip install dpdispatcher[bohrium]`"
@@ -38,7 +39,6 @@ def __init__(self, context):
3839
self.remote_profile = context.remote_profile.copy()
3940

4041
self.grouped = self.remote_profile.get("grouped", True)
41-
self.retry_count = self.remote_profile.get("retry_count", 3)
4242
self.ignore_exit_code = context.remote_profile.get("ignore_exit_code", True)
4343

4444
access_key = (

dpdispatcher/machines/pbs.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818

1919
class PBS(Machine):
20+
# def __init__(self, **kwargs):
21+
# super().__init__(**kwargs)
22+
2023
def gen_script(self, job):
2124
pbs_script = super().gen_script(job)
2225
return pbs_script
@@ -188,24 +191,8 @@ def gen_script_header(self, job):
188191

189192

190193
class SGE(PBS):
191-
def __init__(
192-
self,
193-
batch_type=None,
194-
context_type=None,
195-
local_root=None,
196-
remote_root=None,
197-
remote_profile={},
198-
*,
199-
context=None,
200-
):
201-
super(PBS, self).__init__(
202-
batch_type,
203-
context_type,
204-
local_root,
205-
remote_root,
206-
remote_profile,
207-
context=context,
208-
)
194+
def __init__(self, **kwargs):
195+
super().__init__(**kwargs)
209196

210197
def gen_script_header(self, job):
211198
### Ref:https://softpanorama.org/HPC/PBS_and_derivatives/Reference/pbs_command_vs_sge_commands.shtml

pyproject.toml

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@ build-backend = "setuptools.build_meta"
66
name = "dpdispatcher"
77
dynamic = ["version"]
88
description = "Generate HPC scheduler systems jobs input scripts, submit these scripts to HPC systems, and poke until they finish"
9-
authors = [
10-
{ name = "DeepModeling" },
11-
]
9+
authors = [{ name = "DeepModeling" }]
1210
license = { file = "LICENSE" }
1311
classifiers = [
1412
"Programming Language :: Python :: 3.7",
@@ -32,7 +30,15 @@ dependencies = [
3230
]
3331
requires-python = ">=3.7"
3432
readme = "README.md"
35-
keywords = ["dispatcher", "hpc", "slurm", "lsf", "pbs", "ssh", "jh_unischeduler"]
33+
keywords = [
34+
"dispatcher",
35+
"hpc",
36+
"slurm",
37+
"lsf",
38+
"pbs",
39+
"ssh",
40+
"jh_unischeduler",
41+
]
3642

3743
[project.urls]
3844
Homepage = "https://github.com/deepmodeling/dpdispatcher"
@@ -59,12 +65,8 @@ docs = [
5965
]
6066
cloudserver = ["oss2", "tqdm", "bohrium-sdk"]
6167
bohrium = ["oss2", "tqdm", "bohrium-sdk"]
62-
gui = [
63-
"dpgui",
64-
]
65-
test = [
66-
"dpgui",
67-
]
68+
gui = ["dpgui"]
69+
test = ["dpgui"]
6870

6971
[tool.setuptools.packages.find]
7072
include = ["dpdispatcher*"]
@@ -84,11 +86,11 @@ profile = "black"
8486

8587
[tool.ruff.lint]
8688
select = [
87-
"E", # errors
88-
"F", # pyflakes
89-
"D", # pydocstyle
89+
"E", # errors
90+
"F", # pyflakes
91+
"D", # pydocstyle
9092
"UP", # pyupgrade
91-
"I", # isort
93+
"I", # isort
9294
]
9395
ignore = [
9496
"E501", # line too long
@@ -113,3 +115,6 @@ ignore = [
113115

114116
[tool.ruff.lint.pydocstyle]
115117
convention = "numpy"
118+
119+
[tool.ruff]
120+
line-length = 88

tests/test_argcheck.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def test_machine_argcheck(self):
2727
"symlink": True,
2828
},
2929
"clean_asynchronously": False,
30+
"retry_count": 3,
3031
}
3132
self.assertDictEqual(norm_dict, expected_dict)
3233

0 commit comments

Comments
 (0)