Skip to content

Commit

Permalink
added verbose flag for PyTorch CKA
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasMut committed May 19, 2024
1 parent 40099bb commit 2308914
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
9 changes: 5 additions & 4 deletions thingsvision/core/cka/cka_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(
kernel: str,
unbiased: bool = False,
device: str = "cpu",
verbose: bool = False,
sigma: Optional[float] = 1.0,
) -> None:
"""
Expand All @@ -30,7 +31,7 @@ def __init__(
sigma (float) - for 'rbf' kernel sigma defines the width of the Gaussian;
"""
super().__init__(m=m, kernel=kernel, unbiased=unbiased, sigma=sigma)
device = self._check_device(device)
device = self._check_device(device, verbose)
if device == "cpu":
self.hsic = self._hsic
else:
Expand All @@ -39,7 +40,7 @@ def __init__(
self.device = torch.device(device)

@staticmethod
def _check_device(device: str) -> str:
def _check_device(device: str, verbose: bool) -> str:
"""Check whether the selected device is available on current compute node."""
if device.startswith("cuda"):
gpu_index = re.search(r"cuda:(\d+)", device)
Expand All @@ -58,8 +59,8 @@ def _check_device(device: str) -> str:
category=UserWarning,
)
device = "cuda:0"

print(f"\nUsing device: {device}\n")
if verbose:
print(f"\nUsing device: {device}\n")
return device

def centering(self, K: TensorType["m", "m"]) -> TensorType["m", "m"]:
Expand Down
8 changes: 7 additions & 1 deletion thingsvision/core/cka/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def get_cka(
unbiased: bool = False,
sigma: Optional[float] = 1.0,
device: Optional[str] = None,
verbose: Optional[bool] = False,
) -> Union[CKANumPy, CKATorch]:
"""Return a NumPy or PyTorch implementation of CKA."""
assert backend in BACKENDS, f"\nSupported backends are: {BACKENDS}\n"
Expand All @@ -23,6 +24,11 @@ def get_cka(
device, str
), "\nDevice must be set for using PyTorch backend.\n"
cka = CKATorch(
m=m, kernel=kernel, unbiased=unbiased, device=device, sigma=sigma
m=m,
kernel=kernel,
unbiased=unbiased,
device=device,
sigma=sigma,
verbose=verbose,
)
return cka

0 comments on commit 2308914

Please sign in to comment.