Skip to content

Commit 0ab60c3

Browse files
authored
Merge pull request #224 from ASUS-AICS/predict
Set data_loader_idx to 0 in predict_step()
2 parents ede4e49 + 265ffe6 commit 0ab60c3

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

libmultilabel/nn/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def _shared_eval_epoch_end(self, step_outputs, split):
152152
self.eval_metric.reset()
153153
return metric_dict
154154

155-
def predict_step(self, batch, batch_idx, dataloader_idx):
155+
def predict_step(self, batch, batch_idx, dataloader_idx=0):
156156
"""`predict_step` is triggered when calling `trainer.predict()`.
157157
This function is used to get the top-k labels and their prediction scores.
158158

0 commit comments

Comments
 (0)