Skip to content

Commit 9e34ec0

Browse files
authored
Update dataloader_ge.py
1 parent 759b129 commit 9e34ec0

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

dataloader_ge.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -433,8 +433,8 @@ def compute_loss(self, model, inputs, return_outputs=False):
433433

434434
# objective gradient
435435
outputs = model(input_ids,labels=labels, attention_mask=attention_mask, output_hidden_states=True)
436-
# emb_dif = ((outputs.hidden_states[emb_idx][..., :-1, :] - emb_tar) ** 2).mean()
437-
emb_dif = ((outputs.hidden_states[emb_idx][..., :-1, :] - emb_tar).abs()).mean(-1)
436+
emb_dif = ((outputs.hidden_states[emb_idx][..., :-1, :] - emb_tar) ** 2).mean()
437+
# emb_dif = ((outputs.hidden_states[emb_idx][..., :-1, :] - emb_tar).abs()).mean(-1)
438438
forget_loss = emb_dif[labels[..., 1:] != -100].mean() + 0 * outputs.loss
439439
objective_gradient_fcn(model, forget_loss)
440440

0 commit comments

Comments
 (0)