|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +import torch.nn.functional as F |
| 4 | +import numpy as np |
| 5 | +from util.mesh import laplacian_smooth |
| 6 | + |
| 7 | + |
| 8 | +# segmentation U-Net |
| 9 | +class Unet(nn.Module): |
| 10 | + def __init__(self, c_in=1, c_out=2): |
| 11 | + super(Unet, self).__init__() |
| 12 | + |
| 13 | + self.conv1 = nn.Conv3d(in_channels=c_in, out_channels=16, kernel_size=3, |
| 14 | + stride=1, padding=1) |
| 15 | + self.conv2 = nn.Conv3d(in_channels=16, out_channels=32, kernel_size=3, |
| 16 | + stride=2, padding=1) |
| 17 | + self.conv3 = nn.Conv3d(in_channels=32, out_channels=64, kernel_size=3, |
| 18 | + stride=2, padding=1) |
| 19 | + self.conv4 = nn.Conv3d(in_channels=64, out_channels=128, kernel_size=3, |
| 20 | + stride=2, padding=1) |
| 21 | + self.conv5 = nn.Conv3d(in_channels=128, out_channels=128, kernel_size=3, |
| 22 | + stride=2, padding=1) |
| 23 | + |
| 24 | + self.deconv4 = nn.Conv3d(in_channels=256, out_channels=64, kernel_size=3, |
| 25 | + stride=1, padding=1) |
| 26 | + self.deconv3 = nn.Conv3d(in_channels=128, out_channels=32, kernel_size=3, |
| 27 | + stride=1, padding=1) |
| 28 | + self.deconv2 = nn.Conv3d(in_channels=64, out_channels=16, kernel_size=3, |
| 29 | + stride=1, padding=1) |
| 30 | + self.deconv1 = nn.Conv3d(in_channels=32, out_channels=16, kernel_size=3, |
| 31 | + stride=1, padding=1) |
| 32 | + |
| 33 | + self.lastconv1 = nn.Conv3d(in_channels=16, out_channels=16, kernel_size=3, |
| 34 | + stride=1, padding=1) |
| 35 | + self.lastconv2 = nn.Conv3d(in_channels=16, out_channels=c_out, kernel_size=3, |
| 36 | + stride=1, padding=1) |
| 37 | + self.up = nn.Upsample(scale_factor=2, mode='trilinear') |
| 38 | + |
| 39 | + def forward(self, x): |
| 40 | + |
| 41 | + x1 = F.leaky_relu(self.conv1(x), 0.2) |
| 42 | + x2 = F.leaky_relu(self.conv2(x1), 0.2) |
| 43 | + x3 = F.leaky_relu(self.conv3(x2), 0.2) |
| 44 | + x4 = F.leaky_relu(self.conv4(x3), 0.2) |
| 45 | + x = F.leaky_relu(self.conv5(x4), 0.2) |
| 46 | + x = self.up(x) |
| 47 | + |
| 48 | + x = torch.cat([x, x4], dim=1) |
| 49 | + x = F.leaky_relu(self.deconv4(x), 0.2) |
| 50 | + x = self.up(x) |
| 51 | + |
| 52 | + x = torch.cat([x, x3], dim=1) |
| 53 | + x = F.leaky_relu(self.deconv3(x), 0.2) |
| 54 | + x = self.up(x) |
| 55 | + |
| 56 | + x = torch.cat([x, x2], dim=1) |
| 57 | + x = F.leaky_relu(self.deconv2(x), 0.2) |
| 58 | + x = self.up(x) |
| 59 | + |
| 60 | + x = torch.cat([x, x1], dim=1) |
| 61 | + x = F.leaky_relu(self.deconv1(x), 0.2) |
| 62 | + |
| 63 | + x = F.leaky_relu(self.lastconv1(x), 0.2) |
| 64 | + x = self.lastconv2(x) |
| 65 | + |
| 66 | + return x |
| 67 | + |
| 68 | + |
| 69 | +class CortexODE(nn.Module): |
| 70 | + def __init__(self, dim_in=3, |
| 71 | + dim_h=128, |
| 72 | + kernel_size=5, |
| 73 | + n_scale=3): |
| 74 | + |
| 75 | + super(CortexODE, self).__init__() |
| 76 | + """ |
| 77 | + dim_in (3): input dimension |
| 78 | + dim_h (C): hidden dimension |
| 79 | + kernel_size (K): size of convolutional kernels |
| 80 | + n_scale (Q): number of scales of the multi-scale input |
| 81 | + """ |
| 82 | + |
| 83 | + C = dim_h # hidden dimension |
| 84 | + K = kernel_size # kernel size |
| 85 | + Q = n_scale # number of scales |
| 86 | + |
| 87 | + self.C = C |
| 88 | + self.K = K |
| 89 | + self.Q = Q |
| 90 | + |
| 91 | + # FC layers |
| 92 | + self.fc1 = nn.Linear(dim_in, C) |
| 93 | + self.fc2 = nn.Linear(C*2, C*4) |
| 94 | + self.fc3 = nn.Linear(C*4, C*2) |
| 95 | + self.fc4 = nn.Linear(C*2, dim_in) |
| 96 | + |
| 97 | + # local convolution |
| 98 | + self.localconv = nn.Conv3d(Q, C, (K, K, K)) |
| 99 | + self.localfc = nn.Linear(C, C) |
| 100 | + |
| 101 | + # for cube sampling |
| 102 | + self.initialized = False |
| 103 | + |
| 104 | + grid = np.linspace(-K//2, K//2, K) |
| 105 | + grid_3d = np.stack(np.meshgrid(grid, grid, grid), axis=0).transpose(2,1,3,0) |
| 106 | + self.x_shift = torch.Tensor(grid_3d.copy()).view(-1,3) |
| 107 | + |
| 108 | + |
| 109 | + def _initialize(self, V): |
| 110 | + # intialize rescale |
| 111 | + D1,D2,D3 = V[0,0].shape |
| 112 | + D = max([D1,D2,D3]) |
| 113 | + self.rescale = torch.Tensor([D3/D, D2/D, D1/D]).to(V.device) |
| 114 | + self.D = D |
| 115 | + # initialize coordinates shift |
| 116 | + self.x_shift = self.x_shift.to(V.device) |
| 117 | + self.initialized == True |
| 118 | + |
| 119 | + |
| 120 | + def forward(self, x, V, solver='euler', step_size=0.2, T=1.0): |
| 121 | + if not self.initialized: |
| 122 | + self._initialize(V) |
| 123 | + |
| 124 | + h = step_size |
| 125 | + N = int(T/h) |
| 126 | + |
| 127 | + if solver == 'euler': |
| 128 | + # forward Euler method |
| 129 | + for n in range(N): |
| 130 | + dx = self.deform(x, V) |
| 131 | + x = x + h * dx |
| 132 | + |
| 133 | + if solver == 'midpoint': |
| 134 | + # midpoint method |
| 135 | + for n in range(N): |
| 136 | + dx1 = self.deform(x, V) |
| 137 | + dx2 = self.deform(x + h*dx1/2, V) |
| 138 | + x = x + h * dx2 |
| 139 | + |
| 140 | + if solver == 'heun': |
| 141 | + # Heun's method |
| 142 | + for n in range(N): |
| 143 | + dx1 = self.deform(x, V) |
| 144 | + dx2 = self.deform(x + h*dx1, V) |
| 145 | + x = x + h * (dx1 + dx2) / 2 |
| 146 | + |
| 147 | + if solver == 'rk4': |
| 148 | + # fourth-order RK method |
| 149 | + for n in range(N): |
| 150 | + dx1 = self.deform(x, V) |
| 151 | + dx2 = self.deform(x + h*dx1/2, V) |
| 152 | + dx3 = self.deform(x + h*dx2/2, V) |
| 153 | + dx4 = self.deform(x + h*dx3, V) |
| 154 | + x = x + h * (dx1 + 2*dx2 + 2*dx3 + dx4) / 6 |
| 155 | + |
| 156 | + return x |
| 157 | + |
| 158 | + |
| 159 | + def deform(self, x, V): |
| 160 | + m = x.shape[1] |
| 161 | + |
| 162 | + # local feature |
| 163 | + z_local = self.cube_sampling(x, V) |
| 164 | + z_local = self.localconv(z_local) |
| 165 | + z_local = z_local.view(-1, m, self.C) |
| 166 | + z_local = self.localfc(z_local) |
| 167 | + |
| 168 | + # point feature |
| 169 | + z_point = F.leaky_relu(self.fc1(x), 0.2) |
| 170 | + |
| 171 | + # feature fusion |
| 172 | + z = torch.cat([z_point, z_local], 2) |
| 173 | + z = F.leaky_relu(self.fc2(z), 0.2) |
| 174 | + z = F.leaky_relu(self.fc3(z), 0.2) |
| 175 | + dx = self.fc4(z) |
| 176 | + |
| 177 | + return dx |
| 178 | + |
| 179 | + |
| 180 | + def cube_sampling(self, x, V): |
| 181 | + """ |
| 182 | + x: coordinates |
| 183 | + V: volumetric input |
| 184 | + """ |
| 185 | + with torch.no_grad(): |
| 186 | + m = x.shape[1] |
| 187 | + |
| 188 | + # initialize all cubes (m,Q,K,K,K) |
| 189 | + self.v = torch.zeros([m, self.Q, |
| 190 | + self.K, self.K, self.K]).to(V.device) |
| 191 | + Vq = V # set the initial scale |
| 192 | + for q in range(self.Q): |
| 193 | + if q >= 1: |
| 194 | + Vq = F.avg_pool3d(Vq, 2) # downsampling |
| 195 | + |
| 196 | + # make sure the cubes have the same size |
| 197 | + xq = x.unsqueeze(-2) + self.x_shift / self.D * 2 * (2**q) |
| 198 | + xq = xq.contiguous().view(1,-1,3).unsqueeze(-2).unsqueeze(-2) |
| 199 | + xq = xq / self.rescale # normalize to [-1,1] |
| 200 | + # sample the q-th cube |
| 201 | + vq = F.grid_sample(Vq, xq, mode='bilinear', padding_mode='border', align_corners=True) |
| 202 | + self.v[:,q] = vq[0,0].view(m, self.K, self.K, self.K) |
| 203 | + |
| 204 | + return self.v |
0 commit comments