diff --git a/probe_lens/probes.py b/probe_lens/probes.py index 6302de9..382d208 100644 --- a/probe_lens/probes.py +++ b/probe_lens/probes.py @@ -1,11 +1,11 @@ -from tqdm import tqdm import torch import torch.nn as nn +from tqdm.autonotebook import tqdm class LinearProbe(nn.Module): def __init__(self, input_dim, output_dim=1, device="cpu"): - super(LinearProbe, self).__init__() + super().__init__() self.linear = nn.Linear(input_dim, output_dim, device=device) def forward(self, x): diff --git a/tests/test_probes.py b/tests/test_probes.py index a35cc28..d45f76f 100644 --- a/tests/test_probes.py +++ b/tests/test_probes.py @@ -1,7 +1,8 @@ -from probe_lens.probes import LinearProbe import torch import torch.nn as nn +from probe_lens.probes import LinearProbe + def test_linear_probe(): input_dim = 10