Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions examples/03_hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@
mu_y = a[group_idx] + b * x
y = pm.Normal("y", mu=mu_y, sigma=sigma_y, observed=y_obs)

print(
f"True: mu_a={true_mu_a}, sigma_a={true_sigma_a}, b={true_b}, sigma_y={true_sigma_y}"
)
print(f"True: mu_a={true_mu_a}, sigma_a={true_sigma_a}, b={true_b}, sigma_y={true_sigma_y}")
print(f"Data: {n_groups} groups, {N} observations")
print(f"Group sizes: {n_per_group}")
print()
Expand Down
18 changes: 3 additions & 15 deletions examples/04_zerosumnormal.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,7 @@
for d in range(n_days):
for c in range(n_categories):
n = np.random.poisson(n_obs_per_cell) + 1
mu = (
true_grand_mean
+ true_store_effect[s]
+ true_day_effect[d]
+ true_interaction[s, d, c]
)
mu = true_grand_mean + true_store_effect[s] + true_day_effect[d] + true_interaction[s, d, c]
y_vals = np.random.normal(mu, true_sigma_y, n)
for y in y_vals:
records.append((s, d, c, y))
Expand Down Expand Up @@ -132,22 +127,15 @@
n_zerosum_axes=2,
)

mu = (
grand_mean
+ store_effect[store_idx]
+ day_effect[day_idx]
+ interaction[store_idx, day_idx, cat_idx]
)
mu = grand_mean + store_effect[store_idx] + day_effect[day_idx] + interaction[store_idx, day_idx, cat_idx]

sigma_y = pm.HalfNormal("sigma_y", sigma=5)
pm.Normal("y", mu=mu, sigma=sigma_y, observed=y_obs)

n_free = sum(v.size for v in model.initial_point().values())
print(f"Free RVs: {[rv.name for rv in model.free_RVs]}")
print(f"Unconstrained parameters: {n_free}")
print(
f"Transforms: {[(rv.name, type(model.rvs_to_transforms.get(rv)).__name__) for rv in model.free_RVs]}"
)
print(f"Transforms: {[(rv.name, type(model.rvs_to_transforms.get(rv)).__name__) for rv in model.free_RVs]}")
print()

