Skip to content

Commit 5c3a054

Browse files
xuzhao9facebook-github-bot
authored andcommitted
Add FLOPS support to the new profiler API. (pytorch#51734)
Summary: The new profiler API was added in PR#48280. This PR is to add FLOPS support to the new profiler API. Pull Request resolved: pytorch#51734 Test Plan: ```python python test/test_profiler.py -k test_flops ``` Reviewed By: xuzhao9 Differential Revision: D26261851 Pulled By: ilia-cher fbshipit-source-id: dbeba4c197e6f51a9a8e640e8bb60ec38df87f73
1 parent 430329e commit 5c3a054

File tree

5 files changed

+23
-5
lines changed

5 files changed

+23
-5
lines changed

test/test_profiler.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,20 @@ def test_flops(self):
385385
profiler_output = prof.key_averages(group_by_input_shape=True).table(sort_by="cpu_time_total", row_limit=10)
386386
self.assertIn("FLOPS", profiler_output)
387387

388+
if not (kineto_available() and torch.cuda.is_available()):
389+
return
390+
391+
with profile(activities=[
392+
torch.profiler.ProfilerActivity.CPU,
393+
torch.profiler.ProfilerActivity.CUDA],
394+
record_shapes=True,
395+
with_flops=True,
396+
) as kineto_profiler:
397+
model(inputs)
398+
profiler_output = kineto_profiler.key_averages().table(
399+
sort_by="self_cuda_time_total", row_limit=-1)
400+
self.assertIn("FLOPS", profiler_output)
401+
388402
@unittest.skipIf(not kineto_available(), "Kineto is required")
389403
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
390404
def test_kineto_profiler_api(self):

torch/autograd/profiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,8 +365,8 @@ class profile(object):
365365
366366
with_flops (bool, optional): If with_flops is set, the profiler will estimate
367367
the FLOPS (floating pointer operations per second) value using the operator's input shape
368-
and total CPU time. This allows one to estimate the hardware performance. Currently,
369-
this option only works for the matrix multiplication and convolution functions.
368+
and total time. This allows one to estimate the hardware performance. Currently,
369+
this option only works for the matrix multiplication and 2D convolution operators.
370370
371371
profile_memory (bool, optional): track tensor memory allocation/deallocation.
372372

torch/csrc/autograd/profiler_kineto.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ struct TORCH_API KinetoEvent {
170170
uint8_t activity_type_;
171171
c10::optional<std::vector<std::vector<int64_t>>> shapes_;
172172
c10::optional<std::vector<std::string>> stack_;
173-
uint64_t flops_;
173+
uint64_t flops_ = 0;
174174

175175
std::string name_;
176176
uint64_t device_index_ = 0;

torch/csrc/autograd/profiler_legacy.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ struct TORCH_API LegacyEvent {
331331
uint64_t correlation_id_;
332332
// Extra arguments for computing op flops
333333
std::unordered_map<std::string, c10::IValue> extra_args_;
334-
uint64_t flops_;
334+
uint64_t flops_ = 0;
335335
};
336336

337337
// a linked-list of fixed sized vectors, to avoid

torch/profiler/profiler.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ class profile(object):
9292
during the profiling;
9393
- ``record_shapes`` - save information about operator's input shapes;
9494
- ``profile_memory`` - track tensor memory allocation/deallocation;
95-
- ``with_stack`` - record source information (file and line number) for the ops.
95+
- ``with_stack`` - record source information (file and line number) for the ops;
96+
- ``with_flops`` - use formula to estimate the FLOPS of specific operators (matrix multiplication and 2D convolution);
9697
- ``use_cuda`` - (deprecated, use ``activities``).
9798
9899
.. note::
@@ -178,6 +179,7 @@ def __init__(
178179
record_shapes: bool = False,
179180
profile_memory: bool = False,
180181
with_stack: bool = False,
182+
with_flops: bool = False,
181183
# deprecated:
182184
use_cuda: Optional[bool] = None):
183185
if activities:
@@ -207,6 +209,7 @@ def __init__(
207209
self.record_steps = False
208210
self.on_trace_ready = on_trace_ready
209211
self.record_shapes = record_shapes
212+
self.with_flops = with_flops
210213
self.profile_memory = profile_memory
211214
self.with_stack = with_stack
212215
self.step_num = 0
@@ -353,6 +356,7 @@ def _start_warmup(self):
353356
use_cuda=(ProfilerActivity.CUDA in self.activities),
354357
use_cpu=(ProfilerActivity.CPU in self.activities),
355358
record_shapes=self.record_shapes,
359+
with_flops=self.with_flops,
356360
profile_memory=self.profile_memory,
357361
with_stack=self.with_stack,
358362
use_kineto=True,

0 commit comments

Comments
 (0)