diff --git a/clmr/modules/linear_evaluation.py b/clmr/modules/linear_evaluation.py index 4f22c09..ea6df76 100644 --- a/clmr/modules/linear_evaluation.py +++ b/clmr/modules/linear_evaluation.py @@ -28,8 +28,15 @@ def __init__(self, args, encoder: nn.Module, hidden_dim: int, output_dim: int): self.model = nn.Sequential(nn.Linear(self.hidden_dim, self.output_dim)) self.criterion = self.configure_criterion() - self.accuracy = torchmetrics.Accuracy() - self.average_precision = torchmetrics.AveragePrecision(pos_label=1) + self.accuracy = torchmetrics.Accuracy( + task="multilabel", + num_labels=output_dim + ) + self.average_precision = torchmetrics.AveragePrecision( + task='multilabel', + num_labels=output_dim, + pos_label=1 + ) def forward(self, x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: preds = self._forward_representations(x, y)