From 8ce0d7f1607bc2ea380856406c18f1ea0b858fad Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Sun, 27 Oct 2024 11:18:05 +0000 Subject: [PATCH] Add ruff fixes --- probe_lens/probes.py | 4 ++-- tests/test_probes.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/probe_lens/probes.py b/probe_lens/probes.py index 6302de9..382d208 100644 --- a/probe_lens/probes.py +++ b/probe_lens/probes.py @@ -1,11 +1,11 @@ -from tqdm import tqdm import torch import torch.nn as nn +from tqdm.autonotebook import tqdm class LinearProbe(nn.Module): def __init__(self, input_dim, output_dim=1, device="cpu"): - super(LinearProbe, self).__init__() + super().__init__() self.linear = nn.Linear(input_dim, output_dim, device=device) def forward(self, x): diff --git a/tests/test_probes.py b/tests/test_probes.py index a35cc28..d45f76f 100644 --- a/tests/test_probes.py +++ b/tests/test_probes.py @@ -1,7 +1,8 @@ -from probe_lens.probes import LinearProbe import torch import torch.nn as nn +from probe_lens.probes import LinearProbe + def test_linear_probe(): input_dim = 10