diff --git a/clip/model.py b/clip/model.py index 232b7792e..b70c1e921 100644 --- a/clip/model.py +++ b/clip/model.py @@ -159,7 +159,8 @@ class LayerNorm(nn.LayerNorm): def forward(self, x: torch.Tensor): orig_type = x.dtype - ret = super().forward(x.type(torch.float32)) + ret = F.layer_norm( + x.type(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps) return ret.type(orig_type)