diff --git a/functorch/csrc/BatchRulesLoss.cpp b/functorch/csrc/BatchRulesLoss.cpp index 043b5bb91..3a0d18756 100644 --- a/functorch/csrc/BatchRulesLoss.cpp +++ b/functorch/csrc/BatchRulesLoss.cpp @@ -203,16 +203,22 @@ std::tuple nll_loss_forward_decomposition( // target can be [N, 1, ...] or [1] auto result = -at::gather(self_, channel_dim, target_).squeeze(channel_dim); - auto total_weight = at::full( - {}, result.numel(), self_.scalar_type(), - self_.layout(), self_.device(), nullopt); bool has_ignore_index = ignore_index >= 0; - Tensor ignore_index_mask; + Tensor ignore_index_mask, total_weight; if (has_ignore_index) { ignore_index_mask = target != ignore_index; result = result * ignore_index_mask; - total_weight = ignore_index_mask.sum().to(self_); + if (!(reduction == Reduction::None && self.dim() >= 2)) { + total_weight = ignore_index_mask.sum().to(self_); + } + } + + if (!total_weight.defined()) { + auto init_value = (reduction == Reduction::None && self.dim() >= 2) ? 0.0 : 1.0 * result.numel(); + total_weight = at::full( + {}, init_value, self_.scalar_type(), + self_.layout(), self_.device(), nullopt); } // Apply the reduction