Skip to content

Commit 7bb0069

Browse files
adamomainzfacebook-github-bot
authored andcommitted
adding arc lint opt in for benchmarking (#2448)
Summary: Pull Request resolved: #2448 lets see if arc lint runs here Reviewed By: zertosh Differential Revision: D61929352 fbshipit-source-id: 4651d3dfa2eb743f5b0e30954765c43c77447ee2
1 parent 8c5abe8 commit 7bb0069

File tree

622 files changed

+21337
-11530
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

622 files changed

+21337
-11530
lines changed

.clang-format

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
Language: ObjC
3+
DisableFormat: true
4+
SortIncludes: false
5+
...

install.py

+11-12
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pathlib import Path
66

77
from userbenchmark import list_userbenchmarks
8-
from utils import get_pkg_versions, TORCH_DEPS, generate_pkg_constraints
8+
from utils import generate_pkg_constraints, get_pkg_versions, TORCH_DEPS
99
from utils.python_utils import pip_install_requirements
1010

1111
REPO_ROOT = Path(__file__).parent
@@ -24,26 +24,21 @@
2424
action="store_true",
2525
help="Run in test mode and check package versions",
2626
)
27-
parser.add_argument(
28-
"--skip",
29-
nargs="*",
30-
default=[],
31-
help="Skip models to install."
32-
)
27+
parser.add_argument("--skip", nargs="*", default=[], help="Skip models to install.")
3328
parser.add_argument(
3429
"--torch",
3530
action="store_true",
36-
help="Only require torch to be installed, ignore torchvision and torchaudio."
31+
help="Only require torch to be installed, ignore torchvision and torchaudio.",
3732
)
3833
parser.add_argument(
3934
"--numpy",
4035
action="store_true",
41-
help="Only require numpy to be installed, ignore torch, torchvision and torchaudio."
36+
help="Only require numpy to be installed, ignore torch, torchvision and torchaudio.",
4237
)
4338
parser.add_argument(
4439
"--check-only",
4540
action="store_true",
46-
help="Only run the version check and generate the contraints"
41+
help="Only run the version check and generate the contraints",
4742
)
4843
parser.add_argument("--canary", action="store_true", help="Install canary model.")
4944
parser.add_argument("--continue_on_fail", action="store_true")
@@ -86,14 +81,18 @@
8681
# Install userbenchmark dependencies if exists
8782
userbenchmark_dir = REPO_ROOT.joinpath("userbenchmark", args.userbenchmark)
8883
cmd = [sys.executable, "install.py"]
89-
print(f"Installing userbenchmark {args.userbenchmark} with extra args: {extra_args}")
84+
print(
85+
f"Installing userbenchmark {args.userbenchmark} with extra args: {extra_args}"
86+
)
9087
cmd.extend(extra_args)
9188
if userbenchmark_dir.joinpath("install.py").is_file():
9289
# add the current run env to PYTHONPATH to load framework install utils
9390
run_env = os.environ.copy()
9491
run_env["PYTHONPATH"] = Path(REPO_ROOT).as_posix()
9592
subprocess.check_call(
96-
cmd, cwd=userbenchmark_dir.absolute(), env=run_env,
93+
cmd,
94+
cwd=userbenchmark_dir.absolute(),
95+
env=run_env,
9796
)
9897
sys.exit(0)
9998

run_benchmark.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ def list_benchmarks() -> Dict[str, str]:
2424

