From 62aa121412cf0bbbd979828f8a5bcd71d13273db Mon Sep 17 00:00:00 2001 From: Chenyang Yuan Date: Thu, 15 Feb 2024 09:47:27 -0500 Subject: [PATCH] Updated ideal denoiser --- src/smalldiffusion/model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/smalldiffusion/model.py b/src/smalldiffusion/model.py index 9fec8b7..1571c41 100644 --- a/src/smalldiffusion/model.py +++ b/src/smalldiffusion/model.py @@ -45,12 +45,13 @@ def __init__(self, dataset): def __call__(self, x, sigma): assert sigma.shape == tuple(), 'Only singleton sigma supported' + data = self.data.to(x) x_flat = x.flatten(start_dim=1) - d_flat = self.data.flatten(start_dim=1) + d_flat = data.flatten(start_dim=1) xb, xr = x_flat.shape db, dr = d_flat.shape assert xr == dr, 'Input x must have same dimension as data!' # ||x - x0||^2 ,shape xb x db sq_diffs = sq_norm(x_flat, db) + sq_norm(d_flat, xb).T - 2 * x_flat @ d_flat.T weights = torch.nn.functional.softmax(-sq_diffs/2/sigma**2, dim=1) - return (x - weights @ self.data)/sigma + return (x - torch.einsum('ij,j...->i...', weights, data))/sigma