Skip to content

Commit ea286e6

Browse files
authored
Add conditional mean and cov (#39)
* Add conditional mean and cov * Change log_prob test to asymmetric A * Doc
1 parent 2a43ef6 commit ea286e6

File tree

5 files changed

+141
-12
lines changed

5 files changed

+141
-12
lines changed

tests/test_conditional.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import jax
2+
from jax import numpy as jnp
3+
4+
import thermox
5+
6+
7+
def test_mean_and_cov():
8+
jax.config.update("jax_enable_x64", True)
9+
dim = 2
10+
t = 1.0
11+
12+
A = jnp.array([[3, 2.5], [2, 4.0]])
13+
b = jax.random.normal(jax.random.PRNGKey(1), (dim,))
14+
x0 = jax.random.normal(jax.random.PRNGKey(2), (dim,))
15+
D = 2 * jnp.eye(dim)
16+
17+
mean = thermox.conditional.mean(t, x0, A, b, D)
18+
samples = jax.vmap(
19+
lambda k: thermox.sample(k, jnp.array([0.0, t]), x0, A, b, D)[-1]
20+
)(jax.random.split(jax.random.PRNGKey(0), 1000000))
21+
assert mean.shape == (dim,)
22+
assert jnp.allclose(mean, jnp.mean(samples, axis=0), atol=1e-2)
23+
24+
cov = thermox.conditional.covariance(t, A, D)
25+
assert cov.shape == (dim, dim)
26+
assert jnp.allclose(cov, jnp.cov(samples.T), atol=1e-3)
27+
28+
mean_and_cov = thermox.conditional.mean_and_covariance(t, x0, A, b, D)
29+
assert mean_and_cov[0].shape == (dim,)
30+
assert mean_and_cov[1].shape == (dim, dim)
31+
assert jnp.allclose(mean_and_cov[0], mean, atol=1e-5)
32+
assert jnp.allclose(mean_and_cov[1], cov, atol=1e-5)

tests/test_log_prob.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -91,42 +91,40 @@ def test_MLE():
9191
D_true = jnp.array([[1, 0.3, -0.1], [0.3, 1, 0.2], [-0.1, 0.2, 1.0]])
9292

9393
nts = 300
94-
ts = jnp.linspace(0, 10, nts)
94+
ts = jnp.linspace(0, 100, nts)
9595
x0 = jnp.zeros_like(b_true)
9696

97-
n_trajecs = 3
97+
n_trajecs = 5
9898
rks = jax.random.split(jax.random.PRNGKey(0), n_trajecs)
9999

100100
samps = jax.vmap(lambda key: thermox.sample(key, ts, x0, A_true, b_true, D_true))(
101101
rks
102102
)
103103

104-
A_sqrt_init = jnp.tril(jnp.eye(3) + jax.random.normal(rks[0], (3, 3)) * 1e-1)
104+
A_init = jnp.eye(3) + jax.random.normal(rks[0], (3, 3)) * 1e-1
105105
b_init = jnp.zeros(3)
106106
D_sqrt_init = jnp.eye(3)
107107

108108
log_prob_true = thermox.log_prob(ts, samps[0], A_true, b_true, D_true)
109109
log_prob_init = thermox.log_prob(
110-
ts, samps[0], A_sqrt_init @ A_sqrt_init.T, b_init, D_sqrt_init @ D_sqrt_init.T
110+
ts, samps[0], A_init, b_init, D_sqrt_init @ D_sqrt_init.T
111111
)
112112

113113
assert log_prob_true > log_prob_init
114114

115115
# Gradient descent
116116
def loss(params):
117-
A_sqrt, b, D_sqrt = params
118-
A_sqrt = jnp.tril(A_sqrt)
117+
A, b, D_sqrt = params
119118
D_sqrt = jnp.tril(D_sqrt)
120-
A = A_sqrt @ A_sqrt.T
121119
D = D_sqrt @ D_sqrt.T
122120
return -jax.vmap(lambda s: thermox.log_prob(ts, s, A, b, D))(
123121
samps
124122
).mean() / len(ts)
125123

126124
val_and_g = jax.jit(jax.value_and_grad(loss))
127125

128-
ps = (A_sqrt_init, b_init, D_sqrt_init)
129-
ps_true = (jnp.linalg.cholesky(A_true), b_true, jnp.linalg.cholesky(D_true))
126+
ps = (A_init, b_init, D_sqrt_init)
127+
ps_true = (A_true, b_true, jnp.linalg.cholesky(D_true))
130128

131129
v, g = val_and_g(ps)
132130
v_true, g_true = val_and_g(ps_true)
@@ -138,7 +136,7 @@ def loss(params):
138136
n_steps = 20000
139137
neg_log_probs = jnp.zeros(n_steps)
140138

141-
optimizer = optax.adam(1e-2)
139+
optimizer = optax.adam(1e-3)
142140
opt_state = optimizer.init(ps)
143141

144142
for i in range(n_steps):
@@ -149,7 +147,7 @@ def loss(params):
149147
ps = optax.apply_updates(ps, updates)
150148
neg_log_probs = neg_log_probs.at[i].set(neg_log_prob)
151149

152-
A_recover = ps[0] @ ps[0].T
150+
A_recover = ps[0]
153151
b_recover = ps[1]
154152
D_recover = ps[2] @ ps[2].T
155153

thermox/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from thermox import linalg
2+
from thermox import conditional
23
from thermox.sampler import sample
34
from thermox.prob import log_prob
45
from thermox.utils import preprocess

thermox/conditional.py

+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
from jax import numpy as jnp
2+
from jax import Array
3+
4+
from thermox.utils import (
5+
ProcessedDriftMatrix,
6+
ProcessedDiffusionMatrix,
7+
handle_matrix_inputs,
8+
)
9+
from thermox.sampler import expm_vp
10+
11+
12+
def mean(
13+
t: float,
14+
x0: Array,
15+
A: Array | ProcessedDriftMatrix,
16+
b: Array,
17+
D: Array | ProcessedDiffusionMatrix,
18+
) -> Array:
19+
"""Computes the mean of p(x_t | x_0)
20+
21+
For x_t evolving according to the SDE:
22+
23+
dx = - A * (x - b) dt + sqrt(D) dW
24+
25+
Args:
26+
ts: Times at which samples are collected. Includes time for x0.
27+
x0: Initial state of the process.
28+
A: Drift matrix (Array or thermox.ProcessedDriftMatrix).
29+
Note: If a thermox.ProcessedDriftMatrix instance is used as input,
30+
must be the transformed drift matrix, A_y, given by thermox.preprocess,
31+
not thermox.utils.preprocess_drift_matrix.
32+
b: Drift displacement vector.
33+
D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix).
34+
35+
"""
36+
A_y, D = handle_matrix_inputs(A, D)
37+
38+
y0 = D.sqrt_inv @ (x0 - b)
39+
return b + D.sqrt @ expm_vp(A_y, y0, t)
40+
41+
42+
def covariance(
43+
t: float,
44+
A: Array | ProcessedDriftMatrix,
45+
D: Array | ProcessedDiffusionMatrix,
46+
) -> Array:
47+
"""Computes the covariance of p(x_t | x_0)
48+
49+
For x evolving according to the SDE:
50+
51+
dx = - A * (x - b) dt + sqrt(D) dW
52+
53+
Args:
54+
ts: Times at which samples are collected. Includes time for x0.
55+
A: Drift matrix (Array or thermox.ProcessedDriftMatrix).
56+
Note: If a thermox.ProcessedDriftMatrix instance is used as input,
57+
must be the transformed drift matrix, A_y, given by thermox.preprocess,
58+
not thermox.utils.preprocess_drift_matrix.
59+
D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix).
60+
"""
61+
A_y, D = handle_matrix_inputs(A, D)
62+
63+
identity_diffusion_cov = (
64+
A_y.sym_eigvecs
65+
@ jnp.diag((1 - jnp.exp(-2 * A_y.sym_eigvals * t)) / (2 * A_y.sym_eigvals))
66+
@ A_y.sym_eigvecs.T
67+
)
68+
return D.sqrt @ identity_diffusion_cov @ D.sqrt.T
69+
70+
71+
def mean_and_covariance(
72+
t: float,
73+
x0: Array,
74+
A: Array | ProcessedDriftMatrix,
75+
b: Array,
76+
D: Array | ProcessedDiffusionMatrix,
77+
) -> tuple[Array, Array]:
78+
"""Computes the mean and covariance of p(x_t | x_0)
79+
80+
For x evolving according to the SDE:
81+
82+
dx = - A * (x - b) dt + sqrt(D) dW
83+
84+
Args:
85+
ts: Times at which samples are collected. Includes time for x0.
86+
x0: Initial state of the process.
87+
A: Drift matrix (Array or thermox.ProcessedDriftMatrix).
88+
Note: If a thermox.ProcessedDriftMatrix instance is used as input,
89+
must be the transformed drift matrix, A_y, given by thermox.preprocess,
90+
not thermox.utils.preprocess_drift_matrix.
91+
b: Drift displacement vector.
92+
D: Diffusion matrix (Array or thermox.ProcessedDiffusionMatrix).
93+
94+
"""
95+
A, D = handle_matrix_inputs(A, D)
96+
mean_val = mean(t, x0, A, b, D)
97+
covariance_val = covariance(t, A, D)
98+
return mean_val, covariance_val

thermox/prob.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def log_prob(
3636
3737
Args:
3838
ts: Times at which samples are collected. Includes time for x0.
39-
xs: Initial state of the process.
39+
xs: States of the process.
4040
A: Drift matrix (Array or thermox.ProcessedDriftMatrix).
4141
Note: If a thermox.ProcessedDriftMatrix instance is used as input,
4242
must be the transformed drift matrix, A_y, given by thermox.preprocess,

0 commit comments

Comments
 (0)