diff --git a/src/discovery/matrix.py b/src/discovery/matrix.py index e90b780..cd1e5c7 100644 --- a/src/discovery/matrix.py +++ b/src/discovery/matrix.py @@ -1123,6 +1123,21 @@ class WoodburyKernel_varNP(VariableKernel): def __init__(self, N_var, F, P_var): self.N_var, self.F, self.P_var = N_var, F, P_var + def make_sample(self): + N_sample = self.N_var.make_sample() + P_sample = self.P_var.make_sample() + F = jnparray(self.F) + + def sample(key, params): + key, n = N_sample(key, params) + key, c = P_sample(key, params) + + return key, n + jnp.dot(F, c) + + sample.params = sorted(set(N_sample.params + P_sample.params)) + + return sample + def make_kernelproduct_vary(self, y): y_var = y @@ -1725,7 +1740,7 @@ def __init__(self, N_var, F, P): self.Pinv, self.ldP = P.inv() self.params = N_var.params - def make_sample(self, params): + def make_sample(self): N_sample = self.N_var.make_sample() P_sample = self.P.make_sample() F = jnparray(self.F)