result = compile_model(
Expand Down
8 changes: 2 additions & 6 deletions examples/05_celeri_simplified.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,7 @@
slip_rate = pm.Normal("slip_rate", mu=0, sigma=slip_prior_sigma, shape=n_faults * 2)

# Predicted GPS velocities via design matrices
predicted_velocity = pm.math.dot(G_rotation, rotation) + pm.math.dot(
G_slip, slip_rate
)
predicted_velocity = pm.math.dot(G_rotation, rotation) + pm.math.dot(G_slip, slip_rate)

# GPS station velocity likelihood (StudentT for heavy tails)
pm.StudentT(
Expand Down Expand Up @@ -208,8 +206,6 @@
print(f"\nTrue rotation: {true_rotation}")
print(f"True slip rates: {true_slip}")
else:
print(
f"\nCompilation FAILED after {result.n_attempts} builds, {result.n_tool_calls} tool calls"
)
print(f"\nCompilation FAILED after {result.n_attempts} builds, {result.n_tool_calls} tool calls")
for err in result.validation_errors[:5]:
print(f" - {err}")
16 changes: 4 additions & 12 deletions examples/bench_logp.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,32 +107,24 @@ def main():
pt_result = benchmark_logp_pytensor(model, n_evals=n_evals, x0_model_order=x0)
print(f" pytensor (python loop): {pt_result['us_per_eval']:.2f} us/eval")

cfunc_result = benchmark_logp_numba_cfunc(
model, n_evals=n_evals, x0_model_order=x0
)
cfunc_result = benchmark_logp_numba_cfunc(model, n_evals=n_evals, x0_model_order=x0)
print(f" numba cfunc (rust loop): {cfunc_result['us_per_eval']:.2f} us/eval")

rs_result = benchmark_logp_rust(
build_dir, model, n_evals=n_evals, x0_model_order=x0
)
rs_result = benchmark_logp_rust(build_dir, model, n_evals=n_evals, x0_model_order=x0)
if "error" in rs_result:
print(f" rust-ai: ERROR - {rs_result['error'][:100]}")
else:
print(f" rust-ai: {rs_result['us_per_eval']:.2f} us/eval")

print_logp_comparison(pt_result, rs_result, model_name=name)
print_logp_comparison(
cfunc_result, rs_result, model_name=f"{name} [cfunc vs rust]"
)
print_logp_comparison(cfunc_result, rs_result, model_name=f"{name} [cfunc vs rust]")
results.append((name, pt_result, cfunc_result, rs_result))

# Summary table
print("\n" + "=" * 85)
print("SUMMARY: logp+dlogp evaluation speed")
print("=" * 85)
print(
f"\n{'Model':<25} {'pytensor':<14} {'cfunc+rust':<14} {'rust-ai':<14} {'cfunc/rust':<12}"
)
print(f"\n{'Model':<25} {'pytensor':<14} {'cfunc+rust':<14} {'rust-ai':<14} {'cfunc/rust':<12}")
print("-" * 79)
for name, pt, cf, rs in results:
pt_us = f"{pt['us_per_eval']:.2f} us" if "error" not in pt else "ERROR"
Expand Down
69 changes: 55 additions & 14 deletions examples/generate_blog_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@

import os
import sys
import time
from pathlib import Path

import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -34,6 +34,7 @@
# Model factory functions
# ---------------------------------------------------------------------------


def make_normal_model():
build_dir = Path("compiled_models/normal")
y_obs = np.load(build_dir / "y_data.npy")
Expand Down Expand Up @@ -94,13 +95,14 @@ def make_gp_model():
# Compile + Optimize
# ---------------------------------------------------------------------------


def run_compile_and_optimize(name, make_model_fn, source_code, build_dir):
"""Compile a PyMC model to Rust, then optimize it."""
from transpailer import compile_model, optimize_model
from transpailer.analysis import (
plot_optimization_progress,
plot_waterfall,
plot_timeline,
plot_waterfall,
print_summary,
)

Expand Down Expand Up @@ -163,6 +165,7 @@ def run_compile_and_optimize(name, make_model_fn, source_code, build_dir):
# Benchmark comparison plot
# ---------------------------------------------------------------------------


def run_benchmarks_and_plot():
"""Run logp benchmarks on compiled models and generate comparison bar chart."""
from transpailer.benchmark import (
Expand All @@ -174,7 +177,11 @@ def run_benchmarks_and_plot():
models_info = [
("Normal\n(2 params)", make_normal_model, "compiled_models/normal"),
("LinReg\n(3 params)", make_linreg_model, "compiled_models/linreg"),
("Hierarchical\n(12 params)", make_hierarchical_model, "compiled_models/hierarchical"),
(
"Hierarchical\n(12 params)",
make_hierarchical_model,
"compiled_models/hierarchical",
),
("GP\n(3 params)", make_gp_model, "compiled_models/gp"),
]

Expand Down Expand Up @@ -226,10 +233,24 @@ def run_benchmarks_and_plot():
width = 0.35

# Bar chart: us/eval comparison
bars1 = ax1.bar(x - width / 2, pt_vals, width, label="PyTensor (Numba)",
color="#3498db", edgecolor="#2c3e50", linewidth=0.5)
bars2 = ax1.bar(x + width / 2, rs_vals, width, label="AI-compiled Rust",
color="#e74c3c", edgecolor="#2c3e50", linewidth=0.5)
bars1 = ax1.bar(
x - width / 2,
pt_vals,
width,
label="PyTensor (Numba)",
color="#3498db",
edgecolor="#2c3e50",
linewidth=0.5,
)
bars2 = ax1.bar(
x + width / 2,
rs_vals,
width,
label="AI-compiled Rust",
color="#e74c3c",
edgecolor="#2c3e50",
linewidth=0.5,
)

ax1.set_ylabel("us/eval (lower is better)", fontsize=11)
ax1.set_title("logp+gradient Evaluation Speed", fontsize=13, fontweight="bold")
Expand All @@ -239,12 +260,24 @@ def run_benchmarks_and_plot():
ax1.grid(True, alpha=0.3, axis="y")

for bar in bars1:
ax1.text(bar.get_x() + bar.get_width() / 2, bar.get_height(),
f"{bar.get_height():.1f}", ha="center", va="bottom", fontsize=8)
ax1.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height(),
f"{bar.get_height():.1f}",
ha="center",
va="bottom",
fontsize=8,
)
for bar in bars2:
if bar.get_height() > 0:
ax1.text(bar.get_x() + bar.get_width() / 2, bar.get_height(),
f"{bar.get_height():.2f}", ha="center", va="bottom", fontsize=8)
ax1.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height(),
f"{bar.get_height():.2f}",
ha="center",
va="bottom",
fontsize=8,
)

# Speedup chart
colors = ["#2ecc71" if s >= 3 else "#f39c12" if s >= 2 else "#e74c3c" for s in speedups]
Expand All @@ -258,19 +291,27 @@ def run_benchmarks_and_plot():

for bar, s in zip(bars3, speedups):
if s > 0:
ax2.text(bar.get_x() + bar.get_width() / 2, bar.get_height(),
f"{s:.1f}x", ha="center", va="bottom", fontsize=10, fontweight="bold")
ax2.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height(),
f"{s:.1f}x",
ha="center",
va="bottom",
fontsize=10,
fontweight="bold",
)

