-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
11 changed files
with
821 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .rest import * | ||
from .autoformer2 import * | ||
from .autoformer2 import * | ||
from .pyramid_vig import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# 2022.06.17-Changed for building ViG model | ||
# Huawei Technologies Co., Ltd. <[email protected]> | ||
from .torch_nn import * | ||
from .torch_edge import * | ||
from .torch_vertex import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# 2022.06.17-Changed for building ViG model | ||
# Huawei Technologies Co., Ltd. <[email protected]> | ||
# modified from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
|
||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# -------------------------------------------------------- | ||
# Position embedding utils | ||
# -------------------------------------------------------- | ||
|
||
import numpy as np | ||
|
||
|
||
# -------------------------------------------------------- | ||
# relative position embedding | ||
# References: https://arxiv.org/abs/2009.13658 | ||
# -------------------------------------------------------- | ||
def get_2d_relative_pos_embed(embed_dim, grid_size): | ||
""" | ||
grid_size: int of the grid height and width | ||
return: | ||
pos_embed: [grid_size*grid_size, grid_size*grid_size] | ||
""" | ||
pos_embed = get_2d_sincos_pos_embed(embed_dim, grid_size) | ||
relative_pos = 2 * np.matmul(pos_embed, pos_embed.transpose()) / pos_embed.shape[1] | ||
return relative_pos | ||
|
||
|
||
# -------------------------------------------------------- | ||
# 2D sine-cosine position embedding | ||
# References: | ||
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py | ||
# MoCo v3: https://github.com/facebookresearch/moco-v3 | ||
# -------------------------------------------------------- | ||
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): | ||
""" | ||
grid_size: int of the grid height and width | ||
return: | ||
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) | ||
""" | ||
grid_h = np.arange(grid_size, dtype=np.float32) | ||
grid_w = np.arange(grid_size, dtype=np.float32) | ||
grid = np.meshgrid(grid_w, grid_h) # here w goes first | ||
grid = np.stack(grid, axis=0) | ||
|
||
grid = grid.reshape([2, 1, grid_size, grid_size]) | ||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) | ||
if cls_token: | ||
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) | ||
return pos_embed | ||
|
||
|
||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): | ||
assert embed_dim % 2 == 0 | ||
|
||
# use half of dimensions to encode grid_h | ||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) | ||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) | ||
|
||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) | ||
return emb | ||
|
||
|
||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): | ||
""" | ||
embed_dim: output dimension for each position | ||
pos: a list of positions to be encoded: size (M,) | ||
out: (M, D) | ||
""" | ||
assert embed_dim % 2 == 0 | ||
omega = np.arange(embed_dim // 2, dtype=np.float) | ||
omega /= embed_dim / 2. | ||
omega = 1. / 10000**omega # (D/2,) | ||
|
||
pos = pos.reshape(-1) # (M,) | ||
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product | ||
|
||
emb_sin = np.sin(out) # (M, D/2) | ||
emb_cos = np.cos(out) # (M, D/2) | ||
|
||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) | ||
return emb |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
# 2022.06.17-Changed for building ViG model | ||
# Huawei Technologies Co., Ltd. <[email protected]> | ||
import math | ||
import torch | ||
from torch import nn | ||
import torch.nn.functional as F | ||
|
||
|
||
def pairwise_distance(x): | ||
""" | ||
Compute pairwise distance of a point cloud. | ||
Args: | ||
x: tensor (batch_size, num_points, num_dims) | ||
Returns: | ||
pairwise distance: (batch_size, num_points, num_points) | ||
""" | ||
with torch.no_grad(): | ||
x_inner = -2*torch.matmul(x, x.transpose(2, 1)) | ||
x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True) | ||
return x_square + x_inner + x_square.transpose(2, 1) | ||
|
||
|
||
def part_pairwise_distance(x, start_idx=0, end_idx=1): | ||
""" | ||
Compute pairwise distance of a point cloud. | ||
Args: | ||
x: tensor (batch_size, num_points, num_dims) | ||
Returns: | ||
pairwise distance: (batch_size, num_points, num_points) | ||
""" | ||
with torch.no_grad(): | ||
x_part = x[:, start_idx:end_idx] | ||
x_square_part = torch.sum(torch.mul(x_part, x_part), dim=-1, keepdim=True) | ||
x_inner = -2*torch.matmul(x_part, x.transpose(2, 1)) | ||
x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True) | ||
return x_square_part + x_inner + x_square.transpose(2, 1) | ||
|
||
|
||
def xy_pairwise_distance(x, y): | ||
""" | ||
Compute pairwise distance of a point cloud. | ||
Args: | ||
x: tensor (batch_size, num_points, num_dims) | ||
Returns: | ||
pairwise distance: (batch_size, num_points, num_points) | ||
""" | ||
with torch.no_grad(): | ||
xy_inner = -2*torch.matmul(x, y.transpose(2, 1)) | ||
x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True) | ||
y_square = torch.sum(torch.mul(y, y), dim=-1, keepdim=True) | ||
return x_square + xy_inner + y_square.transpose(2, 1) | ||
|
||
|
||
def dense_knn_matrix(x, k=16, relative_pos=None): | ||
"""Get KNN based on the pairwise distance. | ||
Args: | ||
x: (batch_size, num_dims, num_points, 1) | ||
k: int | ||
Returns: | ||
nearest neighbors: (batch_size, num_points, k) (batch_size, num_points, k) | ||
""" | ||
with torch.no_grad(): | ||
x = x.transpose(2, 1).squeeze(-1) | ||
batch_size, n_points, n_dims = x.shape | ||
### memory efficient implementation ### | ||
n_part = 10000 | ||
if n_points > n_part: | ||
nn_idx_list = [] | ||
groups = math.ceil(n_points / n_part) | ||
for i in range(groups): | ||
start_idx = n_part * i | ||
end_idx = min(n_points, n_part * (i + 1)) | ||
dist = part_pairwise_distance(x.detach(), start_idx, end_idx) | ||
if relative_pos is not None: | ||
dist += relative_pos[:, start_idx:end_idx] | ||
_, nn_idx_part = torch.topk(-dist, k=k) | ||
nn_idx_list += [nn_idx_part] | ||
nn_idx = torch.cat(nn_idx_list, dim=1) | ||
else: | ||
dist = pairwise_distance(x.detach()) | ||
if relative_pos is not None: | ||
dist += relative_pos | ||
_, nn_idx = torch.topk(-dist, k=k) # b, n, k | ||
###### | ||
center_idx = torch.arange(0, n_points, device=x.device).repeat(batch_size, k, 1).transpose(2, 1) | ||
return torch.stack((nn_idx, center_idx), dim=0) | ||
|
||
|
||
def xy_dense_knn_matrix(x, y, k=16, relative_pos=None): | ||
"""Get KNN based on the pairwise distance. | ||
Args: | ||
x: (batch_size, num_dims, num_points, 1) | ||
k: int | ||
Returns: | ||
nearest neighbors: (batch_size, num_points, k) (batch_size, num_points, k) | ||
""" | ||
with torch.no_grad(): | ||
x = x.transpose(2, 1).squeeze(-1) | ||
y = y.transpose(2, 1).squeeze(-1) | ||
batch_size, n_points, n_dims = x.shape | ||
dist = xy_pairwise_distance(x.detach(), y.detach()) | ||
if relative_pos is not None: | ||
dist += relative_pos | ||
_, nn_idx = torch.topk(-dist, k=k) | ||
center_idx = torch.arange(0, n_points, device=x.device).repeat(batch_size, k, 1).transpose(2, 1) | ||
return torch.stack((nn_idx, center_idx), dim=0) | ||
|
||
|
||
class DenseDilated(nn.Module): | ||
""" | ||
Find dilated neighbor from neighbor list | ||
edge_index: (2, batch_size, num_points, k) | ||
""" | ||
def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0): | ||
super(DenseDilated, self).__init__() | ||
self.dilation = dilation | ||
self.stochastic = stochastic | ||
self.epsilon = epsilon | ||
self.k = k | ||
|
||
def forward(self, edge_index): | ||
if self.stochastic: | ||
if torch.rand(1) < self.epsilon and self.training: | ||
num = self.k * self.dilation | ||
randnum = torch.randperm(num)[:self.k] | ||
edge_index = edge_index[:, :, :, randnum] | ||
else: | ||
edge_index = edge_index[:, :, :, ::self.dilation] | ||
else: | ||
edge_index = edge_index[:, :, :, ::self.dilation] | ||
return edge_index | ||
|
||
|
||
class DenseDilatedKnnGraph(nn.Module): | ||
""" | ||
Find the neighbors' indices based on dilated knn | ||
""" | ||
def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0): | ||
super(DenseDilatedKnnGraph, self).__init__() | ||
self.dilation = dilation | ||
self.stochastic = stochastic | ||
self.epsilon = epsilon | ||
self.k = k | ||
self._dilated = DenseDilated(k, dilation, stochastic, epsilon) | ||
|
||
def forward(self, x, y=None, relative_pos=None): | ||
if y is not None: | ||
#### normalize | ||
x = F.normalize(x, p=2.0, dim=1) | ||
y = F.normalize(y, p=2.0, dim=1) | ||
#### | ||
edge_index = xy_dense_knn_matrix(x, y, self.k * self.dilation, relative_pos) | ||
else: | ||
#### normalize | ||
x = F.normalize(x, p=2.0, dim=1) | ||
#### | ||
edge_index = dense_knn_matrix(x, self.k * self.dilation, relative_pos) | ||
return self._dilated(edge_index) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# 2022.06.17-Changed for building ViG model | ||
# Huawei Technologies Co., Ltd. <[email protected]> | ||
import torch | ||
from torch import nn | ||
from torch.nn import Sequential as Seq, Linear as Lin, Conv2d | ||
|
||
|
||
############################## | ||
# Basic layers | ||
############################## | ||
def act_layer(act, inplace=False, neg_slope=0.2, n_prelu=1): | ||
# activation layer | ||
|
||
act = act.lower() | ||
if act == 'relu': | ||
layer = nn.ReLU(inplace) | ||
elif act == 'leakyrelu': | ||
layer = nn.LeakyReLU(neg_slope, inplace) | ||
elif act == 'prelu': | ||
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) | ||
elif act == 'gelu': | ||
layer = nn.GELU() | ||
elif act == 'hswish': | ||
layer = nn.Hardswish(inplace) | ||
else: | ||
raise NotImplementedError('activation layer [%s] is not found' % act) | ||
return layer | ||
|
||
|
||
def norm_layer(norm, nc): | ||
# normalization layer 2d | ||
norm = norm.lower() | ||
if norm == 'batch': | ||
layer = nn.BatchNorm2d(nc, affine=True) | ||
elif norm == 'instance': | ||
layer = nn.InstanceNorm2d(nc, affine=False) | ||
else: | ||
raise NotImplementedError('normalization layer [%s] is not found' % norm) | ||
return layer | ||
|
||
|
||
class MLP(Seq): | ||
def __init__(self, channels, act='relu', norm=None, bias=True): | ||
m = [] | ||
for i in range(1, len(channels)): | ||
m.append(Lin(channels[i - 1], channels[i], bias)) | ||
if act is not None and act.lower() != 'none': | ||
m.append(act_layer(act)) | ||
if norm is not None and norm.lower() != 'none': | ||
m.append(norm_layer(norm, channels[-1])) | ||
super(MLP, self).__init__(*m) | ||
|
||
|
||
class BasicConv(Seq): | ||
def __init__(self, channels, act='relu', norm=None, bias=True, drop=0.): | ||
m = [] | ||
for i in range(1, len(channels)): | ||
m.append(Conv2d(channels[i - 1], channels[i], 1, bias=bias, groups=4)) | ||
if norm is not None and norm.lower() != 'none': | ||
m.append(norm_layer(norm, channels[-1])) | ||
if act is not None and act.lower() != 'none': | ||
m.append(act_layer(act)) | ||
if drop > 0: | ||
m.append(nn.Dropout2d(drop)) | ||
|
||
super(BasicConv, self).__init__(*m) | ||
|
||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
for m in self.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
nn.init.kaiming_normal_(m.weight) | ||
if m.bias is not None: | ||
nn.init.zeros_(m.bias) | ||
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d): | ||
m.weight.data.fill_(1) | ||
m.bias.data.zero_() | ||
|
||
|
||
def batched_index_select(x, idx): | ||
r"""fetches neighbors features from a given neighbor idx | ||
Args: | ||
x (Tensor): input feature Tensor | ||
:math:`\mathbf{X} \in \mathbb{R}^{B \times C \times N \times 1}`. | ||
idx (Tensor): edge_idx | ||
:math:`\mathbf{X} \in \mathbb{R}^{B \times N \times l}`. | ||
Returns: | ||
Tensor: output neighbors features | ||
:math:`\mathbf{X} \in \mathbb{R}^{B \times C \times N \times k}`. | ||
""" | ||
batch_size, num_dims, num_vertices_reduced = x.shape[:3] | ||
_, num_vertices, k = idx.shape | ||
idx_base = torch.arange(0, batch_size, device=idx.device).view(-1, 1, 1) * num_vertices_reduced | ||
idx = idx + idx_base | ||
idx = idx.contiguous().view(-1) | ||
|
||
x = x.transpose(2, 1) | ||
feature = x.contiguous().view(batch_size * num_vertices_reduced, -1)[idx, :] | ||
feature = feature.view(batch_size, num_vertices, k, num_dims).permute(0, 3, 1, 2).contiguous() | ||
return feature |
Oops, something went wrong.