-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_sdes.py
61 lines (41 loc) · 1.3 KB
/
test_sdes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import jax
import jax.numpy as jnp
import jax.random as jr
from sbgm.sde import VESDE, VPSDE, SubVPSDE
def test_sdes():
key = jr.key(0)
data_dim = 2
x0 = jnp.ones((data_dim,))
eps = jr.normal(key, (data_dim,))
t = jnp.array(0.5)
def diffuse(sde, x, t, eps):
mu, std = sde.marginal_prob(x, t)
return mu + std * eps
sde = VESDE(sigma_fn=lambda t: t)
xt = diffuse(sde, x0, t, eps)
f, g = sde.sde(x0, t)
print(f.shape, g.shape, xt.shape)
assert f.shape == (data_dim,)
assert g.shape == ()
assert xt.shape == x0.shape
assert jnp.all(jnp.isfinite(xt))
assert jnp.all(jnp.isfinite(f))
assert jnp.all(jnp.isfinite(g))
sde = VPSDE(beta_integral_fn=lambda t: t)
xt = diffuse(sde, x0, t, eps)
f, g = sde.sde(x0, t)
assert f.shape == (data_dim,)
assert g.shape == ()
assert xt.shape == x0.shape
assert jnp.all(jnp.isfinite(xt))
assert jnp.all(jnp.isfinite(f))
assert jnp.all(jnp.isfinite(g))
sde = SubVPSDE(beta_integral_fn=lambda t: t)
xt = diffuse(sde, x0, t, eps)
f, g = sde.sde(x0, t)
assert f.shape == (data_dim,)
assert g.shape == ()
assert xt.shape == x0.shape
assert jnp.all(jnp.isfinite(xt))
assert jnp.all(jnp.isfinite(f))
assert jnp.all(jnp.isfinite(g))