-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_sampling.py
105 lines (76 loc) · 2.55 KB
/
test_sampling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import jax
import jax.random as jr
import jax.numpy as jnp
import equinox as eqx
from sbgm.sde import VPSDE
from sbgm._sample import get_eu_sample_fn, get_ode_sample_fn
def test_sampling():
import diffrax
print(jax.__version__, eqx.__version__, diffrax.__version__)
key = jr.key(0)
sde = VPSDE(beta_integral_fn=lambda t: t)
n_sample = 10
data_shape = (1, 32, 32)
class MLP(eqx.Module):
net: eqx.nn.MLP
def __init__(self, *args, **kwargs):
self.net = eqx.nn.MLP(*args, **kwargs)
def __call__(self, t, x, q=None, a=None, key=None):
_inputs = [jnp.atleast_1d(t), x.flatten()]
if a is not None:
_inputs += [a]
if q is not None:
_inputs += [q.flatten()]
out = self.net(jnp.concatenate(_inputs))
out = out.reshape(data_shape)
return out
a_dim = None
q_dim = None
_in_size = 1024 + 1
if q_dim is not None:
_in_size += q_dim
if a_dim is not None:
_in_size += a_dim
model = MLP(
_in_size,
out_size=1024,
width_size=1024,
depth=1,
activation=jax.nn.tanh,
key=key
)
key_samples = jr.split(key, n_sample)
eu_sample_fn = get_eu_sample_fn(model, sde, data_shape)
eu_samples = jax.vmap(eu_sample_fn)(key_samples)
assert eu_samples.shape == (n_sample,) + data_shape
assert jnp.all(jnp.isfinite(eu_samples))
ode_sample_fn = get_ode_sample_fn(model, sde, data_shape)
ode_samples = jax.vmap(ode_sample_fn)(key_samples)
assert ode_samples.shape == (n_sample,) + data_shape
assert jnp.all(jnp.isfinite(ode_samples))
a_dim = 5
q_dim = 10
_in_size = 1024 + 1
if q_dim is not None:
_in_size += q_dim
if a_dim is not None:
_in_size += a_dim
model = MLP(
_in_size,
out_size=1024,
width_size=1024,
depth=1,
activation=jax.nn.tanh,
key=key
)
key_samples = jr.split(key, n_sample)
Q = jnp.ones((n_sample,) + (q_dim,))
A = jnp.ones((n_sample,) + (a_dim,))
eu_sample_fn = get_eu_sample_fn(model, sde, data_shape)
eu_samples = jax.vmap(eu_sample_fn)(key_samples, Q, A)
assert eu_samples.shape == (n_sample,) + data_shape
assert jnp.all(jnp.isfinite(eu_samples))
ode_sample_fn = get_ode_sample_fn(model, sde, data_shape)
ode_samples = jax.vmap(ode_sample_fn)(key_samples, Q, A)
assert ode_samples.shape == (n_sample,) + data_shape
assert jnp.all(jnp.isfinite(ode_samples))