diff --git a/pytorch-ddp-accelerate-transformers.md b/pytorch-ddp-accelerate-transformers.md index 8815b7f5f2..5a876e3ade 100644 --- a/pytorch-ddp-accelerate-transformers.md +++ b/pytorch-ddp-accelerate-transformers.md @@ -41,7 +41,7 @@ class BasicNet(nn.Module): self.fc2 = nn.Linear(128, 10) self.act = F.relu - def forward(self, x): + def forward(self, x, labels=None): x = self.act(self.conv1(x)) x = self.act(self.conv2(x)) x = F.max_pool2d(x, 2) @@ -54,7 +54,9 @@ class BasicNet(nn.Module): return output ``` -We define the training device (`cuda`): +Note we specified a `labels=None`; this avoids an error when we try to pass a `labels` keyword argument later. However, it is not being used here. + +Now, we define the training device (`cuda`): ```python device = "cuda"