fig.tight_layout()
fig.savefig(OUTPUT_DIR / "benchmark_comparison.png", dpi=150)
plt.close(fig)
print(f"\nSaved: benchmark_comparison.png")
print("\nSaved: benchmark_comparison.png")


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------


def main():
os.chdir(Path(__file__).parent.parent)
print(f"Working dir: {Path.cwd()}")
Expand Down
4 changes: 1 addition & 3 deletions examples/jax_to_pytorch_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,7 @@ def main():
model = result.get_model({k: np.asarray(v) for k, v in params.items()})
pt_out = model(torch.tensor(np.asarray(x)))
print(f"\nPyTorch output:\n{pt_out.detach().numpy()}")
print(
f"Max diff: {np.max(np.abs(pt_out.detach().numpy() - np.asarray(out))):.2e}"
)
print(f"Max diff: {np.max(np.abs(pt_out.detach().numpy() - np.asarray(out))):.2e}")
else:
print(f"\nTranspilation failed: {result.validation_errors}")

Expand Down
47 changes: 10 additions & 37 deletions examples/mingpt_enzyme/validate_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,14 @@ def parse_data_rs(path):
name = m.group(1)
if name.endswith("_SHAPE"):
continue
vals = [
float(x.strip().rstrip(",")) for x in m.group(2).split(",") if x.strip()
]
vals = [float(x.strip().rstrip(",")) for x in m.group(2).split(",") if x.strip()]
weights[name] = np.array(vals, dtype=np.float32)
return weights


class NewGELU(nn.Module):
def forward(self, x):
return (
0.5
* x
* (
1.0
+ torch.tanh(
(2.0 / 3.141592653589793) ** 0.5 * (x + 0.044715 * x.pow(3.0))
)
)
)
return 0.5 * x * (1.0 + torch.tanh((2.0 / 3.141592653589793) ** 0.5 * (x + 0.044715 * x.pow(3.0))))


