Skip to content

A bug happened when I use face parser to calculate loss #30

@Sakura-ldx

Description

@Sakura-ldx

I created a loss which uses the detector and parser to get the face segment. The loss function is used as a supervisory signal for the loop optimization generation process. The loss is this:

class SegLoss(nn.Module):
    def __init__(self, device):
        super(SegLoss, self).__init__()
        self.face_detector = facer.face_detector('retinaface/mobilenet', device=device)
        self.face_parser = facer.face_parser('farl/lapa/448', device=device)

    def forward(self, x: torch.Tensor, segments: torch.Tensor):
        # image = np.zeros(x.shape[1:]).astype(np.uint8)
        # image = np.ascontiguousarray(np.transpose(image, (1, 2, 0)))
        # save_path = os.path.join("/home/ssd2/ldx/workplace/GANInverter-dev/test_edit/e4e/edit1/kp183072", f"{i}.png")

        x = x.clone()
        x = (x + 1) / 2
        x = x.clamp(0., 1.)
        x = (x * 255).type(torch.uint8)

        faces = self.face_detector(x)
        if not faces:
            return 0
        faces = self.face_parser(x, faces)
        seg_probs = faces['seg']['logits'].softmax(dim=1)[0]

        loss = F.mse_loss(seg_probs, segments)
        return loss

But when the loop reaches the third step, the debug happened:

  File "/home/ssd2/priv/miniconda3/envs/inversion/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/liudongxv/workplace/GANInverter-dev/criteria/seg_loss.py", line 34, in forward
    faces = self.face_parser(x, faces)
  File "/home/ssd2/priv/miniconda3/envs/inversion/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ssd2/priv/miniconda3/envs/inversion/lib/python3.9/site-packages/facer/face_parsing/farl.py", line 85, in forward
    w_seg_logits, _ = self.net(w_images)  # (b*n) x c x h x w
  File "/home/ssd2/priv/miniconda3/envs/inversion/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: vector::_M_range_check: __n (which is 18446744073709551615) >= this->size() (which is 3)

I tried to step into the code, but it seems about the jit so I can't know what caused the index so big. Do you have any infomation?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions