Skip to content

Commit

Permalink
Updated ideal denoiser
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanchenyang committed Feb 15, 2024
1 parent b7abc02 commit 62aa121
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/smalldiffusion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@ def __init__(self, dataset):

def __call__(self, x, sigma):
assert sigma.shape == tuple(), 'Only singleton sigma supported'
data = self.data.to(x)
x_flat = x.flatten(start_dim=1)
d_flat = self.data.flatten(start_dim=1)
d_flat = data.flatten(start_dim=1)
xb, xr = x_flat.shape
db, dr = d_flat.shape
assert xr == dr, 'Input x must have same dimension as data!'
# ||x - x0||^2 ,shape xb x db
sq_diffs = sq_norm(x_flat, db) + sq_norm(d_flat, xb).T - 2 * x_flat @ d_flat.T
weights = torch.nn.functional.softmax(-sq_diffs/2/sigma**2, dim=1)
return (x - weights @ self.data)/sigma
return (x - torch.einsum('ij,j...->i...', weights, data))/sigma

0 comments on commit 62aa121

Please sign in to comment.