forked from TencentARC/GFPGAN
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathconvert_gfpganv_to_clean.py
164 lines (155 loc) · 6.8 KB
/
convert_gfpganv_to_clean.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import argparse
import math
import torch
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
def modify_checkpoint(checkpoint_bilinear, checkpoint_clean):
for ori_k, ori_v in checkpoint_bilinear.items():
if 'stylegan_decoder' in ori_k:
if 'style_mlp' in ori_k: # style_mlp_layers
lr_mul = 0.01
prefix, name, idx, var = ori_k.split('.')
idx = (int(idx) * 2) - 1
crt_k = f'{prefix}.{name}.{idx}.{var}'
if var == 'weight':
_, c_in = ori_v.size()
scale = (1 / math.sqrt(c_in)) * lr_mul
crt_v = ori_v * scale * 2**0.5
else:
crt_v = ori_v * lr_mul * 2**0.5
checkpoint_clean[crt_k] = crt_v
elif 'modulation' in ori_k: # modulation in StyleConv
lr_mul = 1
crt_k = ori_k
var = ori_k.split('.')[-1]
if var == 'weight':
_, c_in = ori_v.size()
scale = (1 / math.sqrt(c_in)) * lr_mul
crt_v = ori_v * scale
else:
crt_v = ori_v * lr_mul
checkpoint_clean[crt_k] = crt_v
elif 'style_conv' in ori_k:
# StyleConv in style_conv1 and style_convs
if 'activate' in ori_k: # FusedLeakyReLU
# eg. style_conv1.activate.bias
# eg. style_convs.13.activate.bias
split_rlt = ori_k.split('.')
if len(split_rlt) == 4:
prefix, name, _, var = split_rlt
crt_k = f'{prefix}.{name}.{var}'
elif len(split_rlt) == 5:
prefix, name, idx, _, var = split_rlt
crt_k = f'{prefix}.{name}.{idx}.{var}'
crt_v = ori_v * 2**0.5 # 2**0.5 used in FusedLeakyReLU
c = crt_v.size(0)
checkpoint_clean[crt_k] = crt_v.view(1, c, 1, 1)
elif 'modulated_conv' in ori_k:
# eg. style_conv1.modulated_conv.weight
# eg. style_convs.13.modulated_conv.weight
_, c_out, c_in, k1, k2 = ori_v.size()
scale = 1 / math.sqrt(c_in * k1 * k2)
crt_k = ori_k
checkpoint_clean[crt_k] = ori_v * scale
elif 'weight' in ori_k:
crt_k = ori_k
checkpoint_clean[crt_k] = ori_v * 2**0.5
elif 'to_rgb' in ori_k: # StyleConv in to_rgb1 and to_rgbs
if 'modulated_conv' in ori_k:
# eg. to_rgb1.modulated_conv.weight
# eg. to_rgbs.5.modulated_conv.weight
_, c_out, c_in, k1, k2 = ori_v.size()
scale = 1 / math.sqrt(c_in * k1 * k2)
crt_k = ori_k
checkpoint_clean[crt_k] = ori_v * scale
else:
crt_k = ori_k
checkpoint_clean[crt_k] = ori_v
else:
crt_k = ori_k
checkpoint_clean[crt_k] = ori_v
# end of 'stylegan_decoder'
elif 'conv_body_first' in ori_k or 'final_conv' in ori_k:
# key name
name, _, var = ori_k.split('.')
crt_k = f'{name}.{var}'
# weight and bias
if var == 'weight':
c_out, c_in, k1, k2 = ori_v.size()
scale = 1 / math.sqrt(c_in * k1 * k2)
checkpoint_clean[crt_k] = ori_v * scale * 2**0.5
else:
checkpoint_clean[crt_k] = ori_v * 2**0.5
elif 'conv_body' in ori_k:
if 'conv_body_up' in ori_k:
ori_k = ori_k.replace('conv2.weight', 'conv2.1.weight')
ori_k = ori_k.replace('skip.weight', 'skip.1.weight')
name1, idx1, name2, _, var = ori_k.split('.')
crt_k = f'{name1}.{idx1}.{name2}.{var}'
if name2 == 'skip':
c_out, c_in, k1, k2 = ori_v.size()
scale = 1 / math.sqrt(c_in * k1 * k2)
checkpoint_clean[crt_k] = ori_v * scale / 2**0.5
else:
if var == 'weight':
c_out, c_in, k1, k2 = ori_v.size()
scale = 1 / math.sqrt(c_in * k1 * k2)
checkpoint_clean[crt_k] = ori_v * scale
else:
checkpoint_clean[crt_k] = ori_v
if 'conv1' in ori_k:
checkpoint_clean[crt_k] *= 2**0.5
elif 'toRGB' in ori_k:
crt_k = ori_k
if 'weight' in ori_k:
c_out, c_in, k1, k2 = ori_v.size()
scale = 1 / math.sqrt(c_in * k1 * k2)
checkpoint_clean[crt_k] = ori_v * scale
else:
checkpoint_clean[crt_k] = ori_v
elif 'final_linear' in ori_k:
crt_k = ori_k
if 'weight' in ori_k:
_, c_in = ori_v.size()
scale = 1 / math.sqrt(c_in)
checkpoint_clean[crt_k] = ori_v * scale
else:
checkpoint_clean[crt_k] = ori_v
elif 'condition' in ori_k:
crt_k = ori_k
if '0.weight' in ori_k:
c_out, c_in, k1, k2 = ori_v.size()
scale = 1 / math.sqrt(c_in * k1 * k2)
checkpoint_clean[crt_k] = ori_v * scale * 2**0.5
elif '0.bias' in ori_k:
checkpoint_clean[crt_k] = ori_v * 2**0.5
elif '2.weight' in ori_k:
c_out, c_in, k1, k2 = ori_v.size()
scale = 1 / math.sqrt(c_in * k1 * k2)
checkpoint_clean[crt_k] = ori_v * scale
elif '2.bias' in ori_k:
checkpoint_clean[crt_k] = ori_v
return checkpoint_clean
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--ori_path', type=str, help='Path to the original model')
parser.add_argument('--narrow', type=float, default=1)
parser.add_argument('--channel_multiplier', type=float, default=2)
parser.add_argument('--save_path', type=str)
args = parser.parse_args()
ori_ckpt = torch.load(args.ori_path)['params_ema']
net = GFPGANv1Clean(
512,
num_style_feat=512,
channel_multiplier=args.channel_multiplier,
decoder_load_path=None,
fix_decoder=False,
# for stylegan decoder
num_mlp=8,
input_is_latent=True,
different_w=True,
narrow=args.narrow,
sft_half=True)
crt_ckpt = net.state_dict()
crt_ckpt = modify_checkpoint(ori_ckpt, crt_ckpt)
print(f'Save to {args.save_path}.')
torch.save(dict(params_ema=crt_ckpt), args.save_path, _use_new_zipfile_serialization=False)