diff --git a/ignite/contrib/metrics/precision_recall_curve.py b/ignite/contrib/metrics/precision_recall_curve.py index 5021315904b3..8d4aebe245ef 100644 --- a/ignite/contrib/metrics/precision_recall_curve.py +++ b/ignite/contrib/metrics/precision_recall_curve.py @@ -92,7 +92,7 @@ def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # type: i _prediction_tensor = torch.cat(self._predictions, dim=0) _target_tensor = torch.cat(self._targets, dim=0) - ws = idist.get_world_size() + ws = idist.get_metrics_computation_world_size() if ws > 1: # All gather across all processes _prediction_tensor = cast(torch.Tensor, idist.all_gather(_prediction_tensor)) diff --git a/ignite/distributed/utils.py b/ignite/distributed/utils.py index 88ddca2287c8..e2c2113f60e5 100644 --- a/ignite/distributed/utils.py +++ b/ignite/distributed/utils.py @@ -1,3 +1,4 @@ +import os import socket from functools import wraps from typing import Any, Callable, List, Mapping, Optional, Tuple, Union @@ -20,6 +21,7 @@ "available_backends", "model_name", "get_world_size", + "get_metrics_computation_world_size", "get_rank", "get_local_rank", "get_nproc_per_node", @@ -141,6 +143,14 @@ def get_world_size() -> int: return _model.get_world_size() +def get_metrics_computation_world_size() -> int: + """Returns world size of current distributed configuration for metrics computation. Returns 1 if no distributed configuration.""" + if os.environ.get("IGNITE_DISABLE_DISTRIBUTED_METRICS") == "1": + return 1 + + return get_world_size() + + def get_rank() -> int: """Returns process rank within current distributed configuration. Returns 0 if no distributed configuration.""" if _need_to_sync and isinstance(_model, _SerialModel): diff --git a/ignite/metrics/epoch_metric.py b/ignite/metrics/epoch_metric.py index 21b199bfd542..cb26ec2f65c5 100644 --- a/ignite/metrics/epoch_metric.py +++ b/ignite/metrics/epoch_metric.py @@ -144,7 +144,7 @@ def compute(self) -> float: _prediction_tensor = torch.cat(self._predictions, dim=0) _target_tensor = torch.cat(self._targets, dim=0) - ws = idist.get_world_size() + ws = idist.get_metrics_computation_world_size() if ws > 1: # All gather across all processes _prediction_tensor = cast(torch.Tensor, idist.all_gather(_prediction_tensor)) diff --git a/ignite/metrics/frequency.py b/ignite/metrics/frequency.py index 8c63edd1ec97..930c770a946d 100644 --- a/ignite/metrics/frequency.py +++ b/ignite/metrics/frequency.py @@ -61,8 +61,8 @@ def update(self, output: int) -> None: def compute(self) -> float: time_divisor = 1.0 - if idist.get_world_size() > 1: - time_divisor *= idist.get_world_size() + if idist.get_metrics_computation_world_size() > 1: + time_divisor *= idist.get_metrics_computation_world_size() # Returns the average processed objects per second across all workers return self._n / self._elapsed * time_divisor diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index 26cb3c12560d..effc616f1768 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -554,7 +554,7 @@ def another_wrapper(self: Metric, *args: Any, **kwargs: Any) -> Callable: raise RuntimeError( "Decorator sync_all_reduce should be used on ignite.metric.Metric class methods only" ) - ws = idist.get_world_size() + ws = idist.get_metrics_computation_world_size() unreduced_attrs = {} if len(attrs) > 0 and ws > 1: for attr in attrs: diff --git a/ignite/metrics/running_average.py b/ignite/metrics/running_average.py index 468838a9908c..50b8f7caa9a1 100644 --- a/ignite/metrics/running_average.py +++ b/ignite/metrics/running_average.py @@ -154,7 +154,7 @@ def _get_metric_value(self) -> Union[torch.Tensor, float]: @sync_all_reduce("src") def _get_output_value(self) -> Union[torch.Tensor, float]: # we need to compute average instead of sum produced by @sync_all_reduce("src") - output = cast(Union[torch.Tensor, float], self.src) / idist.get_world_size() + output = cast(Union[torch.Tensor, float], self.src) / idist.get_metrics_computation_world_size() return output def _metric_iteration_completed(self, engine: Engine) -> None: