diff --git a/affinity_loss.py b/affinity_loss.py index 180f74e..b08b23b 100644 --- a/affinity_loss.py +++ b/affinity_loss.py @@ -78,8 +78,8 @@ def forward(self, logits, labels): lbedge = labels[:, idx_e[0]:idx_e[1], idx_e[2]:idx_e[3]].detach() igncenter = ignore_mask[:, idx_c[0]:idx_c[1], idx_c[2]:idx_c[3]].detach() ignedge = ignore_mask[:, idx_e[0]:idx_e[1], idx_e[2]:idx_e[3]].detach() - lgp_center = probs[:, :, idx_c[0]:idx_c[1], idx_c[2]:idx_c[3]] - lgp_edge = probs[:, :, idx_e[0]:idx_e[1], idx_e[2]:idx_e[3]] + lgp_center = log_probs[:, :, idx_c[0]:idx_c[1], idx_c[2]:idx_c[3]] + lgp_edge = log_probs[:, :, idx_e[0]:idx_e[1], idx_e[2]:idx_e[3]] prob_edge = probs[:, :, idx_e[0]:idx_e[1], idx_e[2]:idx_e[3]] kldiv = (prob_edge * (lgp_edge - lgp_center)).sum(dim=1)