The ShiftCrossEntropy currently utilizes nn.CrossEntropyLoss as its backend, which expects the input to be unnormalized logits. It appears that ShiftCrossEntropy passes input probabilities and target probabilities to the backend instead. This might lead to a deviation from the expected behavior described in equation (7) of the paper.
|
return self.criterion(x1, shift_x2) |
The
ShiftCrossEntropycurrently utilizesnn.CrossEntropyLossas its backend, which expects the input to be unnormalized logits. It appears thatShiftCrossEntropypasses input probabilities and target probabilities to the backend instead. This might lead to a deviation from the expected behavior described in equation (7) of the paper.pesto-full/src/losses/entropy.py
Line 49 in 229f78b