Skip to content

Commit 13baf43

Browse files
authored
Update example.py
1 parent 2d21a0e commit 13baf43

File tree

1 file changed

+86
-27
lines changed

1 file changed

+86
-27
lines changed

example.py

Lines changed: 86 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import pandas as pd
22
from sklearn.model_selection import train_test_split
3+
from dataprocess.utils import file_name_path
34
import torch
45
import os
56
from model import *
67
import cv2
8+
import SimpleITK as sitk
79

810
# Use CUDA
911
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
@@ -82,8 +84,8 @@ def trainbinaryvnet3d():
8284
vallabels = csv_data2.iloc[:, 1].values
8385

8486
vnet3d = BinaryVNet3dModel(image_depth=80, image_height=112, image_width=176, image_channel=1, numclass=1,
85-
batch_size=1, loss_name='BinaryDiceLoss')
86-
vnet3d.trainprocess(trainimages, trainlabels, valimages, vallabels, model_dir='log/BinaryVNet3d/dice',
87+
batch_size=1, loss_name='BinaryCrossEntropyDiceLoss')
88+
vnet3d.trainprocess(trainimages, trainlabels, valimages, vallabels, model_dir='log/BinaryVNet3d/CED',
8789
epochs=50, showwind=[8, 10])
8890

8991

@@ -98,8 +100,8 @@ def trainbinaryunet3d():
98100
vallabels = csv_data2.iloc[:, 1].values
99101

100102
unet3d = BinaryUNet3dModel(image_depth=80, image_height=112, image_width=176, image_channel=1, numclass=1,
101-
batch_size=1, loss_name='BinaryDiceLoss')
102-
unet3d.trainprocess(trainimages, trainlabels, valimages, vallabels, model_dir='log/BinaryUNet3d/dice',
103+
batch_size=1, loss_name='BinaryCrossEntropyDiceLoss')
104+
unet3d.trainprocess(trainimages, trainlabels, valimages, vallabels, model_dir='log/BinaryUNet3d/CED',
103105
epochs=50, showwind=[8, 10])
104106

105107

@@ -114,9 +116,9 @@ def trainmutilvnet3d():
114116
vallabels = csv_data2.iloc[:, 1].values
115117

116118
vnet3d = MutilVNet3dModel(image_depth=80, image_height=112, image_width=176, image_channel=1, numclass=16,
117-
batch_size=1, loss_name='MutilFocalLoss')
118-
vnet3d.trainprocess(trainimages, trainlabels, valimages, vallabels, model_dir='log/MutilVNet3d/focal',
119-
epochs=50, showwind=[8, 10])
119+
batch_size=1, loss_name='MutilCrossEntropyLoss')
120+
vnet3d.trainprocess(trainimages, trainlabels, valimages, vallabels, model_dir='log/MutilVNet3d/CE',
121+
epochs=100, showwind=[8, 10])
120122

121123

122124
def trainmutilunet3d():
@@ -130,20 +132,9 @@ def trainmutilunet3d():
130132
vallabels = csv_data2.iloc[:, 1].values
131133

132134
unet3d = MutilUNet3dModel(image_depth=80, image_height=112, image_width=176, image_channel=1, numclass=16,
133-
batch_size=1, loss_name='MutilFocalLoss')
134-
unet3d.trainprocess(trainimages, trainlabels, valimages, vallabels, model_dir='log/MutilUNet3d/focal',
135-
epochs=50, showwind=[8, 10])
136-
137-
138-
def trainmutilResNet2d():
139-
data_dir = 'dataprocess/data/trainlabels.csv'
140-
csv_data = pd.read_csv(data_dir)
141-
images = csv_data.iloc[:, 0].values
142-
labels = csv_data.iloc[:, 1].values
143-
trainimages, valimages, trainlabels, vallabels = train_test_split(images, labels, test_size=0.2)
144-
unet3d = MutilResNet2dModel(image_height=256, image_width=256, image_channel=1, numclass=120,
145-
batch_size=32, loss_name='MutilCrossEntropyLoss')
146-
unet3d.trainprocess(trainimages, trainlabels, valimages, vallabels, model_dir='log/MutilResNet2d/CE', epochs=50)
135+
batch_size=1, loss_name='MutilCrossEntropyLoss')
136+
unet3d.trainprocess(trainimages, trainlabels, valimages, vallabels, model_dir='log/MutilUNet3d/CE',
137+
epochs=100, showwind=[8, 10])
147138

148139

149140
def inferencebinaryvnet2d():
@@ -154,25 +145,93 @@ def inferencebinaryvnet2d():
154145

155146
vnet2d = BinaryVNet2dModel(image_height=512, image_width=512, image_channel=1, numclass=1, batch_size=8,
156147
loss_name='BinaryDiceLoss', inference=True,
157-
model_path=r'log/BinaryVNet2d/BCED\BinaryVNet2dSegModel.pth')
148+
model_path=r'log/BinaryVNet2d/dice\BinaryVNet2dModel.pth')
158149
outpath = r"D:\cjq\data\GlandCeildata\test\pd2"
159150
for index in range(len(valimages)):
160151
image = cv2.imread(valimages[index], 0)
161152
mask = vnet2d.inference(image)
162153
cv2.imwrite(outpath + "/" + str(index) + ".png", mask)
163154

