diff --git a/convert_torch.py b/convert_torch.py index b639f73..894bffa 100644 --- a/convert_torch.py +++ b/convert_torch.py @@ -41,7 +41,9 @@ def forward(self, input): def copy_param(m,n): - if m.weight is not None: n.weight.data.copy_(m.weight) + if m.weight is not None: + m.weight.data = m.weight.view(n.weight.size()) + n.weight.data.copy_(m.weight) if m.bias is not None: n.bias.data.copy_(m.bias) if hasattr(n,'running_mean'): n.running_mean.copy_(m.running_mean) if hasattr(n,'running_var'): n.running_var.copy_(m.running_var)