1
1
import pandas as pd
2
2
from sklearn .model_selection import train_test_split
3
+ from dataprocess .utils import file_name_path
3
4
import torch
4
5
import os
5
6
from model import *
6
7
import cv2
8
+ import SimpleITK as sitk
7
9
8
10
# Use CUDA
9
11
os .environ ['CUDA_VISIBLE_DEVICES' ] = '0'
@@ -82,8 +84,8 @@ def trainbinaryvnet3d():
82
84
vallabels = csv_data2 .iloc [:, 1 ].values
83
85
84
86
vnet3d = BinaryVNet3dModel (image_depth = 80 , image_height = 112 , image_width = 176 , image_channel = 1 , numclass = 1 ,
85
- batch_size = 1 , loss_name = 'BinaryDiceLoss ' )
86
- vnet3d .trainprocess (trainimages , trainlabels , valimages , vallabels , model_dir = 'log/BinaryVNet3d/dice ' ,
87
+ batch_size = 1 , loss_name = 'BinaryCrossEntropyDiceLoss ' )
88
+ vnet3d .trainprocess (trainimages , trainlabels , valimages , vallabels , model_dir = 'log/BinaryVNet3d/CED ' ,
87
89
epochs = 50 , showwind = [8 , 10 ])
88
90
89
91
@@ -98,8 +100,8 @@ def trainbinaryunet3d():
98
100
vallabels = csv_data2 .iloc [:, 1 ].values
99
101
100
102
unet3d = BinaryUNet3dModel (image_depth = 80 , image_height = 112 , image_width = 176 , image_channel = 1 , numclass = 1 ,
101
- batch_size = 1 , loss_name = 'BinaryDiceLoss ' )
102
- unet3d .trainprocess (trainimages , trainlabels , valimages , vallabels , model_dir = 'log/BinaryUNet3d/dice ' ,
103
+ batch_size = 1 , loss_name = 'BinaryCrossEntropyDiceLoss ' )
104
+ unet3d .trainprocess (trainimages , trainlabels , valimages , vallabels , model_dir = 'log/BinaryUNet3d/CED ' ,
103
105
epochs = 50 , showwind = [8 , 10 ])
104
106
105
107
@@ -114,9 +116,9 @@ def trainmutilvnet3d():
114
116
vallabels = csv_data2 .iloc [:, 1 ].values
115
117
116
118
vnet3d = MutilVNet3dModel (image_depth = 80 , image_height = 112 , image_width = 176 , image_channel = 1 , numclass = 16 ,
117
- batch_size = 1 , loss_name = 'MutilFocalLoss ' )
118
- vnet3d .trainprocess (trainimages , trainlabels , valimages , vallabels , model_dir = 'log/MutilVNet3d/focal ' ,
119
- epochs = 50 , showwind = [8 , 10 ])
119
+ batch_size = 1 , loss_name = 'MutilCrossEntropyLoss ' )
120
+ vnet3d .trainprocess (trainimages , trainlabels , valimages , vallabels , model_dir = 'log/MutilVNet3d/CE ' ,
121
+ epochs = 100 , showwind = [8 , 10 ])
120
122
121
123
122
124
def trainmutilunet3d ():
@@ -130,20 +132,9 @@ def trainmutilunet3d():
130
132
vallabels = csv_data2 .iloc [:, 1 ].values
131
133
132
134
unet3d = MutilUNet3dModel (image_depth = 80 , image_height = 112 , image_width = 176 , image_channel = 1 , numclass = 16 ,
133
- batch_size = 1 , loss_name = 'MutilFocalLoss' )
134
- unet3d .trainprocess (trainimages , trainlabels , valimages , vallabels , model_dir = 'log/MutilUNet3d/focal' ,
135
- epochs = 50 , showwind = [8 , 10 ])
136
-
137
-
138
- def trainmutilResNet2d ():
139
- data_dir = 'dataprocess/data/trainlabels.csv'
140
- csv_data = pd .read_csv (data_dir )
141
- images = csv_data .iloc [:, 0 ].values
142
- labels = csv_data .iloc [:, 1 ].values
143
- trainimages , valimages , trainlabels , vallabels = train_test_split (images , labels , test_size = 0.2 )
144
- unet3d = MutilResNet2dModel (image_height = 256 , image_width = 256 , image_channel = 1 , numclass = 120 ,
145
- batch_size = 32 , loss_name = 'MutilCrossEntropyLoss' )
146
- unet3d .trainprocess (trainimages , trainlabels , valimages , vallabels , model_dir = 'log/MutilResNet2d/CE' , epochs = 50 )
135
+ batch_size = 1 , loss_name = 'MutilCrossEntropyLoss' )
136
+ unet3d .trainprocess (trainimages , trainlabels , valimages , vallabels , model_dir = 'log/MutilUNet3d/CE' ,
137
+ epochs = 100 , showwind = [8 , 10 ])
147
138
148
139
149
140
def inferencebinaryvnet2d ():
@@ -154,25 +145,93 @@ def inferencebinaryvnet2d():
154
145
155
146
vnet2d = BinaryVNet2dModel (image_height = 512 , image_width = 512 , image_channel = 1 , numclass = 1 , batch_size = 8 ,
156
147
loss_name = 'BinaryDiceLoss' , inference = True ,
157
- model_path = r'log/BinaryVNet2d/BCED\BinaryVNet2dSegModel .pth' )
148
+ model_path = r'log/BinaryVNet2d/dice\BinaryVNet2dModel .pth' )
158
149
outpath = r"D:\cjq\data\GlandCeildata\test\pd2"
159
150
for index in range (len (valimages )):
160
151
image = cv2 .imread (valimages [index ], 0 )
161
152
mask = vnet2d .inference (image )
162
153
cv2 .imwrite (outpath + "/" + str (index ) + ".png" , mask )
163
154
164
155
156
+ def inferencemutilvnet2d ():
157
+ data_dir = 'dataprocess/data/testseg.csv'
158
+ csv_data = pd .read_csv (data_dir )
159
+ valimages = csv_data .iloc [:, 0 ].values
160
+ vallabels = csv_data .iloc [:, 1 ].values
161
+
162
+ vnet2d = MutilVNet2dModel (image_height = 512 , image_width = 512 , image_channel = 1 , numclass = 2 , batch_size = 8 ,
163
+ loss_name = 'MutilDiceLoss' , inference = True ,
164
+ model_path = r'log/MutilVNet2d/dice\MutilVNet2d.pth' )
165
+ outpath = r"D:\cjq\data\GlandCeildata\test\pd2"
166
+ for index in range (len (valimages )):
167
+ image = cv2 .imread (valimages [index ], 0 )
168
+ mask = vnet2d .inference (image )
169
+ cv2 .imwrite (outpath + "/" + str (index ) + ".png" , mask )
170
+
171
+
172
+ def inferencebinaryvnet3d ():
173
+ data_dir = r'D:\cjq\data\Amos2022\ROIprocess\validation\Image'
174
+
175
+ vnet3d = BinaryVNet3dModel (image_depth = 80 , image_height = 112 , image_width = 176 , image_channel = 1 , numclass = 1 ,
176
+ batch_size = 1 , loss_name = 'BinaryDiceLoss' , inference = True ,
177
+ model_path = r'log\BinaryVNet3d\dice\BinaryVNet3d.pth' )
178
+ outpath = r"D:\cjq\data\Amos2022\ROIprocess\validation\Maskpd"
179
+ image_files = file_name_path (data_dir , False , True )
180
+ for index in range (len (image_files )):
181
+ image_path = data_dir + '/' + image_files [index ]
182
+ sitkimage = sitk .ReadImage (image_path , sitk .sitkInt16 )
183
+ sitkmask = vnet3d .inference (sitkimage , newSize = (176 , 112 , 80 ))
184
+ output_path = outpath + '/' + image_files [index ]
185
+ sitk .WriteImage (sitkmask , output_path )
186
+
187
+
188
+ def inferencemutilvnet3d ():
189
+ data_dir = r'D:\cjq\data\Amos2022\ROIprocess\validation\Image'
190
+
191
+ vnet3d = MutilVNet3dModel (image_depth = 80 , image_height = 112 , image_width = 176 , image_channel = 1 , numclass = 16 ,
192
+ batch_size = 1 , loss_name = 'MutilFocalLoss' , inference = True ,
193
+ model_path = r'log\MutilVNet3d\dice\MutilVNet3d.pth' )
194
+ outpath = r"D:\cjq\data\Amos2022\ROIprocess\validation\Maskpd"
195
+ image_files = file_name_path (data_dir , False , True )
196
+ for index in range (len (image_files )):
197
+ image_path = data_dir + '/' + image_files [index ]
198
+ sitkimage = sitk .ReadImage (image_path , sitk .sitkInt16 )
199
+ sitkmask = vnet3d .inference (sitkimage , newSize = (176 , 112 , 80 ))
200
+ output_path = outpath + '/' + image_files [index ]
201
+ sitk .WriteImage (sitkmask , output_path )
202
+
203
+
204
+ def trainmutilResNet2d ():
205
+ data_dir = 'dataprocess/data/mnisttrain.csv'
206
+ csv_data = pd .read_csv (data_dir )
207
+ trainimages = csv_data .iloc [:, 1 ].values
208
+ trainlabels = csv_data .iloc [:, 0 ].values
209
+ data_dir2 = 'dataprocess/data/mnistvalidation.csv'
210
+ csv_data2 = pd .read_csv (data_dir2 )
211
+ valimages = csv_data2 .iloc [:, 1 ].values
212
+ vallabels = csv_data2 .iloc [:, 0 ].values
213
+ resnet2d = MutilResNet2dModel (image_height = 64 , image_width = 64 , image_channel = 1 , numclass = 10 ,
214
+ batch_size = 128 , loss_name = 'MutilCrossEntropyLoss' )
215
+ resnet2d .trainprocess (trainimages , trainlabels , valimages , vallabels , model_dir = 'log/MutilResNet2d/CE' , epochs = 50 ,
216
+ lr = 0.001 )
217
+
218
+
165
219
if __name__ == '__main__' :
166
220
# trainbinaryvnet2d()
167
221
# trainbinaryunet2d()
168
222
169
223
# trainmutilvnet2d()
170
224
# trainmutilunet2d()
171
225
172
- trainbinaryvnet3d ()
173
- trainbinaryunet3d ()
226
+ # trainbinaryvnet3d()
227
+ # trainbinaryunet3d()
228
+
229
+ trainmutilvnet3d ()
230
+ trainmutilunet3d ()
174
231
175
- # trainmutilvnet3d()
176
- # trainmutilunet3d()
232
+ # inferencebinaryvnet2d()
233
+ # inferencemutilvnet2d()
234
+ # inferencebinaryvnet3d()
235
+ # inferencemutilvnet3d()
177
236
178
- # inferencebinaryvnet2dseg ()
237
+ # trainmutilResNet2d ()
0 commit comments