Skip to content

Commit

Permalink
fixed rbf kernel in cka
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasMut committed May 17, 2024
1 parent 16940b8 commit 8b3e592
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions thingsvision/core/cka/cka_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,16 @@ def linear_kernel(
return X @ X.T

def rbf_kernel(
self, X: TensorType["m", "d"], sigma: Optional[float] = 1.0
self, X: Union[TensorType["m", "d"], TensorType["m", "p"]]
) -> TensorType["m", "m"]:
"""Use an rbf kernel for computing the gram matrix. Sigma defines the width."""
GX = X @ X.T
KX = torch.diag(GX) - GX + (torch.diag(GX) - GX).T
if sigma is None:
if self.sigma is None:
mdist = torch.median(KX[KX != 0])
sigma = torch.sqrt(mdist)
else:
sigma = self.sigma
KX *= -0.5 / sigma**2
KX = KX.exp()
return KX
Expand Down

0 comments on commit 8b3e592

Please sign in to comment.