diff --git a/examples/03_hierarchical.py b/examples/03_hierarchical.py index 68f0f12..532a7a9 100644 --- a/examples/03_hierarchical.py +++ b/examples/03_hierarchical.py @@ -62,9 +62,7 @@ mu_y = a[group_idx] + b * x y = pm.Normal("y", mu=mu_y, sigma=sigma_y, observed=y_obs) -print( - f"True: mu_a={true_mu_a}, sigma_a={true_sigma_a}, b={true_b}, sigma_y={true_sigma_y}" -) +print(f"True: mu_a={true_mu_a}, sigma_a={true_sigma_a}, b={true_b}, sigma_y={true_sigma_y}") print(f"Data: {n_groups} groups, {N} observations") print(f"Group sizes: {n_per_group}") print() diff --git a/examples/04_zerosumnormal.py b/examples/04_zerosumnormal.py index f713c29..65fc2ce 100644 --- a/examples/04_zerosumnormal.py +++ b/examples/04_zerosumnormal.py @@ -73,12 +73,7 @@ for d in range(n_days): for c in range(n_categories): n = np.random.poisson(n_obs_per_cell) + 1 - mu = ( - true_grand_mean - + true_store_effect[s] - + true_day_effect[d] - + true_interaction[s, d, c] - ) + mu = true_grand_mean + true_store_effect[s] + true_day_effect[d] + true_interaction[s, d, c] y_vals = np.random.normal(mu, true_sigma_y, n) for y in y_vals: records.append((s, d, c, y)) @@ -132,12 +127,7 @@ n_zerosum_axes=2, ) - mu = ( - grand_mean - + store_effect[store_idx] - + day_effect[day_idx] - + interaction[store_idx, day_idx, cat_idx] - ) + mu = grand_mean + store_effect[store_idx] + day_effect[day_idx] + interaction[store_idx, day_idx, cat_idx] sigma_y = pm.HalfNormal("sigma_y", sigma=5) pm.Normal("y", mu=mu, sigma=sigma_y, observed=y_obs) @@ -145,9 +135,7 @@ n_free = sum(v.size for v in model.initial_point().values()) print(f"Free RVs: {[rv.name for rv in model.free_RVs]}") print(f"Unconstrained parameters: {n_free}") -print( - f"Transforms: {[(rv.name, type(model.rvs_to_transforms.get(rv)).__name__) for rv in model.free_RVs]}" -) +print(f"Transforms: {[(rv.name, type(model.rvs_to_transforms.get(rv)).__name__) for rv in model.free_RVs]}") print() result = compile_model( diff --git a/examples/05_celeri_simplified.py b/examples/05_celeri_simplified.py index f81f07a..37b0559 100644 --- a/examples/05_celeri_simplified.py +++ b/examples/05_celeri_simplified.py @@ -138,9 +138,7 @@ slip_rate = pm.Normal("slip_rate", mu=0, sigma=slip_prior_sigma, shape=n_faults * 2) # Predicted GPS velocities via design matrices - predicted_velocity = pm.math.dot(G_rotation, rotation) + pm.math.dot( - G_slip, slip_rate - ) + predicted_velocity = pm.math.dot(G_rotation, rotation) + pm.math.dot(G_slip, slip_rate) # GPS station velocity likelihood (StudentT for heavy tails) pm.StudentT( @@ -208,8 +206,6 @@ print(f"\nTrue rotation: {true_rotation}") print(f"True slip rates: {true_slip}") else: - print( - f"\nCompilation FAILED after {result.n_attempts} builds, {result.n_tool_calls} tool calls" - ) + print(f"\nCompilation FAILED after {result.n_attempts} builds, {result.n_tool_calls} tool calls") for err in result.validation_errors[:5]: print(f" - {err}") diff --git a/examples/bench_logp.py b/examples/bench_logp.py index 1e20c25..efcdcc0 100644 --- a/examples/bench_logp.py +++ b/examples/bench_logp.py @@ -107,32 +107,24 @@ def main(): pt_result = benchmark_logp_pytensor(model, n_evals=n_evals, x0_model_order=x0) print(f" pytensor (python loop): {pt_result['us_per_eval']:.2f} us/eval") - cfunc_result = benchmark_logp_numba_cfunc( - model, n_evals=n_evals, x0_model_order=x0 - ) + cfunc_result = benchmark_logp_numba_cfunc(model, n_evals=n_evals, x0_model_order=x0) print(f" numba cfunc (rust loop): {cfunc_result['us_per_eval']:.2f} us/eval") - rs_result = benchmark_logp_rust( - build_dir, model, n_evals=n_evals, x0_model_order=x0 - ) + rs_result = benchmark_logp_rust(build_dir, model, n_evals=n_evals, x0_model_order=x0) if "error" in rs_result: print(f" rust-ai: ERROR - {rs_result['error'][:100]}") else: print(f" rust-ai: {rs_result['us_per_eval']:.2f} us/eval") print_logp_comparison(pt_result, rs_result, model_name=name) - print_logp_comparison( - cfunc_result, rs_result, model_name=f"{name} [cfunc vs rust]" - ) + print_logp_comparison(cfunc_result, rs_result, model_name=f"{name} [cfunc vs rust]") results.append((name, pt_result, cfunc_result, rs_result)) # Summary table print("\n" + "=" * 85) print("SUMMARY: logp+dlogp evaluation speed") print("=" * 85) - print( - f"\n{'Model':<25} {'pytensor':<14} {'cfunc+rust':<14} {'rust-ai':<14} {'cfunc/rust':<12}" - ) + print(f"\n{'Model':<25} {'pytensor':<14} {'cfunc+rust':<14} {'rust-ai':<14} {'cfunc/rust':<12}") print("-" * 79) for name, pt, cf, rs in results: pt_us = f"{pt['us_per_eval']:.2f} us" if "error" not in pt else "ERROR" diff --git a/examples/generate_blog_plots.py b/examples/generate_blog_plots.py index 1e73c29..40b0da9 100644 --- a/examples/generate_blog_plots.py +++ b/examples/generate_blog_plots.py @@ -14,10 +14,10 @@ import os import sys -import time from pathlib import Path import matplotlib + matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np @@ -34,6 +34,7 @@ # Model factory functions # --------------------------------------------------------------------------- + def make_normal_model(): build_dir = Path("compiled_models/normal") y_obs = np.load(build_dir / "y_data.npy") @@ -94,13 +95,14 @@ def make_gp_model(): # Compile + Optimize # --------------------------------------------------------------------------- + def run_compile_and_optimize(name, make_model_fn, source_code, build_dir): """Compile a PyMC model to Rust, then optimize it.""" from transpailer import compile_model, optimize_model from transpailer.analysis import ( plot_optimization_progress, - plot_waterfall, plot_timeline, + plot_waterfall, print_summary, ) @@ -163,6 +165,7 @@ def run_compile_and_optimize(name, make_model_fn, source_code, build_dir): # Benchmark comparison plot # --------------------------------------------------------------------------- + def run_benchmarks_and_plot(): """Run logp benchmarks on compiled models and generate comparison bar chart.""" from transpailer.benchmark import ( @@ -174,7 +177,11 @@ def run_benchmarks_and_plot(): models_info = [ ("Normal\n(2 params)", make_normal_model, "compiled_models/normal"), ("LinReg\n(3 params)", make_linreg_model, "compiled_models/linreg"), - ("Hierarchical\n(12 params)", make_hierarchical_model, "compiled_models/hierarchical"), + ( + "Hierarchical\n(12 params)", + make_hierarchical_model, + "compiled_models/hierarchical", + ), ("GP\n(3 params)", make_gp_model, "compiled_models/gp"), ] @@ -226,10 +233,24 @@ def run_benchmarks_and_plot(): width = 0.35 # Bar chart: us/eval comparison - bars1 = ax1.bar(x - width / 2, pt_vals, width, label="PyTensor (Numba)", - color="#3498db", edgecolor="#2c3e50", linewidth=0.5) - bars2 = ax1.bar(x + width / 2, rs_vals, width, label="AI-compiled Rust", - color="#e74c3c", edgecolor="#2c3e50", linewidth=0.5) + bars1 = ax1.bar( + x - width / 2, + pt_vals, + width, + label="PyTensor (Numba)", + color="#3498db", + edgecolor="#2c3e50", + linewidth=0.5, + ) + bars2 = ax1.bar( + x + width / 2, + rs_vals, + width, + label="AI-compiled Rust", + color="#e74c3c", + edgecolor="#2c3e50", + linewidth=0.5, + ) ax1.set_ylabel("us/eval (lower is better)", fontsize=11) ax1.set_title("logp+gradient Evaluation Speed", fontsize=13, fontweight="bold") @@ -239,12 +260,24 @@ def run_benchmarks_and_plot(): ax1.grid(True, alpha=0.3, axis="y") for bar in bars1: - ax1.text(bar.get_x() + bar.get_width() / 2, bar.get_height(), - f"{bar.get_height():.1f}", ha="center", va="bottom", fontsize=8) + ax1.text( + bar.get_x() + bar.get_width() / 2, + bar.get_height(), + f"{bar.get_height():.1f}", + ha="center", + va="bottom", + fontsize=8, + ) for bar in bars2: if bar.get_height() > 0: - ax1.text(bar.get_x() + bar.get_width() / 2, bar.get_height(), - f"{bar.get_height():.2f}", ha="center", va="bottom", fontsize=8) + ax1.text( + bar.get_x() + bar.get_width() / 2, + bar.get_height(), + f"{bar.get_height():.2f}", + ha="center", + va="bottom", + fontsize=8, + ) # Speedup chart colors = ["#2ecc71" if s >= 3 else "#f39c12" if s >= 2 else "#e74c3c" for s in speedups] @@ -258,19 +291,27 @@ def run_benchmarks_and_plot(): for bar, s in zip(bars3, speedups): if s > 0: - ax2.text(bar.get_x() + bar.get_width() / 2, bar.get_height(), - f"{s:.1f}x", ha="center", va="bottom", fontsize=10, fontweight="bold") + ax2.text( + bar.get_x() + bar.get_width() / 2, + bar.get_height(), + f"{s:.1f}x", + ha="center", + va="bottom", + fontsize=10, + fontweight="bold", + ) fig.tight_layout() fig.savefig(OUTPUT_DIR / "benchmark_comparison.png", dpi=150) plt.close(fig) - print(f"\nSaved: benchmark_comparison.png") + print("\nSaved: benchmark_comparison.png") # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- + def main(): os.chdir(Path(__file__).parent.parent) print(f"Working dir: {Path.cwd()}") diff --git a/examples/jax_to_pytorch_mlp.py b/examples/jax_to_pytorch_mlp.py index 842508b..fa2f571 100644 --- a/examples/jax_to_pytorch_mlp.py +++ b/examples/jax_to_pytorch_mlp.py @@ -57,9 +57,7 @@ def main(): model = result.get_model({k: np.asarray(v) for k, v in params.items()}) pt_out = model(torch.tensor(np.asarray(x))) print(f"\nPyTorch output:\n{pt_out.detach().numpy()}") - print( - f"Max diff: {np.max(np.abs(pt_out.detach().numpy() - np.asarray(out))):.2e}" - ) + print(f"Max diff: {np.max(np.abs(pt_out.detach().numpy() - np.asarray(out))):.2e}") else: print(f"\nTranspilation failed: {result.validation_errors}") diff --git a/examples/mingpt_enzyme/validate_pytorch.py b/examples/mingpt_enzyme/validate_pytorch.py index 2c305c6..99df373 100644 --- a/examples/mingpt_enzyme/validate_pytorch.py +++ b/examples/mingpt_enzyme/validate_pytorch.py @@ -22,25 +22,14 @@ def parse_data_rs(path): name = m.group(1) if name.endswith("_SHAPE"): continue - vals = [ - float(x.strip().rstrip(",")) for x in m.group(2).split(",") if x.strip() - ] + vals = [float(x.strip().rstrip(",")) for x in m.group(2).split(",") if x.strip()] weights[name] = np.array(vals, dtype=np.float32) return weights class NewGELU(nn.Module): def forward(self, x): - return ( - 0.5 - * x - * ( - 1.0 - + torch.tanh( - (2.0 / 3.141592653589793) ** 0.5 * (x + 0.044715 * x.pow(3.0)) - ) - ) - ) + return 0.5 * x * (1.0 + torch.tanh((2.0 / 3.141592653589793) ** 0.5 * (x + 0.044715 * x.pow(3.0)))) class CausalSelfAttention(nn.Module): @@ -50,9 +39,7 @@ def __init__(self, n_embd, n_head, block_size): self.c_proj = nn.Linear(n_embd, n_embd) self.register_buffer( "bias", - torch.tril(torch.ones(block_size, block_size)).view( - 1, 1, block_size, block_size - ), + torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size), ) self.n_head = n_head self.n_embd = n_embd @@ -93,9 +80,7 @@ def __init__(self, n_layer=3, n_head=3, n_embd=48, block_size=8, vocab_size=32): super().__init__() self.n_embd = n_embd self.seq_len = block_size - self.blocks = nn.ModuleList( - [TransformerBlock(n_embd, n_head, block_size) for _ in range(n_layer)] - ) + self.blocks = nn.ModuleList([TransformerBlock(n_embd, n_head, block_size) for _ in range(n_layer)]) self.ln_f = nn.LayerNorm(n_embd) self.lm_head = nn.Linear(n_embd, vocab_size, bias=False) @@ -116,30 +101,20 @@ def load_weights_from_data_rs(model, weights): prefix = f"BLOCKS_{i}_" b.ln_1.weight.copy_(torch.tensor(weights[f"{prefix}LN_1_WEIGHT"])) b.ln_1.bias.copy_(torch.tensor(weights[f"{prefix}LN_1_BIAS"])) - b.attn.c_attn.weight.copy_( - torch.tensor(weights[f"{prefix}ATTN_C_ATTN_WEIGHT"]).reshape(144, 48) - ) + b.attn.c_attn.weight.copy_(torch.tensor(weights[f"{prefix}ATTN_C_ATTN_WEIGHT"]).reshape(144, 48)) b.attn.c_attn.bias.copy_(torch.tensor(weights[f"{prefix}ATTN_C_ATTN_BIAS"])) - b.attn.c_proj.weight.copy_( - torch.tensor(weights[f"{prefix}ATTN_C_PROJ_WEIGHT"]).reshape(48, 48) - ) + b.attn.c_proj.weight.copy_(torch.tensor(weights[f"{prefix}ATTN_C_PROJ_WEIGHT"]).reshape(48, 48)) b.attn.c_proj.bias.copy_(torch.tensor(weights[f"{prefix}ATTN_C_PROJ_BIAS"])) b.ln_2.weight.copy_(torch.tensor(weights[f"{prefix}LN_2_WEIGHT"])) b.ln_2.bias.copy_(torch.tensor(weights[f"{prefix}LN_2_BIAS"])) - b.c_fc.weight.copy_( - torch.tensor(weights[f"{prefix}C_FC_WEIGHT"]).reshape(192, 48) - ) + b.c_fc.weight.copy_(torch.tensor(weights[f"{prefix}C_FC_WEIGHT"]).reshape(192, 48)) b.c_fc.bias.copy_(torch.tensor(weights[f"{prefix}C_FC_BIAS"])) - b.c_proj.weight.copy_( - torch.tensor(weights[f"{prefix}C_PROJ_WEIGHT"]).reshape(48, 192) - ) + b.c_proj.weight.copy_(torch.tensor(weights[f"{prefix}C_PROJ_WEIGHT"]).reshape(48, 192)) b.c_proj.bias.copy_(torch.tensor(weights[f"{prefix}C_PROJ_BIAS"])) model.ln_f.weight.copy_(torch.tensor(weights["LN_F_WEIGHT"])) model.ln_f.bias.copy_(torch.tensor(weights["LN_F_BIAS"])) - model.lm_head.weight.copy_( - torch.tensor(weights["LM_HEAD_WEIGHT"]).reshape(32, 48) - ) + model.lm_head.weight.copy_(torch.tensor(weights["LM_HEAD_WEIGHT"]).reshape(32, 48)) def main(): @@ -194,9 +169,7 @@ def main(): ] print("\n--- Enzyme vs PyTorch comparison (first 10) ---") - print( - f"{'idx':>4} {'Enzyme':>14} {'PyTorch':>14} {'abs_diff':>12} {'rel_diff':>12}" - ) + print(f"{'idx':>4} {'Enzyme':>14} {'PyTorch':>14} {'abs_diff':>12} {'rel_diff':>12}") for i in range(10): e = enzyme_grads[i] p = grad[i] diff --git a/examples/mingpt_to_rust.py b/examples/mingpt_to_rust.py index 60561fc..b84d76c 100644 --- a/examples/mingpt_to_rust.py +++ b/examples/mingpt_to_rust.py @@ -19,16 +19,7 @@ class NewGELU(nn.Module): def forward(self, x): - return ( - 0.5 - * x - * ( - 1.0 - + torch.tanh( - (2.0 / 3.141592653589793) ** 0.5 * (x + 0.044715 * x.pow(3.0)) - ) - ) - ) + return 0.5 * x * (1.0 + torch.tanh((2.0 / 3.141592653589793) ** 0.5 * (x + 0.044715 * x.pow(3.0)))) class CausalSelfAttention(nn.Module): @@ -38,9 +29,7 @@ def __init__(self, n_embd, n_head, block_size): self.c_proj = nn.Linear(n_embd, n_embd) self.register_buffer( "bias", - torch.tril(torch.ones(block_size, block_size)).view( - 1, 1, block_size, block_size - ), + torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size), ) self.n_head = n_head self.n_embd = n_embd @@ -90,9 +79,7 @@ def __init__(self, n_layer, n_head, n_embd, block_size, vocab_size): super().__init__() self.n_embd = n_embd self.seq_len = block_size # fixed sequence length for transpilation - self.blocks = nn.ModuleList( - [TransformerBlock(n_embd, n_head, block_size) for _ in range(n_layer)] - ) + self.blocks = nn.ModuleList([TransformerBlock(n_embd, n_head, block_size) for _ in range(n_layer)]) self.ln_f = nn.LayerNorm(n_embd) self.lm_head = nn.Linear(n_embd, vocab_size, bias=False) @@ -344,9 +331,7 @@ def forward(self, x): print(f" PyTorch: {pytorch_us:.1f} µs/call") print(f" Rust: {rust_us:.1f} µs/call") speedup = pytorch_us / rust_us - print( - f" Speedup: {speedup:.1f}x {'faster' if speedup > 1 else 'slower'}" - ) + print(f" Speedup: {speedup:.1f}x {'faster' if speedup > 1 else 'slower'}") else: print(" Rust benchmark failed!") else: diff --git a/examples/pytorch_to_jax_mlp.py b/examples/pytorch_to_jax_mlp.py index a913476..b952058 100644 --- a/examples/pytorch_to_jax_mlp.py +++ b/examples/pytorch_to_jax_mlp.py @@ -57,9 +57,7 @@ def main(): # Test the generated model import jax.numpy as jnp - param_data = { - name: param.detach().numpy() for name, param in model.named_parameters() - } + param_data = {name: param.detach().numpy() for name, param in model.named_parameters()} jax_params, forward_fn = result.get_model(param_data) jax_out = forward_fn(jax_params, jnp.array(x.numpy())) print(f"\nJAX output:\n{np.asarray(jax_out)}") diff --git a/examples/run_benchmark.py b/examples/run_benchmark.py index a09cd73..73e55f7 100644 --- a/examples/run_benchmark.py +++ b/examples/run_benchmark.py @@ -106,9 +106,7 @@ def make_zerosumnormal_model(): raw_day = np.array([-0.2, -0.1, 0.0, 0.05, 0.15, 0.35, 0.25]) raw_day += np.random.normal(0, 0.05, n_days) true_day_effect = raw_day - raw_day.mean() - raw_interaction = np.random.normal( - 0, true_sigma_cat, (n_stores, n_days, n_categories) - ) + raw_interaction = np.random.normal(0, true_sigma_cat, (n_stores, n_days, n_categories)) raw_interaction -= raw_interaction.mean(axis=-1, keepdims=True) raw_interaction -= raw_interaction.mean(axis=-2, keepdims=True) @@ -118,12 +116,7 @@ def make_zerosumnormal_model(): for d in range(n_days): for c in range(n_categories): n = np.random.poisson(5) + 1 - mu = ( - true_grand_mean - + true_store_effect[s] - + true_day_effect[d] - + raw_interaction[s, d, c] - ) + mu = true_grand_mean + true_store_effect[s] + true_day_effect[d] + raw_interaction[s, d, c] y_vals = np.random.normal(mu, true_sigma_y, n) for y in y_vals: records.append((s, d, c, y)) @@ -140,9 +133,7 @@ def make_zerosumnormal_model(): sigma_store = pm.HalfNormal("sigma_store", sigma=2) sigma_day = pm.HalfNormal("sigma_day", sigma=2) sigma_cat = pm.HalfNormal("sigma_cat", sigma=2) - store_effect = pm.ZeroSumNormal( - "store_effect", sigma=sigma_store, shape=n_stores - ) + store_effect = pm.ZeroSumNormal("store_effect", sigma=sigma_store, shape=n_stores) day_effect = pm.ZeroSumNormal("day_effect", sigma=sigma_day, shape=n_days) interaction = pm.ZeroSumNormal( "interaction", @@ -150,12 +141,7 @@ def make_zerosumnormal_model(): shape=(n_stores, n_days, n_categories), n_zerosum_axes=2, ) - mu_y = ( - grand_mean - + store_effect[store_idx] - + day_effect[day_idx] - + interaction[store_idx, day_idx, cat_idx] - ) + mu_y = grand_mean + store_effect[store_idx] + day_effect[day_idx] + interaction[store_idx, day_idx, cat_idx] sigma_y = pm.HalfNormal("sigma_y", sigma=5) pm.Normal("y", mu=mu_y, sigma=sigma_y, observed=y_obs) return ( diff --git a/notebooks/overview.py b/notebooks/overview.py index 29f2aaf..33b80a9 100644 --- a/notebooks/overview.py +++ b/notebooks/overview.py @@ -89,14 +89,9 @@ def _(): mu = alpha + beta * x_data pm.Normal("y", mu=mu, sigma=sigma, observed=y_data) - print( - f"Model: {len(linreg_model.free_RVs)} free parameters, " - f"{len(linreg_model.observed_RVs)} observed" - ) + print(f"Model: {len(linreg_model.free_RVs)} free parameters, {len(linreg_model.observed_RVs)} observed") print(f"Parameters: {[rv.name for rv in linreg_model.free_RVs]}") - print( - f"Transforms: {[type(linreg_model.rvs_to_transforms.get(rv)).__name__ for rv in linreg_model.free_RVs]}" - ) + print(f"Transforms: {[type(linreg_model.rvs_to_transforms.get(rv)).__name__ for rv in linreg_model.free_RVs]}") return (linreg_model,) @@ -135,17 +130,13 @@ def _(linreg_model): print("\nObserved data:") for name2, info2 in ctx.observed_data.items(): - print( - f" {name2}: n={info2['n']}, range=[{info2['min']:.3f}, {info2['max']:.3f}]" - ) + print(f" {name2}: n={info2['n']}, range=[{info2['min']:.3f}, {info2['max']:.3f}]") print("\nCovariates:") for name3, info3 in ctx.covariate_data.items(): is_idx = info3.get("is_index_array", False) label = f" [INDEX ARRAY, {info3['n_groups']} groups]" if is_idx else "" - print( - f" {name3}: n={info3['n']}, range=[{info3['min']:.3f}, {info3['max']:.3f}]{label}" - ) + print(f" {name3}: n={info3['n']}, range=[{info3['min']:.3f}, {info3['max']:.3f}]{label}") print(f"\nValidation points: 1 initial + {len(ctx.extra_points)} extra") print(f" Initial logp = {ctx.initial_point.logp:.6f}") @@ -235,14 +226,10 @@ def _(ctx, mo): rows.append("| Point | PyMC logp | Gradient[0] | Gradient[1] | Gradient[2] |") rows.append("|-------|-----------|-------------|-------------|-------------|") - pts = [("initial", ctx.initial_point)] + [ - (f"extra_{i}", p) for i, p in enumerate(ctx.extra_points) - ] + pts = [("initial", ctx.initial_point)] + [(f"extra_{i}", p) for i, p in enumerate(ctx.extra_points)] for name, vp in pts: g = vp.dlogp - rows.append( - f"| {name} | {vp.logp:.4f} | {g[0]:.4f} | {g[1]:.4f} | {g[2]:.4f} |" - ) + rows.append(f"| {name} | {vp.logp:.4f} | {g[0]:.4f} | {g[1]:.4f} | {g[2]:.4f} |") mo.md("**Validation reference values:**\n\n" + "\n".join(rows)) return @@ -348,10 +335,7 @@ def _(): sigma_y = _pm.HalfNormal("sigma_y", sigma=5) _pm.Normal("y_obs", mu=a[group_idx] + b * x_h, sigma=sigma_y, observed=y_h) - print( - f"Hierarchical model: {len(hierarchical_model.free_RVs)} free params, " - f"{_N} observations, {n_groups} groups" - ) + print(f"Hierarchical model: {len(hierarchical_model.free_RVs)} free params, {_N} observations, {n_groups} groups") return (hierarchical_model,) diff --git a/pyproject.toml b/pyproject.toml index 86d6e8a..fc7d1fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,3 +37,10 @@ dev = [ "pytest>=9.0.2", "ruff>=0.15", ] + +[tool.ruff] +line-length = 120 + +[tool.ruff.lint] +select = ["E", "F", "I", "W"] +ignore = ["E501"] diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 287e6ee..61ac6b1 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -6,7 +6,6 @@ from pathlib import Path from unittest.mock import MagicMock - from transpailer.compiler import ( CompilationResult, _AgentState, @@ -20,7 +19,6 @@ RustModelExporter, ) - # --------------------------------------------------------------------------- # CompilationResult # --------------------------------------------------------------------------- @@ -379,9 +377,7 @@ def test_dispatches_write(self): messages=[], ) - result = _execute_tool( - "write_rust_code", {"code": "// test"}, state, verbose=False - ) + result = _execute_tool("write_rust_code", {"code": "// test"}, state, verbose=False) assert "Written" in result def test_dispatches_read(self): @@ -395,9 +391,7 @@ def test_dispatches_read(self): messages=[], ) - result = _execute_tool( - "read_file", {"path": "test.rs"}, state, verbose=False - ) + result = _execute_tool("read_file", {"path": "test.rs"}, state, verbose=False) assert "// hello" in result diff --git a/tests/test_exporter.py b/tests/test_exporter.py index 7a0c103..a5c32a3 100644 --- a/tests/test_exporter.py +++ b/tests/test_exporter.py @@ -15,7 +15,6 @@ export_model, ) - # --------------------------------------------------------------------------- # ParamInfo dataclass # --------------------------------------------------------------------------- @@ -198,9 +197,7 @@ def test_different_seeds_differ(self, normal_model): ctx2 = RustModelExporter(normal_model, seed=2).context # At least one extra point should differ - assert any( - p1.point != p2.point for p1, p2 in zip(ctx1.extra_points, ctx2.extra_points) - ) + assert any(p1.point != p2.point for p1, p2 in zip(ctx1.extra_points, ctx2.extra_points)) class TestExporterLinregModel: @@ -230,11 +227,7 @@ def test_group_index_detection(self, hierarchical_model): ctx = exporter.context # Should detect the group index array - index_covariates = { - name: info - for name, info in ctx.covariate_data.items() - if info.get("is_index_array") - } + index_covariates = {name: info for name, info in ctx.covariate_data.items() if info.get("is_index_array")} assert len(index_covariates) >= 1 # Check the index array properties diff --git a/tests/test_jax_pytorch.py b/tests/test_jax_pytorch.py index 730f208..441589e 100644 --- a/tests/test_jax_pytorch.py +++ b/tests/test_jax_pytorch.py @@ -116,9 +116,7 @@ def __init__(self): super().__init__() self.fc = nn.Linear(2, 3) with torch.no_grad(): - self.fc.weight.copy_( - torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) - ) + self.fc.weight.copy_(torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])) self.fc.bias.copy_(torch.tensor([0.1, 0.2, 0.3])) def forward(self, x): @@ -159,6 +157,7 @@ def test_extract_validation_points(self, simple_model): def test_forward_output_matches(self, simple_model): import torch + from transpailer.pytorch_exporter import PytorchModelExporter model, x = simple_model @@ -178,11 +177,11 @@ class TestTranspilerTools: """Test the transpiler's tool execution logic without API calls.""" def test_write_code_syntax_check(self): + from transpailer.jax_exporter import ModelContext from transpailer.jax_pytorch_transpiler import ( - _tool_write_code, _AgentState, + _tool_write_code, ) - from transpailer.jax_exporter import ModelContext state = _AgentState( direction="jax_to_pytorch", @@ -200,18 +199,18 @@ def test_write_code_syntax_check(self): # Valid code result = _tool_write_code({"code": "x = 1 + 2"}, state, verbose=False) assert "Written" in result - assert state.generated_code == "x = 1 + 2" + assert state.generated_code.strip() == "x = 1 + 2" # Invalid code result = _tool_write_code({"code": "def f(:"}, state, verbose=False) assert "Syntax error" in result def test_validate_no_code(self): + from transpailer.jax_exporter import ModelContext from transpailer.jax_pytorch_transpiler import ( - _tool_validate, _AgentState, + _tool_validate, ) - from transpailer.jax_exporter import ModelContext state = _AgentState( direction="jax_to_pytorch", @@ -233,11 +232,12 @@ def test_validate_no_code(self): def test_validate_pytorch_correct_model(self): """Test that validation passes for a correctly transpiled model.""" import jax.numpy as jnp + from transpailer.jax_exporter import JaxModelExporter from transpailer.jax_pytorch_transpiler import ( - _tool_write_code, - _tool_validate, _AgentState, + _tool_validate, + _tool_write_code, ) # Create a simple JAX model @@ -288,12 +288,13 @@ def test_validate_jax_correct_model(self): """Test that validation passes for a correctly transpiled JAX model.""" import torch import torch.nn as nn - from transpailer.pytorch_exporter import PytorchModelExporter + from transpailer.jax_pytorch_transpiler import ( - _tool_write_code, - _tool_validate, _AgentState, + _tool_validate, + _tool_write_code, ) + from transpailer.pytorch_exporter import PytorchModelExporter # Create a simple PyTorch model class Linear(nn.Module): diff --git a/tests/test_pytorch_rust.py b/tests/test_pytorch_rust.py index 358d138..dc14e8b 100644 --- a/tests/test_pytorch_rust.py +++ b/tests/test_pytorch_rust.py @@ -13,7 +13,6 @@ import numpy as np import pytest - # ── Rust Project Setup Tests ───────────────────────────────────────────────── @@ -24,6 +23,7 @@ class TestRustProjectSetup: def simple_context(self): import torch import torch.nn as nn + from transpailer.pytorch_exporter import PytorchModelExporter class Linear(nn.Module): @@ -31,9 +31,7 @@ def __init__(self): super().__init__() self.fc = nn.Linear(2, 3) with torch.no_grad(): - self.fc.weight.copy_( - torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) - ) + self.fc.weight.copy_(torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])) self.fc.bias.copy_(torch.tensor([0.1, 0.2, 0.3])) def forward(self, x): @@ -93,6 +91,7 @@ class TestTranspilerTools: def agent_state(self): import torch import torch.nn as nn + from transpailer.pytorch_exporter import PytorchModelExporter from transpailer.pytorch_rust_transpiler import ( _AgentState, @@ -102,9 +101,7 @@ def agent_state(self): class Linear(nn.Module): def __init__(self): super().__init__() - self.w = nn.Parameter( - torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32) - ) + self.w = nn.Parameter(torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32)) self.b = nn.Parameter(torch.tensor([0.1, 0.2], dtype=torch.float32)) def forward(self, x): @@ -198,14 +195,13 @@ def simple_model_context(self): """Create a simple model context for testing validation.""" import torch import torch.nn as nn + from transpailer.pytorch_exporter import PytorchModelExporter class Simple(nn.Module): def __init__(self): super().__init__() - self.w = nn.Parameter( - torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32) - ) + self.w = nn.Parameter(torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32)) self.b = nn.Parameter(torch.tensor([0.1, 0.2], dtype=torch.float32)) def forward(self, x): @@ -260,6 +256,7 @@ class TestFullPipeline: def model_and_state(self): import torch import torch.nn as nn + from transpailer.pytorch_exporter import PytorchModelExporter from transpailer.pytorch_rust_transpiler import ( _AgentState, @@ -269,9 +266,7 @@ def model_and_state(self): class Simple(nn.Module): def __init__(self): super().__init__() - self.w = nn.Parameter( - torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32) - ) + self.w = nn.Parameter(torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32)) self.b = nn.Parameter(torch.tensor([0.1, 0.2], dtype=torch.float32)) def forward(self, x): @@ -295,9 +290,9 @@ def forward(self, x): def test_correct_rust_validates(self, model_and_state): """Write manually correct Rust code and verify it passes validation.""" from transpailer.pytorch_rust_transpiler import ( - _tool_write_code, _tool_cargo_build, _tool_validate, + _tool_write_code, ) try: @@ -394,11 +389,7 @@ def test_skill_has_backprop(self): from transpailer.pytorch_rust_transpiler import _load_skill skill = _load_skill("pytorch_to_rust") - assert ( - "backward" in skill.lower() - or "backprop" in skill.lower() - or "gradient" in skill.lower() - ) + assert "backward" in skill.lower() or "backprop" in skill.lower() or "gradient" in skill.lower() # ── Result Type Tests ──────────────────────────────────────────────────────── @@ -454,6 +445,7 @@ class TestPromptBuilding: def test_prompt_contains_model_info(self): import torch.nn as nn + from transpailer.pytorch_exporter import PytorchModelExporter from transpailer.pytorch_rust_transpiler import _build_user_prompt diff --git a/transpailer/__init__.py b/transpailer/__init__.py index d569bc7..d1d4cf2 100644 --- a/transpailer/__init__.py +++ b/transpailer/__init__.py @@ -8,53 +8,54 @@ import importlib from typing import TYPE_CHECKING +from transpailer.jax_exporter import ( + JaxModelExporter, + export_jax_model, +) + # JAX ↔ PyTorch (no heavy deps at import time) from transpailer.jax_pytorch_transpiler import ( + TranspileResult, transpile_jax_to_pytorch, transpile_pytorch_to_jax, - TranspileResult, -) -from transpailer.jax_exporter import ( - JaxModelExporter, - export_jax_model, ) from transpailer.pytorch_exporter import ( PytorchModelExporter, export_pytorch_model, ) from transpailer.pytorch_rust_transpiler import ( - transpile_pytorch_to_rust, RustTranspileResult, + transpile_pytorch_to_rust, ) # PyMC/Stan imports are lazy — they pull in heavy deps (pymc, bridgestan) if TYPE_CHECKING: - from transpailer.exporter import ( - ModelContext, - RustModelExporter, - export_model, + from transpailer.analysis import ( + plot_optimization_progress, + plot_timeline, + plot_waterfall, + print_summary, ) from transpailer.compiler import ( + OptimizationEvent, compile_model, optimize_model, - OptimizationEvent, ) - from transpailer.analysis import ( - plot_optimization_progress, - plot_waterfall, - plot_timeline, - print_summary, + from transpailer.exporter import ( + ModelContext, + RustModelExporter, + export_model, + ) + from transpailer.stan_compiler import ( + StanCompilationResult, + compile_stan_model, ) from transpailer.stan_exporter import ( StanModelContext, StanModelExporter, export_stan_model, ) - from transpailer.stan_compiler import ( - compile_stan_model, - StanCompilationResult, - ) - from transpailer.stan_to_pymc import transpile_stan_to_pymc, StanToPyMCResult + from transpailer.stan_to_pymc import StanToPyMCResult, transpile_stan_to_pymc def __getattr__(name: str): diff --git a/transpailer/analysis.py b/transpailer/analysis.py index 6a92923..6f81a6d 100644 --- a/transpailer/analysis.py +++ b/transpailer/analysis.py @@ -254,9 +254,7 @@ def plot_waterfall( colors = ["#2ecc71" if d > 0 else "#e74c3c" for d in deltas] - bars = ax.bar( - range(len(deltas)), deltas, color=colors, edgecolor="#2c3e50", linewidth=0.5 - ) + bars = ax.bar(range(len(deltas)), deltas, color=colors, edgecolor="#2c3e50", linewidth=0.5) # Value labels on bars for bar, delta in zip(bars, deltas): @@ -328,9 +326,7 @@ def plot_timeline( "timestamp_s": str(ev.timestamp), "event_type": ev.event_type, "status": ev.status, - "us_per_eval": str(ev.us_per_eval) - if ev.us_per_eval is not None - else "", + "us_per_eval": str(ev.us_per_eval) if ev.us_per_eval is not None else "", } for ev in source.optimization_log ] @@ -421,10 +417,7 @@ def print_summary(source: str | Path | CompilationResult) -> str: prev_us = baseline for rec in kept: delta = prev_us - rec.us_per_eval - lines.append( - f" [{rec.code_hash}] {rec.us_per_eval:.3f} us/eval " - f"(delta: {delta:+.3f}, {rec.description})" - ) + lines.append(f" [{rec.code_hash}] {rec.us_per_eval:.3f} us/eval (delta: {delta:+.3f}, {rec.description})") prev_us = rec.us_per_eval summary = "\n".join(lines) diff --git a/transpailer/benchmark.py b/transpailer/benchmark.py index 048d626..bcba13e 100644 --- a/transpailer/benchmark.py +++ b/transpailer/benchmark.py @@ -17,9 +17,7 @@ _BENCH_RUNNER_DIR = Path(__file__).resolve().parent.parent / "bench_runner" -def benchmark_nutpie( - model: pm.Model, draws: int = 2000, tune: int = 1000, chains: int = 4 -) -> dict: +def benchmark_nutpie(model: pm.Model, draws: int = 2000, tune: int = 1000, chains: int = 4) -> dict: """Benchmark PyMC sampling with nutpie backend.""" print(f" nutpie: {chains} chains x {draws} draws...") start = time.time() @@ -43,9 +41,7 @@ def benchmark_nutpie( } -def benchmark_rust( - build_dir: str | Path, draws: int = 2000, tune: int = 1000, chains: int = 4 -) -> dict: +def benchmark_rust(build_dir: str | Path, draws: int = 2000, tune: int = 1000, chains: int = 4) -> dict: """Benchmark the AI-compiled Rust sampler.""" build_dir = Path(build_dir) binary = build_dir / "target" / "release" / "sample" @@ -90,9 +86,7 @@ def benchmark_rust( # --------------------------------------------------------------------------- -def _make_test_point( - model: pm.Model, rng: np.random.Generator | None = None -) -> np.ndarray: +def _make_test_point(model: pm.Model, rng: np.random.Generator | None = None) -> np.ndarray: """Generate a random unconstrained test point in model variable order. Uses a random point (instead of the initial point) to avoid @@ -113,9 +107,7 @@ def _prepare_frozen_inputs(model, x0_model_order=None): Returns (jit_fn, x0, frozen_rv, model_fn). """ frozen_model = freeze_dims_and_data(model) - logp_dlogp_wrapper = frozen_model.logp_dlogp_function( - ravel_inputs=True, mode="NUMBA" - ) + logp_dlogp_wrapper = frozen_model.logp_dlogp_function(ravel_inputs=True, mode="NUMBA") logp_dlogp_fn = logp_dlogp_wrapper._pytensor_function logp_dlogp_fn.trust_input = True @@ -130,9 +122,7 @@ def _prepare_frozen_inputs(model, x0_model_order=None): # Reorder x0 from model order → frozen order for this function model_ip = {v.name: ip[v.name] for v in model_fn._grad_vars} model_rv = DictToArrayBijection.map(model_ip) - x0_dict = DictToArrayBijection.rmap( - RaveledVars(x0_model_order, model_rv.point_map_info) - ) + x0_dict = DictToArrayBijection.rmap(RaveledVars(x0_model_order, model_rv.point_map_info)) frozen_ip = {name: x0_dict[name] for name in frozen_vars} frozen_rv = DictToArrayBijection.map(frozen_ip) x0 = frozen_rv.data @@ -149,9 +139,7 @@ def _reorder_dlogp(dlogp_val, frozen_rv, model_fn): dlogp_dict = DictToArrayBijection.rmap( RaveledVars(np.asarray(dlogp_val, dtype=np.float64), frozen_rv.point_map_info) ) - return DictToArrayBijection.map( - {v.name: dlogp_dict[v.name] for v in model_fn._grad_vars} - ).data + return DictToArrayBijection.map({v.name: dlogp_dict[v.name] for v in model_fn._grad_vars}).data def benchmark_logp_pytensor( @@ -357,9 +345,7 @@ def benchmark_logp_rust( } -def print_logp_comparison( - pytensor_result: dict, rust_result: dict, model_name: str = "" -): +def print_logp_comparison(pytensor_result: dict, rust_result: dict, model_name: str = ""): """Print logp+dlogp evaluation benchmark comparison.""" header = f"LOGP+DLOGP BENCHMARK{f': {model_name}' if model_name else ''}" print("\n" + "=" * 65) @@ -374,9 +360,7 @@ def print_logp_comparison( print(f"{pt['backend']:<20} {'ERROR':<12}") else: evals_per_sec_pt = 1e6 / pt["us_per_eval"] - print( - f"{pt['backend']:<20} {pt['us_per_eval']:<12.2f} {evals_per_sec_pt:<14,.0f} {'1.00x':<10}" - ) + print(f"{pt['backend']:<20} {pt['us_per_eval']:<12.2f} {evals_per_sec_pt:<14,.0f} {'1.00x':<10}") rs = rust_result if "error" in rs: @@ -384,42 +368,28 @@ def print_logp_comparison( else: evals_per_sec_rs = 1e6 / rs["us_per_eval"] speedup = pt["us_per_eval"] / rs["us_per_eval"] if "error" not in pt else 0 - print( - f"{'rust-ai':<20} {rs['us_per_eval']:<12.2f} {evals_per_sec_rs:<14,.0f} {speedup:<10.2f}x" - ) + print(f"{'rust-ai':<20} {rs['us_per_eval']:<12.2f} {evals_per_sec_rs:<14,.0f} {speedup:<10.2f}x") # Check logp and dlogp agreement if "error" not in pt: logp_diff = abs(pt["logp"] - rs["logp"]) logp_rel_err = logp_diff / max(abs(pt["logp"]), 1e-10) logp_ok = logp_rel_err < 1e-4 - logp_status = ( - "MATCH" if logp_ok else f"MISMATCH (rel_err={logp_rel_err:.2e})" - ) - print( - f"\n logp check: pytensor={pt['logp']:.8f} rust={rs['logp']:.8f} [{logp_status}]" - ) + logp_status = "MATCH" if logp_ok else f"MISMATCH (rel_err={logp_rel_err:.2e})" + print(f"\n logp check: pytensor={pt['logp']:.8f} rust={rs['logp']:.8f} [{logp_status}]") pt_dlogp = pt["dlogp"] rs_dlogp = rs["dlogp"] dlogp_abs_err = np.max(np.abs(pt_dlogp - rs_dlogp)) dlogp_rel_err = dlogp_abs_err / max(np.max(np.abs(pt_dlogp)), 1e-10) dlogp_ok = dlogp_rel_err < 1e-4 - dlogp_status = ( - "MATCH" if dlogp_ok else f"MISMATCH (rel_err={dlogp_rel_err:.2e})" - ) - print( - f" dlogp check: pytensor={pt_dlogp} rust={rs_dlogp} [{dlogp_status}]" - ) + dlogp_status = "MATCH" if dlogp_ok else f"MISMATCH (rel_err={dlogp_rel_err:.2e})" + print(f" dlogp check: pytensor={pt_dlogp} rust={rs_dlogp} [{dlogp_status}]") assert logp_ok, ( - f"logp mismatch: pytensor={pt['logp']:.10f} rust={rs['logp']:.10f} " - f"rel_err={logp_rel_err:.2e}" - ) - assert dlogp_ok, ( - f"dlogp mismatch: pytensor={pt_dlogp} rust={rs_dlogp} " - f"rel_err={dlogp_rel_err:.2e}" + f"logp mismatch: pytensor={pt['logp']:.10f} rust={rs['logp']:.10f} rel_err={logp_rel_err:.2e}" ) + assert dlogp_ok, f"dlogp mismatch: pytensor={pt_dlogp} rust={rs_dlogp} rel_err={dlogp_rel_err:.2e}" print() @@ -434,16 +404,12 @@ def print_comparison(nutpie_result: dict, rust_result: dict): print("-" * 54) nt = nutpie_result["elapsed_s"] - print( - f"{'nutpie':<20} {nt:<12.2f} {nutpie_result['throughput']:<12.0f} {'1.00x':<10}" - ) + print(f"{'nutpie':<20} {nt:<12.2f} {nutpie_result['throughput']:<12.0f} {'1.00x':<10}") if "error" not in rust_result: rt = rust_result["elapsed_s"] speedup = nt / rt - print( - f"{'rust-ai':<20} {rt:<12.2f} {rust_result['throughput']:<12.0f} {speedup:<10.2f}x" - ) + print(f"{'rust-ai':<20} {rt:<12.2f} {rust_result['throughput']:<12.0f} {speedup:<10.2f}x") else: print(f"{'rust-ai':<20} {'FAILED':<12}") diff --git a/transpailer/cli.py b/transpailer/cli.py index 0121c9c..e9a7626 100644 --- a/transpailer/cli.py +++ b/transpailer/cli.py @@ -9,7 +9,6 @@ import click - _SKILLS_DIR = Path(__file__).parent / "skills" # Mapping from (source, target) pairs to relevant skill files @@ -69,8 +68,7 @@ def _detect_framework(code: str, filename: str) -> str: return framework raise click.UsageError( - f"Cannot auto-detect source framework for '{filename}'. " - "Use --from to specify it explicitly." + f"Cannot auto-detect source framework for '{filename}'. Use --from to specify it explicitly." ) @@ -127,16 +125,12 @@ def _transpile( if skills: system += f"\n\n# Domain knowledge\n\n{skills}" - user_msg = ( - f"Transpile the following {source} code to {target}.\n\n" - f"```\n{code}\n```" - ) + user_msg = f"Transpile the following {source} code to {target}.\n\n```\n{code}\n```" api_key = os.environ.get("ANTHROPIC_API_KEY") if not api_key: raise click.ClickException( - "ANTHROPIC_API_KEY environment variable is required. " - "Set it with: export ANTHROPIC_API_KEY=sk-..." + "ANTHROPIC_API_KEY environment variable is required. Set it with: export ANTHROPIC_API_KEY=sk-..." ) client = anthropic.Anthropic(api_key=api_key) @@ -230,8 +224,7 @@ def convert( filename = "stdin" else: raise click.UsageError( - "No input file provided and no data on stdin. " - "Usage: transpailer convert --to " + "No input file provided and no data on stdin. Usage: transpailer convert --to " ) target = _normalize_framework(target) diff --git a/transpailer/compiler.py b/transpailer/compiler.py index 1017ec1..e143a2a 100644 --- a/transpailer/compiler.py +++ b/transpailer/compiler.py @@ -260,8 +260,7 @@ { "name": "cargo_build", "description": ( - "Build the Rust project with `cargo build --release`. " - "Returns build output including any compiler errors." + "Build the Rust project with `cargo build --release`. Returns build output including any compiler errors." ), "input_schema": { "type": "object", @@ -375,16 +374,12 @@ def _detect_skills( has_gp = True break # Also check the logp graph text for GP indicators - if not has_gp and any( - kw in ctx.logp_graph.lower() for kw in ["cholesky", "mvnormal", "gp_"] - ): + if not has_gp and any(kw in ctx.logp_graph.lower() for kw in ["cholesky", "mvnormal", "gp_"]): has_gp = True if has_gp: cuda = use_cuda if use_cuda is not None else _cuda_available() - accelerate = ( - use_accelerate if use_accelerate is not None else _accelerate_available() - ) + accelerate = use_accelerate if use_accelerate is not None else _accelerate_available() if cuda: skills.append("gp_cuda") elif accelerate: @@ -679,9 +674,7 @@ def compile_model( total_input_tokens += response.usage.input_tokens total_output_tokens += response.usage.output_tokens if verbose: - print( - f" Turn {turn}: {response.usage.input_tokens} in / {response.usage.output_tokens} out tokens" - ) + print(f" Turn {turn}: {response.usage.input_tokens} in / {response.usage.output_tokens} out tokens") # Check stop reason if response.stop_reason == "end_turn": @@ -730,9 +723,7 @@ def compile_model( validation_errors = [] if not state.validated: - validation_errors.append( - f"Agent did not achieve validation after {state.tool_calls} tool calls" - ) + validation_errors.append(f"Agent did not achieve validation after {state.tool_calls} tool calls") result = CompilationResult( rust_code=rust_code, @@ -760,9 +751,7 @@ def compile_model( return result -def _execute_tool( - name: str, input_data: dict, state: _AgentState, verbose: bool -) -> str: +def _execute_tool(name: str, input_data: dict, state: _AgentState, verbose: bool) -> str: """Execute a tool and return the result string.""" if name == "write_rust_code": return _tool_write_rust_code(input_data, state, verbose) @@ -792,14 +781,9 @@ def _tool_write_rust_code(input_data: dict, state: _AgentState, verbose: bool) - _setup_enzyme_toolchain(state.build_path) # Move feature flag to lib.rs (crate root) if agent put it in generated.rs if "#![feature(autodiff)]" in code: - code_clean = code.replace("#![feature(autodiff)]\n", "").replace( - "#![feature(autodiff)]", "" - ) + code_clean = code.replace("#![feature(autodiff)]\n", "").replace("#![feature(autodiff)]", "") gen_path.write_text(code_clean) - lib_rs = ( - "#![feature(autodiff)]\n" - "pub mod data;\npub mod generated;\npub use generated::*;\n" - ) + lib_rs = "#![feature(autodiff)]\npub mod data;\npub mod generated;\npub use generated::*;\n" (state.build_path / "src" / "lib.rs").write_text(lib_rs) if verbose: @@ -884,9 +868,7 @@ def _tool_validate_logp(state: _AgentState, verbose: bool) -> str: if not binary.exists(): return "Error: validation binary not found. Run cargo_build first." - all_points = [("initial", ctx.initial_point)] + [ - (f"extra_{i}", p) for i, p in enumerate(ctx.extra_points) - ] + all_points = [("initial", ctx.initial_point)] + [(f"extra_{i}", p) for i, p in enumerate(ctx.extra_points)] input_lines = [] for name, vp in all_points: @@ -915,9 +897,7 @@ def _tool_validate_logp(state: _AgentState, verbose: bool) -> str: output_lines = [line for line in result.stdout.strip().split("\n") if line] if len(output_lines) != len(all_points): - return ( - f"Error: expected {len(all_points)} output lines, got {len(output_lines)}" - ) + return f"Error: expected {len(all_points)} output lines, got {len(output_lines)}" # Parse results parsed = [] @@ -955,21 +935,17 @@ def _tool_validate_logp(state: _AgentState, verbose: bool) -> str: # logp comparison if has_constant_offset: - adjusted_err = abs((rust_logp - mean_offset) - vp.logp) / max( - abs(vp.logp), 1.0 - ) + adjusted_err = abs((rust_logp - mean_offset) - vp.logp) / max(abs(vp.logp), 1.0) else: adjusted_err = abs(rust_logp - vp.logp) / max(abs(vp.logp), 1.0) status = "OK" if adjusted_err <= 1e-4 else "MISMATCH" report_lines.append( - f"{name}: logp PyMC={vp.logp:.10f} Rust={rust_logp:.10f} " - f"rel_err={adjusted_err:.2e} [{status}]" + f"{name}: logp PyMC={vp.logp:.10f} Rust={rust_logp:.10f} rel_err={adjusted_err:.2e} [{status}]" ) if adjusted_err > 1e-4: errors.append( - f"{name}: logp mismatch: PyMC={vp.logp:.10f}, Rust={rust_logp:.10f}, " - f"rel_err={adjusted_err:.2e}" + f"{name}: logp mismatch: PyMC={vp.logp:.10f}, Rust={rust_logp:.10f}, rel_err={adjusted_err:.2e}" ) # Per-RV logp decomposition (only for initial point) @@ -986,8 +962,7 @@ def _tool_validate_logp(state: _AgentState, verbose: bool) -> str: grad_errors += 1 if grad_errors <= 5: # Show first 5 errors.append( - f"{name}: gradient[{j}] mismatch: PyMC={pymc_g:.6e}, " - f"Rust={rust_g:.6e}, rel_err={grad_err:.2e}" + f"{name}: gradient[{j}] mismatch: PyMC={pymc_g:.6e}, Rust={rust_g:.6e}, rel_err={grad_err:.2e}" ) if grad_errors > 0: report_lines.append(f" {grad_errors} gradient mismatches") @@ -1024,10 +999,7 @@ def _tool_validate_logp(state: _AgentState, verbose: bool) -> str: code_hash=_code_hash(state.build_path), ) ) - return ( - f"VALIDATION FAILED ({len(errors)} errors):\n\n{report}\n\nErrors:\n" - + "\n".join(errors) - ) + return f"VALIDATION FAILED ({len(errors)} errors):\n\n{report}\n\nErrors:\n" + "\n".join(errors) def _tool_read_file(input_data: dict, state: _AgentState, verbose: bool) -> str: @@ -1061,9 +1033,7 @@ def _tool_read_file(input_data: dict, state: _AgentState, verbose: bool) -> str: return content -def _tool_add_cargo_dependency( - input_data: dict, state: _AgentState, verbose: bool -) -> str: +def _tool_add_cargo_dependency(input_data: dict, state: _AgentState, verbose: bool) -> str: """Add a crate dependency to Cargo.toml.""" name = input_data.get("name", "") version = input_data.get("version", "") @@ -1147,9 +1117,7 @@ def _generate_data_rs(ctx) -> str: # Format with full f64 precision formatted_values = ", ".join(f"{v:.17e}" for v in flat) - lines.append( - f"pub const {name.upper()}_DATA: &[f64] = &[{formatted_values}];\n" - ) + lines.append(f"pub const {name.upper()}_DATA: &[f64] = &[{formatted_values}];\n") return "\n".join(lines) @@ -1228,9 +1196,7 @@ def _setup_rust_project( (src_dir / "lib.rs").write_text(lib_rs) # Placeholder generated.rs so the project structure is valid - (src_dir / "generated.rs").write_text( - "// Placeholder — will be overwritten by the agent\n" - ) + (src_dir / "generated.rs").write_text("// Placeholder — will be overwritten by the agent\n") # Validation binary validate_rs = """ @@ -1552,9 +1518,7 @@ def optimize_model( total_input_tokens += response.usage.input_tokens total_output_tokens += response.usage.output_tokens if verbose: - print( - f" Turn {turn}: {response.usage.input_tokens} in / {response.usage.output_tokens} out tokens" - ) + print(f" Turn {turn}: {response.usage.input_tokens} in / {response.usage.output_tokens} out tokens") if response.stop_reason == "end_turn": if verbose: diff --git a/transpailer/exporter.py b/transpailer/exporter.py index e1ba221..9337d4e 100644 --- a/transpailer/exporter.py +++ b/transpailer/exporter.py @@ -86,10 +86,7 @@ def to_dict(self) -> dict: "logp": self.initial_point.logp, "dlogp": self.initial_point.dlogp, }, - "extra_points": [ - {"point": p.point, "logp": p.logp, "dlogp": p.dlogp} - for p in self.extra_points - ], + "extra_points": [{"point": p.point, "logp": p.logp, "dlogp": p.dlogp} for p in self.extra_points], }, } @@ -129,9 +126,7 @@ def _extract(self) -> ModelContext: transform = model.rvs_to_transforms.get(rv, None) # Use value_var shape for unconstrained size (transforms may change dims) rv_shape = list(rv.type.shape) if hasattr(rv.type, "shape") else [] - unc_shape = ( - list(value_var.type.shape) if hasattr(value_var.type, "shape") else [] - ) + unc_shape = list(value_var.type.shape) if hasattr(value_var.type, "shape") else [] size = int(np.prod(unc_shape)) if unc_shape else 1 zs_axes = None if transform and type(transform).__name__ == "ZeroSumTransform": @@ -203,10 +198,7 @@ def _per_rv_logp(point): return {name: float(fn(point)) for name, fn in per_rv_logp_fns.items()} initial = ValidationPoint( - point={ - k: v.tolist() if hasattr(v, "tolist") else v - for k, v in test_point.items() - }, + point={k: v.tolist() if hasattr(v, "tolist") else v for k, v in test_point.items()}, logp=float(logp_fn(test_point)), dlogp=dlogp_fn(test_point).tolist(), per_rv_logp=_per_rv_logp(test_point), @@ -246,9 +238,7 @@ def _per_rv_logp(point): ) @staticmethod - def _extract_covariates( - model: pm.Model, observed_data: dict[str, dict] - ) -> dict[str, dict]: + def _extract_covariates(model: pm.Model, observed_data: dict[str, dict]) -> dict[str, dict]: """Find non-scalar constant arrays in the logp graph (predictors/covariates). These are numpy arrays passed into the model (e.g. x in regression) @@ -297,11 +287,7 @@ def _extract_covariates( if np.all(flat == np.floor(flat)) and np.min(flat) == 0: unique_vals = np.unique(flat) max_val = int(np.max(flat)) - if ( - max_val >= 2 - and max_val < 200 - and np.array_equal(unique_vals, np.arange(max_val + 1)) - ): + if max_val >= 2 and max_val < 200 and np.array_equal(unique_vals, np.arange(max_val + 1)): is_index = True n_groups = max_val + 1 @@ -341,12 +327,8 @@ def _infer_data_mapping(self, ctx) -> list[str]: # Covariates: match index arrays to index variables, others to remaining vars covariate_items = list(ctx.covariate_data.items()) - index_covariates = [ - (n, i) for n, i in covariate_items if i.get("is_index_array") - ] - non_index_covariates = [ - (n, i) for n, i in covariate_items if not i.get("is_index_array") - ] + index_covariates = [(n, i) for n, i in covariate_items if i.get("is_index_array")] + non_index_covariates = [(n, i) for n, i in covariate_items if not i.get("is_index_array")] # Match index covariates to index variables from source # If only one index covariate and one index variable, match them directly @@ -359,9 +341,7 @@ def _infer_data_mapping(self, ctx) -> list[str]: f"(integer indices, {n_groups} groups, cast to `usize` for indexing)" ) else: - for (cov_name, cov_info), src_var in zip( - index_covariates, sorted(index_vars) - ): + for (cov_name, cov_info), src_var in zip(index_covariates, sorted(index_vars)): n_groups = cov_info.get("n_groups", 0) hints.append( f"`{src_var}` in source → `{cov_name.upper()}_DATA` " @@ -376,9 +356,7 @@ def _infer_data_mapping(self, ctx) -> list[str]: arith_vars -= {"observed", "shape"} remaining_src_vars = sorted(arith_vars) - for (cov_name, cov_info), src_var in zip( - non_index_covariates, remaining_src_vars - ): + for (cov_name, cov_info), src_var in zip(non_index_covariates, remaining_src_vars): hints.append(f"`{src_var}` in source → `{cov_name.upper()}_DATA`") # Observed data mapping @@ -387,9 +365,7 @@ def _infer_data_mapping(self, ctx) -> list[str]: obs_match = _re.search(r"(\w+)\s*~.*observed", source) if obs_match: src_obs = obs_match.group(1) - hints.append( - f"`{src_obs}` (observed) in source → `{obs_name.upper()}_DATA`" - ) + hints.append(f"`{src_obs}` (observed) in source → `{obs_name.upper()}_DATA`") return hints @@ -536,10 +512,7 @@ def to_prompt(self) -> str: label = f"INTEGER INDEX ARRAY (values 0..{n_groups - 1}, {n_groups} groups) — cast to usize for array indexing" else: label = "covariate/predictor" - parts.append( - f"- `{rust_name}_DATA: &[f64]` — {label}, n={n}, " - f"range={range_str}, mean={mean_str}" - ) + parts.append(f"- `{rust_name}_DATA: &[f64]` — {label}, n={n}, range={range_str}, mean={mean_str}") parts.append(f" `{rust_name}_N: usize = {n}`") # Try to add source-to-data mapping hints @@ -552,13 +525,9 @@ def to_prompt(self) -> str: parts.append(f"- {hint}") parts.append("") - parts.append( - f"## Optimized PyTensor Graph (logp)\n```\n{ctx.logp_graph}\n```\n" - ) + parts.append(f"## Optimized PyTensor Graph (logp)\n```\n{ctx.logp_graph}\n```\n") - parts.append( - f"## Optimized PyTensor Graph (dlogp/gradient)\n```\n{ctx.dlogp_graph}\n```\n" - ) + parts.append(f"## Optimized PyTensor Graph (dlogp/gradient)\n```\n{ctx.dlogp_graph}\n```\n") parts.append("## Individual logp terms (optimized, per RV)\n") for name, term in ctx.logp_terms.items(): @@ -566,9 +535,7 @@ def to_prompt(self) -> str: parts.append(f"### {name}\n```\n{display}\n```\n") parts.append("## Validation") - parts.append( - "Your generated code MUST produce these exact values (within float64 precision):\n" - ) + parts.append("Your generated code MUST produce these exact values (within float64 precision):\n") parts.append(f"At initial point: {json.dumps(ctx.initial_point.point)}") parts.append(f"- logp = {ctx.initial_point.logp:.10f}") @@ -639,9 +606,7 @@ def _flatten(v): assert_close(logp, {vp.logp:.10f}, 1e-6, "logp");""") for i, g in enumerate(vp.dlogp): - tests.append( - f' assert_close(gradient[{i}], {g:.10e}, 1e-4, "grad[{i}]");' - ) + tests.append(f' assert_close(gradient[{i}], {g:.10e}, 1e-4, "grad[{i}]");') tests.append(" }\n") tests.append("}\n") @@ -676,9 +641,7 @@ def export_model( exporter = export_model(model) prompt = exporter.to_prompt() """ - exporter = RustModelExporter( - model, source_code=source_code, n_extra_points=n_extra_points - ) + exporter = RustModelExporter(model, source_code=source_code, n_extra_points=n_extra_points) if output_dir: exporter.save_all(output_dir) return exporter diff --git a/transpailer/formatting.py b/transpailer/formatting.py new file mode 100644 index 0000000..5a9c1a3 --- /dev/null +++ b/transpailer/formatting.py @@ -0,0 +1,25 @@ +"""Auto-format generated code using ruff.""" + +from __future__ import annotations + +import subprocess + + +def format_python_code(code: str) -> str: + """Format Python code using ruff format. + + Falls back to the original code if ruff is not available or formatting fails. + """ + try: + result = subprocess.run( + ["ruff", "format", "--quiet", "-"], + input=code, + capture_output=True, + text=True, + timeout=10, + ) + if result.returncode == 0: + return result.stdout + except (FileNotFoundError, subprocess.TimeoutExpired): + pass + return code diff --git a/transpailer/jax_exporter.py b/transpailer/jax_exporter.py index a860048..8921b91 100644 --- a/transpailer/jax_exporter.py +++ b/transpailer/jax_exporter.py @@ -57,18 +57,9 @@ def to_dict(self) -> dict: return { "source_framework": self.source_framework, "source_code": self.source_code, - "parameters": [ - {"name": p.name, "shape": p.shape, "dtype": p.dtype, "size": p.size} - for p in self.params - ], - "inputs": [ - {"name": i.name, "shape": i.shape, "dtype": i.dtype, "size": i.size} - for i in self.inputs - ], - "outputs": [ - {"name": o.name, "shape": o.shape, "dtype": o.dtype, "size": o.size} - for o in self.outputs - ], + "parameters": [{"name": p.name, "shape": p.shape, "dtype": p.dtype, "size": p.size} for p in self.params], + "inputs": [{"name": i.name, "shape": i.shape, "dtype": i.dtype, "size": i.size} for i in self.inputs], + "outputs": [{"name": o.name, "shape": o.shape, "dtype": o.dtype, "size": o.size} for o in self.outputs], "validation_points": [ { "params": vp.params, diff --git a/transpailer/jax_pytorch_transpiler.py b/transpailer/jax_pytorch_transpiler.py index e122dfd..eaf5ea4 100644 --- a/transpailer/jax_pytorch_transpiler.py +++ b/transpailer/jax_pytorch_transpiler.py @@ -18,9 +18,9 @@ import numpy as np +from transpailer.formatting import format_python_code as _format_python from transpailer.jax_exporter import ModelContext - _SKILLS_DIR = Path(__file__).parent / "skills" @@ -164,9 +164,7 @@ def forward(params: dict, x): }, { "name": "read_source", - "description": ( - "Re-read the original source code of the model being transpiled." - ), + "description": ("Re-read the original source code of the model being transpiled."), "input_schema": { "type": "object", "properties": {}, @@ -325,8 +323,7 @@ def _build_user_prompt(ctx: ModelContext, direction: str) -> str: parts.append(f" expected output = {vp.output}") else: parts.append( - f" expected output: shape={list(out_arr.shape)}, " - f"mean={out_arr.mean():.6f}, std={out_arr.std():.6f}" + f" expected output: shape={list(out_arr.shape)}, mean={out_arr.mean():.6f}, std={out_arr.std():.6f}" ) # Show expected gradients (truncated) @@ -341,10 +338,7 @@ def _build_user_prompt(ctx: ModelContext, direction: str) -> str: ) parts.append("") - parts.append( - f"Generate the {target} code using `write_code`, " - f"then call `validate_model` to check correctness." - ) + parts.append(f"Generate the {target} code using `write_code`, then call `validate_model` to check correctness.") return "\n".join(parts) @@ -384,6 +378,7 @@ def _tool_write_code( print(f" [write_code] Syntax error: {e}") return f"Syntax error in generated code: {e}" + code = _format_python(code) state.generated_code = code if verbose: print(f" [write_code] Wrote {len(code)} chars") @@ -444,10 +439,7 @@ def _validate_pytorch(namespace: dict, state: _AgentState, verbose: bool) -> str x = torch.tensor(np.array(inp["x"]), dtype=torch.float32) output = model(x) else: - tensors = { - k: torch.tensor(np.array(v), dtype=torch.float32) - for k, v in inp.items() - } + tensors = {k: torch.tensor(np.array(v), dtype=torch.float32) for k, v in inp.items()} output = model(**tensors) out_np = output.detach().cpu().numpy() @@ -455,28 +447,19 @@ def _validate_pytorch(namespace: dict, state: _AgentState, verbose: bool) -> str # Check output if out_np.shape != ref_out.shape: - errors.append( - f"{label}: shape mismatch: got {out_np.shape}, expected {ref_out.shape}" - ) - report_lines.append( - f"{label}: SHAPE MISMATCH {out_np.shape} vs {ref_out.shape}" - ) + errors.append(f"{label}: shape mismatch: got {out_np.shape}, expected {ref_out.shape}") + report_lines.append(f"{label}: SHAPE MISMATCH {out_np.shape} vs {ref_out.shape}") continue max_diff = float(np.max(np.abs(out_np - ref_out))) - rel_err = float( - np.max(np.abs(out_np - ref_out) / np.maximum(np.abs(ref_out), 1e-8)) - ) + rel_err = float(np.max(np.abs(out_np - ref_out) / np.maximum(np.abs(ref_out), 1e-8))) out_ok = rel_err <= 1e-4 report_lines.append( - f"{label}: output max_diff={max_diff:.2e} rel_err={rel_err:.2e} " - f"[{'OK' if out_ok else 'MISMATCH'}]" + f"{label}: output max_diff={max_diff:.2e} rel_err={rel_err:.2e} [{'OK' if out_ok else 'MISMATCH'}]" ) if not out_ok: - errors.append( - f"{label}: output mismatch: max_diff={max_diff:.2e}, rel_err={rel_err:.2e}" - ) + errors.append(f"{label}: output mismatch: max_diff={max_diff:.2e}, rel_err={rel_err:.2e}") # Check gradients loss = output.sum() @@ -489,20 +472,14 @@ def _validate_pytorch(namespace: dict, state: _AgentState, verbose: bool) -> str if name == pname and param.grad is not None: got_g = param.grad.detach().cpu().numpy() grad_diff = float(np.max(np.abs(got_g - ref_g))) - grad_rel = float( - np.max( - np.abs(got_g - ref_g) / np.maximum(np.abs(ref_g), 1e-8) - ) - ) + grad_rel = float(np.max(np.abs(got_g - ref_g) / np.maximum(np.abs(ref_g), 1e-8))) grad_ok = grad_rel <= 1e-3 report_lines.append( f" grad['{name}']: max_diff={grad_diff:.2e} rel_err={grad_rel:.2e} " f"[{'OK' if grad_ok else 'MISMATCH'}]" ) if not grad_ok: - errors.append( - f"{label}: grad['{name}'] mismatch: rel_err={grad_rel:.2e}" - ) + errors.append(f"{label}: grad['{name}'] mismatch: rel_err={grad_rel:.2e}") found = True break if not found: @@ -522,9 +499,8 @@ def _validate_pytorch(namespace: dict, state: _AgentState, verbose: bool) -> str else: if verbose: print(f" [validate] FAILED ({len(errors)} errors)") - return ( - f"VALIDATION FAILED ({len(errors)} errors):\n\n{report}\n\n" - f"Errors:\n" + "\n".join(f"- {e}" for e in errors) + return f"VALIDATION FAILED ({len(errors)} errors):\n\n{report}\n\nErrors:\n" + "\n".join( + f"- {e}" for e in errors ) @@ -536,9 +512,7 @@ def _validate_jax(namespace: dict, state: _AgentState, verbose: bool) -> str: if "forward" not in namespace: return "Error: generated code does not define `forward(params, x)` function." if "init_params" not in namespace: - return ( - "Error: generated code does not define `init_params(param_data)` function." - ) + return "Error: generated code does not define `init_params(param_data)` function." forward_fn = namespace["forward"] init_fn = namespace["init_params"] @@ -575,28 +549,19 @@ def scalar_fn(params, x): ref_out = np.array(vp.output) if out_np.shape != ref_out.shape: - errors.append( - f"{label}: shape mismatch: got {out_np.shape}, expected {ref_out.shape}" - ) - report_lines.append( - f"{label}: SHAPE MISMATCH {out_np.shape} vs {ref_out.shape}" - ) + errors.append(f"{label}: shape mismatch: got {out_np.shape}, expected {ref_out.shape}") + report_lines.append(f"{label}: SHAPE MISMATCH {out_np.shape} vs {ref_out.shape}") continue max_diff = float(np.max(np.abs(out_np - ref_out))) - rel_err = float( - np.max(np.abs(out_np - ref_out) / np.maximum(np.abs(ref_out), 1e-8)) - ) + rel_err = float(np.max(np.abs(out_np - ref_out) / np.maximum(np.abs(ref_out), 1e-8))) out_ok = rel_err <= 1e-4 report_lines.append( - f"{label}: output max_diff={max_diff:.2e} rel_err={rel_err:.2e} " - f"[{'OK' if out_ok else 'MISMATCH'}]" + f"{label}: output max_diff={max_diff:.2e} rel_err={rel_err:.2e} [{'OK' if out_ok else 'MISMATCH'}]" ) if not out_ok: - errors.append( - f"{label}: output mismatch: max_diff={max_diff:.2e}, rel_err={rel_err:.2e}" - ) + errors.append(f"{label}: output mismatch: max_diff={max_diff:.2e}, rel_err={rel_err:.2e}") # Check gradients grads = grad_fn(params, x) @@ -606,18 +571,14 @@ def scalar_fn(params, x): if pname in grads: got_g = np.asarray(grads[pname]) grad_diff = float(np.max(np.abs(got_g - ref_g))) - grad_rel = float( - np.max(np.abs(got_g - ref_g) / np.maximum(np.abs(ref_g), 1e-8)) - ) + grad_rel = float(np.max(np.abs(got_g - ref_g) / np.maximum(np.abs(ref_g), 1e-8))) grad_ok = grad_rel <= 1e-3 report_lines.append( f" grad['{pname}']: max_diff={grad_diff:.2e} rel_err={grad_rel:.2e} " f"[{'OK' if grad_ok else 'MISMATCH'}]" ) if not grad_ok: - errors.append( - f"{label}: grad['{pname}'] mismatch: rel_err={grad_rel:.2e}" - ) + errors.append(f"{label}: grad['{pname}'] mismatch: rel_err={grad_rel:.2e}") else: errors.append(f"{label}: gradient for '{pname}' not found") @@ -635,9 +596,8 @@ def scalar_fn(params, x): else: if verbose: print(f" [validate] FAILED ({len(errors)} errors)") - return ( - f"VALIDATION FAILED ({len(errors)} errors):\n\n{report}\n\n" - f"Errors:\n" + "\n".join(f"- {e}" for e in errors) + return f"VALIDATION FAILED ({len(errors)} errors):\n\n{report}\n\nErrors:\n" + "\n".join( + f"- {e}" for e in errors ) @@ -683,11 +643,7 @@ def _run_agent_loop( total_input_tokens += response.usage.input_tokens total_output_tokens += response.usage.output_tokens if verbose: - print( - f" Turn {turn}: " - f"{response.usage.input_tokens} in / " - f"{response.usage.output_tokens} out tokens" - ) + print(f" Turn {turn}: {response.usage.input_tokens} in / {response.usage.output_tokens} out tokens") if response.stop_reason == "end_turn": if verbose: @@ -785,9 +741,7 @@ def transpile_jax_to_pytorch( if verbose: n_params = sum(p.size for p in ctx.params) - print( - f" {n_params} parameters, {len(ctx.validation_points)} validation points" - ) + print(f" {n_params} parameters, {len(ctx.validation_points)} validation points") # Step 2: Build prompts and run agent system_prompt = _build_system_prompt("jax_to_pytorch") @@ -815,9 +769,7 @@ def transpile_jax_to_pytorch( validation_errors = [] if not state.validated: - validation_errors.append( - f"Agent did not achieve validation after {state.tool_calls} tool calls" - ) + validation_errors.append(f"Agent did not achieve validation after {state.tool_calls} tool calls") return TranspileResult( source_framework="jax", @@ -881,9 +833,7 @@ def transpile_pytorch_to_jax( if verbose: n_params = sum(p.size for p in ctx.params) - print( - f" {n_params} parameters, {len(ctx.validation_points)} validation points" - ) + print(f" {n_params} parameters, {len(ctx.validation_points)} validation points") # Step 2: Build prompts and run agent system_prompt = _build_system_prompt("pytorch_to_jax") @@ -911,9 +861,7 @@ def transpile_pytorch_to_jax( validation_errors = [] if not state.validated: - validation_errors.append( - f"Agent did not achieve validation after {state.tool_calls} tool calls" - ) + validation_errors.append(f"Agent did not achieve validation after {state.tool_calls} tool calls") return TranspileResult( source_framework="pytorch", diff --git a/transpailer/nutpie_bridge.py b/transpailer/nutpie_bridge.py index 8a61a44..16739e7 100644 --- a/transpailer/nutpie_bridge.py +++ b/transpailer/nutpie_bridge.py @@ -49,8 +49,7 @@ def _build_shared_lib(build_dir: Path) -> Path: if not so_path.exists(): raise RuntimeError( - f"Shared library not found at {so_path}. " - 'Ensure Cargo.toml has [lib] crate-type = ["cdylib"]' + f'Shared library not found at {so_path}. Ensure Cargo.toml has [lib] crate-type = ["cdylib"]' ) return so_path @@ -124,9 +123,7 @@ def to_nutpie( ip = model.initial_point() from pymc.blocking import DictToArrayBijection - x0 = DictToArrayBijection.map( - {v.name: ip[v.name] for v in model_fn._grad_vars} - ).data + x0 = DictToArrayBijection.map({v.name: ip[v.name] for v in model_fn._grad_vars}).data n_dim = len(x0) # Get variable names, shapes, dtypes from the model @@ -165,9 +162,7 @@ def expand_fn(x): def make_initial_point(seed): ip_ = model.initial_point() - return DictToArrayBijection.map( - {v.name: ip_[v.name] for v in model_fn._grad_vars} - ).data.astype(np.float64) + return DictToArrayBijection.map({v.name: ip_[v.name] for v in model_fn._grad_vars}).data.astype(np.float64) return from_pyfunc( ndim=n_dim, diff --git a/transpailer/pytorch_exporter.py b/transpailer/pytorch_exporter.py index 42cd16b..a4c1761 100644 --- a/transpailer/pytorch_exporter.py +++ b/transpailer/pytorch_exporter.py @@ -82,11 +82,7 @@ def _extract(self) -> ModelContext: elif isinstance(self.sample_input, (tuple, list)): for i, val in enumerate(self.sample_input): t = torch.as_tensor(val) - name = ( - self.input_names[i] - if self.input_names and i < len(self.input_names) - else f"x_{i}" - ) + name = self.input_names[i] if self.input_names and i < len(self.input_names) else f"x_{i}" input_infos.append( TensorInfo( name=name, @@ -162,9 +158,7 @@ def _forward(self, module, inp): import torch if isinstance(inp, dict): - return module( - **{k: torch.as_tensor(v, dtype=torch.float32) for k, v in inp.items()} - ) + return module(**{k: torch.as_tensor(v, dtype=torch.float32) for k, v in inp.items()}) elif isinstance(inp, (tuple, list)): return module(*[torch.as_tensor(v, dtype=torch.float32) for v in inp]) else: @@ -205,9 +199,7 @@ def _input_to_dict(self, inp) -> dict: if isinstance(inp, dict): return { - k: np.asarray(v).tolist() - if not isinstance(v, torch.Tensor) - else v.detach().cpu().numpy().tolist() + k: np.asarray(v).tolist() if not isinstance(v, torch.Tensor) else v.detach().cpu().numpy().tolist() for k, v in inp.items() } elif isinstance(inp, (tuple, list)): diff --git a/transpailer/pytorch_rust_transpiler.py b/transpailer/pytorch_rust_transpiler.py index b2719b8..72a74b8 100644 --- a/transpailer/pytorch_rust_transpiler.py +++ b/transpailer/pytorch_rust_transpiler.py @@ -25,7 +25,6 @@ from transpailer.jax_exporter import ModelContext, ValidationPoint - _SKILLS_DIR = Path(__file__).parent / "skills" @@ -311,8 +310,7 @@ { "name": "cargo_build", "description": ( - "Build the Rust project with `cargo build --release`. " - "Returns build output including any compiler errors." + "Build the Rust project with `cargo build --release`. Returns build output including any compiler errors." ), "input_schema": { "type": "object", @@ -333,9 +331,7 @@ }, { "name": "read_source", - "description": ( - "Re-read the original PyTorch source code of the model being transpiled." - ), + "description": ("Re-read the original PyTorch source code of the model being transpiled."), "input_schema": { "type": "object", "properties": {}, @@ -497,9 +493,7 @@ def _setup_rust_project(build_path: Path, ctx: ModelContext, backend: str = "pur data_lines.append(f" {vals},") data_lines.append("];") - data_lines.append( - f"pub const {safe_name}_SHAPE: &[usize] = &{param_info.shape};" - ) + data_lines.append(f"pub const {safe_name}_SHAPE: &[usize] = &{param_info.shape};") data_lines.append("") (src / "data.rs").write_text("\n".join(data_lines)) @@ -532,23 +526,17 @@ def _build_user_prompt(ctx: ModelContext) -> str: # Data.rs mapping parts.append("## Parameter Constants in data.rs\n") - parts.append( - "Each parameter is available as a flat `&[f32]` constant plus a shape constant:\n" - ) + parts.append("Each parameter is available as a flat `&[f32]` constant plus a shape constant:\n") for p in ctx.params: safe_name = p.name.replace(".", "_").upper() - parts.append( - f"- `{safe_name}`: &[f32] (len={p.size}), `{safe_name}_SHAPE`: &[usize] = {p.shape}" - ) + parts.append(f"- `{safe_name}`: &[f32] (len={p.size}), `{safe_name}_SHAPE`: &[usize] = {p.shape}") parts.append("") # Input info parts.append("## Inputs\n") for i in ctx.inputs: parts.append(f"- `{i.name}`: shape={i.shape}, dtype={i.dtype}") - parts.append( - "The `forward()` function receives the input as a flat &[f32] array.\n" - ) + parts.append("The `forward()` function receives the input as a flat &[f32] array.\n") # Output info parts.append("## Outputs\n") @@ -573,8 +561,7 @@ def _build_user_prompt(ctx: ModelContext) -> str: parts.append(f" expected output = {vp.output}") else: parts.append( - f" expected output: shape={list(out_arr.shape)}, " - f"mean={out_arr.mean():.6f}, std={out_arr.std():.6f}" + f" expected output: shape={list(out_arr.shape)}, mean={out_arr.mean():.6f}, std={out_arr.std():.6f}" ) # Show expected gradients (truncated) @@ -637,9 +624,7 @@ def _tool_write_code( if verbose: print(f" [write_code] Wrote {len(code)} chars to generated.rs") - return ( - f"Written {len(code)} chars to src/generated.rs. Use `cargo_build` to compile." - ) + return f"Written {len(code)} chars to src/generated.rs. Use `cargo_build` to compile." def _tool_cargo_build(state: _AgentState, verbose: bool) -> str: @@ -707,9 +692,7 @@ def _tool_validate(state: _AgentState, verbose: bool) -> str: if "x" in inp: flat_input = np.array(inp["x"], dtype=np.float32).ravel() else: - flat_input = np.concatenate( - [np.array(v, dtype=np.float32).ravel() for v in inp.values()] - ) + flat_input = np.concatenate([np.array(v, dtype=np.float32).ravel() for v in inp.values()]) input_str = ",".join(f"{v:.9e}" for v in flat_input) @@ -746,30 +729,19 @@ def _tool_validate(state: _AgentState, verbose: bool) -> str: ref_output = np.array(vp.output, dtype=np.float32).ravel() if rust_output.shape != ref_output.shape: - errors.append( - f"{label}: output shape mismatch: got {rust_output.shape}, expected {ref_output.shape}" - ) - report_lines.append( - f"{label}: SHAPE MISMATCH {rust_output.shape} vs {ref_output.shape}" - ) + errors.append(f"{label}: output shape mismatch: got {rust_output.shape}, expected {ref_output.shape}") + report_lines.append(f"{label}: SHAPE MISMATCH {rust_output.shape} vs {ref_output.shape}") continue max_diff = float(np.max(np.abs(rust_output - ref_output))) - rel_err = float( - np.max( - np.abs(rust_output - ref_output) / np.maximum(np.abs(ref_output), 1e-8) - ) - ) + rel_err = float(np.max(np.abs(rust_output - ref_output) / np.maximum(np.abs(ref_output), 1e-8))) out_ok = rel_err <= 1e-4 report_lines.append( - f"{label}: output max_diff={max_diff:.2e} rel_err={rel_err:.2e} " - f"[{'OK' if out_ok else 'MISMATCH'}]" + f"{label}: output max_diff={max_diff:.2e} rel_err={rel_err:.2e} [{'OK' if out_ok else 'MISMATCH'}]" ) if not out_ok: - errors.append( - f"{label}: output mismatch: max_diff={max_diff:.2e}, rel_err={rel_err:.2e}" - ) + errors.append(f"{label}: output mismatch: max_diff={max_diff:.2e}, rel_err={rel_err:.2e}") # Test gradients for each parameter for pname, ref_grad in vp.grad_params.items(): @@ -806,15 +778,11 @@ def _tool_validate(state: _AgentState, verbose: bool) -> str: ref_g = np.array(ref_grad, dtype=np.float32).ravel() if rust_grad.shape != ref_g.shape: - errors.append( - f"{label}: grad['{pname}'] shape mismatch: {rust_grad.shape} vs {ref_g.shape}" - ) + errors.append(f"{label}: grad['{pname}'] shape mismatch: {rust_grad.shape} vs {ref_g.shape}") continue grad_diff = float(np.max(np.abs(rust_grad - ref_g))) - grad_rel = float( - np.max(np.abs(rust_grad - ref_g) / np.maximum(np.abs(ref_g), 1e-8)) - ) + grad_rel = float(np.max(np.abs(rust_grad - ref_g) / np.maximum(np.abs(ref_g), 1e-8))) grad_ok = grad_rel <= 1e-3 report_lines.append( @@ -822,9 +790,7 @@ def _tool_validate(state: _AgentState, verbose: bool) -> str: f"[{'OK' if grad_ok else 'MISMATCH'}]" ) if not grad_ok: - errors.append( - f"{label}: grad['{pname}'] mismatch: rel_err={grad_rel:.2e}" - ) + errors.append(f"{label}: grad['{pname}'] mismatch: rel_err={grad_rel:.2e}") # Restore original data.rs if we modified it if len(ctx.validation_points) > 1: @@ -840,9 +806,8 @@ def _tool_validate(state: _AgentState, verbose: bool) -> str: else: if verbose: print(f" [validate] FAILED ({len(errors)} errors)") - return ( - f"VALIDATION FAILED ({len(errors)} errors):\n\n{report}\n\n" - f"Errors:\n" + "\n".join(f"- {e}" for e in errors) + return f"VALIDATION FAILED ({len(errors)} errors):\n\n{report}\n\nErrors:\n" + "\n".join( + f"- {e}" for e in errors ) @@ -870,20 +835,14 @@ def _update_data_rs(build_path: Path, ctx: ModelContext, vp: ValidationPoint): data_lines.append(f" {vals},") data_lines.append("];") - data_lines.append( - f"pub const {safe_name}_SHAPE: &[usize] = &{param_info.shape};" - ) + data_lines.append(f"pub const {safe_name}_SHAPE: &[usize] = &{param_info.shape};") data_lines.append("") (build_path / "src" / "data.rs").write_text("\n".join(data_lines)) def _tool_read_source(state: _AgentState, verbose: bool) -> str: - source = ( - state.source_code - or state.source_context.source_code - or "(no source code available)" - ) + source = state.source_code or state.source_context.source_code or "(no source code available)" if verbose: print(f" [read_source] {len(source)} chars") return source @@ -1008,11 +967,7 @@ def _run_agent_loop( total_input_tokens += response.usage.input_tokens total_output_tokens += response.usage.output_tokens if verbose: - print( - f" Turn {turn}: " - f"{response.usage.input_tokens} in / " - f"{response.usage.output_tokens} out tokens" - ) + print(f" Turn {turn}: {response.usage.input_tokens} in / {response.usage.output_tokens} out tokens") if response.stop_reason == "end_turn": if verbose: @@ -1118,9 +1073,7 @@ def transpile_pytorch_to_rust( if verbose: n_params = sum(p.size for p in ctx.params) - print( - f" {n_params} parameters, {len(ctx.validation_points)} validation points" - ) + print(f" {n_params} parameters, {len(ctx.validation_points)} validation points") # Step 2: Set up Rust build directory if build_dir: @@ -1173,9 +1126,7 @@ def transpile_pytorch_to_rust( validation_errors = [] if not state.validated: - validation_errors.append( - f"Agent did not achieve validation after {state.tool_calls} tool calls" - ) + validation_errors.append(f"Agent did not achieve validation after {state.tool_calls} tool calls") return RustTranspileResult( generated_code=rust_code, diff --git a/transpailer/stan_compiler.py b/transpailer/stan_compiler.py index c3f9087..74ffdae 100644 --- a/transpailer/stan_compiler.py +++ b/transpailer/stan_compiler.py @@ -173,8 +173,7 @@ { "name": "cargo_build", "description": ( - "Build the Rust project with `cargo build --release`. " - "Returns build output including any compiler errors." + "Build the Rust project with `cargo build --release`. Returns build output including any compiler errors." ), "input_schema": { "type": "object", @@ -431,9 +430,7 @@ def compile_stan_model( total_input_tokens += response.usage.input_tokens total_output_tokens += response.usage.output_tokens if verbose: - print( - f" Turn {turn}: {response.usage.input_tokens} in / {response.usage.output_tokens} out tokens" - ) + print(f" Turn {turn}: {response.usage.input_tokens} in / {response.usage.output_tokens} out tokens") if response.stop_reason == "end_turn": if verbose: @@ -479,9 +476,7 @@ def compile_stan_model( validation_errors = [] if not state.validated: - validation_errors.append( - f"Agent did not achieve validation after {state.tool_calls} tool calls" - ) + validation_errors.append(f"Agent did not achieve validation after {state.tool_calls} tool calls") return StanCompilationResult( rust_code=rust_code, @@ -570,9 +565,7 @@ def _tool_validate_logp(state: _AgentState, verbose: bool) -> str: if not binary.exists(): return "Error: validation binary not found. Run cargo_build first." - all_points = [("initial", ctx.initial_point)] + [ - (f"extra_{i}", p) for i, p in enumerate(ctx.extra_points) - ] + all_points = [("initial", ctx.initial_point)] + [(f"extra_{i}", p) for i, p in enumerate(ctx.extra_points)] input_lines = [] for name, vp in all_points: @@ -597,9 +590,7 @@ def _tool_validate_logp(state: _AgentState, verbose: bool) -> str: output_lines = [line for line in result.stdout.strip().split("\n") if line] if len(output_lines) != len(all_points): - return ( - f"Error: expected {len(all_points)} output lines, got {len(output_lines)}" - ) + return f"Error: expected {len(all_points)} output lines, got {len(output_lines)}" parsed = [] for output_line in output_lines: @@ -623,13 +614,11 @@ def _tool_validate_logp(state: _AgentState, verbose: bool) -> str: rel_err = abs(rust_logp - vp.logp) / max(abs(vp.logp), 1.0) status = "OK" if rel_err <= 1e-4 else "MISMATCH" report_lines.append( - f"{name}: logp BridgeStan={vp.logp:.10f} Rust={rust_logp:.10f} " - f"rel_err={rel_err:.2e} [{status}]" + f"{name}: logp BridgeStan={vp.logp:.10f} Rust={rust_logp:.10f} rel_err={rel_err:.2e} [{status}]" ) if rel_err > 1e-4: errors.append( - f"{name}: logp mismatch: BridgeStan={vp.logp:.10f}, " - f"Rust={rust_logp:.10f}, rel_err={rel_err:.2e}" + f"{name}: logp mismatch: BridgeStan={vp.logp:.10f}, Rust={rust_logp:.10f}, rel_err={rel_err:.2e}" ) # Gradient comparison @@ -658,10 +647,7 @@ def _tool_validate_logp(state: _AgentState, verbose: bool) -> str: else: if verbose: print(f" [validate_logp] FAILED ({len(errors)} errors)") - return ( - f"VALIDATION FAILED ({len(errors)} errors):\n\n{report}\n\n" - f"Errors:\n" + "\n".join(errors) - ) + return f"VALIDATION FAILED ({len(errors)} errors):\n\n{report}\n\nErrors:\n" + "\n".join(errors) def _tool_read_file( @@ -780,9 +766,7 @@ def _generate_data_rs(data: dict | None) -> str: # Also export shape info for multi-dimensional arrays if arr.ndim > 1: for dim_i, dim_size in enumerate(arr.shape): - lines.append( - f"pub const {name.upper()}_DIM{dim_i}: usize = {dim_size};" - ) + lines.append(f"pub const {name.upper()}_DIM{dim_i}: usize = {dim_size};") lines.append("") return "\n".join(lines) @@ -840,9 +824,7 @@ def _setup_rust_project( (src_dir / "lib.rs").write_text(lib_rs) # Placeholder generated.rs - (src_dir / "generated.rs").write_text( - "// Placeholder — will be overwritten by the agent\n" - ) + (src_dir / "generated.rs").write_text("// Placeholder — will be overwritten by the agent\n") # Validation binary validate_rs = """ diff --git a/transpailer/stan_exporter.py b/transpailer/stan_exporter.py index 6da4615..9d15c7d 100644 --- a/transpailer/stan_exporter.py +++ b/transpailer/stan_exporter.py @@ -68,10 +68,7 @@ def to_dict(self) -> dict: "logp": self.initial_point.logp, "dlogp": self.initial_point.dlogp, }, - "extra_points": [ - {"point": p.point, "logp": p.logp, "dlogp": p.dlogp} - for p in self.extra_points - ], + "extra_points": [{"point": p.point, "logp": p.logp, "dlogp": p.dlogp} for p in self.extra_points], }, } @@ -177,9 +174,7 @@ def _extract(self) -> StanModelContext: unc_param_names=list(unc_param_names), n_params=n_unconstrained, n_params_constrained=n_constrained, - data_json=json.dumps(self._data) - if isinstance(self._data, dict) - else self._data, + data_json=json.dumps(self._data) if isinstance(self._data, dict) else self._data, data_summary=data_summary, initial_point=initial, extra_points=extra_points, @@ -243,8 +238,7 @@ def to_prompt(self) -> str: range_str = f"[{mn:.3f}, {mx:.3f}]" if mn is not None else "unknown" mean_str = f"{mean:.3f}" if mean is not None else "unknown" parts.append( - f"- `{name.upper()}_DATA: &[f64]` — shape={shape}, n={n}, " - f"range={range_str}, mean={mean_str}" + f"- `{name.upper()}_DATA: &[f64]` — shape={shape}, n={n}, range={range_str}, mean={mean_str}" ) parts.append(f" `{name.upper()}_N: usize = {n}`") @@ -253,17 +247,13 @@ def to_prompt(self) -> str: for name, value in self._data.items(): arr = np.asarray(value) if arr.ndim == 0 and np.issubdtype(arr.dtype, np.integer): - parts.append( - f"- `{name.upper()}: usize = {int(arr)}` (scalar integer)" - ) + parts.append(f"- `{name.upper()}: usize = {int(arr)}` (scalar integer)") parts.append("") # Validation parts.append("## Validation") - parts.append( - "Your generated code MUST produce these exact values (within float64 precision):\n" - ) + parts.append("Your generated code MUST produce these exact values (within float64 precision):\n") parts.append( "NOTE: BridgeStan computes log_density with jacobian=True, propto=True.\n" "This means the log density INCLUDES Jacobian adjustments for constrained parameters\n" @@ -324,9 +314,7 @@ def to_rust_tests(self, struct_name: str = "GeneratedLogp") -> str: assert_close(logp, {vp.logp:.10f}, 1e-6, "logp");""") for i, g in enumerate(vp.dlogp): - tests.append( - f' assert_close(gradient[{i}], {g:.10e}, 1e-4, "grad[{i}]");' - ) + tests.append(f' assert_close(gradient[{i}], {g:.10e}, 1e-4, "grad[{i}]");') tests.append(" }\n") tests.append("}\n") diff --git a/transpailer/stan_to_pymc.py b/transpailer/stan_to_pymc.py index 377468b..d2447fe 100644 --- a/transpailer/stan_to_pymc.py +++ b/transpailer/stan_to_pymc.py @@ -16,6 +16,7 @@ import numpy as np +from transpailer.formatting import format_python_code as _format_python _SKILLS_DIR = Path(__file__).parent / "skills" @@ -95,8 +96,7 @@ def make_model(data: dict) -> pm.Model: "code": { "type": "string", "description": ( - "Complete Python source code defining make_model(data). " - "Must include all necessary imports." + "Complete Python source code defining make_model(data). Must include all necessary imports." ), }, }, @@ -118,8 +118,7 @@ def make_model(data: dict) -> pm.Model: { "name": "read_stan_code", "description": ( - "Re-read the original Stan source code. Useful if you need to " - "double-check details of the Stan model." + "Re-read the original Stan source code. Useful if you need to double-check details of the Stan model." ), "input_schema": { "type": "object", @@ -313,10 +312,7 @@ def transpile_stan_to_pymc( reference_points.append({"point": pt.point, "logp": pt.logp, "dlogp": pt.dlogp}) if verbose: - print( - f" {ctx.n_params} unconstrained params, " - f"{len(reference_points)} validation points" - ) + print(f" {ctx.n_params} unconstrained params, {len(reference_points)} validation points") # Step 2: Build prompts system_prompt = _build_system_prompt() @@ -361,11 +357,7 @@ def transpile_stan_to_pymc( total_input_tokens += response.usage.input_tokens total_output_tokens += response.usage.output_tokens if verbose: - print( - f" Turn {turn}: " - f"{response.usage.input_tokens} in / " - f"{response.usage.output_tokens} out tokens" - ) + print(f" Turn {turn}: {response.usage.input_tokens} in / {response.usage.output_tokens} out tokens") if response.stop_reason == "end_turn": if verbose: @@ -407,9 +399,7 @@ def transpile_stan_to_pymc( validation_errors = [] if not state.validated: - validation_errors.append( - f"Agent did not achieve validation after {state.tool_calls} tool calls" - ) + validation_errors.append(f"Agent did not achieve validation after {state.tool_calls} tool calls") return StanToPyMCResult( pymc_code=state.pymc_code, @@ -461,6 +451,7 @@ def _tool_write_pymc_code( print(f" [write_pymc_code] Syntax error: {e}") return f"Syntax error in generated code: {e}" + code = _format_python(code) state.pymc_code = code if verbose: print(f" [write_pymc_code] Wrote {len(code)} chars") @@ -507,9 +498,7 @@ def _tool_validate_model(state: _AgentState, verbose: bool) -> str: # per parameter for proper normalization. This constant offset doesn't affect # sampling but causes logp comparison mismatches. _HALF_RV_OPS = {"HalfNormalRV", "HalfCauchyRV", "HalfStudentTRV", "HalfFlatRV"} - n_half = sum( - 1 for rv in model.free_RVs if type(rv.owner.op).__name__ in _HALF_RV_OPS - ) + n_half = sum(1 for rv in model.free_RVs if type(rv.owner.op).__name__ in _HALF_RV_OPS) half_logp_correction = n_half * np.log(2) if verbose and n_half > 0: print( @@ -693,16 +682,10 @@ def _map_unc_point_to_pymc( for var in pymc_vars: var_name = var.name # Strip PyMC transform suffixes - base_name = re.sub( - r"_(log|logodds|interval|circular|ordered|simplex)__$", "", var_name - ) + base_name = re.sub(r"_(log|logodds|interval|circular|ordered|simplex)__$", "", var_name) # Determine size of this variable - var_size = int( - np.prod(var.type.shape) - if hasattr(var.type, "shape") and var.type.shape - else 1 - ) + var_size = int(np.prod(var.type.shape) if hasattr(var.type, "shape") and var.type.shape else 1) # Find matching Stan param group matched = False