Skip to content

Commit 81c7bd3

Browse files
authored
accessing UNet depth from config.py
1 parent 01b3d0a commit 81c7bd3

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def plot_losses(running_train_loss, running_val_loss, train_epoch_loss, val_epoc
8787
print('len(val_loader) : {} @bs={}'.format(len(val_loader), batch_size))
8888

8989
# defining the model
90-
model = UNet(n_classes = 1, depth = 5, padding = True).to(device) # try decreasing the depth value if there is a memory error
90+
model = UNet(n_classes = 1, depth = cfg.depth, padding = True).to(device) # try decreasing the depth value if there is a memory error
9191

9292
resume = cfg.resume
9393

0 commit comments

Comments
 (0)