|
| 1 | +# encoding: utf-8 |
| 2 | +# ref: https://github.com/CaoWGG/RepVGG/blob/develop/repvgg.py |
| 3 | + |
| 4 | + |
| 5 | +import logging |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +import torch |
| 9 | +import torch.nn as nn |
| 10 | + |
| 11 | +from fastreid.layers import * |
| 12 | +from fastreid.utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message |
| 13 | +from .build import BACKBONE_REGISTRY |
| 14 | + |
| 15 | +logger = logging.getLogger(__name__) |
| 16 | + |
| 17 | + |
| 18 | +def deploy(self, mode=False): |
| 19 | + self.deploying = mode |
| 20 | + for module in self.children(): |
| 21 | + if hasattr(module, 'deploying'): |
| 22 | + module.deploy(mode) |
| 23 | + |
| 24 | + |
| 25 | +nn.Sequential.deploying = False |
| 26 | +nn.Sequential.deploy = deploy |
| 27 | + |
| 28 | + |
| 29 | +def conv_bn(norm_type, in_channels, out_channels, kernel_size, stride, padding, groups=1): |
| 30 | + result = nn.Sequential() |
| 31 | + result.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels, |
| 32 | + kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, |
| 33 | + bias=False)) |
| 34 | + result.add_module('bn', get_norm(norm_type, out_channels)) |
| 35 | + return result |
| 36 | + |
| 37 | + |
| 38 | +class RepVGGBlock(nn.Module): |
| 39 | + |
| 40 | + def __init__(self, in_channels, out_channels, norm_type, kernel_size, |
| 41 | + stride=1, padding=0, groups=1): |
| 42 | + super(RepVGGBlock, self).__init__() |
| 43 | + self.deploying = False |
| 44 | + |
| 45 | + self.groups = groups |
| 46 | + self.in_channels = in_channels |
| 47 | + |
| 48 | + assert kernel_size == 3 |
| 49 | + assert padding == 1 |
| 50 | + |
| 51 | + padding_11 = padding - kernel_size // 2 |
| 52 | + |
| 53 | + self.nonlinearity = nn.ReLU() |
| 54 | + |
| 55 | + self.in_channels = in_channels |
| 56 | + self.in_channels = in_channels |
| 57 | + self.kernel_size = kernel_size |
| 58 | + self.stride = stride |
| 59 | + self.padding = padding |
| 60 | + self.groups = groups |
| 61 | + |
| 62 | + self.register_parameter('fused_weight', None) |
| 63 | + self.register_parameter('fused_bias', None) |
| 64 | + |
| 65 | + self.rbr_identity = get_norm(norm_type, in_channels) if out_channels == in_channels and stride == 1 else None |
| 66 | + self.rbr_dense = conv_bn(norm_type, in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, |
| 67 | + stride=stride, padding=padding, groups=groups) |
| 68 | + self.rbr_1x1 = conv_bn(norm_type, in_channels=in_channels, out_channels=out_channels, kernel_size=1, |
| 69 | + stride=stride, padding=padding_11, groups=groups) |
| 70 | + |
| 71 | + def forward(self, inputs): |
| 72 | + if self.deploying: |
| 73 | + assert self.fused_weight is not None and self.fused_bias is not None, \ |
| 74 | + "Make deploy mode=True to generate fused weight and fused bias first" |
| 75 | + fused_out = self.nonlinearity(torch.nn.functional.conv2d( |
| 76 | + inputs, self.fused_weight, self.fused_bias, self.stride, self.padding, 1, self.groups)) |
| 77 | + return fused_out |
| 78 | + |
| 79 | + if self.rbr_identity is None: |
| 80 | + id_out = 0 |
| 81 | + else: |
| 82 | + id_out = self.rbr_identity(inputs) |
| 83 | + out = self.nonlinearity(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out) |
| 84 | + |
| 85 | + return out |
| 86 | + |
| 87 | + def get_equivalent_kernel_bias(self): |
| 88 | + kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense) |
| 89 | + kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1) |
| 90 | + kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity) |
| 91 | + return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid |
| 92 | + |
| 93 | + def _pad_1x1_to_3x3_tensor(self, kernel1x1): |
| 94 | + if kernel1x1 is None: |
| 95 | + return 0 |
| 96 | + else: |
| 97 | + return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1]) |
| 98 | + |
| 99 | + def _fuse_bn_tensor(self, branch): |
| 100 | + if branch is None: |
| 101 | + return 0, 0 |
| 102 | + if isinstance(branch, nn.Sequential): |
| 103 | + kernel = branch.conv.weight |
| 104 | + running_mean = branch.bn.running_mean |
| 105 | + running_var = branch.bn.running_var |
| 106 | + gamma = branch.bn.weight |
| 107 | + beta = branch.bn.bias |
| 108 | + eps = branch.bn.eps |
| 109 | + else: |
| 110 | + assert branch.__class__.__name__.find('BatchNorm') != -1 |
| 111 | + if not hasattr(self, 'id_tensor'): |
| 112 | + input_dim = self.in_channels // self.groups |
| 113 | + kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32) |
| 114 | + for i in range(self.in_channels): |
| 115 | + kernel_value[i, i % input_dim, 1, 1] = 1 |
| 116 | + self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device) |
| 117 | + kernel = self.id_tensor |
| 118 | + running_mean = branch.running_mean |
| 119 | + running_var = branch.running_var |
| 120 | + gamma = branch.weight |
| 121 | + beta = branch.bias |
| 122 | + eps = branch.eps |
| 123 | + std = (running_var + eps).sqrt() |
| 124 | + t = (gamma / std).reshape(-1, 1, 1, 1) |
| 125 | + return kernel * t, beta - running_mean * gamma / std |
| 126 | + |
| 127 | + def deploy(self, mode=False): |
| 128 | + self.deploying = mode |
| 129 | + if mode: |
| 130 | + fused_weight, fused_bias = self.get_equivalent_kernel_bias() |
| 131 | + self.register_parameter('fused_weight', nn.Parameter(fused_weight)) |
| 132 | + self.register_parameter('fused_bias', nn.Parameter(fused_bias)) |
| 133 | + del self.rbr_identity, self.rbr_1x1, self.rbr_dense |
| 134 | + |
| 135 | + |
| 136 | +class RepVGG(nn.Module): |
| 137 | + |
| 138 | + def __init__(self, last_stride, norm_type, num_blocks, width_multiplier=None, override_groups_map=None): |
| 139 | + super(RepVGG, self).__init__() |
| 140 | + |
| 141 | + assert len(width_multiplier) == 4 |
| 142 | + |
| 143 | + self.deploying = False |
| 144 | + self.override_groups_map = override_groups_map or dict() |
| 145 | + |
| 146 | + assert 0 not in self.override_groups_map |
| 147 | + |
| 148 | + self.in_planes = min(64, int(64 * width_multiplier[0])) |
| 149 | + |
| 150 | + self.stage0 = RepVGGBlock(in_channels=3, out_channels=self.in_planes, norm_type=norm_type, |
| 151 | + kernel_size=3, stride=2, padding=1) |
| 152 | + self.cur_layer_idx = 1 |
| 153 | + self.stage1 = self._make_stage(int(64 * width_multiplier[0]), norm_type, num_blocks[0], stride=2) |
| 154 | + self.stage2 = self._make_stage(int(128 * width_multiplier[1]), norm_type, num_blocks[1], stride=2) |
| 155 | + self.stage3 = self._make_stage(int(256 * width_multiplier[2]), norm_type, num_blocks[2], stride=2) |
| 156 | + self.stage4 = self._make_stage(int(512 * width_multiplier[3]), norm_type, num_blocks[3], stride=last_stride) |
| 157 | + |
| 158 | + def _make_stage(self, planes, norm_type, num_blocks, stride): |
| 159 | + strides = [stride] + [1] * (num_blocks - 1) |
| 160 | + blocks = [] |
| 161 | + for stride in strides: |
| 162 | + cur_groups = self.override_groups_map.get(self.cur_layer_idx, 1) |
| 163 | + blocks.append(RepVGGBlock(in_channels=self.in_planes, out_channels=planes, norm_type=norm_type, |
| 164 | + kernel_size=3, stride=stride, padding=1, groups=cur_groups)) |
| 165 | + self.in_planes = planes |
| 166 | + self.cur_layer_idx += 1 |
| 167 | + return nn.Sequential(*blocks) |
| 168 | + |
| 169 | + def deploy(self, mode=False): |
| 170 | + self.deploying = mode |
| 171 | + for module in self.children(): |
| 172 | + if hasattr(module, 'deploying'): |
| 173 | + module.deploy(mode) |
| 174 | + |
| 175 | + def forward(self, x): |
| 176 | + out = self.stage0(x) |
| 177 | + out = self.stage1(out) |
| 178 | + out = self.stage2(out) |
| 179 | + out = self.stage3(out) |
| 180 | + out = self.stage4(out) |
| 181 | + return out |
| 182 | + |
| 183 | + |
| 184 | +optional_groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26] |
| 185 | +g2_map = {l: 2 for l in optional_groupwise_layers} |
| 186 | +g4_map = {l: 4 for l in optional_groupwise_layers} |
| 187 | + |
| 188 | + |
| 189 | +def create_RepVGG_A0(last_stride, norm_type): |
| 190 | + return RepVGG(last_stride, norm_type, num_blocks=[2, 4, 14, 1], |
| 191 | + width_multiplier=[0.75, 0.75, 0.75, 2.5], override_groups_map=None) |
| 192 | + |
| 193 | + |
| 194 | +def create_RepVGG_A1(last_stride, norm_type): |
| 195 | + return RepVGG(last_stride, norm_type, num_blocks=[2, 4, 14, 1], |
| 196 | + width_multiplier=[1, 1, 1, 2.5], override_groups_map=None) |
| 197 | + |
| 198 | + |
| 199 | +def create_RepVGG_A2(last_stride, norm_type): |
| 200 | + return RepVGG(last_stride, norm_type, num_blocks=[2, 4, 14, 1], |
| 201 | + width_multiplier=[1.5, 1.5, 1.5, 2.75], override_groups_map=None) |
| 202 | + |
| 203 | + |
| 204 | +def create_RepVGG_B0(last_stride, norm_type): |
| 205 | + return RepVGG(last_stride, norm_type, num_blocks=[4, 6, 16, 1], |
| 206 | + width_multiplier=[1, 1, 1, 2.5], override_groups_map=None) |
| 207 | + |
| 208 | + |
| 209 | +def create_RepVGG_B1(last_stride, norm_type): |
| 210 | + return RepVGG(last_stride, norm_type, num_blocks=[4, 6, 16, 1], |
| 211 | + width_multiplier=[2, 2, 2, 4], override_groups_map=None) |
| 212 | + |
| 213 | + |
| 214 | +def create_RepVGG_B1g2(last_stride, norm_type): |
| 215 | + return RepVGG(last_stride, norm_type, num_blocks=[4, 6, 16, 1], |
| 216 | + width_multiplier=[2, 2, 2, 4], override_groups_map=g2_map) |
| 217 | + |
| 218 | + |
| 219 | +def create_RepVGG_B1g4(last_stride, norm_type): |
| 220 | + return RepVGG(last_stride, norm_type, num_blocks=[4, 6, 16, 1], |
| 221 | + width_multiplier=[2, 2, 2, 4], override_groups_map=g4_map) |
| 222 | + |
| 223 | + |
| 224 | +def create_RepVGG_B2(last_stride, norm_type): |
| 225 | + return RepVGG(last_stride, norm_type, num_blocks=[4, 6, 16, 1], |
| 226 | + width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None) |
| 227 | + |
| 228 | + |
| 229 | +def create_RepVGG_B2g2(last_stride, norm_type): |
| 230 | + return RepVGG(last_stride, norm_type, num_blocks=[4, 6, 16, 1], |
| 231 | + width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g2_map) |
| 232 | + |
| 233 | + |
| 234 | +def create_RepVGG_B2g4(last_stride, norm_type): |
| 235 | + return RepVGG(last_stride, norm_type, num_blocks=[4, 6, 16, 1], |
| 236 | + width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g4_map) |
| 237 | + |
| 238 | + |
| 239 | +def create_RepVGG_B3(last_stride, norm_type): |
| 240 | + return RepVGG(last_stride, norm_type, num_blocks=[4, 6, 16, 1], |
| 241 | + width_multiplier=[3, 3, 3, 5], override_groups_map=None) |
| 242 | + |
| 243 | + |
| 244 | +def create_RepVGG_B3g2(last_stride, norm_type): |
| 245 | + return RepVGG(last_stride, norm_type, num_blocks=[4, 6, 16, 1], |
| 246 | + width_multiplier=[3, 3, 3, 5], override_groups_map=g2_map) |
| 247 | + |
| 248 | + |
| 249 | +def create_RepVGG_B3g4(last_stride, norm_type): |
| 250 | + return RepVGG(last_stride, norm_type, num_blocks=[4, 6, 16, 1], |
| 251 | + width_multiplier=[3, 3, 3, 5], override_groups_map=g4_map) |
| 252 | + |
| 253 | + |
| 254 | +@BACKBONE_REGISTRY.register() |
| 255 | +def build_repvgg_backbone(cfg): |
| 256 | + """ |
| 257 | + Create a RepVGG instance from config. |
| 258 | + Returns: |
| 259 | + RepVGG: a :class: `RepVGG` instance. |
| 260 | + """ |
| 261 | + |
| 262 | + # fmt: off |
| 263 | + pretrain = cfg.MODEL.BACKBONE.PRETRAIN |
| 264 | + pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH |
| 265 | + last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE |
| 266 | + bn_norm = cfg.MODEL.BACKBONE.NORM |
| 267 | + depth = cfg.MODEL.BACKBONE.DEPTH |
| 268 | + # fmt: on |
| 269 | + |
| 270 | + func_dict = { |
| 271 | + 'A0': create_RepVGG_A0, |
| 272 | + 'A1': create_RepVGG_A1, |
| 273 | + 'A2': create_RepVGG_A2, |
| 274 | + 'B0': create_RepVGG_B0, |
| 275 | + 'B1': create_RepVGG_B1, |
| 276 | + 'B1g2': create_RepVGG_B1g2, |
| 277 | + 'B1g4': create_RepVGG_B1g4, |
| 278 | + 'B2': create_RepVGG_B2, |
| 279 | + 'B2g2': create_RepVGG_B2g2, |
| 280 | + 'B2g4': create_RepVGG_B2g4, |
| 281 | + 'B3': create_RepVGG_B3, |
| 282 | + 'B3g2': create_RepVGG_B3g2, |
| 283 | + 'B3g4': create_RepVGG_B3g4, |
| 284 | + } |
| 285 | + |
| 286 | + model = func_dict[depth](last_stride, bn_norm) |
| 287 | + |
| 288 | + if pretrain: |
| 289 | + try: |
| 290 | + state_dict = torch.load(pretrain_path, map_location=torch.device("cpu")) |
| 291 | + logger.info(f"Loading pretrained model from {pretrain_path}") |
| 292 | + except FileNotFoundError as e: |
| 293 | + logger.info(f'{pretrain_path} is not found! Please check this path.') |
| 294 | + raise e |
| 295 | + except KeyError as e: |
| 296 | + logger.info("State dict keys error! Please check the state dict.") |
| 297 | + raise e |
| 298 | + |
| 299 | + incompatible = model.load_state_dict(state_dict, strict=False) |
| 300 | + if incompatible.missing_keys: |
| 301 | + logger.info( |
| 302 | + get_missing_parameters_message(incompatible.missing_keys) |
| 303 | + ) |
| 304 | + if incompatible.unexpected_keys: |
| 305 | + logger.info( |
| 306 | + get_unexpected_parameters_message(incompatible.unexpected_keys) |
| 307 | + ) |
| 308 | + |
| 309 | + return model |
0 commit comments