Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/01_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
if result.success:
print(f"\nCompilation successful in {result.n_attempts} attempt(s)!")
print(f"Timings: {result.timings}")
print(f"\nGenerated Rust code saved to: compiled_models/normal/src/generated.rs")
print("\nGenerated Rust code saved to: compiled_models/normal/src/generated.rs")
else:
print(f"\nCompilation FAILED after {result.n_attempts} attempts")
for err in result.validation_errors:
Expand Down
8 changes: 5 additions & 3 deletions examples/03_hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@
mu_y = a[group_idx] + b * x
y = pm.Normal("y", mu=mu_y, sigma=sigma_y, observed=y_obs)

print(f"True: mu_a={true_mu_a}, sigma_a={true_sigma_a}, b={true_b}, sigma_y={true_sigma_y}")
print(
f"True: mu_a={true_mu_a}, sigma_a={true_sigma_a}, b={true_b}, sigma_y={true_sigma_y}"
)
print(f"Data: {n_groups} groups, {N} observations")
print(f"Group sizes: {n_per_group}")
print()
Expand All @@ -76,8 +78,8 @@

if result.success:
print(f"\nCompilation successful in {result.n_attempts} attempt(s)!")
print(f"\nNow you can benchmark:")
print(f" python -c 'from pymc_rust_compiler.benchmark import *; ...'")
print("\nNow you can benchmark:")
print(" python -c 'from pymc_rust_compiler.benchmark import *; ...'")
else:
print(f"\nCompilation FAILED after {result.n_attempts} attempts")
for err in result.validation_errors[:5]:
Expand Down
47 changes: 33 additions & 14 deletions examples/04_zerosumnormal.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
np.random.seed(314)

n_stores = 6
n_days = 7 # Mon-Sun
n_days = 7 # Mon-Sun
n_categories = 4 # e.g., Electronics, Clothing, Food, Home

store_names = [f"store_{i}" for i in range(n_stores)]
Expand Down Expand Up @@ -73,8 +73,12 @@
for d in range(n_days):
for c in range(n_categories):
n = np.random.poisson(n_obs_per_cell) + 1
mu = (true_grand_mean + true_store_effect[s]
+ true_day_effect[d] + true_interaction[s, d, c])
mu = (
true_grand_mean
+ true_store_effect[s]
+ true_day_effect[d]
+ true_interaction[s, d, c]
)
y_vals = np.random.normal(mu, true_sigma_y, n)
for y in y_vals:
records.append((s, d, c, y))
Expand Down Expand Up @@ -128,18 +132,22 @@
n_zerosum_axes=2,
)

mu = (grand_mean
+ store_effect[store_idx]
+ day_effect[day_idx]
+ interaction[store_idx, day_idx, cat_idx])
mu = (
grand_mean
+ store_effect[store_idx]
+ day_effect[day_idx]
+ interaction[store_idx, day_idx, cat_idx]
)

sigma_y = pm.HalfNormal("sigma_y", sigma=5)
pm.Normal("y", mu=mu, sigma=sigma_y, observed=y_obs)

n_free = sum(v.size for v in model.initial_point().values())
print(f"Free RVs: {[rv.name for rv in model.free_RVs]}")
print(f"Unconstrained parameters: {n_free}")
print(f"Transforms: {[(rv.name, type(model.rvs_to_transforms.get(rv)).__name__) for rv in model.free_RVs]}")
print(
f"Transforms: {[(rv.name, type(model.rvs_to_transforms.get(rv)).__name__) for rv in model.free_RVs]}"
)
print()

