-
Notifications
You must be signed in to change notification settings - Fork 151
Open
Description
已知
训练时 根据以上公式,可以从
eps = torch.randn_like(src).to(device) # 根据 x的shape 随机生成eps 数据 数据符合正态分布随机数
x_t = ddpm.sample_forward(src, t, eps) # 原图 步数t 噪声图
eps_theta = net(x_t, t.reshape(current_batch_size, 1))
loss = loss_fn(eps_theta, eps)
而推理时
但是在如下代码实现中为什么 在倒数第二行,返回的是 mean + noise,即 输入
def sample_backward_step(self, x_t, t, net, simple_var=False):
n = x_t.shape[0]
t_tensor = torch.tensor([t] * n, dtype=torch.long).to(x_t.device).unsqueeze(1)
eps = net(x_t, t_tensor)
if t == 0:
noise = 0
else:
if simple_var:
var = self.betas[t]
else:
var = (1 - self.alpha_bars[t - 1]) / (1 - self.alpha_bars[t]) * self.betas[t]
noise = torch.randn_like(x_t)
noise *= torch.sqrt(var)
mean = (x_t - (1 - self.alphas[t]) / torch.sqrt(1 - self.alpha_bars[t]) * eps) / torch.sqrt(self.alphas[t])
x_t = mean + noise
return x_t
望解答
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels