Skip to content

Commit 265ffe6

Browse files
committed
Set dataloader_idx to the default value.
1 parent 0a31c2d commit 265ffe6

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

libmultilabel/nn/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,13 +152,14 @@ def _shared_eval_epoch_end(self, step_outputs, split):
152152
self.eval_metric.reset()
153153
return metric_dict
154154

155-
def predict_step(self, batch, batch_idx):
155+
def predict_step(self, batch, batch_idx, dataloader_idx=0):
156156
"""`predict_step` is triggered when calling `trainer.predict()`.
157157
This function is used to get the top-k labels and their prediction scores.
158158
159159
Args:
160160
batch (dict): A batch of text and label.
161161
batch_idx (int): Index of current batch.
162+
dataloader_idx (int): Index of current dataloader.
162163
163164
Returns:
164165
dict: Top k label indexes and the prediction scores.

0 commit comments

Comments
 (0)