2525
def run():
2626
available_benchmarks = list_benchmarks()
27-
parser = argparse.ArgumentParser(description="Run a TorchBench user benchmark", add_help=False)
27+
parser = argparse.ArgumentParser(
28+
description="Run a TorchBench user benchmark", add_help=False
29+
)
2830
parser.add_argument(
2931
"bm_name",
3032
choices=available_benchmarks.keys(),

scripts/userbenchmark/upload_s3_csv.py

+26-19
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import argparse
2-
import sys
32
import os
43
import re
5-
from pathlib import Path
4+
import sys
65
from datetime import datetime
6+
from pathlib import Path
77

88
REPO_ROOT = Path(__file__).parent.parent.parent.resolve()
99

10+
1011
class add_path:
1112
def __init__(self, path):
1213
self.path = path
@@ -22,18 +23,17 @@ def __exit__(self, exc_type, exc_value, traceback):
2223

2324

2425
with add_path(str(REPO_ROOT)):
25-
from utils.s3_utils import (
26-
S3Client,
27-
USERBENCHMARK_S3_BUCKET,
28-
)
26+
from utils.s3_utils import S3Client, USERBENCHMARK_S3_BUCKET
2927

3028

31-
def upload_s3(s3_object: str,
32-
ub_name: str,
33-
workflow_run_id: str,
34-
workflow_run_attempt: str,
35-
file_path: Path,
36-
dryrun: bool):
29+
def upload_s3(
30+
s3_object: str,
31+
ub_name: str,
32+
workflow_run_id: str,
33+
workflow_run_attempt: str,
34+
file_path: Path,
35+
dryrun: bool,
36+
):
3737
"""S3 path:
3838
s3://ossci-metrics/<s3_object>/<ub_name>/<workflow_run_id>/<workflow_run_attempt>/file_name
3939
"""
@@ -46,7 +46,12 @@ def upload_s3(s3_object: str,
4646

4747
def _get_files_to_upload(file_path: str, match_filename: str):
4848
filename_regex = re.compile(match_filename)
49-
return [ file_name for file_name in os.listdir(file_path) if filename_regex.match(file_name) ]
49+
return [
50+
file_name
51+
for file_name in os.listdir(file_path)
52+
if filename_regex.match(file_name)
53+
]
54+
5055

5156
if __name__ == "__main__":
5257
parser = argparse.ArgumentParser(description=__doc__)
@@ -83,9 +88,11 @@ def _get_files_to_upload(file_path: str, match_filename: str):
8388

8489
for file in files_to_upload:
8590
file_path = Path(args.upload_path).joinpath(file)
86-
upload_s3(s3_object=args.s3_prefix,
87-
ub_name=args.userbenchmark,
88-
workflow_run_id=workflow_run_id,
89-
workflow_run_attempt=workflow_run_attempt,
90-
file_path=file_path,
91-
dryrun=args.dryrun)
91+
upload_s3(
92+
s3_object=args.s3_prefix,
93+
ub_name=args.userbenchmark,
94+
workflow_run_id=workflow_run_id,
95+
workflow_run_attempt=workflow_run_attempt,
96+
file_path=file_path,
97+
dryrun=args.dryrun,
98+
)

test_bench.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -113,11 +113,9 @@ def test_eval(self, model_path, device, benchmark, pytestconfig):
113113

114114
benchmark(task.invoke)
115115
benchmark.extra_info["machine_state"] = get_machine_state()
116-
benchmark.extra_info["batch_size"] = task.get_model_attribute(
117-
"batch_size"
118-
)
116+
benchmark.extra_info["batch_size"] = task.get_model_attribute("batch_size")
119117
benchmark.extra_info["precision"] = task.get_model_attribute(
120-
"dargs", "precision"
118+
"dargs", "precision"
121119
)
122120
benchmark.extra_info["test"] = "eval"
123121

torchbenchmark/__init__.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __exit__(self, exc_type, exc_value, traceback):
4242
except ValueError:
4343
pass
4444

45+
4546
class add_ld_library_path:
4647
def __init__(self, path):
4748
self.path = path
@@ -57,6 +58,7 @@ def __enter__(self):
5758
def __exit__(self, exc_type, exc_value, traceback):
5859
os.environ = self.os_environ.copy()
5960

61+
6062
with add_path(str(REPO_PATH)):
6163
from utils import get_pkg_versions, TORCH_DEPS
6264

@@ -192,7 +194,7 @@ def setup(
192194
model_paths = list(model_paths)
193195
model_paths.extend(canary_model_paths)
194196
skip_models = [] if not skip_models else skip_models
195-
model_paths = [ x for x in model_paths if os.path.basename(x) not in skip_models ]
197+
model_paths = [x for x in model_paths if os.path.basename(x) not in skip_models]
196198
for model_path in model_paths:
197199
print(f"running setup for {model_path}...", end="", flush=True)
198200
if test_mode:
@@ -346,6 +348,7 @@ def _maybe_import_model(package: str, model_name: str) -> Dict[str, Any]:
346348
import importlib
347349
import os
348350
import traceback
351+
349352
from torchbenchmark import load_model_by_name
350353

351354
diagnostic_msg = ""
@@ -589,7 +592,10 @@ def watch_cuda_memory(
589592

590593

591594
def list_models_details(workers: int = 1) -> List[ModelDetails]:
592-
return [ModelTask(os.path.basename(model_path)).model_details for model_path in _list_model_paths()]
595+
return [
596+
ModelTask(os.path.basename(model_path)).model_details
597+
for model_path in _list_model_paths()
598+
]
593599

594600

595601
def list_models(model_match=None):

torchbenchmark/_components/_impl/tasks/base.py

+34-20
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Add Task abstraction to reduce the friction of controlling a remote worker."""
2+
23
import abc
34
import ast
45
import functools
@@ -19,8 +20,7 @@ class TaskBase(abc.ABC):
1920
"""
2021

2122
@abc.abstractproperty
22-
def worker(self) -> base.WorkerBase:
23-
...
23+
def worker(self) -> base.WorkerBase: ...
2424

2525

2626
def parse_f(f: typing.Callable) -> typing.Tuple[inspect.Signature, str]:
@@ -41,12 +41,14 @@ def parse_f(f: typing.Callable) -> typing.Tuple[inspect.Signature, str]:
4141
if arg_parameter.kind == inspect.Parameter.VAR_POSITIONAL:
4242
raise TypeError(
4343
f"Variadic positional argument `*{arg}` not permitted "
44-
"for `run_in_worker` function.")
44+
"for `run_in_worker` function."
45+
)
4546

4647
if arg_parameter.kind == inspect.Parameter.VAR_KEYWORD:
4748
raise TypeError(
4849
f"Variadic keywork argument `**{arg}` not permitted "
49-
"for `run_in_worker` function.")
50+
"for `run_in_worker` function."
51+
)
5052

5153
if arg_parameter.annotation == inspect.Parameter.empty:
5254
raise TypeError(f"Missing type annotation for parameter `{arg}`")
@@ -84,7 +86,9 @@ def parse_f(f: typing.Callable) -> typing.Tuple[inspect.Signature, str]:
8486
# `functools.wraps` is not detectable (by design) and is thus caveat
8587
# emptor.
8688
if getattr(f, "__wrapped__", None):
87-
raise TypeError(textwrap.dedent("""
89+
raise TypeError(
90+
textwrap.dedent(
91+
"""
8892
`f` cannot be decorated below `@run_in_worker` (except for
8993
@staticmethod) because the extraction logic would not carry through
9094
said decorator(s).
@@ -100,7 +104,9 @@ def foo() -> None:
100104
@my_decorator
101105
def foo() -> None:
102106
...
103-
""").strip())
107+
"""
108+
).strip()
109+
)
104110

105111
# Dedent, as `f` may have been defined in a scoped context.
106112
f_src = textwrap.dedent(inspect.getsource(f))
@@ -127,10 +133,12 @@ def foo() -> None:
127133
# line comment), we simply elect to skip over them and index on the
128134
# first node that will give valid indices.
129135
if node.col_offset == -1:
130-
assert isinstance(node.value, ast.Str), f"Expected `ast.Str`, got {type(node)}. ({node}) {node.lineno}"
136+
assert isinstance(
137+
node.value, ast.Str
138+
), f"Expected `ast.Str`, got {type(node)}. ({node}) {node.lineno}"
131139
continue
132140

133-
raw_body_lines = src_lines[node.lineno - 1:]
141+
raw_body_lines = src_lines[node.lineno - 1 :]
134142
col_offset = node.col_offset
135143
break
136144

@@ -229,16 +237,15 @@ def outer(f: typing.Callable[..., typing.Any]) -> typing.Callable[..., typing.An
229237
pass
230238

231239
signature, f_body = parse_f(f)
232-
has_return_value = (signature.return_annotation is not None)
240+
has_return_value = signature.return_annotation is not None
233241
if has_return_value and not scoped:
234242
raise TypeError(
235-
"Unscoped (globally executed) call can not have a return value.")
243+
"Unscoped (globally executed) call can not have a return value."
244+
)
236245

237246
@functools.wraps(f)
238247
def inner(
239-
self: TaskBase,
240-
*args: typing.Any,
241-
**kwargs: typing.Any
248+
self: TaskBase, *args: typing.Any, **kwargs: typing.Any
242249
) -> typing.Any:
243250
bound_signature = signature.bind(*args, **kwargs)
244251
bound_signature.apply_defaults()
@@ -250,13 +257,17 @@ def inner(
250257
except ValueError:
251258
raise ValueError(f"unmarshallable arg {arg_name}: {arg_value}")
252259

253-
body.append(f"{arg_name} = marshal.loads(bytes.fromhex({repr(arg_bytes.hex())})) # {arg_value}")
260+
body.append(
261+
f"{arg_name} = marshal.loads(bytes.fromhex({repr(arg_bytes.hex())})) # {arg_value}"
262+
)
254263
body.extend(["", "# Wrapped source"] + f_body.splitlines(keepends=False))
255264

256-
src = "\n".join([
257-
"def _run_in_worker_f():",
258-
textwrap.indent("\n".join(body), " " * 4),
259-
textwrap.dedent("""
265+
src = "\n".join(
266+
[
267+
"def _run_in_worker_f():",
268+
textwrap.indent("\n".join(body), " " * 4),
269+
textwrap.dedent(
270+
"""
260271
try:
261272
# Clear prior value if it exists.
262273
del _run_in_worker_result
@@ -265,8 +276,10 @@ def inner(
265276
pass
266277
267278
_run_in_worker_result = _run_in_worker_f()
268-
""")
269-
])
279+
"""
280+
),
281+
]
282+
)
270283

271284
# `worker.load` is not free, so for void functions we skip it.
272285
if has_return_value:
@@ -278,4 +291,5 @@ def inner(
278291
self.worker.run(src)
279292

280293
return inner
294+
281295
return outer

0 commit comments

Comments
 (0)