164155

156+
def inferencemutilvnet2d():
157+
data_dir = 'dataprocess/data/testseg.csv'
158+
csv_data = pd.read_csv(data_dir)
159+
valimages = csv_data.iloc[:, 0].values
160+
vallabels = csv_data.iloc[:, 1].values
161+
162+
vnet2d = MutilVNet2dModel(image_height=512, image_width=512, image_channel=1, numclass=2, batch_size=8,
163+
loss_name='MutilDiceLoss', inference=True,
164+
model_path=r'log/MutilVNet2d/dice\MutilVNet2d.pth')
165+
outpath = r"D:\cjq\data\GlandCeildata\test\pd2"
166+
for index in range(len(valimages)):
167+
image = cv2.imread(valimages[index], 0)
168+
mask = vnet2d.inference(image)
169+
cv2.imwrite(outpath + "/" + str(index) + ".png", mask)
170+
171+
172+
def inferencebinaryvnet3d():
173+
data_dir = r'D:\cjq\data\Amos2022\ROIprocess\validation\Image'
174+
175+
vnet3d = BinaryVNet3dModel(image_depth=80, image_height=112, image_width=176, image_channel=1, numclass=1,
176+
batch_size=1, loss_name='BinaryDiceLoss', inference=True,
177+
model_path=r'log\BinaryVNet3d\dice\BinaryVNet3d.pth')
178+
outpath = r"D:\cjq\data\Amos2022\ROIprocess\validation\Maskpd"
179+
image_files = file_name_path(data_dir, False, True)
180+
for index in range(len(image_files)):
181+
image_path = data_dir + '/' + image_files[index]
182+
sitkimage = sitk.ReadImage(image_path, sitk.sitkInt16)
183+
sitkmask = vnet3d.inference(sitkimage, newSize=(176, 112, 80))
184+
output_path = outpath + '/' + image_files[index]
185+
sitk.WriteImage(sitkmask, output_path)
186+
187+
188+
def inferencemutilvnet3d():
189+
data_dir = r'D:\cjq\data\Amos2022\ROIprocess\validation\Image'
190+
191+
vnet3d = MutilVNet3dModel(image_depth=80, image_height=112, image_width=176, image_channel=1, numclass=16,
192+
batch_size=1, loss_name='MutilFocalLoss', inference=True,
193+
model_path=r'log\MutilVNet3d\dice\MutilVNet3d.pth')
194+
outpath = r"D:\cjq\data\Amos2022\ROIprocess\validation\Maskpd"
195+
image_files = file_name_path(data_dir, False, True)
196+
for index in range(len(image_files)):
197+
image_path = data_dir + '/' + image_files[index]
198+
sitkimage = sitk.ReadImage(image_path, sitk.sitkInt16)
199+
sitkmask = vnet3d.inference(sitkimage, newSize=(176, 112, 80))
200+
output_path = outpath + '/' + image_files[index]
201+
sitk.WriteImage(sitkmask, output_path)
202+
203+
204+
def trainmutilResNet2d():
205+
data_dir = 'dataprocess/data/mnisttrain.csv'
206+
csv_data = pd.read_csv(data_dir)
207+
trainimages = csv_data.iloc[:, 1].values
208+
trainlabels = csv_data.iloc[:, 0].values
209+
data_dir2 = 'dataprocess/data/mnistvalidation.csv'
210+
csv_data2 = pd.read_csv(data_dir2)
211+
valimages = csv_data2.iloc[:, 1].values
212+
vallabels = csv_data2.iloc[:, 0].values
213+
resnet2d = MutilResNet2dModel(image_height=64, image_width=64, image_channel=1, numclass=10,
214+
batch_size=128, loss_name='MutilCrossEntropyLoss')
215+
resnet2d.trainprocess(trainimages, trainlabels, valimages, vallabels, model_dir='log/MutilResNet2d/CE', epochs=50,
216+
lr=0.001)
217+
218+
165219
if __name__ == '__main__':
166220
# trainbinaryvnet2d()
167221
# trainbinaryunet2d()
168222

169223
# trainmutilvnet2d()
170224
# trainmutilunet2d()
171225

172-
trainbinaryvnet3d()
173-
trainbinaryunet3d()
226+
# trainbinaryvnet3d()
227+
# trainbinaryunet3d()
228+
229+
trainmutilvnet3d()
230+
trainmutilunet3d()
174231

175-
# trainmutilvnet3d()
176-
# trainmutilunet3d()
232+
# inferencebinaryvnet2d()
233+
# inferencemutilvnet2d()
234+
# inferencebinaryvnet3d()
235+
# inferencemutilvnet3d()
177236

178-
# inferencebinaryvnet2dseg()
237+
# trainmutilResNet2d()

0 commit comments

Comments
 (0)