Skip to content

Commit 997f580

Browse files
committed
upload model
1 parent b9c4c2a commit 997f580

File tree

2 files changed

+207
-0
lines changed

2 files changed

+207
-0
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
.DS_Store
2+
__pycache__
3+
.ipynb_checkpoints/

model/net.py

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
import numpy as np
5+
from util.mesh import laplacian_smooth
6+
7+
8+
# segmentation U-Net
9+
class Unet(nn.Module):
10+
def __init__(self, c_in=1, c_out=2):
11+
super(Unet, self).__init__()
12+
13+
self.conv1 = nn.Conv3d(in_channels=c_in, out_channels=16, kernel_size=3,
14+
stride=1, padding=1)
15+
self.conv2 = nn.Conv3d(in_channels=16, out_channels=32, kernel_size=3,
16+
stride=2, padding=1)
17+
self.conv3 = nn.Conv3d(in_channels=32, out_channels=64, kernel_size=3,
18+
stride=2, padding=1)
19+
self.conv4 = nn.Conv3d(in_channels=64, out_channels=128, kernel_size=3,
20+
stride=2, padding=1)
21+
self.conv5 = nn.Conv3d(in_channels=128, out_channels=128, kernel_size=3,
22+
stride=2, padding=1)
23+
24+
self.deconv4 = nn.Conv3d(in_channels=256, out_channels=64, kernel_size=3,
25+
stride=1, padding=1)
26+
self.deconv3 = nn.Conv3d(in_channels=128, out_channels=32, kernel_size=3,
27+
stride=1, padding=1)
28+
self.deconv2 = nn.Conv3d(in_channels=64, out_channels=16, kernel_size=3,
29+
stride=1, padding=1)
30+
self.deconv1 = nn.Conv3d(in_channels=32, out_channels=16, kernel_size=3,
31+
stride=1, padding=1)
32+
33+
self.lastconv1 = nn.Conv3d(in_channels=16, out_channels=16, kernel_size=3,
34+
stride=1, padding=1)
35+
self.lastconv2 = nn.Conv3d(in_channels=16, out_channels=c_out, kernel_size=3,
36+
stride=1, padding=1)
37+
self.up = nn.Upsample(scale_factor=2, mode='trilinear')
38+
39+
def forward(self, x):
40+
41+
x1 = F.leaky_relu(self.conv1(x), 0.2)
42+
x2 = F.leaky_relu(self.conv2(x1), 0.2)
43+
x3 = F.leaky_relu(self.conv3(x2), 0.2)
44+
x4 = F.leaky_relu(self.conv4(x3), 0.2)
45+
x = F.leaky_relu(self.conv5(x4), 0.2)
46+
x = self.up(x)
47+
48+
x = torch.cat([x, x4], dim=1)
49+
x = F.leaky_relu(self.deconv4(x), 0.2)
50+
x = self.up(x)
51+
52+
x = torch.cat([x, x3], dim=1)
53+
x = F.leaky_relu(self.deconv3(x), 0.2)
54+
x = self.up(x)
55+
56+
x = torch.cat([x, x2], dim=1)
57+
x = F.leaky_relu(self.deconv2(x), 0.2)
58+
x = self.up(x)
59+
60+
x = torch.cat([x, x1], dim=1)
61+
x = F.leaky_relu(self.deconv1(x), 0.2)
62+
63+
x = F.leaky_relu(self.lastconv1(x), 0.2)
64+
x = self.lastconv2(x)
65+
66+
return x
67+
68+
69+
class CortexODE(nn.Module):
70+
def __init__(self, dim_in=3,
71+
dim_h=128,
72+
kernel_size=5,
73+
n_scale=3):
74+
75+
super(CortexODE, self).__init__()
76+
"""
77+
dim_in (3): input dimension
78+
dim_h (C): hidden dimension
79+
kernel_size (K): size of convolutional kernels
80+
n_scale (Q): number of scales of the multi-scale input
81+
"""
82+
83+
C = dim_h # hidden dimension
84+
K = kernel_size # kernel size
85+
Q = n_scale # number of scales
86+
87+
self.C = C
88+
self.K = K
89+
self.Q = Q
90+
91+
# FC layers
92+
self.fc1 = nn.Linear(dim_in, C)
93+
self.fc2 = nn.Linear(C*2, C*4)
94+
self.fc3 = nn.Linear(C*4, C*2)
95+
self.fc4 = nn.Linear(C*2, dim_in)
96+
97+
# local convolution
98+
self.localconv = nn.Conv3d(Q, C, (K, K, K))
99+
self.localfc = nn.Linear(C, C)
100+
101+
# for cube sampling
102+
self.initialized = False
103+
104+
grid = np.linspace(-K//2, K//2, K)
105+
grid_3d = np.stack(np.meshgrid(grid, grid, grid), axis=0).transpose(2,1,3,0)
106+
self.x_shift = torch.Tensor(grid_3d.copy()).view(-1,3)
107+
108+
109+
def _initialize(self, V):
110+
# intialize rescale
111+
D1,D2,D3 = V[0,0].shape
112+
D = max([D1,D2,D3])
113+
self.rescale = torch.Tensor([D3/D, D2/D, D1/D]).to(V.device)
114+
self.D = D
115+
# initialize coordinates shift
116+
self.x_shift = self.x_shift.to(V.device)
117+
self.initialized == True
118+
119+
120+
def forward(self, x, V, solver='euler', step_size=0.2, T=1.0):
121+
if not self.initialized:
122+
self._initialize(V)
123+
124+
h = step_size
125+
N = int(T/h)
126+
127+
if solver == 'euler':
128+
# forward Euler method
129+
for n in range(N):
130+
dx = self.deform(x, V)
131+
x = x + h * dx
132+
133+
if solver == 'midpoint':
134+
# midpoint method
135+
for n in range(N):
136+
dx1 = self.deform(x, V)
137+
dx2 = self.deform(x + h*dx1/2, V)
138+
x = x + h * dx2
139+
140+
if solver == 'heun':
141+
# Heun's method
142+
for n in range(N):
143+
dx1 = self.deform(x, V)
144+
dx2 = self.deform(x + h*dx1, V)
145+
x = x + h * (dx1 + dx2) / 2
146+
147+
if solver == 'rk4':
148+
# fourth-order RK method
149+
for n in range(N):
150+
dx1 = self.deform(x, V)
151+
dx2 = self.deform(x + h*dx1/2, V)
152+
dx3 = self.deform(x + h*dx2/2, V)
153+
dx4 = self.deform(x + h*dx3, V)
154+
x = x + h * (dx1 + 2*dx2 + 2*dx3 + dx4) / 6
155+
156+
return x
157+
158+
159+
def deform(self, x, V):
160+
m = x.shape[1]
161+
162+
# local feature
163+
z_local = self.cube_sampling(x, V)
164+
z_local = self.localconv(z_local)
165+
z_local = z_local.view(-1, m, self.C)
166+
z_local = self.localfc(z_local)
167+
168+
# point feature
169+
z_point = F.leaky_relu(self.fc1(x), 0.2)
170+
171+
# feature fusion
172+
z = torch.cat([z_point, z_local], 2)
173+
z = F.leaky_relu(self.fc2(z), 0.2)
174+
z = F.leaky_relu(self.fc3(z), 0.2)
175+
dx = self.fc4(z)
176+
177+
return dx
178+
179+
180+
def cube_sampling(self, x, V):
181+
"""
182+
x: coordinates
183+
V: volumetric input
184+
"""
185+
with torch.no_grad():
186+
m = x.shape[1]
187+
188+
# initialize all cubes (m,Q,K,K,K)
189+
self.v = torch.zeros([m, self.Q,
190+
self.K, self.K, self.K]).to(V.device)
191+
Vq = V # set the initial scale
192+
for q in range(self.Q):
193+
if q >= 1:
194+
Vq = F.avg_pool3d(Vq, 2) # downsampling
195+
196+
# make sure the cubes have the same size
197+
xq = x.unsqueeze(-2) + self.x_shift / self.D * 2 * (2**q)
198+
xq = xq.contiguous().view(1,-1,3).unsqueeze(-2).unsqueeze(-2)
199+
xq = xq / self.rescale # normalize to [-1,1]
200+
# sample the q-th cube
201+
vq = F.grid_sample(Vq, xq, mode='bilinear', padding_mode='border', align_corners=True)
202+
self.v[:,q] = vq[0,0].view(m, self.K, self.K, self.K)
203+
204+
return self.v

0 commit comments

Comments
 (0)