Skip to content

Commit

Permalink
Update trainer.py
Browse files Browse the repository at this point in the history
  • Loading branch information
rhcsky committed Jan 30, 2021
1 parent 9920fc6 commit 0d9d9fb
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,14 @@ def train(self):

# Train & Validation
main_pbar = tqdm(range(start_epoch, self.config.epochs), initial=start_epoch, position=0,
total=self.config.epochs, ncols=100, desc="Process")
total=self.config.epochs, desc="Process")
for epoch in main_pbar:
train_losses = AverageMeter()
valid_losses = AverageMeter()

# TRAIN
model.train()
train_pbar = tqdm(enumerate(train_loader), total=num_train, desc="Train", ncols=100, position=1,
leave=False)
train_pbar = tqdm(enumerate(train_loader), total=num_train, desc="Train", position=1, leave=False)
for i, (x1, x2, y) in train_pbar:
if self.config.use_gpu:
x1, x2, y = x1.to(self.device), x2.to(self.device), y.to(self.device)
Expand All @@ -112,8 +111,7 @@ def train(self):
model.eval()
valid_acc = 0
correct_sum = 0
valid_pbar = tqdm(enumerate(valid_loader), total=num_valid, desc="Valid", ncols=100, position=1,
leave=False)
valid_pbar = tqdm(enumerate(valid_loader), total=num_valid, desc="Valid", position=1, leave=False)
with torch.no_grad():
for i, (x1, x2, y) in valid_pbar:

Expand Down

0 comments on commit 0d9d9fb

Please sign in to comment.