Skip to content

Commit

Permalink
fixed rbf kernel in cka
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasMut committed May 17, 2024
1 parent 8b3e592 commit 40099bb
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions thingsvision/core/cka/cka_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,15 @@ def linear_kernel(self, X: Array) -> Array:
"""Use a linear kernel for computing the gram matrix."""
return X @ X.T

def rbf_kernel(self, X: Array, sigma: float = 1.0) -> Array:
def rbf_kernel(self, X: Array) -> Array:
"""Use an rbf kernel for computing the gram matrix. Sigma defines the width."""
GX = X @ X.T
KX = np.diag(GX) - GX + (np.diag(GX) - GX).T
if sigma is None:
if self.sigma is None:
mdist = np.median(KX[KX != 0])
sigma = math.sqrt(mdist)
else:
sigma = self.sigma
KX *= -0.5 / sigma**2
KX = np.exp(KX)
return KX
Expand Down

0 comments on commit 40099bb

Please sign in to comment.