forked from ilyassmoummad/ProtoCLR
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
36 lines (28 loc) · 942 Bytes
/
utils.py
File metadata and controls
36 lines (28 loc) · 942 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import torch.nn as nn
import torch.nn.functional as F
class Normalization(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return (x - x.min()) / (x.max() - x.min())
class Standardization(torch.nn.Module):
def __init__(self, mean, std):
super().__init__()
self.mean = mean
self.std = std
def forward(self, x):
return (x - self.mean) / self.std
class Projector(nn.Module):
def __init__(self, model_name='cvt', out_dim=128):
super(Projector, self).__init__()
dim = model_dim[model_name]
out_dim = out_dim
self.mlp = nn.Sequential(nn.Linear(dim, dim, bias=False),
nn.ReLU(inplace=True),
nn.Linear(dim, out_dim))
def forward(self, features):
return self.mlp(features)
model_dim = {
'cvt': 384,
}