Skip to content

Commit

Permalink
Add GNN Vision
Browse files Browse the repository at this point in the history
Add GNN Vision
  • Loading branch information
chakkritte committed May 17, 2023
1 parent a718a45 commit 1aa9276
Show file tree
Hide file tree
Showing 11 changed files with 821 additions and 5 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ Darts family models pre-trained

## What's New

May 17, 2023
- Add Vision GNN ImageNet-1k models, Thank from [Vision GNN](https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/vig_pytorch)

Feb 2, 2022
- Add AutoFormerV2 ImageNet-1k models, Thank from [AutoFormerV2](https://github.com/silent-chen/AutoFormerV2-model-zoo)

Expand Down Expand Up @@ -55,6 +58,7 @@ Oct 27, 2020
- [rest_lite, rest_small, rest_base, rest_large](https://github.com/Alibaba-MIIL/TResNet)
- [pnas5](https://github.com/samyak0210/saliency/)
- [autoformerv2_tiny, autoformerv2_small, autoformerv2_base](https://github.com/silent-chen/AutoFormerV2-model-zoo)
- [pvig_ti_224_gelu, pvig_s_224_gelu, pvig_m_224_gelu, pvig_b_224_gelu](https://arxiv.org/abs/2206.00272)

## Install

Expand Down
3 changes: 2 additions & 1 deletion darmo/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .rest import *
from .autoformer2 import *
from .autoformer2 import *
from .pyramid_vig import *
5 changes: 5 additions & 0 deletions darmo/models/gcn_lib/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# 2022.06.17-Changed for building ViG model
# Huawei Technologies Co., Ltd. <[email protected]>
from .torch_nn import *
from .torch_edge import *
from .torch_vertex import *
84 changes: 84 additions & 0 deletions darmo/models/gcn_lib/pos_embed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# 2022.06.17-Changed for building ViG model
# Huawei Technologies Co., Ltd. <[email protected]>
# modified from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# Position embedding utils
# --------------------------------------------------------

import numpy as np


# --------------------------------------------------------
# relative position embedding
# References: https://arxiv.org/abs/2009.13658
# --------------------------------------------------------
def get_2d_relative_pos_embed(embed_dim, grid_size):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, grid_size*grid_size]
"""
pos_embed = get_2d_sincos_pos_embed(embed_dim, grid_size)
relative_pos = 2 * np.matmul(pos_embed, pos_embed.transpose()) / pos_embed.shape[1]
return relative_pos


# --------------------------------------------------------
# 2D sine-cosine position embedding
# References:
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
# MoCo v3: https://github.com/facebookresearch/moco-v3
# --------------------------------------------------------
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)

grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed


def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0

# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)

emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)

pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product

emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)

emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
159 changes: 159 additions & 0 deletions darmo/models/gcn_lib/torch_edge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# 2022.06.17-Changed for building ViG model
# Huawei Technologies Co., Ltd. <[email protected]>
import math
import torch
from torch import nn
import torch.nn.functional as F


def pairwise_distance(x):
"""
Compute pairwise distance of a point cloud.
Args:
x: tensor (batch_size, num_points, num_dims)
Returns:
pairwise distance: (batch_size, num_points, num_points)
"""
with torch.no_grad():
x_inner = -2*torch.matmul(x, x.transpose(2, 1))
x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True)
return x_square + x_inner + x_square.transpose(2, 1)


def part_pairwise_distance(x, start_idx=0, end_idx=1):
"""
Compute pairwise distance of a point cloud.
Args:
x: tensor (batch_size, num_points, num_dims)
Returns:
pairwise distance: (batch_size, num_points, num_points)
"""
with torch.no_grad():
x_part = x[:, start_idx:end_idx]
x_square_part = torch.sum(torch.mul(x_part, x_part), dim=-1, keepdim=True)
x_inner = -2*torch.matmul(x_part, x.transpose(2, 1))
x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True)
return x_square_part + x_inner + x_square.transpose(2, 1)


def xy_pairwise_distance(x, y):
"""
Compute pairwise distance of a point cloud.
Args:
x: tensor (batch_size, num_points, num_dims)
Returns:
pairwise distance: (batch_size, num_points, num_points)
"""
with torch.no_grad():
xy_inner = -2*torch.matmul(x, y.transpose(2, 1))
x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True)
y_square = torch.sum(torch.mul(y, y), dim=-1, keepdim=True)
return x_square + xy_inner + y_square.transpose(2, 1)


def dense_knn_matrix(x, k=16, relative_pos=None):
"""Get KNN based on the pairwise distance.
Args:
x: (batch_size, num_dims, num_points, 1)
k: int
Returns:
nearest neighbors: (batch_size, num_points, k) (batch_size, num_points, k)
"""
with torch.no_grad():
x = x.transpose(2, 1).squeeze(-1)
batch_size, n_points, n_dims = x.shape
### memory efficient implementation ###
n_part = 10000
if n_points > n_part:
nn_idx_list = []
groups = math.ceil(n_points / n_part)
for i in range(groups):
start_idx = n_part * i
end_idx = min(n_points, n_part * (i + 1))
dist = part_pairwise_distance(x.detach(), start_idx, end_idx)
if relative_pos is not None:
dist += relative_pos[:, start_idx:end_idx]
_, nn_idx_part = torch.topk(-dist, k=k)
nn_idx_list += [nn_idx_part]
nn_idx = torch.cat(nn_idx_list, dim=1)
else:
dist = pairwise_distance(x.detach())
if relative_pos is not None:
dist += relative_pos
_, nn_idx = torch.topk(-dist, k=k) # b, n, k
######
center_idx = torch.arange(0, n_points, device=x.device).repeat(batch_size, k, 1).transpose(2, 1)
return torch.stack((nn_idx, center_idx), dim=0)


def xy_dense_knn_matrix(x, y, k=16, relative_pos=None):
"""Get KNN based on the pairwise distance.
Args:
x: (batch_size, num_dims, num_points, 1)
k: int
Returns:
nearest neighbors: (batch_size, num_points, k) (batch_size, num_points, k)
"""
with torch.no_grad():
x = x.transpose(2, 1).squeeze(-1)
y = y.transpose(2, 1).squeeze(-1)
batch_size, n_points, n_dims = x.shape
dist = xy_pairwise_distance(x.detach(), y.detach())
if relative_pos is not None:
dist += relative_pos
_, nn_idx = torch.topk(-dist, k=k)
center_idx = torch.arange(0, n_points, device=x.device).repeat(batch_size, k, 1).transpose(2, 1)
return torch.stack((nn_idx, center_idx), dim=0)


class DenseDilated(nn.Module):
"""
Find dilated neighbor from neighbor list
edge_index: (2, batch_size, num_points, k)
"""
def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0):
super(DenseDilated, self).__init__()
self.dilation = dilation
self.stochastic = stochastic
self.epsilon = epsilon
self.k = k

def forward(self, edge_index):
if self.stochastic:
if torch.rand(1) < self.epsilon and self.training:
num = self.k * self.dilation
randnum = torch.randperm(num)[:self.k]
edge_index = edge_index[:, :, :, randnum]
else:
edge_index = edge_index[:, :, :, ::self.dilation]
else:
edge_index = edge_index[:, :, :, ::self.dilation]
return edge_index


class DenseDilatedKnnGraph(nn.Module):
"""
Find the neighbors' indices based on dilated knn
"""
def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0):
super(DenseDilatedKnnGraph, self).__init__()
self.dilation = dilation
self.stochastic = stochastic
self.epsilon = epsilon
self.k = k
self._dilated = DenseDilated(k, dilation, stochastic, epsilon)

def forward(self, x, y=None, relative_pos=None):
if y is not None:
#### normalize
x = F.normalize(x, p=2.0, dim=1)
y = F.normalize(y, p=2.0, dim=1)
####
edge_index = xy_dense_knn_matrix(x, y, self.k * self.dilation, relative_pos)
else:
#### normalize
x = F.normalize(x, p=2.0, dim=1)
####
edge_index = dense_knn_matrix(x, self.k * self.dilation, relative_pos)
return self._dilated(edge_index)
102 changes: 102 additions & 0 deletions darmo/models/gcn_lib/torch_nn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# 2022.06.17-Changed for building ViG model
# Huawei Technologies Co., Ltd. <[email protected]>
import torch
from torch import nn
from torch.nn import Sequential as Seq, Linear as Lin, Conv2d


##############################
# Basic layers
##############################
def act_layer(act, inplace=False, neg_slope=0.2, n_prelu=1):
# activation layer

act = act.lower()
if act == 'relu':
layer = nn.ReLU(inplace)
elif act == 'leakyrelu':
layer = nn.LeakyReLU(neg_slope, inplace)
elif act == 'prelu':
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
elif act == 'gelu':
layer = nn.GELU()
elif act == 'hswish':
layer = nn.Hardswish(inplace)
else:
raise NotImplementedError('activation layer [%s] is not found' % act)
return layer


def norm_layer(norm, nc):
# normalization layer 2d
norm = norm.lower()
if norm == 'batch':
layer = nn.BatchNorm2d(nc, affine=True)
elif norm == 'instance':
layer = nn.InstanceNorm2d(nc, affine=False)
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm)
return layer


class MLP(Seq):
def __init__(self, channels, act='relu', norm=None, bias=True):
m = []
for i in range(1, len(channels)):
m.append(Lin(channels[i - 1], channels[i], bias))
if act is not None and act.lower() != 'none':
m.append(act_layer(act))
if norm is not None and norm.lower() != 'none':
m.append(norm_layer(norm, channels[-1]))
super(MLP, self).__init__(*m)


class BasicConv(Seq):
def __init__(self, channels, act='relu', norm=None, bias=True, drop=0.):
m = []
for i in range(1, len(channels)):
m.append(Conv2d(channels[i - 1], channels[i], 1, bias=bias, groups=4))
if norm is not None and norm.lower() != 'none':
m.append(norm_layer(norm, channels[-1]))
if act is not None and act.lower() != 'none':
m.append(act_layer(act))
if drop > 0:
m.append(nn.Dropout2d(drop))

super(BasicConv, self).__init__(*m)

self.reset_parameters()

def reset_parameters(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()


def batched_index_select(x, idx):
r"""fetches neighbors features from a given neighbor idx
Args:
x (Tensor): input feature Tensor
:math:`\mathbf{X} \in \mathbb{R}^{B \times C \times N \times 1}`.
idx (Tensor): edge_idx
:math:`\mathbf{X} \in \mathbb{R}^{B \times N \times l}`.
Returns:
Tensor: output neighbors features
:math:`\mathbf{X} \in \mathbb{R}^{B \times C \times N \times k}`.
"""
batch_size, num_dims, num_vertices_reduced = x.shape[:3]
_, num_vertices, k = idx.shape
idx_base = torch.arange(0, batch_size, device=idx.device).view(-1, 1, 1) * num_vertices_reduced
idx = idx + idx_base
idx = idx.contiguous().view(-1)

x = x.transpose(2, 1)
feature = x.contiguous().view(batch_size * num_vertices_reduced, -1)[idx, :]
feature = feature.view(batch_size, num_vertices, k, num_dims).permute(0, 3, 1, 2).contiguous()
return feature
Loading

0 comments on commit 1aa9276

Please sign in to comment.