diff --git a/flagai/trainer_v1.py b/flagai/trainer_v1.py index adbf5618..94ab0fc2 100755 --- a/flagai/trainer_v1.py +++ b/flagai/trainer_v1.py @@ -876,7 +876,7 @@ def train_step_pytorch(self, optimizer.zero_grad() else: optimizer.step() - # optimizer.zero_grad() + optimizer.zero_grad() self.accumulate_count = 0 else: self.accumulate_count += 1