class CausalSelfAttention(nn.Module):
Expand All @@ -50,9 +39,7 @@ def __init__(self, n_embd, n_head, block_size):
self.c_proj = nn.Linear(n_embd, n_embd)
self.register_buffer(
"bias",
torch.tril(torch.ones(block_size, block_size)).view(
1, 1, block_size, block_size
),
torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size),
)
self.n_head = n_head
self.n_embd = n_embd
Expand Down Expand Up @@ -93,9 +80,7 @@ def __init__(self, n_layer=3, n_head=3, n_embd=48, block_size=8, vocab_size=32):
super().__init__()
self.n_embd = n_embd
self.seq_len = block_size
self.blocks = nn.ModuleList(
[TransformerBlock(n_embd, n_head, block_size) for _ in range(n_layer)]
)
self.blocks = nn.ModuleList([TransformerBlock(n_embd, n_head, block_size) for _ in range(n_layer)])
self.ln_f = nn.LayerNorm(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)

Expand All @@ -116,30 +101,20 @@ def load_weights_from_data_rs(model, weights):
prefix = f"BLOCKS_{i}_"
b.ln_1.weight.copy_(torch.tensor(weights[f"{prefix}LN_1_WEIGHT"]))
b.ln_1.bias.copy_(torch.tensor(weights[f"{prefix}LN_1_BIAS"]))
b.attn.c_attn.weight.copy_(
torch.tensor(weights[f"{prefix}ATTN_C_ATTN_WEIGHT"]).reshape(144, 48)
)
b.attn.c_attn.weight.copy_(torch.tensor(weights[f"{prefix}ATTN_C_ATTN_WEIGHT"]).reshape(144, 48))
b.attn.c_attn.bias.copy_(torch.tensor(weights[f"{prefix}ATTN_C_ATTN_BIAS"]))
b.attn.c_proj.weight.copy_(
torch.tensor(weights[f"{prefix}ATTN_C_PROJ_WEIGHT"]).reshape(48, 48)
)
b.attn.c_proj.weight.copy_(torch.tensor(weights[f"{prefix}ATTN_C_PROJ_WEIGHT"]).reshape(48, 48))
b.attn.c_proj.bias.copy_(torch.tensor(weights[f"{prefix}ATTN_C_PROJ_BIAS"]))
b.ln_2.weight.copy_(torch.tensor(weights[f"{prefix}LN_2_WEIGHT"]))
b.ln_2.bias.copy_(torch.tensor(weights[f"{prefix}LN_2_BIAS"]))
b.c_fc.weight.copy_(
torch.tensor(weights[f"{prefix}C_FC_WEIGHT"]).reshape(192, 48)
)
b.c_fc.weight.copy_(torch.tensor(weights[f"{prefix}C_FC_WEIGHT"]).reshape(192, 48))
b.c_fc.bias.copy_(torch.tensor(weights[f"{prefix}C_FC_BIAS"]))
b.c_proj.weight.copy_(
torch.tensor(weights[f"{prefix}C_PROJ_WEIGHT"]).reshape(48, 192)
)
b.c_proj.weight.copy_(torch.tensor(weights[f"{prefix}C_PROJ_WEIGHT"]).reshape(48, 192))
b.c_proj.bias.copy_(torch.tensor(weights[f"{prefix}C_PROJ_BIAS"]))

model.ln_f.weight.copy_(torch.tensor(weights["LN_F_WEIGHT"]))
model.ln_f.bias.copy_(torch.tensor(weights["LN_F_BIAS"]))
model.lm_head.weight.copy_(
torch.tensor(weights["LM_HEAD_WEIGHT"]).reshape(32, 48)
)
model.lm_head.weight.copy_(torch.tensor(weights["LM_HEAD_WEIGHT"]).reshape(32, 48))


def main():
Expand Down Expand Up @@ -194,9 +169,7 @@ def main():
]

print("\n--- Enzyme vs PyTorch comparison (first 10) ---")
print(
f"{'idx':>4} {'Enzyme':>14} {'PyTorch':>14} {'abs_diff':>12} {'rel_diff':>12}"
)
print(f"{'idx':>4} {'Enzyme':>14} {'PyTorch':>14} {'abs_diff':>12} {'rel_diff':>12}")
for i in range(10):
e = enzyme_grads[i]
p = grad[i]
Expand Down
Loading
Loading