Skip to content

Commit 243fea7

Browse files
jeffkbkimfacebook-github-bot
authored andcommitted
4/N: Integration (#3460)
Summary: Pull Request resolved: #3460 Differential Revision: D84082461
1 parent 781603c commit 243fea7

File tree

4 files changed

+14
-25
lines changed

4 files changed

+14
-25
lines changed

torchrec/metrics/cpu_offloaded_metric_module.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
MetricUpdateJob,
2323
SynchronizationMarker,
2424
)
25-
from torchrec.metrics.metric_module import MetricValue, RecMetricModule
25+
from torchrec.metrics.metric_module import MetricsFuture, MetricValue, RecMetricModule
2626
from torchrec.metrics.metric_state_snapshot import MetricStateSnapshot
2727
from torchrec.metrics.model_utils import parse_task_model_outputs
2828
from torchrec.metrics.rec_metric import RecMetricException
@@ -254,24 +254,24 @@ def compute(self) -> Dict[str, MetricValue]:
254254
)
255255

256256
@override
257-
def async_compute(
258-
self, future: concurrent.futures.Future[Dict[str, MetricValue]]
259-
) -> None:
257+
def async_compute(self) -> MetricsFuture:
260258
"""
261259
Entry point for asynchronous metric compute. It enqueues a synchronization marker
262260
to the update queue.
263261
264262
Args:
265263
future: Pre-created future where the computed metrics will be set.
266264
"""
265+
metrics_future = concurrent.futures.Future()
267266
if self._shutdown_event.is_set():
268-
future.set_exception(
267+
metrics_future.set_exception(
269268
RecMetricException("metric processor thread is shut down.")
270269
)
271-
return
270+
return metrics_future
272271

273-
self.update_queue.put_nowait(SynchronizationMarker(future))
272+
self.update_queue.put_nowait(SynchronizationMarker(metrics_future))
274273
self.update_queue_size_logger.add(self.update_queue.qsize())
274+
return metrics_future
275275

276276
def _process_synchronization_marker(
277277
self, synchronization_marker: SynchronizationMarker

torchrec/metrics/metric_module.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@
115115

116116

117117
MetricValue = Union[torch.Tensor, float]
118+
MetricsFuture = concurrent.futures.Future[Dict[str, MetricValue]]
118119

119120

120121
class StateMetric(abc.ABC):
@@ -490,9 +491,7 @@ def load_pre_compute_states(
490491
def shutdown(self) -> None:
491492
logger.info("Initiating graceful shutdown...")
492493

493-
def async_compute(
494-
self, future: concurrent.futures.Future[Dict[str, MetricValue]]
495-
) -> None:
494+
def async_compute(self) -> MetricsFuture:
496495
raise RecMetricException("async_compute is not supported in RecMetricModule")
497496

498497

torchrec/metrics/tests/test_cpu_offloaded_metric_module.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -207,10 +207,6 @@ def test_async_compute_synchronization_marker(self) -> None:
207207
208208
Note that the comms module's metrics are actually the ones that are computed.
209209
"""
210-
future: concurrent.futures.Future[Dict[str, MetricValue]] = (
211-
concurrent.futures.Future()
212-
)
213-
214210
model_out = {
215211
"task1-prediction": torch.tensor([0.5]),
216212
"task1-label": torch.tensor([0.7]),
@@ -220,7 +216,7 @@ def test_async_compute_synchronization_marker(self) -> None:
220216
for _ in range(10):
221217
self.cpu_module.update(model_out)
222218

223-
self.cpu_module.async_compute(future)
219+
self.cpu_module.async_compute()
224220

225221
comms_mock_metric = cast(
226222
MockRecMetric, self.cpu_module.comms_module.rec_metrics.rec_metrics[0]
@@ -234,10 +230,7 @@ def test_async_compute_synchronization_marker(self) -> None:
234230
def test_async_compute_after_shutdown(self) -> None:
235231
self.cpu_module.shutdown()
236232

237-
future: concurrent.futures.Future[Dict[str, MetricValue]] = (
238-
concurrent.futures.Future()
239-
)
240-
self.cpu_module.async_compute(future)
233+
future = self.cpu_module.async_compute()
241234

242235
self.assertRaisesRegex(
243236
RecMetricException, "metric processor thread is shut down.", future.result
@@ -275,7 +268,7 @@ def test_wait_until_queue_is_empty(self) -> None:
275268
"task1-weight": torch.tensor([1.0]),
276269
}
277270
self.cpu_module.update(model_out)
278-
self.cpu_module.async_compute(concurrent.futures.Future())
271+
self.cpu_module.async_compute()
279272

280273
self.cpu_module.wait_until_queue_is_empty(self.cpu_module.update_queue)
281274
self.cpu_module.wait_until_queue_is_empty(self.cpu_module.compute_queue)
@@ -576,10 +569,7 @@ def _compare_metric_results_worker(
576569

577570
standard_results = standard_module.compute()
578571

579-
future: concurrent.futures.Future[Dict[str, MetricValue]] = (
580-
concurrent.futures.Future()
581-
)
582-
cpu_offloaded_module.async_compute(future)
572+
future = cpu_offloaded_module.async_compute()
583573

584574
# Wait for async compute to finish. Compare the input to each update()
585575
offloaded_results = future.result(timeout=10.0)

torchrec/metrics/tests/test_metric_module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -662,7 +662,7 @@ def test_async_compute_raises_exception(self) -> None:
662662
RecMetricException,
663663
"async_compute is not supported in RecMetricModule",
664664
):
665-
metric_module.async_compute(concurrent.futures.Future())
665+
metric_module.async_compute()
666666

667667

668668
def metric_module_gather_state(

0 commit comments

Comments
 (0)