result = compile_model(
Expand Down Expand Up @@ -167,13 +175,24 @@
import arviz as az

print("\n--- Posterior summary (hyperparameters) ---")
print(az.summary(idata, var_names=[
"grand_mean", "sigma_store", "sigma_day", "sigma_cat", "sigma_y",
]))
print(
az.summary(
idata,
var_names=[
"grand_mean",
"sigma_store",
"sigma_day",
"sigma_cat",
"sigma_y",
],
)
)

print(f"\nTrue values: grand_mean={true_grand_mean}, "
f"sigma_store={true_sigma_store}, sigma_day={true_sigma_day}, "
f"sigma_cat={true_sigma_cat}, sigma_y={true_sigma_y}")
print(
f"\nTrue values: grand_mean={true_grand_mean}, "
f"sigma_store={true_sigma_store}, sigma_day={true_sigma_day}, "
f"sigma_cat={true_sigma_cat}, sigma_y={true_sigma_y}"
)

# Posterior plots
axes = az.plot_posterior(
Expand Down
57 changes: 36 additions & 21 deletions examples/05_celeri_simplified.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,39 @@
# --- Synthetic tectonic data ---
np.random.seed(42)

n_blocks = 3 # tectonic blocks
n_faults = 4 # fault segments
n_stations = 25 # GPS stations
n_bounded = 3 # faults with geologic slip rate bounds
n_blocks = 3 # tectonic blocks
n_faults = 4 # fault segments
n_stations = 25 # GPS stations
n_bounded = 3 # faults with geologic slip rate bounds

# True parameters
true_rotation = np.array([
0.5, -0.3, 0.1, # Block 1: wx, wy, wz (rad/Gyr)
-0.2, 0.4, -0.1, # Block 2
0.1, -0.1, 0.3, # Block 3
])
true_rotation = np.array(
[
0.5,
-0.3,
0.1, # Block 1: wx, wy, wz (rad/Gyr)
-0.2,
0.4,
-0.1, # Block 2
0.1,
-0.1,
0.3, # Block 3
]
)
rotation_scale = np.array([1.0, 1.0, 0.5] * n_blocks) # prior scales

true_slip = np.array([
2.0, 0.5, # Fault 1: strike-slip, dip-slip (mm/yr)
-1.5, 1.0, # Fault 2
0.8, -0.3, # Fault 3
-0.5, 0.2, # Fault 4
])
true_slip = np.array(
[
2.0,
0.5, # Fault 1: strike-slip, dip-slip (mm/yr)
-1.5,
1.0, # Fault 2
0.8,
-0.3, # Fault 3
-0.5,
0.2, # Fault 4
]
)
slip_prior_sigma = 5.0

# Design matrices (Green's functions)
Expand All @@ -84,7 +98,7 @@
# Regularization
gamma = 2.0 # regularization strength

print(f"Tectonic block model:")
print("Tectonic block model:")
print(f" {n_blocks} blocks ({n_blocks * 3} rotation params)")
print(f" {n_faults} faults ({n_faults * 2} slip rate params)")
print(f" {n_stations} GPS stations ({n_stations * 2} velocity observations)")
Expand Down Expand Up @@ -124,9 +138,8 @@
slip_rate = pm.Normal("slip_rate", mu=0, sigma=slip_prior_sigma, shape=n_faults * 2)

# Predicted GPS velocities via design matrices
predicted_velocity = (
pm.math.dot(G_rotation, rotation)
+ pm.math.dot(G_slip, slip_rate)
predicted_velocity = pm.math.dot(G_rotation, rotation) + pm.math.dot(
G_slip, slip_rate
)

# GPS station velocity likelihood (StudentT for heavy tails)
Expand Down Expand Up @@ -167,7 +180,7 @@
)

if result.success:
print(f"\nCompilation successful!")
print("\nCompilation successful!")
print(f" Builds: {result.n_attempts}")
print(f" Tool calls: {result.n_tool_calls}")
print(f" Turns: {result.conversation_turns}")
Expand Down Expand Up @@ -195,6 +208,8 @@
print(f"\nTrue rotation: {true_rotation}")
print(f"True slip rates: {true_slip}")
else:
print(f"\nCompilation FAILED after {result.n_attempts} builds, {result.n_tool_calls} tool calls")
print(
f"\nCompilation FAILED after {result.n_attempts} builds, {result.n_tool_calls} tool calls"
)
for err in result.validation_errors[:5]:
print(f" - {err}")
22 changes: 15 additions & 7 deletions examples/bench_logp.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def make_hierarchical_model():
"""Hierarchical model, 12 unconstrained parameters."""
build_dir = Path("compiled_models/hierarchical")
y_obs = np.load(build_dir / "y_data.npy")
x = np.load(build_dir / "x_0_data.npy") # binary covariate
x = np.load(build_dir / "x_0_data.npy") # binary covariate
group_idx = np.load(build_dir / "x_1_data.npy").astype(int) # group indices
n_groups = int(group_idx.max()) + 1
with pm.Model() as model:
Expand Down Expand Up @@ -94,9 +94,9 @@ def main():

results = []
for name, make_fn in models:
print(f"\n{'='*65}")
print(f"\n{'=' * 65}")
print(f" {name}")
print(f"{'='*65}")
print(f"{'=' * 65}")

model, build_dir = make_fn()
n_evals = N_EVALS
Expand All @@ -107,24 +107,32 @@ def main():
pt_result = benchmark_logp_pytensor(model, n_evals=n_evals, x0_model_order=x0)
print(f" pytensor (python loop): {pt_result['us_per_eval']:.2f} us/eval")

cfunc_result = benchmark_logp_numba_cfunc(model, n_evals=n_evals, x0_model_order=x0)
cfunc_result = benchmark_logp_numba_cfunc(
model, n_evals=n_evals, x0_model_order=x0
)
print(f" numba cfunc (rust loop): {cfunc_result['us_per_eval']:.2f} us/eval")

rs_result = benchmark_logp_rust(build_dir, model, n_evals=n_evals, x0_model_order=x0)
rs_result = benchmark_logp_rust(
build_dir, model, n_evals=n_evals, x0_model_order=x0
)
if "error" in rs_result:
print(f" rust-ai: ERROR - {rs_result['error'][:100]}")
else:
print(f" rust-ai: {rs_result['us_per_eval']:.2f} us/eval")

print_logp_comparison(pt_result, rs_result, model_name=name)
print_logp_comparison(cfunc_result, rs_result, model_name=f"{name} [cfunc vs rust]")
print_logp_comparison(
cfunc_result, rs_result, model_name=f"{name} [cfunc vs rust]"
)
results.append((name, pt_result, cfunc_result, rs_result))

# Summary table
print("\n" + "=" * 85)
print("SUMMARY: logp+dlogp evaluation speed")
print("=" * 85)
print(f"\n{'Model':<25} {'pytensor':<14} {'cfunc+rust':<14} {'rust-ai':<14} {'cfunc/rust':<12}")
print(
f"\n{'Model':<25} {'pytensor':<14} {'cfunc+rust':<14} {'rust-ai':<14} {'cfunc/rust':<12}"
)
print("-" * 79)
for name, pt, cf, rs in results:
pt_us = f"{pt['us_per_eval']:.2f} us" if "error" not in pt else "ERROR"
Expand Down
9 changes: 6 additions & 3 deletions examples/jax_to_pytorch_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,21 @@ def main():
)

if result.success:
print(f"\nTranspilation successful!")
print("\nTranspilation successful!")
print(f" Tool calls: {result.n_tool_calls}")
print(f" Tokens: {result.token_usage['total_tokens']}")
print(f"\nGenerated PyTorch code:")
print("\nGenerated PyTorch code:")
print(result.generated_code)

# Test the generated model
import torch

model = result.get_model({k: np.asarray(v) for k, v in params.items()})
pt_out = model(torch.tensor(np.asarray(x)))
print(f"\nPyTorch output:\n{pt_out.detach().numpy()}")
print(f"Max diff: {np.max(np.abs(pt_out.detach().numpy() - np.asarray(out))):.2e}")
print(
f"Max diff: {np.max(np.abs(pt_out.detach().numpy() - np.asarray(out))):.2e}"
)
else:
print(f"\nTranspilation failed: {result.validation_errors}")

Expand Down
Loading
Loading