diff --git a/clip/model.py b/clip/model.py index 232b7792e..e388d402d 100644 --- a/clip/model.py +++ b/clip/model.py @@ -154,15 +154,16 @@ def stem(x): return x -class LayerNorm(nn.LayerNorm): - """Subclass torch's LayerNorm to handle fp16.""" +class LayerNorm(nn.Module): + def __init__(self, *args, **kwargs): + super(LayerNorm, self).__init__() + self.inner_layernorm = nn.LayerNorm(*args, **kwargs) def forward(self, x: torch.Tensor): orig_type = x.dtype - ret = super().forward(x.type(torch.float32)) + ret = self.inner_layernorm(x.type(torch.float32)) return ret.type(orig_type) - class QuickGELU(nn.Module): def forward(self, x: torch.Tensor): return x * torch.sigmoid(1.702 * x)