diff --git a/examples/01_normal.py b/examples/01_normal.py index 5d8852c..3e58f97 100644 --- a/examples/01_normal.py +++ b/examples/01_normal.py @@ -46,7 +46,7 @@ if result.success: print(f"\nCompilation successful in {result.n_attempts} attempt(s)!") print(f"Timings: {result.timings}") - print(f"\nGenerated Rust code saved to: compiled_models/normal/src/generated.rs") + print("\nGenerated Rust code saved to: compiled_models/normal/src/generated.rs") else: print(f"\nCompilation FAILED after {result.n_attempts} attempts") for err in result.validation_errors: diff --git a/examples/03_hierarchical.py b/examples/03_hierarchical.py index 87089ec..3336ebd 100644 --- a/examples/03_hierarchical.py +++ b/examples/03_hierarchical.py @@ -62,7 +62,9 @@ mu_y = a[group_idx] + b * x y = pm.Normal("y", mu=mu_y, sigma=sigma_y, observed=y_obs) -print(f"True: mu_a={true_mu_a}, sigma_a={true_sigma_a}, b={true_b}, sigma_y={true_sigma_y}") +print( + f"True: mu_a={true_mu_a}, sigma_a={true_sigma_a}, b={true_b}, sigma_y={true_sigma_y}" +) print(f"Data: {n_groups} groups, {N} observations") print(f"Group sizes: {n_per_group}") print() @@ -76,8 +78,8 @@ if result.success: print(f"\nCompilation successful in {result.n_attempts} attempt(s)!") - print(f"\nNow you can benchmark:") - print(f" python -c 'from pymc_rust_compiler.benchmark import *; ...'") + print("\nNow you can benchmark:") + print(" python -c 'from pymc_rust_compiler.benchmark import *; ...'") else: print(f"\nCompilation FAILED after {result.n_attempts} attempts") for err in result.validation_errors[:5]: diff --git a/examples/04_zerosumnormal.py b/examples/04_zerosumnormal.py index a7dc719..d31a50e 100644 --- a/examples/04_zerosumnormal.py +++ b/examples/04_zerosumnormal.py @@ -36,7 +36,7 @@ np.random.seed(314) n_stores = 6 -n_days = 7 # Mon-Sun +n_days = 7 # Mon-Sun n_categories = 4 # e.g., Electronics, Clothing, Food, Home store_names = [f"store_{i}" for i in range(n_stores)] @@ -73,8 +73,12 @@ for d in range(n_days): for c in range(n_categories): n = np.random.poisson(n_obs_per_cell) + 1 - mu = (true_grand_mean + true_store_effect[s] - + true_day_effect[d] + true_interaction[s, d, c]) + mu = ( + true_grand_mean + + true_store_effect[s] + + true_day_effect[d] + + true_interaction[s, d, c] + ) y_vals = np.random.normal(mu, true_sigma_y, n) for y in y_vals: records.append((s, d, c, y)) @@ -128,10 +132,12 @@ n_zerosum_axes=2, ) - mu = (grand_mean - + store_effect[store_idx] - + day_effect[day_idx] - + interaction[store_idx, day_idx, cat_idx]) + mu = ( + grand_mean + + store_effect[store_idx] + + day_effect[day_idx] + + interaction[store_idx, day_idx, cat_idx] + ) sigma_y = pm.HalfNormal("sigma_y", sigma=5) pm.Normal("y", mu=mu, sigma=sigma_y, observed=y_obs) @@ -139,7 +145,9 @@ n_free = sum(v.size for v in model.initial_point().values()) print(f"Free RVs: {[rv.name for rv in model.free_RVs]}") print(f"Unconstrained parameters: {n_free}") -print(f"Transforms: {[(rv.name, type(model.rvs_to_transforms.get(rv)).__name__) for rv in model.free_RVs]}") +print( + f"Transforms: {[(rv.name, type(model.rvs_to_transforms.get(rv)).__name__) for rv in model.free_RVs]}" +) print() result = compile_model( @@ -167,13 +175,24 @@ import arviz as az print("\n--- Posterior summary (hyperparameters) ---") - print(az.summary(idata, var_names=[ - "grand_mean", "sigma_store", "sigma_day", "sigma_cat", "sigma_y", - ])) + print( + az.summary( + idata, + var_names=[ + "grand_mean", + "sigma_store", + "sigma_day", + "sigma_cat", + "sigma_y", + ], + ) + ) - print(f"\nTrue values: grand_mean={true_grand_mean}, " - f"sigma_store={true_sigma_store}, sigma_day={true_sigma_day}, " - f"sigma_cat={true_sigma_cat}, sigma_y={true_sigma_y}") + print( + f"\nTrue values: grand_mean={true_grand_mean}, " + f"sigma_store={true_sigma_store}, sigma_day={true_sigma_day}, " + f"sigma_cat={true_sigma_cat}, sigma_y={true_sigma_y}" + ) # Posterior plots axes = az.plot_posterior( diff --git a/examples/05_celeri_simplified.py b/examples/05_celeri_simplified.py index 2bc3953..4daa4c9 100644 --- a/examples/05_celeri_simplified.py +++ b/examples/05_celeri_simplified.py @@ -39,25 +39,39 @@ # --- Synthetic tectonic data --- np.random.seed(42) -n_blocks = 3 # tectonic blocks -n_faults = 4 # fault segments -n_stations = 25 # GPS stations -n_bounded = 3 # faults with geologic slip rate bounds +n_blocks = 3 # tectonic blocks +n_faults = 4 # fault segments +n_stations = 25 # GPS stations +n_bounded = 3 # faults with geologic slip rate bounds # True parameters -true_rotation = np.array([ - 0.5, -0.3, 0.1, # Block 1: wx, wy, wz (rad/Gyr) - -0.2, 0.4, -0.1, # Block 2 - 0.1, -0.1, 0.3, # Block 3 -]) +true_rotation = np.array( + [ + 0.5, + -0.3, + 0.1, # Block 1: wx, wy, wz (rad/Gyr) + -0.2, + 0.4, + -0.1, # Block 2 + 0.1, + -0.1, + 0.3, # Block 3 + ] +) rotation_scale = np.array([1.0, 1.0, 0.5] * n_blocks) # prior scales -true_slip = np.array([ - 2.0, 0.5, # Fault 1: strike-slip, dip-slip (mm/yr) - -1.5, 1.0, # Fault 2 - 0.8, -0.3, # Fault 3 - -0.5, 0.2, # Fault 4 -]) +true_slip = np.array( + [ + 2.0, + 0.5, # Fault 1: strike-slip, dip-slip (mm/yr) + -1.5, + 1.0, # Fault 2 + 0.8, + -0.3, # Fault 3 + -0.5, + 0.2, # Fault 4 + ] +) slip_prior_sigma = 5.0 # Design matrices (Green's functions) @@ -84,7 +98,7 @@ # Regularization gamma = 2.0 # regularization strength -print(f"Tectonic block model:") +print("Tectonic block model:") print(f" {n_blocks} blocks ({n_blocks * 3} rotation params)") print(f" {n_faults} faults ({n_faults * 2} slip rate params)") print(f" {n_stations} GPS stations ({n_stations * 2} velocity observations)") @@ -124,9 +138,8 @@ slip_rate = pm.Normal("slip_rate", mu=0, sigma=slip_prior_sigma, shape=n_faults * 2) # Predicted GPS velocities via design matrices - predicted_velocity = ( - pm.math.dot(G_rotation, rotation) - + pm.math.dot(G_slip, slip_rate) + predicted_velocity = pm.math.dot(G_rotation, rotation) + pm.math.dot( + G_slip, slip_rate ) # GPS station velocity likelihood (StudentT for heavy tails) @@ -167,7 +180,7 @@ ) if result.success: - print(f"\nCompilation successful!") + print("\nCompilation successful!") print(f" Builds: {result.n_attempts}") print(f" Tool calls: {result.n_tool_calls}") print(f" Turns: {result.conversation_turns}") @@ -195,6 +208,8 @@ print(f"\nTrue rotation: {true_rotation}") print(f"True slip rates: {true_slip}") else: - print(f"\nCompilation FAILED after {result.n_attempts} builds, {result.n_tool_calls} tool calls") + print( + f"\nCompilation FAILED after {result.n_attempts} builds, {result.n_tool_calls} tool calls" + ) for err in result.validation_errors[:5]: print(f" - {err}") diff --git a/examples/bench_logp.py b/examples/bench_logp.py index c96e306..46980d3 100644 --- a/examples/bench_logp.py +++ b/examples/bench_logp.py @@ -54,7 +54,7 @@ def make_hierarchical_model(): """Hierarchical model, 12 unconstrained parameters.""" build_dir = Path("compiled_models/hierarchical") y_obs = np.load(build_dir / "y_data.npy") - x = np.load(build_dir / "x_0_data.npy") # binary covariate + x = np.load(build_dir / "x_0_data.npy") # binary covariate group_idx = np.load(build_dir / "x_1_data.npy").astype(int) # group indices n_groups = int(group_idx.max()) + 1 with pm.Model() as model: @@ -94,9 +94,9 @@ def main(): results = [] for name, make_fn in models: - print(f"\n{'='*65}") + print(f"\n{'=' * 65}") print(f" {name}") - print(f"{'='*65}") + print(f"{'=' * 65}") model, build_dir = make_fn() n_evals = N_EVALS @@ -107,24 +107,32 @@ def main(): pt_result = benchmark_logp_pytensor(model, n_evals=n_evals, x0_model_order=x0) print(f" pytensor (python loop): {pt_result['us_per_eval']:.2f} us/eval") - cfunc_result = benchmark_logp_numba_cfunc(model, n_evals=n_evals, x0_model_order=x0) + cfunc_result = benchmark_logp_numba_cfunc( + model, n_evals=n_evals, x0_model_order=x0 + ) print(f" numba cfunc (rust loop): {cfunc_result['us_per_eval']:.2f} us/eval") - rs_result = benchmark_logp_rust(build_dir, model, n_evals=n_evals, x0_model_order=x0) + rs_result = benchmark_logp_rust( + build_dir, model, n_evals=n_evals, x0_model_order=x0 + ) if "error" in rs_result: print(f" rust-ai: ERROR - {rs_result['error'][:100]}") else: print(f" rust-ai: {rs_result['us_per_eval']:.2f} us/eval") print_logp_comparison(pt_result, rs_result, model_name=name) - print_logp_comparison(cfunc_result, rs_result, model_name=f"{name} [cfunc vs rust]") + print_logp_comparison( + cfunc_result, rs_result, model_name=f"{name} [cfunc vs rust]" + ) results.append((name, pt_result, cfunc_result, rs_result)) # Summary table print("\n" + "=" * 85) print("SUMMARY: logp+dlogp evaluation speed") print("=" * 85) - print(f"\n{'Model':<25} {'pytensor':<14} {'cfunc+rust':<14} {'rust-ai':<14} {'cfunc/rust':<12}") + print( + f"\n{'Model':<25} {'pytensor':<14} {'cfunc+rust':<14} {'rust-ai':<14} {'cfunc/rust':<12}" + ) print("-" * 79) for name, pt, cf, rs in results: pt_us = f"{pt['us_per_eval']:.2f} us" if "error" not in pt else "ERROR" diff --git a/examples/jax_to_pytorch_mlp.py b/examples/jax_to_pytorch_mlp.py index 40f147f..91f36a9 100644 --- a/examples/jax_to_pytorch_mlp.py +++ b/examples/jax_to_pytorch_mlp.py @@ -45,18 +45,21 @@ def main(): ) if result.success: - print(f"\nTranspilation successful!") + print("\nTranspilation successful!") print(f" Tool calls: {result.n_tool_calls}") print(f" Tokens: {result.token_usage['total_tokens']}") - print(f"\nGenerated PyTorch code:") + print("\nGenerated PyTorch code:") print(result.generated_code) # Test the generated model import torch + model = result.get_model({k: np.asarray(v) for k, v in params.items()}) pt_out = model(torch.tensor(np.asarray(x))) print(f"\nPyTorch output:\n{pt_out.detach().numpy()}") - print(f"Max diff: {np.max(np.abs(pt_out.detach().numpy() - np.asarray(out))):.2e}") + print( + f"Max diff: {np.max(np.abs(pt_out.detach().numpy() - np.asarray(out))):.2e}" + ) else: print(f"\nTranspilation failed: {result.validation_errors}") diff --git a/examples/mingpt_enzyme/validate_pytorch.py b/examples/mingpt_enzyme/validate_pytorch.py index a1e1f23..2c305c6 100644 --- a/examples/mingpt_enzyme/validate_pytorch.py +++ b/examples/mingpt_enzyme/validate_pytorch.py @@ -17,21 +17,30 @@ def parse_data_rs(path): weights = {} # Match: pub const NAME: &[f32] = &[ ... ]; - pattern = r'pub const (\w+): &\[f32\] = &\[([\s\S]*?)\];' + pattern = r"pub const (\w+): &\[f32\] = &\[([\s\S]*?)\];" for m in re.finditer(pattern, content): name = m.group(1) if name.endswith("_SHAPE"): continue - vals = [float(x.strip().rstrip(',')) for x in m.group(2).split(',') if x.strip()] + vals = [ + float(x.strip().rstrip(",")) for x in m.group(2).split(",") if x.strip() + ] weights[name] = np.array(vals, dtype=np.float32) return weights class NewGELU(nn.Module): def forward(self, x): - return 0.5 * x * (1.0 + torch.tanh( - (2.0 / 3.141592653589793) ** 0.5 * (x + 0.044715 * x.pow(3.0)) - )) + return ( + 0.5 + * x + * ( + 1.0 + + torch.tanh( + (2.0 / 3.141592653589793) ** 0.5 * (x + 0.044715 * x.pow(3.0)) + ) + ) + ) class CausalSelfAttention(nn.Module): @@ -41,7 +50,9 @@ def __init__(self, n_embd, n_head, block_size): self.c_proj = nn.Linear(n_embd, n_embd) self.register_buffer( "bias", - torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size), + torch.tril(torch.ones(block_size, block_size)).view( + 1, 1, block_size, block_size + ), ) self.n_head = n_head self.n_embd = n_embd @@ -82,9 +93,9 @@ def __init__(self, n_layer=3, n_head=3, n_embd=48, block_size=8, vocab_size=32): super().__init__() self.n_embd = n_embd self.seq_len = block_size - self.blocks = nn.ModuleList([ - TransformerBlock(n_embd, n_head, block_size) for _ in range(n_layer) - ]) + self.blocks = nn.ModuleList( + [TransformerBlock(n_embd, n_head, block_size) for _ in range(n_layer)] + ) self.ln_f = nn.LayerNorm(n_embd) self.lm_head = nn.Linear(n_embd, vocab_size, bias=False) @@ -105,20 +116,30 @@ def load_weights_from_data_rs(model, weights): prefix = f"BLOCKS_{i}_" b.ln_1.weight.copy_(torch.tensor(weights[f"{prefix}LN_1_WEIGHT"])) b.ln_1.bias.copy_(torch.tensor(weights[f"{prefix}LN_1_BIAS"])) - b.attn.c_attn.weight.copy_(torch.tensor(weights[f"{prefix}ATTN_C_ATTN_WEIGHT"]).reshape(144, 48)) + b.attn.c_attn.weight.copy_( + torch.tensor(weights[f"{prefix}ATTN_C_ATTN_WEIGHT"]).reshape(144, 48) + ) b.attn.c_attn.bias.copy_(torch.tensor(weights[f"{prefix}ATTN_C_ATTN_BIAS"])) - b.attn.c_proj.weight.copy_(torch.tensor(weights[f"{prefix}ATTN_C_PROJ_WEIGHT"]).reshape(48, 48)) + b.attn.c_proj.weight.copy_( + torch.tensor(weights[f"{prefix}ATTN_C_PROJ_WEIGHT"]).reshape(48, 48) + ) b.attn.c_proj.bias.copy_(torch.tensor(weights[f"{prefix}ATTN_C_PROJ_BIAS"])) b.ln_2.weight.copy_(torch.tensor(weights[f"{prefix}LN_2_WEIGHT"])) b.ln_2.bias.copy_(torch.tensor(weights[f"{prefix}LN_2_BIAS"])) - b.c_fc.weight.copy_(torch.tensor(weights[f"{prefix}C_FC_WEIGHT"]).reshape(192, 48)) + b.c_fc.weight.copy_( + torch.tensor(weights[f"{prefix}C_FC_WEIGHT"]).reshape(192, 48) + ) b.c_fc.bias.copy_(torch.tensor(weights[f"{prefix}C_FC_BIAS"])) - b.c_proj.weight.copy_(torch.tensor(weights[f"{prefix}C_PROJ_WEIGHT"]).reshape(48, 192)) + b.c_proj.weight.copy_( + torch.tensor(weights[f"{prefix}C_PROJ_WEIGHT"]).reshape(48, 192) + ) b.c_proj.bias.copy_(torch.tensor(weights[f"{prefix}C_PROJ_BIAS"])) model.ln_f.weight.copy_(torch.tensor(weights["LN_F_WEIGHT"])) model.ln_f.bias.copy_(torch.tensor(weights["LN_F_BIAS"])) - model.lm_head.weight.copy_(torch.tensor(weights["LM_HEAD_WEIGHT"]).reshape(32, 48)) + model.lm_head.weight.copy_( + torch.tensor(weights["LM_HEAD_WEIGHT"]).reshape(32, 48) + ) def main(): @@ -149,23 +170,33 @@ def main(): grad = x.grad.numpy() print(f"Gradient shape: {grad.shape}") - print(f"Gradient stats:") + print("Gradient stats:") print(f" sum = {grad.sum():.6f}") print(f" abs_sum = {np.abs(grad).sum():.6f}") print(f" max = {grad.max():.6f}") print(f" min = {grad.min():.6f}") - print(f"First 10 gradients:") + print("First 10 gradients:") for i in range(10): print(f" grad[{i}] = {grad[i]:.8f}") # Compare with Enzyme results enzyme_grads = [ - -0.06038946, 0.19937083, -0.17638184, -0.23096114, 0.14284474, - 0.30951929, -0.02494755, -0.27090901, -0.05807663, 0.07012614, + -0.06038946, + 0.19937083, + -0.17638184, + -0.23096114, + 0.14284474, + 0.30951929, + -0.02494755, + -0.27090901, + -0.05807663, + 0.07012614, ] - print(f"\n--- Enzyme vs PyTorch comparison (first 10) ---") - print(f"{'idx':>4} {'Enzyme':>14} {'PyTorch':>14} {'abs_diff':>12} {'rel_diff':>12}") + print("\n--- Enzyme vs PyTorch comparison (first 10) ---") + print( + f"{'idx':>4} {'Enzyme':>14} {'PyTorch':>14} {'abs_diff':>12} {'rel_diff':>12}" + ) for i in range(10): e = enzyme_grads[i] p = grad[i] diff --git a/examples/mingpt_to_rust.py b/examples/mingpt_to_rust.py index 03cb45e..26218d2 100644 --- a/examples/mingpt_to_rust.py +++ b/examples/mingpt_to_rust.py @@ -10,7 +10,6 @@ import sys import time -import numpy as np import torch import torch.nn as nn from torch.nn import functional as F @@ -20,9 +19,16 @@ class NewGELU(nn.Module): def forward(self, x): - return 0.5 * x * (1.0 + torch.tanh( - (2.0 / 3.141592653589793) ** 0.5 * (x + 0.044715 * x.pow(3.0)) - )) + return ( + 0.5 + * x + * ( + 1.0 + + torch.tanh( + (2.0 / 3.141592653589793) ** 0.5 * (x + 0.044715 * x.pow(3.0)) + ) + ) + ) class CausalSelfAttention(nn.Module): @@ -32,7 +38,9 @@ def __init__(self, n_embd, n_head, block_size): self.c_proj = nn.Linear(n_embd, n_embd) self.register_buffer( "bias", - torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size), + torch.tril(torch.ones(block_size, block_size)).view( + 1, 1, block_size, block_size + ), ) self.n_head = n_head self.n_embd = n_embd @@ -82,9 +90,9 @@ def __init__(self, n_layer, n_head, n_embd, block_size, vocab_size): super().__init__() self.n_embd = n_embd self.seq_len = block_size # fixed sequence length for transpilation - self.blocks = nn.ModuleList([ - TransformerBlock(n_embd, n_head, block_size) for _ in range(n_layer) - ]) + self.blocks = nn.ModuleList( + [TransformerBlock(n_embd, n_head, block_size) for _ in range(n_layer)] + ) self.ln_f = nn.LayerNorm(n_embd) self.lm_head = nn.Linear(n_embd, vocab_size, bias=False) @@ -117,7 +125,7 @@ def create_mingpt_nano(): n_params = sum(p.numel() for p in model.parameters()) print(f"MinGPT-nano: {n_layer} layers, {n_head} heads, {n_embd} embd") print(f" vocab_size={vocab_size}, block_size={block_size}") - print(f" Parameters: {n_params:,} ({n_params/1e3:.1f}K)") + print(f" Parameters: {n_params:,} ({n_params / 1e3:.1f}K)") sample_input = torch.randn(block_size * n_embd) return model, sample_input @@ -169,9 +177,9 @@ def benchmark_rust(binary_path, sample_input, n_warmup=100, n_runs=10000): print(f"Rust benchmark error: {result.stderr[:500]}") return None - lines = [l for l in result.stdout.strip().split("\n") if l.strip()] + lines = [line for line in result.stdout.strip().split("\n") if line.strip()] # Skip warmup lines - run_lines = lines[n_warmup:] + lines[n_warmup:] # The total elapsed includes warmup, so estimate per-call from total runs us_per_call = (elapsed / (n_warmup + n_runs)) * 1e6 @@ -318,7 +326,7 @@ def forward(self, x): if result.success: # Save the generated Rust code result.save("mingpt_generated.rs") - print(f"\n Generated Rust code saved to mingpt_generated.rs") + print("\n Generated Rust code saved to mingpt_generated.rs") print(f" Build dir: {result.build_dir}") # Step 4: Benchmark Rust @@ -336,7 +344,9 @@ def forward(self, x): print(f" PyTorch: {pytorch_us:.1f} µs/call") print(f" Rust: {rust_us:.1f} µs/call") speedup = pytorch_us / rust_us - print(f" Speedup: {speedup:.1f}x {'faster' if speedup > 1 else 'slower'}") + print( + f" Speedup: {speedup:.1f}x {'faster' if speedup > 1 else 'slower'}" + ) else: print(" Rust benchmark failed!") else: diff --git a/examples/pytorch_to_jax_mlp.py b/examples/pytorch_to_jax_mlp.py index 5659f14..9a04a3b 100644 --- a/examples/pytorch_to_jax_mlp.py +++ b/examples/pytorch_to_jax_mlp.py @@ -48,15 +48,18 @@ def main(): ) if result.success: - print(f"\nTranspilation successful!") + print("\nTranspilation successful!") print(f" Tool calls: {result.n_tool_calls}") print(f" Tokens: {result.token_usage['total_tokens']}") - print(f"\nGenerated JAX code:") + print("\nGenerated JAX code:") print(result.generated_code) # Test the generated model import jax.numpy as jnp - param_data = {name: param.detach().numpy() for name, param in model.named_parameters()} + + param_data = { + name: param.detach().numpy() for name, param in model.named_parameters() + } jax_params, forward_fn = result.get_model(param_data) jax_out = forward_fn(jax_params, jnp.array(x.numpy())) print(f"\nJAX output:\n{np.asarray(jax_out)}") diff --git a/examples/pytorch_to_rust_mlp.py b/examples/pytorch_to_rust_mlp.py index 851d281..e8db002 100644 --- a/examples/pytorch_to_rust_mlp.py +++ b/examples/pytorch_to_rust_mlp.py @@ -4,10 +4,8 @@ The generated Rust code has zero dependencies — just raw f32 math. """ -import os import torch import torch.nn as nn -import numpy as np from pymc_rust_compiler import transpile_pytorch_to_rust @@ -38,7 +36,7 @@ def main(): print(f"PyTorch output: {pytorch_output}") # Source code for context - source = ''' + source = """ class MLP(nn.Module): def __init__(self, in_dim=4, hidden=8, out_dim=2): super().__init__() @@ -48,7 +46,7 @@ def __init__(self, in_dim=4, hidden=8, out_dim=2): def forward(self, x): x = torch.relu(self.fc1(x)) return self.fc2(x) -''' +""" # Transpile to Rust result = transpile_pytorch_to_rust( @@ -59,7 +57,7 @@ def forward(self, x): ) if result.success: - print(f"\nTranspilation successful!") + print("\nTranspilation successful!") print(f" Build dir: {result.build_dir}") print(f" Tool calls: {result.n_tool_calls}") print(f" Builds: {result.n_attempts}") @@ -69,7 +67,7 @@ def forward(self, x): if len(result.generated_code) > 500: print("...") else: - print(f"\nTranspilation failed:") + print("\nTranspilation failed:") for err in result.validation_errors: print(f" - {err}") diff --git a/examples/run_benchmark.py b/examples/run_benchmark.py index dd0faf7..f7dc1de 100644 --- a/examples/run_benchmark.py +++ b/examples/run_benchmark.py @@ -11,13 +11,15 @@ 4. Print comparison table """ -import time - import numpy as np import pymc as pm from pymc_rust_compiler import compile_model -from pymc_rust_compiler.benchmark import benchmark_nutpie, benchmark_rust, print_comparison +from pymc_rust_compiler.benchmark import ( + benchmark_nutpie, + benchmark_rust, + print_comparison, +) def make_normal_model(): @@ -96,7 +98,6 @@ def make_zerosumnormal_model(): # True effects true_grand_mean = 8.0 true_sigma_store = 0.4 - true_sigma_day = 0.3 true_sigma_cat = 0.5 true_sigma_y = 0.6 @@ -105,7 +106,9 @@ def make_zerosumnormal_model(): raw_day = np.array([-0.2, -0.1, 0.0, 0.05, 0.15, 0.35, 0.25]) raw_day += np.random.normal(0, 0.05, n_days) true_day_effect = raw_day - raw_day.mean() - raw_interaction = np.random.normal(0, true_sigma_cat, (n_stores, n_days, n_categories)) + raw_interaction = np.random.normal( + 0, true_sigma_cat, (n_stores, n_days, n_categories) + ) raw_interaction -= raw_interaction.mean(axis=-1, keepdims=True) raw_interaction -= raw_interaction.mean(axis=-2, keepdims=True) @@ -115,8 +118,12 @@ def make_zerosumnormal_model(): for d in range(n_days): for c in range(n_categories): n = np.random.poisson(5) + 1 - mu = (true_grand_mean + true_store_effect[s] - + true_day_effect[d] + raw_interaction[s, d, c]) + mu = ( + true_grand_mean + + true_store_effect[s] + + true_day_effect[d] + + raw_interaction[s, d, c] + ) y_vals = np.random.normal(mu, true_sigma_y, n) for y in y_vals: records.append((s, d, c, y)) @@ -133,17 +140,28 @@ def make_zerosumnormal_model(): sigma_store = pm.HalfNormal("sigma_store", sigma=2) sigma_day = pm.HalfNormal("sigma_day", sigma=2) sigma_cat = pm.HalfNormal("sigma_cat", sigma=2) - store_effect = pm.ZeroSumNormal("store_effect", sigma=sigma_store, shape=n_stores) + store_effect = pm.ZeroSumNormal( + "store_effect", sigma=sigma_store, shape=n_stores + ) day_effect = pm.ZeroSumNormal("day_effect", sigma=sigma_day, shape=n_days) interaction = pm.ZeroSumNormal( - "interaction", sigma=sigma_cat, - shape=(n_stores, n_days, n_categories), n_zerosum_axes=2, + "interaction", + sigma=sigma_cat, + shape=(n_stores, n_days, n_categories), + n_zerosum_axes=2, + ) + mu_y = ( + grand_mean + + store_effect[store_idx] + + day_effect[day_idx] + + interaction[store_idx, day_idx, cat_idx] ) - mu_y = (grand_mean + store_effect[store_idx] + day_effect[day_idx] - + interaction[store_idx, day_idx, cat_idx]) sigma_y = pm.HalfNormal("sigma_y", sigma=5) pm.Normal("y", mu=mu_y, sigma=sigma_y, observed=y_obs) - return model, f"ZeroSumNormal ANOVA: {n_stores}×{n_days}×{n_categories}, {N} obs, 124 params" + return ( + model, + f"ZeroSumNormal ANOVA: {n_stores}×{n_days}×{n_categories}, {N} obs, 124 params", + ) MODELS = [ @@ -181,7 +199,7 @@ def main(): ) if not result.success: - print(f" FAILED — skipping benchmark") + print(" FAILED — skipping benchmark") results.append((name, None, None)) continue @@ -209,7 +227,7 @@ def main(): else: nt = nutpie_r["elapsed_s"] rt = rust_r["elapsed_s"] - print(f"{name:<20} {nt:<12.2f} {rt:<12.2f} {nt/rt:<10.2f}x") + print(f"{name:<20} {nt:<12.2f} {rt:<12.2f} {nt / rt:<10.2f}x") if __name__ == "__main__": diff --git a/examples/stan_01_normal.py b/examples/stan_01_normal.py index ee4d8c2..0b11e5e 100644 --- a/examples/stan_01_normal.py +++ b/examples/stan_01_normal.py @@ -36,7 +36,7 @@ verbose=True, ) - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print(f"Success: {result.success}") print(f"Tool calls: {result.n_tool_calls}") print(f"Build attempts: {result.n_attempts}") diff --git a/examples/stan_02_hierarchical.py b/examples/stan_02_hierarchical.py index 09c108f..c41c0fb 100644 --- a/examples/stan_02_hierarchical.py +++ b/examples/stan_02_hierarchical.py @@ -3,8 +3,6 @@ Classic Eight Schools example — non-centered parameterization. """ -import numpy as np - from pymc_rust_compiler import compile_stan_model STAN_CODE = """ @@ -45,7 +43,7 @@ verbose=True, ) - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print(f"Success: {result.success}") print(f"Tool calls: {result.n_tool_calls}") print(f"Build attempts: {result.n_attempts}") diff --git a/examples/stan_pymc_01_normal.py b/examples/stan_pymc_01_normal.py index 3839b43..d94b007 100644 --- a/examples/stan_pymc_01_normal.py +++ b/examples/stan_pymc_01_normal.py @@ -37,14 +37,14 @@ verbose=True, ) - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print(f"Success: {result.success}") print(f"Tool calls: {result.n_tool_calls}") print(f"Validation attempts: {result.n_attempts}") print(f"Tokens: {result.token_usage}") if result.success: - print(f"\n--- Generated PyMC Code ---") + print("\n--- Generated PyMC Code ---") print(result.pymc_code) # Test that the model actually works @@ -54,5 +54,5 @@ else: print(f"\nValidation errors: {result.validation_errors}") if result.pymc_code: - print(f"\n--- Last Generated Code ---") + print("\n--- Last Generated Code ---") print(result.pymc_code) diff --git a/examples/stan_pymc_02_hierarchical.py b/examples/stan_pymc_02_hierarchical.py index ef6c27a..5349ec1 100644 --- a/examples/stan_pymc_02_hierarchical.py +++ b/examples/stan_pymc_02_hierarchical.py @@ -56,14 +56,14 @@ verbose=True, ) - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print(f"Success: {result.success}") print(f"Tool calls: {result.n_tool_calls}") print(f"Validation attempts: {result.n_attempts}") print(f"Tokens: {result.token_usage}") if result.success: - print(f"\n--- Generated PyMC Code ---") + print("\n--- Generated PyMC Code ---") print(result.pymc_code) model = result.get_model(data) diff --git a/notebooks/overview.py b/notebooks/overview.py index 60b3ab1..87c2587 100644 --- a/notebooks/overview.py +++ b/notebooks/overview.py @@ -87,12 +87,16 @@ def _(): beta = pm.Normal("beta", mu=0, sigma=10) sigma = pm.HalfNormal("sigma", sigma=5) mu = alpha + beta * x_data - y = pm.Normal("y", mu=mu, sigma=sigma, observed=y_data) + pm.Normal("y", mu=mu, sigma=sigma, observed=y_data) - print(f"Model: {len(linreg_model.free_RVs)} free parameters, " - f"{len(linreg_model.observed_RVs)} observed") + print( + f"Model: {len(linreg_model.free_RVs)} free parameters, " + f"{len(linreg_model.observed_RVs)} observed" + ) print(f"Parameters: {[rv.name for rv in linreg_model.free_RVs]}") - print(f"Transforms: {[type(linreg_model.rvs_to_transforms.get(rv)).__name__ for rv in linreg_model.free_RVs]}") + print( + f"Transforms: {[type(linreg_model.rvs_to_transforms.get(rv)).__name__ for rv in linreg_model.free_RVs]}" + ) return (linreg_model,) @@ -129,15 +133,19 @@ def _(linreg_model): transform_str = f" [{p.transform}]" if p.transform else "" print(f" {p.name}: shape={p.shape}, size={p.size}{transform_str}") - print(f"\nObserved data:") + print("\nObserved data:") for name2, info2 in ctx.observed_data.items(): - print(f" {name2}: n={info2['n']}, range=[{info2['min']:.3f}, {info2['max']:.3f}]") + print( + f" {name2}: n={info2['n']}, range=[{info2['min']:.3f}, {info2['max']:.3f}]" + ) - print(f"\nCovariates:") + print("\nCovariates:") for name3, info3 in ctx.covariate_data.items(): is_idx = info3.get("is_index_array", False) label = f" [INDEX ARRAY, {info3['n_groups']} groups]" if is_idx else "" - print(f" {name3}: n={info3['n']}, range=[{info3['min']:.3f}, {info3['max']:.3f}]{label}") + print( + f" {name3}: n={info3['n']}, range=[{info3['min']:.3f}, {info3['max']:.3f}]{label}" + ) print(f"\nValidation points: 1 initial + {len(ctx.extra_points)} extra") print(f" Initial logp = {ctx.initial_point.logp:.6f}") @@ -227,7 +235,9 @@ def _(ctx, mo): rows.append("| Point | PyMC logp | Gradient[0] | Gradient[1] | Gradient[2] |") rows.append("|-------|-----------|-------------|-------------|-------------|") - pts = [("initial", ctx.initial_point)] + [(f"extra_{i}", p) for i, p in enumerate(ctx.extra_points)] + pts = [("initial", ctx.initial_point)] + [ + (f"extra_{i}", p) for i, p in enumerate(ctx.extra_points) + ] for name, vp in pts: g = vp.dlogp rows.append( @@ -338,8 +348,10 @@ def _(): sigma_y = _pm.HalfNormal("sigma_y", sigma=5) _pm.Normal("y_obs", mu=a[group_idx] + b * x_h, sigma=sigma_y, observed=y_h) - print(f"Hierarchical model: {len(hierarchical_model.free_RVs)} free params, " - f"{_N} observations, {n_groups} groups") + print( + f"Hierarchical model: {len(hierarchical_model.free_RVs)} free params, " + f"{_N} observations, {n_groups} groups" + ) return (hierarchical_model,) diff --git a/pymc_rust_compiler/__init__.py b/pymc_rust_compiler/__init__.py index d73e2f0..f7cbdc7 100644 --- a/pymc_rust_compiler/__init__.py +++ b/pymc_rust_compiler/__init__.py @@ -29,8 +29,16 @@ # PyMC/Stan imports are lazy — they pull in heavy deps (pymc, bridgestan) if TYPE_CHECKING: - from pymc_rust_compiler.exporter import ModelContext, RustModelExporter, export_model - from pymc_rust_compiler.compiler import compile_model, optimize_model, OptimizationEvent + from pymc_rust_compiler.exporter import ( + ModelContext, + RustModelExporter, + export_model, + ) + from pymc_rust_compiler.compiler import ( + compile_model, + optimize_model, + OptimizationEvent, + ) from pymc_rust_compiler.analysis import ( plot_optimization_progress, plot_waterfall, @@ -42,7 +50,10 @@ StanModelExporter, export_stan_model, ) - from pymc_rust_compiler.stan_compiler import compile_stan_model, StanCompilationResult + from pymc_rust_compiler.stan_compiler import ( + compile_stan_model, + StanCompilationResult, + ) from pymc_rust_compiler.stan_to_pymc import transpile_stan_to_pymc, StanToPyMCResult @@ -55,16 +66,28 @@ def __getattr__(name: str): "compile_model": ("pymc_rust_compiler.compiler", "compile_model"), "optimize_model": ("pymc_rust_compiler.compiler", "optimize_model"), "OptimizationEvent": ("pymc_rust_compiler.compiler", "OptimizationEvent"), - "plot_optimization_progress": ("pymc_rust_compiler.analysis", "plot_optimization_progress"), + "plot_optimization_progress": ( + "pymc_rust_compiler.analysis", + "plot_optimization_progress", + ), "plot_waterfall": ("pymc_rust_compiler.analysis", "plot_waterfall"), "plot_timeline": ("pymc_rust_compiler.analysis", "plot_timeline"), "print_summary": ("pymc_rust_compiler.analysis", "print_summary"), "StanModelContext": ("pymc_rust_compiler.stan_exporter", "StanModelContext"), "StanModelExporter": ("pymc_rust_compiler.stan_exporter", "StanModelExporter"), "export_stan_model": ("pymc_rust_compiler.stan_exporter", "export_stan_model"), - "compile_stan_model": ("pymc_rust_compiler.stan_compiler", "compile_stan_model"), - "StanCompilationResult": ("pymc_rust_compiler.stan_compiler", "StanCompilationResult"), - "transpile_stan_to_pymc": ("pymc_rust_compiler.stan_to_pymc", "transpile_stan_to_pymc"), + "compile_stan_model": ( + "pymc_rust_compiler.stan_compiler", + "compile_stan_model", + ), + "StanCompilationResult": ( + "pymc_rust_compiler.stan_compiler", + "StanCompilationResult", + ), + "transpile_stan_to_pymc": ( + "pymc_rust_compiler.stan_to_pymc", + "transpile_stan_to_pymc", + ), "StanToPyMCResult": ("pymc_rust_compiler.stan_to_pymc", "StanToPyMCResult"), } if name in _lazy_imports: @@ -114,4 +137,5 @@ def __getattr__(name: str): def to_nutpie(compile_result, model): """Convert a CompilationResult to a nutpie-compatible model. Lazy import.""" from pymc_rust_compiler.nutpie_bridge import to_nutpie as _to_nutpie + return _to_nutpie(compile_result, model) diff --git a/pymc_rust_compiler/analysis.py b/pymc_rust_compiler/analysis.py index 845e772..f079a2e 100644 --- a/pymc_rust_compiler/analysis.py +++ b/pymc_rust_compiler/analysis.py @@ -44,14 +44,16 @@ def _load_from_tsv(path: str | Path) -> list[_BenchmarkRecord]: for row in reader: if row["event_type"] != "benchmark" or not row["us_per_eval"]: continue - records.append(_BenchmarkRecord( - turn=int(row["turn"]), - timestamp_s=float(row["timestamp_s"]), - us_per_eval=float(row["us_per_eval"]), - status=row["status"], - description=row["description"], - code_hash=row["code_hash"], - )) + records.append( + _BenchmarkRecord( + turn=int(row["turn"]), + timestamp_s=float(row["timestamp_s"]), + us_per_eval=float(row["us_per_eval"]), + status=row["status"], + description=row["description"], + code_hash=row["code_hash"], + ) + ) return records @@ -61,14 +63,16 @@ def _load_from_result(result: CompilationResult) -> list[_BenchmarkRecord]: for ev in result.optimization_log: if ev.event_type != "benchmark" or ev.us_per_eval is None: continue - records.append(_BenchmarkRecord( - turn=ev.turn, - timestamp_s=ev.timestamp, - us_per_eval=ev.us_per_eval, - status=ev.status, - description=ev.description, - code_hash=ev.code_hash, - )) + records.append( + _BenchmarkRecord( + turn=ev.turn, + timestamp_s=ev.timestamp, + us_per_eval=ev.us_per_eval, + status=ev.status, + description=ev.description, + code_hash=ev.code_hash, + ) + ) return records @@ -126,17 +130,27 @@ def plot_optimization_progress( # Plot discarded as faint gray if discard_idx: ax.scatter( - discard_idx, discard_us, - c="#cccccc", s=40, zorder=2, label="Discarded", - edgecolors="#aaaaaa", linewidths=0.5, + discard_idx, + discard_us, + c="#cccccc", + s=40, + zorder=2, + label="Discarded", + edgecolors="#aaaaaa", + linewidths=0.5, ) # Plot kept as green if keep_idx: ax.scatter( - keep_idx, keep_us, - c="#2ecc71", s=80, zorder=3, label="Kept", - edgecolors="#27ae60", linewidths=1, + keep_idx, + keep_us, + c="#2ecc71", + s=80, + zorder=3, + label="Kept", + edgecolors="#27ae60", + linewidths=1, ) # Running minimum (frontier) as step line @@ -147,8 +161,12 @@ def plot_optimization_progress( running_min.append(current_min) ax.step( - range(len(records)), running_min, - where="post", color="#2c3e50", linewidth=2, zorder=4, + range(len(records)), + running_min, + where="post", + color="#2c3e50", + linewidth=2, + zorder=4, label="Best so far", ) @@ -182,10 +200,12 @@ def plot_optimization_progress( f"Baseline: {baseline:.2f} us/eval\n" f"Best: {best:.2f} us/eval\n" f"Improvement: {improvement_pct:.1f}%\n" - f"Kept: {n_keep}/{n_total} ({100*n_keep/n_total:.0f}%)" + f"Kept: {n_keep}/{n_total} ({100 * n_keep / n_total:.0f}%)" ) ax.text( - 0.02, 0.02, stats_text, + 0.02, + 0.02, + stats_text, transform=ax.transAxes, fontsize=9, verticalalignment="bottom", @@ -233,16 +253,21 @@ def plot_waterfall( colors = ["#2ecc71" if d > 0 else "#e74c3c" for d in deltas] - bars = ax.bar(range(len(deltas)), deltas, color=colors, edgecolor="#2c3e50", linewidth=0.5) + bars = ax.bar( + range(len(deltas)), deltas, color=colors, edgecolor="#2c3e50", linewidth=0.5 + ) # Value labels on bars for bar, delta in zip(bars, deltas): y = bar.get_height() ax.text( - bar.get_x() + bar.get_width() / 2, y, + bar.get_x() + bar.get_width() / 2, + y, f"{delta:+.2f}", - ha="center", va="bottom" if y >= 0 else "top", - fontsize=8, fontweight="bold", + ha="center", + va="bottom" if y >= 0 else "top", + fontsize=8, + fontweight="bold", ) ax.set_xticks(range(len(labels))) @@ -254,11 +279,13 @@ def plot_waterfall( total = sum(deltas) ax.text( - 0.98, 0.98, + 0.98, + 0.98, f"Total improvement: {total:.2f} us/eval", transform=ax.transAxes, fontsize=10, - ha="right", va="top", + ha="right", + va="top", fontfamily="monospace", bbox=dict(boxstyle="round,pad=0.4", facecolor="wheat", alpha=0.8), ) @@ -300,7 +327,9 @@ def plot_timeline( "timestamp_s": str(ev.timestamp), "event_type": ev.event_type, "status": ev.status, - "us_per_eval": str(ev.us_per_eval) if ev.us_per_eval is not None else "", + "us_per_eval": str(ev.us_per_eval) + if ev.us_per_eval is not None + else "", } for ev in source.optimization_log ] @@ -373,19 +402,19 @@ def print_summary(source: str | Path | CompilationResult) -> str: improvement_pct = (1 - best / baseline) * 100 if baseline > 0 else 0 lines = [ - f"Optimization Summary", + "Optimization Summary", f"{'=' * 40}", f"Total experiments: {len(records)}", f" Kept: {len(kept)}", f" Discarded: {len(discarded)}", f" Keep rate: {100 * len(kept) / len(records):.0f}%", - f"", + "", f"Baseline: {baseline:.3f} us/eval", f"Best: {best:.3f} us/eval", f"Improvement: {improvement_pct:.1f}%", f"Speedup: {baseline / best:.2f}x" if best > 0 else "", - f"", - f"Kept experiments (chronological):", + "", + "Kept experiments (chronological):", ] prev_us = baseline diff --git a/pymc_rust_compiler/benchmark.py b/pymc_rust_compiler/benchmark.py index 2472d11..048d626 100644 --- a/pymc_rust_compiler/benchmark.py +++ b/pymc_rust_compiler/benchmark.py @@ -17,7 +17,9 @@ _BENCH_RUNNER_DIR = Path(__file__).resolve().parent.parent / "bench_runner" -def benchmark_nutpie(model: pm.Model, draws: int = 2000, tune: int = 1000, chains: int = 4) -> dict: +def benchmark_nutpie( + model: pm.Model, draws: int = 2000, tune: int = 1000, chains: int = 4 +) -> dict: """Benchmark PyMC sampling with nutpie backend.""" print(f" nutpie: {chains} chains x {draws} draws...") start = time.time() @@ -41,7 +43,9 @@ def benchmark_nutpie(model: pm.Model, draws: int = 2000, tune: int = 1000, chain } -def benchmark_rust(build_dir: str | Path, draws: int = 2000, tune: int = 1000, chains: int = 4) -> dict: +def benchmark_rust( + build_dir: str | Path, draws: int = 2000, tune: int = 1000, chains: int = 4 +) -> dict: """Benchmark the AI-compiled Rust sampler.""" build_dir = Path(build_dir) binary = build_dir / "target" / "release" / "sample" @@ -85,7 +89,10 @@ def benchmark_rust(build_dir: str | Path, draws: int = 2000, tune: int = 1000, c # logp+dlogp evaluation benchmark (no sampling overhead) # --------------------------------------------------------------------------- -def _make_test_point(model: pm.Model, rng: np.random.Generator | None = None) -> np.ndarray: + +def _make_test_point( + model: pm.Model, rng: np.random.Generator | None = None +) -> np.ndarray: """Generate a random unconstrained test point in model variable order. Uses a random point (instead of the initial point) to avoid @@ -106,7 +113,9 @@ def _prepare_frozen_inputs(model, x0_model_order=None): Returns (jit_fn, x0, frozen_rv, model_fn). """ frozen_model = freeze_dims_and_data(model) - logp_dlogp_wrapper = frozen_model.logp_dlogp_function(ravel_inputs=True, mode="NUMBA") + logp_dlogp_wrapper = frozen_model.logp_dlogp_function( + ravel_inputs=True, mode="NUMBA" + ) logp_dlogp_fn = logp_dlogp_wrapper._pytensor_function logp_dlogp_fn.trust_input = True @@ -121,7 +130,9 @@ def _prepare_frozen_inputs(model, x0_model_order=None): # Reorder x0 from model order → frozen order for this function model_ip = {v.name: ip[v.name] for v in model_fn._grad_vars} model_rv = DictToArrayBijection.map(model_ip) - x0_dict = DictToArrayBijection.rmap(RaveledVars(x0_model_order, model_rv.point_map_info)) + x0_dict = DictToArrayBijection.rmap( + RaveledVars(x0_model_order, model_rv.point_map_info) + ) frozen_ip = {name: x0_dict[name] for name in frozen_vars} frozen_rv = DictToArrayBijection.map(frozen_ip) x0 = frozen_rv.data @@ -144,7 +155,9 @@ def _reorder_dlogp(dlogp_val, frozen_rv, model_fn): def benchmark_logp_pytensor( - model: pm.Model, n_evals: int = 10_000, x0_model_order: np.ndarray | None = None, + model: pm.Model, + n_evals: int = 10_000, + x0_model_order: np.ndarray | None = None, ) -> dict: """Benchmark PyTensor's compiled logp+dlogp function (what nutpie calls). @@ -181,6 +194,7 @@ def benchmark_logp_pytensor( # Numba cfunc benchmark: Rust calling Numba via C function pointer (like nutpie) # --------------------------------------------------------------------------- + def _build_bench_runner() -> ctypes.CDLL: """Build and load the bench_runner shared library.""" so_path = _BENCH_RUNNER_DIR / "target" / "release" / "libbench_runner.so" @@ -195,13 +209,13 @@ def _build_bench_runner() -> ctypes.CDLL: lib = ctypes.CDLL(str(so_path)) lib.bench_logp_cfunc.restype = ctypes.c_double lib.bench_logp_cfunc.argtypes = [ - ctypes.c_size_t, # func_ptr (usize) - ctypes.c_uint64, # dim - ctypes.c_void_p, # x_ptr - ctypes.c_uint64, # n_warmup - ctypes.c_uint64, # n_iters - ctypes.c_void_p, # logp_out - ctypes.c_void_p, # grad_out + ctypes.c_size_t, # func_ptr (usize) + ctypes.c_uint64, # dim + ctypes.c_void_p, # x_ptr + ctypes.c_uint64, # n_warmup + ctypes.c_uint64, # n_iters + ctypes.c_void_p, # logp_out + ctypes.c_void_p, # grad_out ] return lib @@ -213,10 +227,10 @@ def _make_numba_cfunc(jit_fn, n_dim: int): can call directly with zero Python overhead. """ c_sig = numba.types.int64( - numba.types.uint64, # dim - numba.types.CPointer(numba.types.float64), # x (input) - numba.types.CPointer(numba.types.float64), # grad (output) - numba.types.CPointer(numba.types.float64), # logp (output) + numba.types.uint64, # dim + numba.types.CPointer(numba.types.float64), # x (input) + numba.types.CPointer(numba.types.float64), # grad (output) + numba.types.CPointer(numba.types.float64), # logp (output) ) @numba.cfunc(c_sig) @@ -233,7 +247,9 @@ def logp_cfunc(dim, x_ptr, grad_ptr, logp_ptr): def benchmark_logp_numba_cfunc( - model: pm.Model, n_evals: int = 10_000, x0_model_order: np.ndarray | None = None, + model: pm.Model, + n_evals: int = 10_000, + x0_model_order: np.ndarray | None = None, ) -> dict: """Benchmark Numba logp+dlogp called from Rust via C function pointer. @@ -261,7 +277,7 @@ def benchmark_logp_numba_cfunc( cfunc.address, n_dim, x_arr.ctypes.data, - 200, # warmup + 200, # warmup n_evals, logp_arr.ctypes.data, grad_arr.ctypes.data, @@ -282,7 +298,9 @@ def benchmark_logp_numba_cfunc( def benchmark_logp_rust( - build_dir: str | Path, model: pm.Model, n_evals: int = 10_000, + build_dir: str | Path, + model: pm.Model, + n_evals: int = 10_000, x0_model_order: np.ndarray | None = None, ) -> dict: """Benchmark the AI-compiled Rust logp+dlogp function.""" @@ -339,7 +357,9 @@ def benchmark_logp_rust( } -def print_logp_comparison(pytensor_result: dict, rust_result: dict, model_name: str = ""): +def print_logp_comparison( + pytensor_result: dict, rust_result: dict, model_name: str = "" +): """Print logp+dlogp evaluation benchmark comparison.""" header = f"LOGP+DLOGP BENCHMARK{f': {model_name}' if model_name else ''}" print("\n" + "=" * 65) @@ -354,7 +374,9 @@ def print_logp_comparison(pytensor_result: dict, rust_result: dict, model_name: print(f"{pt['backend']:<20} {'ERROR':<12}") else: evals_per_sec_pt = 1e6 / pt["us_per_eval"] - print(f"{pt['backend']:<20} {pt['us_per_eval']:<12.2f} {evals_per_sec_pt:<14,.0f} {'1.00x':<10}") + print( + f"{pt['backend']:<20} {pt['us_per_eval']:<12.2f} {evals_per_sec_pt:<14,.0f} {'1.00x':<10}" + ) rs = rust_result if "error" in rs: @@ -362,23 +384,33 @@ def print_logp_comparison(pytensor_result: dict, rust_result: dict, model_name: else: evals_per_sec_rs = 1e6 / rs["us_per_eval"] speedup = pt["us_per_eval"] / rs["us_per_eval"] if "error" not in pt else 0 - print(f"{'rust-ai':<20} {rs['us_per_eval']:<12.2f} {evals_per_sec_rs:<14,.0f} {speedup:<10.2f}x") + print( + f"{'rust-ai':<20} {rs['us_per_eval']:<12.2f} {evals_per_sec_rs:<14,.0f} {speedup:<10.2f}x" + ) # Check logp and dlogp agreement if "error" not in pt: logp_diff = abs(pt["logp"] - rs["logp"]) logp_rel_err = logp_diff / max(abs(pt["logp"]), 1e-10) logp_ok = logp_rel_err < 1e-4 - logp_status = "MATCH" if logp_ok else f"MISMATCH (rel_err={logp_rel_err:.2e})" - print(f"\n logp check: pytensor={pt['logp']:.8f} rust={rs['logp']:.8f} [{logp_status}]") + logp_status = ( + "MATCH" if logp_ok else f"MISMATCH (rel_err={logp_rel_err:.2e})" + ) + print( + f"\n logp check: pytensor={pt['logp']:.8f} rust={rs['logp']:.8f} [{logp_status}]" + ) pt_dlogp = pt["dlogp"] rs_dlogp = rs["dlogp"] dlogp_abs_err = np.max(np.abs(pt_dlogp - rs_dlogp)) dlogp_rel_err = dlogp_abs_err / max(np.max(np.abs(pt_dlogp)), 1e-10) dlogp_ok = dlogp_rel_err < 1e-4 - dlogp_status = "MATCH" if dlogp_ok else f"MISMATCH (rel_err={dlogp_rel_err:.2e})" - print(f" dlogp check: pytensor={pt_dlogp} rust={rs_dlogp} [{dlogp_status}]") + dlogp_status = ( + "MATCH" if dlogp_ok else f"MISMATCH (rel_err={dlogp_rel_err:.2e})" + ) + print( + f" dlogp check: pytensor={pt_dlogp} rust={rs_dlogp} [{dlogp_status}]" + ) assert logp_ok, ( f"logp mismatch: pytensor={pt['logp']:.10f} rust={rs['logp']:.10f} " @@ -402,12 +434,16 @@ def print_comparison(nutpie_result: dict, rust_result: dict): print("-" * 54) nt = nutpie_result["elapsed_s"] - print(f"{'nutpie':<20} {nt:<12.2f} {nutpie_result['throughput']:<12.0f} {'1.00x':<10}") + print( + f"{'nutpie':<20} {nt:<12.2f} {nutpie_result['throughput']:<12.0f} {'1.00x':<10}" + ) if "error" not in rust_result: rt = rust_result["elapsed_s"] speedup = nt / rt - print(f"{'rust-ai':<20} {rt:<12.2f} {rust_result['throughput']:<12.0f} {speedup:<10.2f}x") + print( + f"{'rust-ai':<20} {rt:<12.2f} {rust_result['throughput']:<12.0f} {speedup:<10.2f}x" + ) else: print(f"{'rust-ai':<20} {'FAILED':<12}") diff --git a/pymc_rust_compiler/compiler.py b/pymc_rust_compiler/compiler.py index ee0c2b2..9807c92 100644 --- a/pymc_rust_compiler/compiler.py +++ b/pymc_rust_compiler/compiler.py @@ -16,9 +16,7 @@ import csv import functools import hashlib -import json import os -import re import subprocess import tempfile import time @@ -335,7 +333,9 @@ def _cuda_available() -> bool: """Check if CUDA is available at runtime.""" try: result = subprocess.run( - ["nvidia-smi"], capture_output=True, timeout=5, + ["nvidia-smi"], + capture_output=True, + timeout=5, ) return result.returncode == 0 except (FileNotFoundError, subprocess.TimeoutExpired): @@ -346,11 +346,15 @@ def _cuda_available() -> bool: def _accelerate_available() -> bool: """Check if Apple Silicon with Accelerate framework is available.""" import platform + return platform.system() == "Darwin" and platform.machine() == "arm64" def _detect_skills( - model: pm.Model, ctx, use_cuda: bool | None = None, use_accelerate: bool | None = None, + model: pm.Model, + ctx, + use_cuda: bool | None = None, + use_accelerate: bool | None = None, use_enzyme: bool | None = None, ) -> list[str]: """Detect which skills are needed based on model structure. @@ -372,14 +376,15 @@ def _detect_skills( break # Also check the logp graph text for GP indicators if not has_gp and any( - kw in ctx.logp_graph.lower() - for kw in ["cholesky", "mvnormal", "gp_"] + kw in ctx.logp_graph.lower() for kw in ["cholesky", "mvnormal", "gp_"] ): has_gp = True if has_gp: cuda = use_cuda if use_cuda is not None else _cuda_available() - accelerate = use_accelerate if use_accelerate is not None else _accelerate_available() + accelerate = ( + use_accelerate if use_accelerate is not None else _accelerate_available() + ) if cuda: skills.append("gp_cuda") elif accelerate: @@ -417,7 +422,7 @@ def _build_system_prompt(skills: list[str]) -> str: for skill_name in skills: content = _load_skill(skill_name) if content: - prompt += f"\n\n{'='*60}\n{content}" + prompt += f"\n\n{'=' * 60}\n{content}" return prompt @@ -465,9 +470,13 @@ class CompilationResult: timings: dict[str, float] n_tool_calls: int = 0 conversation_turns: int = 0 - token_usage: dict[str, int] = field(default_factory=lambda: { - "input_tokens": 0, "output_tokens": 0, "total_tokens": 0, - }) + token_usage: dict[str, int] = field( + default_factory=lambda: { + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + } + ) us_per_eval: float | None = None # benchmark result if available optimization_log: list[OptimizationEvent] = field(default_factory=list) @@ -493,20 +502,29 @@ def write_results_tsv(self, path: str | Path | None = None) -> Path: with open(path, "w", newline="") as f: writer = csv.writer(f, delimiter="\t") - writer.writerow([ - "turn", "timestamp_s", "event_type", "status", - "us_per_eval", "code_hash", "description", - ]) + writer.writerow( + [ + "turn", + "timestamp_s", + "event_type", + "status", + "us_per_eval", + "code_hash", + "description", + ] + ) for ev in self.optimization_log: - writer.writerow([ - ev.turn, - f"{ev.timestamp:.2f}", - ev.event_type, - ev.status, - f"{ev.us_per_eval:.3f}" if ev.us_per_eval is not None else "", - ev.code_hash, - ev.description, - ]) + writer.writerow( + [ + ev.turn, + f"{ev.timestamp:.2f}", + ev.event_type, + ev.status, + f"{ev.us_per_eval:.3f}" if ev.us_per_eval is not None else "", + ev.code_hash, + ev.description, + ] + ) return path @@ -588,7 +606,13 @@ def compile_model( print(f" {ctx.n_params} parameters, {len(prompt)} char prompt") # Detect model-specific skills and build augmented system prompt - skills = _detect_skills(model, ctx, use_cuda=use_cuda, use_accelerate=use_accelerate, use_enzyme=use_enzyme) + skills = _detect_skills( + model, + ctx, + use_cuda=use_cuda, + use_accelerate=use_accelerate, + use_enzyme=use_enzyme, + ) system_prompt = _build_system_prompt(skills) if verbose and skills: print(f" Skills loaded: {', '.join(skills)}") @@ -617,15 +641,17 @@ def compile_model( state = _AgentState( build_path=build_path, ctx=ctx, - messages=[{ - "role": "user", - "content": ( - "Generate a Rust logp+gradient implementation for this PyMC model.\n\n" - "Use your tools to write the code, build it, and validate it. " - "Iterate until validation passes.\n\n" - f"{prompt}" - ), - }], + messages=[ + { + "role": "user", + "content": ( + "Generate a Rust logp+gradient implementation for this PyMC model.\n\n" + "Use your tools to write the code, build it, and validate it. " + "Iterate until validation passes.\n\n" + f"{prompt}" + ), + } + ], timings=timings, ) @@ -653,7 +679,9 @@ def compile_model( total_input_tokens += response.usage.input_tokens total_output_tokens += response.usage.output_tokens if verbose: - print(f" Turn {turn}: {response.usage.input_tokens} in / {response.usage.output_tokens} out tokens") + print( + f" Turn {turn}: {response.usage.input_tokens} in / {response.usage.output_tokens} out tokens" + ) # Check stop reason if response.stop_reason == "end_turn": @@ -679,11 +707,13 @@ def compile_model( elif block.type == "tool_use": state.tool_calls += 1 result = _execute_tool(block.name, block.input, state, verbose) - tool_results.append({ - "type": "tool_result", - "tool_use_id": block.id, - "content": result, - }) + tool_results.append( + { + "type": "tool_result", + "tool_use_id": block.id, + "content": result, + } + ) # Check if validation passed if state.validated: @@ -748,9 +778,7 @@ def _execute_tool( return f"Unknown tool: {name}" -def _tool_write_rust_code( - input_data: dict, state: _AgentState, verbose: bool -) -> str: +def _tool_write_rust_code(input_data: dict, state: _AgentState, verbose: bool) -> str: """Write the generated.rs file.""" code = input_data.get("code", "") if not code: @@ -777,14 +805,16 @@ def _tool_write_rust_code( if verbose: print(f" [write_rust_code] Wrote {len(code)} chars to generated.rs") - state.optimization_log.append(OptimizationEvent( - turn=state.tool_calls, - timestamp=time.time() - state.t0_loop, - event_type="write_code", - status="OK", - description=f"Wrote {len(code)} chars", - code_hash=_code_hash(state.build_path), - )) + state.optimization_log.append( + OptimizationEvent( + turn=state.tool_calls, + timestamp=time.time() - state.t0_loop, + event_type="write_code", + status="OK", + description=f"Wrote {len(code)} chars", + code_hash=_code_hash(state.build_path), + ) + ) return f"Written {len(code)} chars to src/generated.rs" @@ -814,26 +844,30 @@ def _tool_cargo_build(state: _AgentState, verbose: bool) -> str: if result.returncode == 0: if verbose: print(f" [cargo_build] OK ({elapsed:.1f}s)") - state.optimization_log.append(OptimizationEvent( - turn=state.tool_calls, - timestamp=time.time() - state.t0_loop, - event_type="build", - status="PASS", - description=f"Build OK ({elapsed:.1f}s)", - code_hash=_code_hash(state.build_path), - )) + state.optimization_log.append( + OptimizationEvent( + turn=state.tool_calls, + timestamp=time.time() - state.t0_loop, + event_type="build", + status="PASS", + description=f"Build OK ({elapsed:.1f}s)", + code_hash=_code_hash(state.build_path), + ) + ) return f"Build successful ({elapsed:.1f}s)" else: if verbose: print(f" [cargo_build] FAILED ({elapsed:.1f}s)") - state.optimization_log.append(OptimizationEvent( - turn=state.tool_calls, - timestamp=time.time() - state.t0_loop, - event_type="build", - status="CRASH", - description=f"Build FAILED ({elapsed:.1f}s)", - code_hash=_code_hash(state.build_path), - )) + state.optimization_log.append( + OptimizationEvent( + turn=state.tool_calls, + timestamp=time.time() - state.t0_loop, + event_type="build", + status="CRASH", + description=f"Build FAILED ({elapsed:.1f}s)", + code_hash=_code_hash(state.build_path), + ) + ) # Return compiler errors (truncated to avoid token explosion) errors = result.stderr if len(errors) > 4000: @@ -878,10 +912,12 @@ def _tool_validate_logp(state: _AgentState, verbose: bool) -> str: if result.returncode != 0: return f"Error: validator crashed: {result.stderr[:500]}" - output_lines = [l for l in result.stdout.strip().split("\n") if l] + output_lines = [line for line in result.stdout.strip().split("\n") if line] if len(output_lines) != len(all_points): - return f"Error: expected {len(all_points)} output lines, got {len(output_lines)}" + return ( + f"Error: expected {len(all_points)} output lines, got {len(output_lines)}" + ) # Parse results parsed = [] @@ -919,7 +955,9 @@ def _tool_validate_logp(state: _AgentState, verbose: bool) -> str: # logp comparison if has_constant_offset: - adjusted_err = abs((rust_logp - mean_offset) - vp.logp) / max(abs(vp.logp), 1.0) + adjusted_err = abs((rust_logp - mean_offset) - vp.logp) / max( + abs(vp.logp), 1.0 + ) else: adjusted_err = abs(rust_logp - vp.logp) / max(abs(vp.logp), 1.0) @@ -961,33 +999,38 @@ def _tool_validate_logp(state: _AgentState, verbose: bool) -> str: if not errors: state.validated = True if verbose: - print(f" [validate_logp] PASSED!") - state.optimization_log.append(OptimizationEvent( - turn=state.tool_calls, - timestamp=time.time() - state.t0_loop, - event_type="validation", - status="PASS", - description="Validation passed", - code_hash=_code_hash(state.build_path), - )) + print(" [validate_logp] PASSED!") + state.optimization_log.append( + OptimizationEvent( + turn=state.tool_calls, + timestamp=time.time() - state.t0_loop, + event_type="validation", + status="PASS", + description="Validation passed", + code_hash=_code_hash(state.build_path), + ) + ) return f"VALIDATION PASSED!\n\n{report}" else: if verbose: print(f" [validate_logp] FAILED ({len(errors)} errors)") - state.optimization_log.append(OptimizationEvent( - turn=state.tool_calls, - timestamp=time.time() - state.t0_loop, - event_type="validation", - status="FAIL", - description=f"Validation failed ({len(errors)} errors)", - code_hash=_code_hash(state.build_path), - )) - return f"VALIDATION FAILED ({len(errors)} errors):\n\n{report}\n\nErrors:\n" + "\n".join(errors) + state.optimization_log.append( + OptimizationEvent( + turn=state.tool_calls, + timestamp=time.time() - state.t0_loop, + event_type="validation", + status="FAIL", + description=f"Validation failed ({len(errors)} errors)", + code_hash=_code_hash(state.build_path), + ) + ) + return ( + f"VALIDATION FAILED ({len(errors)} errors):\n\n{report}\n\nErrors:\n" + + "\n".join(errors) + ) -def _tool_read_file( - input_data: dict, state: _AgentState, verbose: bool -) -> str: +def _tool_read_file(input_data: dict, state: _AgentState, verbose: bool) -> str: """Read a file from the build directory.""" rel_path = input_data.get("path", "") if not rel_path: @@ -1116,13 +1159,13 @@ def _setup_enzyme_toolchain(build_path: Path) -> None: toolchain_file = build_path / "rust-toolchain.toml" if toolchain_file.exists(): return - toolchain_file.write_text( - '[toolchain]\nchannel = "nightly"\n' - ) + toolchain_file.write_text('[toolchain]\nchannel = "nightly"\n') def _setup_rust_project( - build_path: Path, ctx, extra_cargo_deps: dict[str, str] | None = None, + build_path: Path, + ctx, + extra_cargo_deps: dict[str, str] | None = None, build_rs: str | None = None, ): """Create the Rust project structure with pre-generated data.""" @@ -1138,7 +1181,7 @@ def _setup_rust_project( ] for dep_name, dep_version in (extra_cargo_deps or {}).items(): if dep_version.startswith("{"): - deps_lines.append(f'{dep_name} = {dep_version}') + deps_lines.append(f"{dep_name} = {dep_version}") else: deps_lines.append(f'{dep_name} = "{dep_version}"') @@ -1371,17 +1414,19 @@ def _bench_logp_tool(state: _AgentState, verbose: bool) -> str: if verbose: print(f" [bench_logp] {us_per_eval:.3f} us/eval ({n_evals:,} evaluations)") - state.optimization_log.append(OptimizationEvent( - turn=state.tool_calls, - timestamp=time.time() - state.t0_loop, - event_type="benchmark", - status="OK", - us_per_eval=us_per_eval, - description=f"{us_per_eval:.3f} us/eval", - code_hash=_code_hash(state.build_path), - )) + state.optimization_log.append( + OptimizationEvent( + turn=state.tool_calls, + timestamp=time.time() - state.t0_loop, + event_type="benchmark", + status="OK", + us_per_eval=us_per_eval, + description=f"{us_per_eval:.3f} us/eval", + code_hash=_code_hash(state.build_path), + ) + ) - return f"Benchmark: {us_per_eval:.3f} us/eval ({n_evals:,} evaluations, {1e6/us_per_eval:,.0f} evals/sec)" + return f"Benchmark: {us_per_eval:.3f} us/eval ({n_evals:,} evaluations, {1e6 / us_per_eval:,.0f} evals/sec)" def optimize_model( @@ -1450,35 +1495,43 @@ def optimize_model( state = _AgentState( build_path=build_path, ctx=ctx, - messages=[{ - "role": "user", - "content": ( - "Optimize this Rust logp+gradient implementation for maximum speed.\n\n" - "The code is CORRECT and passes validation. Your goal is to make it faster " - "while keeping output numerically identical.\n\n" - "Steps:\n" - "1. Run `bench_logp` to get the baseline speed\n" - "2. Read the current code with `read_file`\n" - "3. Apply optimizations and write the new code\n" - "4. Build and validate (correctness must be preserved!)\n" - "5. Benchmark again to measure improvement\n" - "6. Iterate if there's more to gain\n\n" - f"The model has {ctx.n_params} parameters and the code is {len(current_code)} chars.\n" - f"Build directory: {build_path}\n" - ), - }], + messages=[ + { + "role": "user", + "content": ( + "Optimize this Rust logp+gradient implementation for maximum speed.\n\n" + "The code is CORRECT and passes validation. Your goal is to make it faster " + "while keeping output numerically identical.\n\n" + "Steps:\n" + "1. Run `bench_logp` to get the baseline speed\n" + "2. Read the current code with `read_file`\n" + "3. Apply optimizations and write the new code\n" + "4. Build and validate (correctness must be preserved!)\n" + "5. Benchmark again to measure improvement\n" + "6. Iterate if there's more to gain\n\n" + f"The model has {ctx.n_params} parameters and the code is {len(current_code)} chars.\n" + f"Build directory: {build_path}\n" + ), + } + ], ) if verbose: print(f"\nStarting optimization loop (build_dir={build_path})...") # Detect skills for system prompt augmentation - skills = _detect_skills(model, ctx, use_cuda=use_cuda, use_accelerate=use_accelerate, use_enzyme=use_enzyme) + skills = _detect_skills( + model, + ctx, + use_cuda=use_cuda, + use_accelerate=use_accelerate, + use_enzyme=use_enzyme, + ) system_prompt = OPTIMIZE_SYSTEM_PROMPT for skill_name in skills: content = _load_skill(skill_name) if content: - system_prompt += f"\n\n{'='*60}\n{content}" + system_prompt += f"\n\n{'=' * 60}\n{content}" total_input_tokens = 0 total_output_tokens = 0 @@ -1486,7 +1539,7 @@ def optimize_model( best_us = None for turn in range(max_turns): - t0 = time.time() + time.time() response = client.messages.create( model=model_name, max_tokens=16384, @@ -1499,7 +1552,9 @@ def optimize_model( total_input_tokens += response.usage.input_tokens total_output_tokens += response.usage.output_tokens if verbose: - print(f" Turn {turn}: {response.usage.input_tokens} in / {response.usage.output_tokens} out tokens") + print( + f" Turn {turn}: {response.usage.input_tokens} in / {response.usage.output_tokens} out tokens" + ) if response.stop_reason == "end_turn": if verbose: @@ -1539,11 +1594,13 @@ def optimize_model( break else: result = _execute_tool(block.name, block.input, state, verbose) - tool_results.append({ - "type": "tool_result", - "tool_use_id": block.id, - "content": result, - }) + tool_results.append( + { + "type": "tool_result", + "tool_use_id": block.id, + "content": result, + } + ) state.messages.append({"role": "user", "content": tool_results}) diff --git a/pymc_rust_compiler/exporter.py b/pymc_rust_compiler/exporter.py index b1f7fc0..e1ba221 100644 --- a/pymc_rust_compiler/exporter.py +++ b/pymc_rust_compiler/exporter.py @@ -12,6 +12,7 @@ import numpy as np import pymc as pm + try: from pytensor.graph.traversal import graph_inputs except ImportError: @@ -24,9 +25,9 @@ class ParamInfo: name: str value_var: str transform: str | None - shape: list[int] # original RV shape - unc_shape: list[int] # unconstrained (value_var) shape - size: int # unconstrained size + shape: list[int] # original RV shape + unc_shape: list[int] # unconstrained (value_var) shape + size: int # unconstrained size zerosum_axes: list[int] | None = None # axes for ZeroSumTransform @property @@ -128,7 +129,9 @@ def _extract(self) -> ModelContext: transform = model.rvs_to_transforms.get(rv, None) # Use value_var shape for unconstrained size (transforms may change dims) rv_shape = list(rv.type.shape) if hasattr(rv.type, "shape") else [] - unc_shape = list(value_var.type.shape) if hasattr(value_var.type, "shape") else [] + unc_shape = ( + list(value_var.type.shape) if hasattr(value_var.type, "shape") else [] + ) size = int(np.prod(unc_shape)) if unc_shape else 1 zs_axes = None if transform and type(transform).__name__ == "ZeroSumTransform": @@ -294,7 +297,11 @@ def _extract_covariates( if np.all(flat == np.floor(flat)) and np.min(flat) == 0: unique_vals = np.unique(flat) max_val = int(np.max(flat)) - if max_val >= 2 and max_val < 200 and np.array_equal(unique_vals, np.arange(max_val + 1)): + if ( + max_val >= 2 + and max_val < 200 + and np.array_equal(unique_vals, np.arange(max_val + 1)) + ): is_index = True n_groups = max_val + 1 @@ -326,16 +333,20 @@ def _infer_data_mapping(self, ctx) -> list[str]: hints = [] # Find variable names used in indexing: a[group_idx] → group_idx - index_vars = set(_re.findall(r'\w+\[(\w+)\]', source)) + index_vars = set(_re.findall(r"\w+\[(\w+)\]", source)) # Find other variable names that are likely data arrays (not parameters) # Parameters are defined with ~ or = expressions - param_names = set(_re.findall(r'(\w+)\s*~', source)) - param_names.update(_re.findall(r'(\w+)\s*=\s*\w+', source)) + param_names = set(_re.findall(r"(\w+)\s*~", source)) + param_names.update(_re.findall(r"(\w+)\s*=\s*\w+", source)) # Covariates: match index arrays to index variables, others to remaining vars covariate_items = list(ctx.covariate_data.items()) - index_covariates = [(n, i) for n, i in covariate_items if i.get("is_index_array")] - non_index_covariates = [(n, i) for n, i in covariate_items if not i.get("is_index_array")] + index_covariates = [ + (n, i) for n, i in covariate_items if i.get("is_index_array") + ] + non_index_covariates = [ + (n, i) for n, i in covariate_items if not i.get("is_index_array") + ] # Match index covariates to index variables from source # If only one index covariate and one index variable, match them directly @@ -348,7 +359,9 @@ def _infer_data_mapping(self, ctx) -> list[str]: f"(integer indices, {n_groups} groups, cast to `usize` for indexing)" ) else: - for (cov_name, cov_info), src_var in zip(index_covariates, sorted(index_vars)): + for (cov_name, cov_info), src_var in zip( + index_covariates, sorted(index_vars) + ): n_groups = cov_info.get("n_groups", 0) hints.append( f"`{src_var}` in source → `{cov_name.upper()}_DATA` " @@ -357,21 +370,21 @@ def _infer_data_mapping(self, ctx) -> list[str]: # Match remaining covariates to non-index, non-parameter variables # Look for variables used in arithmetic but not defined as params - arith_vars = set(_re.findall(r'[\+\-\*]\s*(\w+)', source)) + arith_vars = set(_re.findall(r"[\+\-\*]\s*(\w+)", source)) arith_vars -= param_names arith_vars -= index_vars - arith_vars -= {'observed', 'shape'} + arith_vars -= {"observed", "shape"} remaining_src_vars = sorted(arith_vars) - for (cov_name, cov_info), src_var in zip(non_index_covariates, remaining_src_vars): - hints.append( - f"`{src_var}` in source → `{cov_name.upper()}_DATA`" - ) + for (cov_name, cov_info), src_var in zip( + non_index_covariates, remaining_src_vars + ): + hints.append(f"`{src_var}` in source → `{cov_name.upper()}_DATA`") # Observed data mapping for obs_name in ctx.observed_data: # Find the observed variable in source (pattern: var ~ ..., observed) - obs_match = _re.search(r'(\w+)\s*~.*observed', source) + obs_match = _re.search(r"(\w+)\s*~.*observed", source) if obs_match: src_obs = obs_match.group(1) hints.append( @@ -520,7 +533,7 @@ def to_prompt(self) -> str: if is_obs: label = "observed" elif is_idx: - label = f"INTEGER INDEX ARRAY (values 0..{n_groups-1}, {n_groups} groups) — cast to usize for array indexing" + label = f"INTEGER INDEX ARRAY (values 0..{n_groups - 1}, {n_groups} groups) — cast to usize for array indexing" else: label = "covariate/predictor" parts.append( @@ -539,9 +552,13 @@ def to_prompt(self) -> str: parts.append(f"- {hint}") parts.append("") - parts.append(f"## Optimized PyTensor Graph (logp)\n```\n{ctx.logp_graph}\n```\n") + parts.append( + f"## Optimized PyTensor Graph (logp)\n```\n{ctx.logp_graph}\n```\n" + ) - parts.append(f"## Optimized PyTensor Graph (dlogp/gradient)\n```\n{ctx.dlogp_graph}\n```\n") + parts.append( + f"## Optimized PyTensor Graph (dlogp/gradient)\n```\n{ctx.dlogp_graph}\n```\n" + ) parts.append("## Individual logp terms (optimized, per RV)\n") for name, term in ctx.logp_terms.items(): @@ -549,7 +566,9 @@ def to_prompt(self) -> str: parts.append(f"### {name}\n```\n{display}\n```\n") parts.append("## Validation") - parts.append("Your generated code MUST produce these exact values (within float64 precision):\n") + parts.append( + "Your generated code MUST produce these exact values (within float64 precision):\n" + ) parts.append(f"At initial point: {json.dumps(ctx.initial_point.point)}") parts.append(f"- logp = {ctx.initial_point.logp:.10f}") diff --git a/pymc_rust_compiler/jax_exporter.py b/pymc_rust_compiler/jax_exporter.py index 23723cc..a860048 100644 --- a/pymc_rust_compiler/jax_exporter.py +++ b/pymc_rust_compiler/jax_exporter.py @@ -20,6 +20,7 @@ @dataclass class TensorInfo: """Metadata about a single parameter tensor or output.""" + name: str shape: list[int] dtype: str @@ -33,6 +34,7 @@ def is_scalar(self) -> bool: @dataclass class ValidationPoint: """A set of inputs/params and expected outputs + gradients.""" + params: dict[str, list | float] inputs: dict[str, list | float] output: list | float @@ -42,6 +44,7 @@ class ValidationPoint: @dataclass class ModelContext: """All information extracted from a JAX model.""" + source_framework: str # "jax" or "pytorch" source_code: str | None params: list[TensorInfo] @@ -121,31 +124,39 @@ def _extract(self) -> ModelContext: param_infos = [] for name, val in self.params.items(): arr = np.asarray(val) - param_infos.append(TensorInfo( - name=name, - shape=list(arr.shape), - dtype=str(arr.dtype), - size=int(np.prod(arr.shape)) if arr.shape else 1, - )) + param_infos.append( + TensorInfo( + name=name, + shape=list(arr.shape), + dtype=str(arr.dtype), + size=int(np.prod(arr.shape)) if arr.shape else 1, + ) + ) # Extract input info input_infos = [] if isinstance(self.sample_input, dict): for name, val in self.sample_input.items(): arr = np.asarray(val) - input_infos.append(TensorInfo( - name=name, shape=list(arr.shape), - dtype=str(arr.dtype), - size=int(np.prod(arr.shape)) if arr.shape else 1, - )) + input_infos.append( + TensorInfo( + name=name, + shape=list(arr.shape), + dtype=str(arr.dtype), + size=int(np.prod(arr.shape)) if arr.shape else 1, + ) + ) else: arr = np.asarray(self.sample_input) name = self.input_names[0] if self.input_names else "x" - input_infos.append(TensorInfo( - name=name, shape=list(arr.shape), - dtype=str(arr.dtype), - size=int(np.prod(arr.shape)) if arr.shape else 1, - )) + input_infos.append( + TensorInfo( + name=name, + shape=list(arr.shape), + dtype=str(arr.dtype), + size=int(np.prod(arr.shape)) if arr.shape else 1, + ) + ) # Determine the function to differentiate if self._loss_fn is not None: @@ -157,9 +168,11 @@ def scalar_fn(params, x): test_out = self.fn(self.params, self.sample_input) test_arr = np.asarray(test_out) if test_arr.ndim == 0 or test_arr.size == 1: + def scalar_fn(params, x): return jnp.sum(self.fn(params, x)) else: + def scalar_fn(params, x): return jnp.sum(self.fn(params, x)) @@ -168,11 +181,14 @@ def scalar_fn(params, x): # Extract output info from a forward pass output = self.fn(self.params, self.sample_input) out_arr = np.asarray(output) - output_infos = [TensorInfo( - name="output", shape=list(out_arr.shape), - dtype=str(out_arr.dtype), - size=int(np.prod(out_arr.shape)) if out_arr.shape else 1, - )] + output_infos = [ + TensorInfo( + name="output", + shape=list(out_arr.shape), + dtype=str(out_arr.dtype), + size=int(np.prod(out_arr.shape)) if out_arr.shape else 1, + ) + ] # Generate validation points rng = np.random.default_rng(self._seed) @@ -180,12 +196,14 @@ def scalar_fn(params, x): # Point 0: original params grads = grad_fn(self.params, self.sample_input) - validation_points.append(ValidationPoint( - params={k: np.asarray(v).tolist() for k, v in self.params.items()}, - inputs=self._input_to_dict(self.sample_input), - output=out_arr.tolist(), - grad_params={k: np.asarray(v).tolist() for k, v in grads.items()}, - )) + validation_points.append( + ValidationPoint( + params={k: np.asarray(v).tolist() for k, v in self.params.items()}, + inputs=self._input_to_dict(self.sample_input), + output=out_arr.tolist(), + grad_params={k: np.asarray(v).tolist() for k, v in grads.items()}, + ) + ) # Extra points: perturbed params for _ in range(self._n_extra_points): @@ -196,12 +214,14 @@ def scalar_fn(params, x): out = self.fn(perturbed, self.sample_input) grads = grad_fn(perturbed, self.sample_input) - validation_points.append(ValidationPoint( - params={k: np.asarray(v).tolist() for k, v in perturbed.items()}, - inputs=self._input_to_dict(self.sample_input), - output=np.asarray(out).tolist(), - grad_params={k: np.asarray(v).tolist() for k, v in grads.items()}, - )) + validation_points.append( + ValidationPoint( + params={k: np.asarray(v).tolist() for k, v in perturbed.items()}, + inputs=self._input_to_dict(self.sample_input), + output=np.asarray(out).tolist(), + grad_params={k: np.asarray(v).tolist() for k, v in grads.items()}, + ) + ) source = self._source_code or self._try_extract_source() diff --git a/pymc_rust_compiler/jax_pytorch_transpiler.py b/pymc_rust_compiler/jax_pytorch_transpiler.py index 7616cc3..86bc51b 100644 --- a/pymc_rust_compiler/jax_pytorch_transpiler.py +++ b/pymc_rust_compiler/jax_pytorch_transpiler.py @@ -10,7 +10,6 @@ from __future__ import annotations import functools -import json import os import time from dataclasses import dataclass, field @@ -19,7 +18,7 @@ import numpy as np -from pymc_rust_compiler.jax_exporter import ModelContext, ValidationPoint +from pymc_rust_compiler.jax_exporter import ModelContext _SKILLS_DIR = Path(__file__).parent / "skills" @@ -178,9 +177,11 @@ def forward(params: dict, x): # ── Result types ──────────────────────────────────────────────────────────── + @dataclass class TranspileResult: """Result of transpiling between JAX and PyTorch.""" + source_framework: str # "jax" or "pytorch" target_framework: str # "pytorch" or "jax" generated_code: str @@ -190,9 +191,13 @@ class TranspileResult: n_tool_calls: int = 0 conversation_turns: int = 0 timings: dict[str, float] = field(default_factory=dict) - token_usage: dict[str, int] = field(default_factory=lambda: { - "input_tokens": 0, "output_tokens": 0, "total_tokens": 0, - }) + token_usage: dict[str, int] = field( + default_factory=lambda: { + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + } + ) @property def success(self) -> bool: @@ -222,9 +227,11 @@ def get_model(self, params: dict | None = None): # ── Agent state ───────────────────────────────────────────────────────────── + @dataclass class _AgentState: """Mutable state for the agent loop.""" + direction: str # "jax_to_pytorch" or "pytorch_to_jax" source_context: ModelContext generated_code: str @@ -237,6 +244,7 @@ class _AgentState: # ── Skill loading ─────────────────────────────────────────────────────────── + @functools.lru_cache(maxsize=None) def _load_skill(name: str) -> str: path = _SKILLS_DIR / f"{name}.md" @@ -247,6 +255,7 @@ def _load_skill(name: str) -> str: # ── Prompt building ───────────────────────────────────────────────────────── + def _build_system_prompt(direction: str) -> str: if direction == "jax_to_pytorch": prompt = JAX_TO_PYTORCH_SYSTEM @@ -256,7 +265,7 @@ def _build_system_prompt(direction: str) -> str: skill = _load_skill("pytorch_to_jax") if skill: - prompt += f"\n\n{'='*60}\n{skill}" + prompt += f"\n\n{'=' * 60}\n{skill}" return prompt @@ -342,8 +351,12 @@ def _build_user_prompt(ctx: ModelContext, direction: str) -> str: # ── Tool execution ────────────────────────────────────────────────────────── + def _execute_tool( - name: str, input_data: dict, state: _AgentState, verbose: bool, + name: str, + input_data: dict, + state: _AgentState, + verbose: bool, ) -> str: if name == "write_code": return _tool_write_code(input_data, state, verbose) @@ -356,7 +369,9 @@ def _execute_tool( def _tool_write_code( - input_data: dict, state: _AgentState, verbose: bool, + input_data: dict, + state: _AgentState, + verbose: bool, ) -> str: code = input_data.get("code", "") if not code: @@ -429,7 +444,10 @@ def _validate_pytorch(namespace: dict, state: _AgentState, verbose: bool) -> str x = torch.tensor(np.array(inp["x"]), dtype=torch.float32) output = model(x) else: - tensors = {k: torch.tensor(np.array(v), dtype=torch.float32) for k, v in inp.items()} + tensors = { + k: torch.tensor(np.array(v), dtype=torch.float32) + for k, v in inp.items() + } output = model(**tensors) out_np = output.detach().cpu().numpy() @@ -437,12 +455,18 @@ def _validate_pytorch(namespace: dict, state: _AgentState, verbose: bool) -> str # Check output if out_np.shape != ref_out.shape: - errors.append(f"{label}: shape mismatch: got {out_np.shape}, expected {ref_out.shape}") - report_lines.append(f"{label}: SHAPE MISMATCH {out_np.shape} vs {ref_out.shape}") + errors.append( + f"{label}: shape mismatch: got {out_np.shape}, expected {ref_out.shape}" + ) + report_lines.append( + f"{label}: SHAPE MISMATCH {out_np.shape} vs {ref_out.shape}" + ) continue max_diff = float(np.max(np.abs(out_np - ref_out))) - rel_err = float(np.max(np.abs(out_np - ref_out) / np.maximum(np.abs(ref_out), 1e-8))) + rel_err = float( + np.max(np.abs(out_np - ref_out) / np.maximum(np.abs(ref_out), 1e-8)) + ) out_ok = rel_err <= 1e-4 report_lines.append( @@ -450,7 +474,9 @@ def _validate_pytorch(namespace: dict, state: _AgentState, verbose: bool) -> str f"[{'OK' if out_ok else 'MISMATCH'}]" ) if not out_ok: - errors.append(f"{label}: output mismatch: max_diff={max_diff:.2e}, rel_err={rel_err:.2e}") + errors.append( + f"{label}: output mismatch: max_diff={max_diff:.2e}, rel_err={rel_err:.2e}" + ) # Check gradients loss = output.sum() @@ -463,14 +489,20 @@ def _validate_pytorch(namespace: dict, state: _AgentState, verbose: bool) -> str if name == pname and param.grad is not None: got_g = param.grad.detach().cpu().numpy() grad_diff = float(np.max(np.abs(got_g - ref_g))) - grad_rel = float(np.max(np.abs(got_g - ref_g) / np.maximum(np.abs(ref_g), 1e-8))) + grad_rel = float( + np.max( + np.abs(got_g - ref_g) / np.maximum(np.abs(ref_g), 1e-8) + ) + ) grad_ok = grad_rel <= 1e-3 report_lines.append( f" grad['{name}']: max_diff={grad_diff:.2e} rel_err={grad_rel:.2e} " f"[{'OK' if grad_ok else 'MISMATCH'}]" ) if not grad_ok: - errors.append(f"{label}: grad['{name}'] mismatch: rel_err={grad_rel:.2e}") + errors.append( + f"{label}: grad['{name}'] mismatch: rel_err={grad_rel:.2e}" + ) found = True break if not found: @@ -504,7 +536,9 @@ def _validate_jax(namespace: dict, state: _AgentState, verbose: bool) -> str: if "forward" not in namespace: return "Error: generated code does not define `forward(params, x)` function." if "init_params" not in namespace: - return "Error: generated code does not define `init_params(param_data)` function." + return ( + "Error: generated code does not define `init_params(param_data)` function." + ) forward_fn = namespace["forward"] init_fn = namespace["init_params"] @@ -541,12 +575,18 @@ def scalar_fn(params, x): ref_out = np.array(vp.output) if out_np.shape != ref_out.shape: - errors.append(f"{label}: shape mismatch: got {out_np.shape}, expected {ref_out.shape}") - report_lines.append(f"{label}: SHAPE MISMATCH {out_np.shape} vs {ref_out.shape}") + errors.append( + f"{label}: shape mismatch: got {out_np.shape}, expected {ref_out.shape}" + ) + report_lines.append( + f"{label}: SHAPE MISMATCH {out_np.shape} vs {ref_out.shape}" + ) continue max_diff = float(np.max(np.abs(out_np - ref_out))) - rel_err = float(np.max(np.abs(out_np - ref_out) / np.maximum(np.abs(ref_out), 1e-8))) + rel_err = float( + np.max(np.abs(out_np - ref_out) / np.maximum(np.abs(ref_out), 1e-8)) + ) out_ok = rel_err <= 1e-4 report_lines.append( @@ -554,7 +594,9 @@ def scalar_fn(params, x): f"[{'OK' if out_ok else 'MISMATCH'}]" ) if not out_ok: - errors.append(f"{label}: output mismatch: max_diff={max_diff:.2e}, rel_err={rel_err:.2e}") + errors.append( + f"{label}: output mismatch: max_diff={max_diff:.2e}, rel_err={rel_err:.2e}" + ) # Check gradients grads = grad_fn(params, x) @@ -564,14 +606,18 @@ def scalar_fn(params, x): if pname in grads: got_g = np.asarray(grads[pname]) grad_diff = float(np.max(np.abs(got_g - ref_g))) - grad_rel = float(np.max(np.abs(got_g - ref_g) / np.maximum(np.abs(ref_g), 1e-8))) + grad_rel = float( + np.max(np.abs(got_g - ref_g) / np.maximum(np.abs(ref_g), 1e-8)) + ) grad_ok = grad_rel <= 1e-3 report_lines.append( f" grad['{pname}']: max_diff={grad_diff:.2e} rel_err={grad_rel:.2e} " f"[{'OK' if grad_ok else 'MISMATCH'}]" ) if not grad_ok: - errors.append(f"{label}: grad['{pname}'] mismatch: rel_err={grad_rel:.2e}") + errors.append( + f"{label}: grad['{pname}'] mismatch: rel_err={grad_rel:.2e}" + ) else: errors.append(f"{label}: gradient for '{pname}' not found") @@ -604,6 +650,7 @@ def _tool_read_source(state: _AgentState, verbose: bool) -> str: # ── Main transpiler functions ─────────────────────────────────────────────── + def _run_agent_loop( state: _AgentState, system_prompt: str, @@ -664,11 +711,13 @@ def _run_agent_loop( elif block.type == "tool_use": state.tool_calls += 1 result = _execute_tool(block.name, block.input, state, verbose) - tool_results.append({ - "type": "tool_result", - "tool_use_id": block.id, - "content": result, - }) + tool_results.append( + { + "type": "tool_result", + "tool_use_id": block.id, + "content": result, + } + ) if state.validated: break @@ -725,15 +774,20 @@ def transpile_jax_to_pytorch( print("Extracting JAX model context...") t0 = time.time() exporter = JaxModelExporter( - fn, params, sample_input, - source_code=source_code, loss_fn=loss_fn, + fn, + params, + sample_input, + source_code=source_code, + loss_fn=loss_fn, ) ctx = exporter.context timings["extract"] = time.time() - t0 if verbose: n_params = sum(p.size for p in ctx.params) - print(f" {n_params} parameters, {len(ctx.validation_points)} validation points") + print( + f" {n_params} parameters, {len(ctx.validation_points)} validation points" + ) # Step 2: Build prompts and run agent system_prompt = _build_system_prompt("jax_to_pytorch") @@ -751,7 +805,12 @@ def transpile_jax_to_pytorch( print("\nStarting agent loop...") turns, token_usage = _run_agent_loop( - state, system_prompt, api_key, max_turns, model_name, verbose, + state, + system_prompt, + api_key, + max_turns, + model_name, + verbose, ) validation_errors = [] @@ -812,15 +871,19 @@ def transpile_pytorch_to_jax( print("Extracting PyTorch model context...") t0 = time.time() exporter = PytorchModelExporter( - module, sample_input, - source_code=source_code, loss_fn=loss_fn, + module, + sample_input, + source_code=source_code, + loss_fn=loss_fn, ) ctx = exporter.context timings["extract"] = time.time() - t0 if verbose: n_params = sum(p.size for p in ctx.params) - print(f" {n_params} parameters, {len(ctx.validation_points)} validation points") + print( + f" {n_params} parameters, {len(ctx.validation_points)} validation points" + ) # Step 2: Build prompts and run agent system_prompt = _build_system_prompt("pytorch_to_jax") @@ -838,7 +901,12 @@ def transpile_pytorch_to_jax( print("\nStarting agent loop...") turns, token_usage = _run_agent_loop( - state, system_prompt, api_key, max_turns, model_name, verbose, + state, + system_prompt, + api_key, + max_turns, + model_name, + verbose, ) validation_errors = [] diff --git a/pymc_rust_compiler/nutpie_bridge.py b/pymc_rust_compiler/nutpie_bridge.py index 5cb9c28..83cf6b1 100644 --- a/pymc_rust_compiler/nutpie_bridge.py +++ b/pymc_rust_compiler/nutpie_bridge.py @@ -50,7 +50,7 @@ def _build_shared_lib(build_dir: Path) -> Path: if not so_path.exists(): raise RuntimeError( f"Shared library not found at {so_path}. " - "Ensure Cargo.toml has [lib] crate-type = [\"cdylib\"]" + 'Ensure Cargo.toml has [lib] crate-type = ["cdylib"]' ) return so_path @@ -62,10 +62,10 @@ def _load_logp_fn(so_path: Path, n_dim: int): # C FFI signature: int logp_ffi(const double* x, double* grad, double* logp_out, int dim) lib.logp_ffi.restype = ctypes.c_int lib.logp_ffi.argtypes = [ - ctypes.c_void_p, # x (input) - ctypes.c_void_p, # grad (output) - ctypes.c_void_p, # logp_out (output) - ctypes.c_int, # dim + ctypes.c_void_p, # x (input) + ctypes.c_void_p, # grad (output) + ctypes.c_void_p, # logp_out (output) + ctypes.c_int, # dim ] # Pre-allocate output buffers @@ -91,7 +91,7 @@ def logp_fn(x): def to_nutpie( compile_result: CompilationResult, model: pm.Model, -) -> "nutpie.compiled_pyfunc.PyFuncModel": +) -> "nutpie.compiled_pyfunc.PyFuncModel": # noqa: F821 """Convert a CompilationResult into a nutpie-compatible model for sampling. Args: @@ -123,6 +123,7 @@ def to_nutpie( model_fn = model.logp_dlogp_function(ravel_inputs=True) ip = model.initial_point() from pymc.blocking import DictToArrayBijection + x0 = DictToArrayBijection.map( {v.name: ip[v.name] for v in model_fn._grad_vars} ).data @@ -156,9 +157,10 @@ def expand_fn(x): offset = 0 for name, shape in zip(var_names, var_shapes): size = int(np.prod(shape)) if shape else 1 - result[name] = x[offset:offset + size].reshape(shape) + result[name] = x[offset : offset + size].reshape(shape) offset += size return result + return expand_fn def make_initial_point(seed): diff --git a/pymc_rust_compiler/pytorch_exporter.py b/pymc_rust_compiler/pytorch_exporter.py index a24893a..bc77f79 100644 --- a/pymc_rust_compiler/pytorch_exporter.py +++ b/pymc_rust_compiler/pytorch_exporter.py @@ -11,7 +11,6 @@ import inspect import textwrap -from dataclasses import dataclass, field from typing import Any, Callable import numpy as np @@ -58,40 +57,55 @@ def _extract(self) -> ModelContext: # Extract parameter info param_infos = [] for name, param in module.named_parameters(): - param_infos.append(TensorInfo( - name=name, - shape=list(param.shape), - dtype=str(param.dtype).replace("torch.", ""), - size=int(param.numel()), - )) + param_infos.append( + TensorInfo( + name=name, + shape=list(param.shape), + dtype=str(param.dtype).replace("torch.", ""), + size=int(param.numel()), + ) + ) # Extract input info input_infos = [] if isinstance(self.sample_input, dict): for name, val in self.sample_input.items(): t = torch.as_tensor(val) - input_infos.append(TensorInfo( - name=name, shape=list(t.shape), - dtype=str(t.dtype).replace("torch.", ""), - size=int(t.numel()), - )) + input_infos.append( + TensorInfo( + name=name, + shape=list(t.shape), + dtype=str(t.dtype).replace("torch.", ""), + size=int(t.numel()), + ) + ) elif isinstance(self.sample_input, (tuple, list)): for i, val in enumerate(self.sample_input): t = torch.as_tensor(val) - name = self.input_names[i] if self.input_names and i < len(self.input_names) else f"x_{i}" - input_infos.append(TensorInfo( - name=name, shape=list(t.shape), - dtype=str(t.dtype).replace("torch.", ""), - size=int(t.numel()), - )) + name = ( + self.input_names[i] + if self.input_names and i < len(self.input_names) + else f"x_{i}" + ) + input_infos.append( + TensorInfo( + name=name, + shape=list(t.shape), + dtype=str(t.dtype).replace("torch.", ""), + size=int(t.numel()), + ) + ) else: t = torch.as_tensor(self.sample_input) name = self.input_names[0] if self.input_names else "x" - input_infos.append(TensorInfo( - name=name, shape=list(t.shape), - dtype=str(t.dtype).replace("torch.", ""), - size=int(t.numel()), - )) + input_infos.append( + TensorInfo( + name=name, + shape=list(t.shape), + dtype=str(t.dtype).replace("torch.", ""), + size=int(t.numel()), + ) + ) # Forward pass to get output info module.eval() @@ -99,11 +113,14 @@ def _extract(self) -> ModelContext: output = self._forward(module, self.sample_input) out_np = output.detach().cpu().numpy() - output_infos = [TensorInfo( - name="output", shape=list(out_np.shape), - dtype=str(output.dtype).replace("torch.", ""), - size=int(np.prod(out_np.shape)) if out_np.shape else 1, - )] + output_infos = [ + TensorInfo( + name="output", + shape=list(out_np.shape), + dtype=str(output.dtype).replace("torch.", ""), + size=int(np.prod(out_np.shape)) if out_np.shape else 1, + ) + ] # Generate validation points rng = np.random.default_rng(self._seed) @@ -143,15 +160,17 @@ def _extract(self) -> ModelContext: def _forward(self, module, inp): import torch + if isinstance(inp, dict): - return module(**{k: torch.as_tensor(v, dtype=torch.float32) for k, v in inp.items()}) + return module( + **{k: torch.as_tensor(v, dtype=torch.float32) for k, v in inp.items()} + ) elif isinstance(inp, (tuple, list)): return module(*[torch.as_tensor(v, dtype=torch.float32) for v in inp]) else: return module(torch.as_tensor(inp, dtype=torch.float32)) def _compute_validation_point(self, module, inp) -> ValidationPoint: - import torch module.zero_grad() module.train() @@ -183,10 +202,21 @@ def _compute_validation_point(self, module, inp) -> ValidationPoint: def _input_to_dict(self, inp) -> dict: import torch + if isinstance(inp, dict): - return {k: np.asarray(v).tolist() if not isinstance(v, torch.Tensor) else v.detach().cpu().numpy().tolist() for k, v in inp.items()} + return { + k: np.asarray(v).tolist() + if not isinstance(v, torch.Tensor) + else v.detach().cpu().numpy().tolist() + for k, v in inp.items() + } elif isinstance(inp, (tuple, list)): - return {f"x_{i}": np.asarray(v).tolist() if not isinstance(v, torch.Tensor) else v.detach().cpu().numpy().tolist() for i, v in enumerate(inp)} + return { + f"x_{i}": np.asarray(v).tolist() + if not isinstance(v, torch.Tensor) + else v.detach().cpu().numpy().tolist() + for i, v in enumerate(inp) + } else: if isinstance(inp, torch.Tensor): return {"x": inp.detach().cpu().numpy().tolist()} diff --git a/pymc_rust_compiler/pytorch_rust_transpiler.py b/pymc_rust_compiler/pytorch_rust_transpiler.py index 18ca326..1eba78d 100644 --- a/pymc_rust_compiler/pytorch_rust_transpiler.py +++ b/pymc_rust_compiler/pytorch_rust_transpiler.py @@ -13,11 +13,9 @@ from __future__ import annotations import functools -import json import os import subprocess import tempfile -import textwrap import time from dataclasses import dataclass, field from pathlib import Path @@ -25,7 +23,7 @@ import numpy as np -from pymc_rust_compiler.jax_exporter import ModelContext, TensorInfo, ValidationPoint +from pymc_rust_compiler.jax_exporter import ModelContext, ValidationPoint _SKILLS_DIR = Path(__file__).parent / "skills" @@ -393,9 +391,11 @@ # ── Result type ────────────────────────────────────────────────────────────── + @dataclass class RustTranspileResult: """Result of transpiling PyTorch to Rust.""" + generated_code: str validated: bool validation_errors: list[str] @@ -404,9 +404,13 @@ class RustTranspileResult: n_tool_calls: int = 0 conversation_turns: int = 0 timings: dict[str, float] = field(default_factory=dict) - token_usage: dict[str, int] = field(default_factory=lambda: { - "input_tokens": 0, "output_tokens": 0, "total_tokens": 0, - }) + token_usage: dict[str, int] = field( + default_factory=lambda: { + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + } + ) @property def success(self) -> bool: @@ -426,9 +430,11 @@ def binary_path(self) -> Path | None: # ── Agent state ────────────────────────────────────────────────────────────── + @dataclass class _AgentState: """Mutable state for the agent loop.""" + build_path: Path source_context: ModelContext source_code: str | None @@ -442,6 +448,7 @@ class _AgentState: # ── Skill loading ──────────────────────────────────────────────────────────── + @functools.lru_cache(maxsize=None) def _load_skill(name: str) -> str: path = _SKILLS_DIR / f"{name}.md" @@ -452,6 +459,7 @@ def _load_skill(name: str) -> str: # ── Rust project setup ─────────────────────────────────────────────────────── + def _setup_rust_project(build_path: Path, ctx: ModelContext, backend: str = "pure"): """Create a Rust project with parameter data baked in.""" src = build_path / "src" @@ -484,12 +492,14 @@ def _setup_rust_project(build_path: Path, ctx: ModelContext, backend: str = "pur # Write values in chunks of 8 for i in range(0, len(values), 8): - chunk = values[i:i+8] + chunk = values[i : i + 8] vals = ", ".join(f"{v:.9e}" for v in chunk) data_lines.append(f" {vals},") data_lines.append("];") - data_lines.append(f"pub const {safe_name}_SHAPE: &[usize] = &{param_info.shape};") + data_lines.append( + f"pub const {safe_name}_SHAPE: &[usize] = &{param_info.shape};" + ) data_lines.append("") (src / "data.rs").write_text("\n".join(data_lines)) @@ -506,6 +516,7 @@ def _setup_rust_project(build_path: Path, ctx: ModelContext, backend: str = "pur # ── Prompt building ────────────────────────────────────────────────────────── + def _build_user_prompt(ctx: ModelContext) -> str: parts = [] parts.append("Translate this PyTorch model to pure Rust.\n") @@ -521,17 +532,23 @@ def _build_user_prompt(ctx: ModelContext) -> str: # Data.rs mapping parts.append("## Parameter Constants in data.rs\n") - parts.append("Each parameter is available as a flat `&[f32]` constant plus a shape constant:\n") + parts.append( + "Each parameter is available as a flat `&[f32]` constant plus a shape constant:\n" + ) for p in ctx.params: safe_name = p.name.replace(".", "_").upper() - parts.append(f"- `{safe_name}`: &[f32] (len={p.size}), `{safe_name}_SHAPE`: &[usize] = {p.shape}") + parts.append( + f"- `{safe_name}`: &[f32] (len={p.size}), `{safe_name}_SHAPE`: &[usize] = {p.shape}" + ) parts.append("") # Input info parts.append("## Inputs\n") for i in ctx.inputs: parts.append(f"- `{i.name}`: shape={i.shape}, dtype={i.dtype}") - parts.append("The `forward()` function receives the input as a flat &[f32] array.\n") + parts.append( + "The `forward()` function receives the input as a flat &[f32] array.\n" + ) # Output info parts.append("## Outputs\n") @@ -582,8 +599,12 @@ def _build_user_prompt(ctx: ModelContext) -> str: # ── Tool execution ─────────────────────────────────────────────────────────── + def _execute_tool( - name: str, input_data: dict, state: _AgentState, verbose: bool, + name: str, + input_data: dict, + state: _AgentState, + verbose: bool, ) -> str: if name == "write_code": return _tool_write_code(input_data, state, verbose) @@ -602,7 +623,9 @@ def _execute_tool( def _tool_write_code( - input_data: dict, state: _AgentState, verbose: bool, + input_data: dict, + state: _AgentState, + verbose: bool, ) -> str: code = input_data.get("code", "") if not code: @@ -614,7 +637,9 @@ def _tool_write_code( if verbose: print(f" [write_code] Wrote {len(code)} chars to generated.rs") - return f"Written {len(code)} chars to src/generated.rs. Use `cargo_build` to compile." + return ( + f"Written {len(code)} chars to src/generated.rs. Use `cargo_build` to compile." + ) def _tool_cargo_build(state: _AgentState, verbose: bool) -> str: @@ -724,11 +749,17 @@ def _tool_validate(state: _AgentState, verbose: bool) -> str: errors.append( f"{label}: output shape mismatch: got {rust_output.shape}, expected {ref_output.shape}" ) - report_lines.append(f"{label}: SHAPE MISMATCH {rust_output.shape} vs {ref_output.shape}") + report_lines.append( + f"{label}: SHAPE MISMATCH {rust_output.shape} vs {ref_output.shape}" + ) continue max_diff = float(np.max(np.abs(rust_output - ref_output))) - rel_err = float(np.max(np.abs(rust_output - ref_output) / np.maximum(np.abs(ref_output), 1e-8))) + rel_err = float( + np.max( + np.abs(rust_output - ref_output) / np.maximum(np.abs(ref_output), 1e-8) + ) + ) out_ok = rel_err <= 1e-4 report_lines.append( @@ -736,7 +767,9 @@ def _tool_validate(state: _AgentState, verbose: bool) -> str: f"[{'OK' if out_ok else 'MISMATCH'}]" ) if not out_ok: - errors.append(f"{label}: output mismatch: max_diff={max_diff:.2e}, rel_err={rel_err:.2e}") + errors.append( + f"{label}: output mismatch: max_diff={max_diff:.2e}, rel_err={rel_err:.2e}" + ) # Test gradients for each parameter for pname, ref_grad in vp.grad_params.items(): @@ -779,7 +812,9 @@ def _tool_validate(state: _AgentState, verbose: bool) -> str: continue grad_diff = float(np.max(np.abs(rust_grad - ref_g))) - grad_rel = float(np.max(np.abs(rust_grad - ref_g) / np.maximum(np.abs(ref_g), 1e-8))) + grad_rel = float( + np.max(np.abs(rust_grad - ref_g) / np.maximum(np.abs(ref_g), 1e-8)) + ) grad_ok = grad_rel <= 1e-3 report_lines.append( @@ -787,7 +822,9 @@ def _tool_validate(state: _AgentState, verbose: bool) -> str: f"[{'OK' if grad_ok else 'MISMATCH'}]" ) if not grad_ok: - errors.append(f"{label}: grad['{pname}'] mismatch: rel_err={grad_rel:.2e}") + errors.append( + f"{label}: grad['{pname}'] mismatch: rel_err={grad_rel:.2e}" + ) # Restore original data.rs if we modified it if len(ctx.validation_points) > 1: @@ -828,26 +865,34 @@ def _update_data_rs(build_path: Path, ctx: ModelContext, vp: ValidationPoint): data_lines.append(f"pub const {safe_name}: &[f32] = &[") for i in range(0, len(values), 8): - chunk = values[i:i+8] + chunk = values[i : i + 8] vals = ", ".join(f"{v:.9e}" for v in chunk) data_lines.append(f" {vals},") data_lines.append("];") - data_lines.append(f"pub const {safe_name}_SHAPE: &[usize] = &{param_info.shape};") + data_lines.append( + f"pub const {safe_name}_SHAPE: &[usize] = &{param_info.shape};" + ) data_lines.append("") (build_path / "src" / "data.rs").write_text("\n".join(data_lines)) def _tool_read_source(state: _AgentState, verbose: bool) -> str: - source = state.source_code or state.source_context.source_code or "(no source code available)" + source = ( + state.source_code + or state.source_context.source_code + or "(no source code available)" + ) if verbose: print(f" [read_source] {len(source)} chars") return source def _tool_read_file( - input_data: dict, state: _AgentState, verbose: bool, + input_data: dict, + state: _AgentState, + verbose: bool, ) -> str: rel_path = input_data.get("path", "") if not rel_path: @@ -871,7 +916,9 @@ def _tool_read_file( def _tool_add_cargo_dependency( - input_data: dict, state: _AgentState, verbose: bool, + input_data: dict, + state: _AgentState, + verbose: bool, ) -> str: """Add a crate dependency to Cargo.toml.""" name = input_data.get("name", "") @@ -928,6 +975,7 @@ def _tool_add_cargo_dependency( # ── Agent loop ─────────────────────────────────────────────────────────────── + def _run_agent_loop( state: _AgentState, system_prompt: str, @@ -988,11 +1036,13 @@ def _run_agent_loop( elif block.type == "tool_use": state.tool_calls += 1 result = _execute_tool(block.name, block.input, state, verbose) - tool_results.append({ - "type": "tool_result", - "tool_use_id": block.id, - "content": result, - }) + tool_results.append( + { + "type": "tool_result", + "tool_use_id": block.id, + "content": result, + } + ) if state.validated: break @@ -1011,6 +1061,7 @@ def _run_agent_loop( # ── Main entry point ───────────────────────────────────────────────────────── + def transpile_pytorch_to_rust( module: Any, # torch.nn.Module sample_input: Any, @@ -1057,15 +1108,19 @@ def transpile_pytorch_to_rust( print(f"Extracting PyTorch model context (backend={backend})...") t0 = time.time() exporter = PytorchModelExporter( - module, sample_input, - source_code=source_code, loss_fn=loss_fn, + module, + sample_input, + source_code=source_code, + loss_fn=loss_fn, ) ctx = exporter.context timings["extract"] = time.time() - t0 if verbose: n_params = sum(p.size for p in ctx.params) - print(f" {n_params} parameters, {len(ctx.validation_points)} validation points") + print( + f" {n_params} parameters, {len(ctx.validation_points)} validation points" + ) # Step 2: Set up Rust build directory if build_dir: @@ -1087,7 +1142,7 @@ def transpile_pytorch_to_rust( system_prompt = SYSTEM_PROMPT skill = _load_skill("pytorch_to_rust") if skill: - system_prompt += f"\n\n{'='*60}\n{skill}" + system_prompt += f"\n\n{'=' * 60}\n{skill}" # Step 4: Build user prompt and run agent user_prompt = _build_user_prompt(ctx) @@ -1104,7 +1159,12 @@ def transpile_pytorch_to_rust( print("\nStarting agent loop...") turns, token_usage = _run_agent_loop( - state, system_prompt, api_key, max_turns, model_name, verbose, + state, + system_prompt, + api_key, + max_turns, + model_name, + verbose, ) # Read final generated code diff --git a/pymc_rust_compiler/stan_compiler.py b/pymc_rust_compiler/stan_compiler.py index 1b4e506..f2aa866 100644 --- a/pymc_rust_compiler/stan_compiler.py +++ b/pymc_rust_compiler/stan_compiler.py @@ -7,7 +7,6 @@ from __future__ import annotations import functools -import json import os import subprocess import tempfile @@ -253,9 +252,13 @@ class StanCompilationResult: timings: dict[str, float] n_tool_calls: int = 0 conversation_turns: int = 0 - token_usage: dict[str, int] = field(default_factory=lambda: { - "input_tokens": 0, "output_tokens": 0, "total_tokens": 0, - }) + token_usage: dict[str, int] = field( + default_factory=lambda: { + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + } + ) us_per_eval: float | None = None @property @@ -283,8 +286,12 @@ def _detect_stan_skills(stan_code: str) -> list[str]: # GP detection gp_indicators = [ - "gp_exp_quad_cov", "gp_matern", "gp_periodic", - "cov_exp_quad", "multi_normal", "multi_normal_cholesky", + "gp_exp_quad_cov", + "gp_matern", + "gp_periodic", + "cov_exp_quad", + "multi_normal", + "multi_normal_cholesky", "cholesky_decompose", ] if any(kw in stan_code.lower() for kw in gp_indicators): @@ -308,7 +315,7 @@ def _build_system_prompt(skills: list[str]) -> str: for skill_name in skills: content = _load_skill(skill_name) if content: - prompt += f"\n\n{'='*60}\n{content}" + prompt += f"\n\n{'=' * 60}\n{content}" return prompt @@ -388,15 +395,17 @@ def compile_stan_model( state = _AgentState( build_path=build_path, ctx=ctx, - messages=[{ - "role": "user", - "content": ( - "Generate a Rust logp+gradient implementation for this Stan model.\n\n" - "Use your tools to write the code, build it, and validate it. " - "Iterate until validation passes.\n\n" - f"{prompt}" - ), - }], + messages=[ + { + "role": "user", + "content": ( + "Generate a Rust logp+gradient implementation for this Stan model.\n\n" + "Use your tools to write the code, build it, and validate it. " + "Iterate until validation passes.\n\n" + f"{prompt}" + ), + } + ], timings=timings, ) @@ -422,7 +431,9 @@ def compile_stan_model( total_input_tokens += response.usage.input_tokens total_output_tokens += response.usage.output_tokens if verbose: - print(f" Turn {turn}: {response.usage.input_tokens} in / {response.usage.output_tokens} out tokens") + print( + f" Turn {turn}: {response.usage.input_tokens} in / {response.usage.output_tokens} out tokens" + ) if response.stop_reason == "end_turn": if verbose: @@ -446,11 +457,13 @@ def compile_stan_model( elif block.type == "tool_use": state.tool_calls += 1 result = _execute_tool(block.name, block.input, state, verbose) - tool_results.append({ - "type": "tool_result", - "tool_use_id": block.id, - "content": result, - }) + tool_results.append( + { + "type": "tool_result", + "tool_use_id": block.id, + "content": result, + } + ) if state.validated: break @@ -488,7 +501,10 @@ def compile_stan_model( def _execute_tool( - name: str, input_data: dict, state: _AgentState, verbose: bool, + name: str, + input_data: dict, + state: _AgentState, + verbose: bool, ) -> str: """Execute a tool and return the result string.""" if name == "write_rust_code": @@ -506,7 +522,9 @@ def _execute_tool( def _tool_write_rust_code( - input_data: dict, state: _AgentState, verbose: bool, + input_data: dict, + state: _AgentState, + verbose: bool, ) -> str: code = input_data.get("code", "") if not code: @@ -576,10 +594,12 @@ def _tool_validate_logp(state: _AgentState, verbose: bool) -> str: if result.returncode != 0: return f"Error: validator crashed: {result.stderr[:500]}" - output_lines = [l for l in result.stdout.strip().split("\n") if l] + output_lines = [line for line in result.stdout.strip().split("\n") if line] if len(output_lines) != len(all_points): - return f"Error: expected {len(all_points)} output lines, got {len(output_lines)}" + return ( + f"Error: expected {len(all_points)} output lines, got {len(output_lines)}" + ) parsed = [] for output_line in output_lines: @@ -645,7 +665,9 @@ def _tool_validate_logp(state: _AgentState, verbose: bool) -> str: def _tool_read_file( - input_data: dict, state: _AgentState, verbose: bool, + input_data: dict, + state: _AgentState, + verbose: bool, ) -> str: rel_path = input_data.get("path", "") if not rel_path: @@ -674,7 +696,9 @@ def _tool_read_file( def _tool_add_cargo_dependency( - input_data: dict, state: _AgentState, verbose: bool, + input_data: dict, + state: _AgentState, + verbose: bool, ) -> str: """Add a crate dependency to Cargo.toml.""" name = input_data.get("name", "") @@ -751,9 +775,7 @@ def _generate_data_rs(data: dict | None) -> str: n = len(flat) lines.append(f"pub const {name.upper()}_N: usize = {n};") formatted = ", ".join(f"{v:.17e}" for v in flat) - lines.append( - f"pub const {name.upper()}_DATA: &[f64] = &[{formatted}];\n" - ) + lines.append(f"pub const {name.upper()}_DATA: &[f64] = &[{formatted}];\n") # Also export shape info for multi-dimensional arrays if arr.ndim > 1: @@ -785,7 +807,7 @@ def _setup_rust_project( ] for dep_name, dep_version in (extra_cargo_deps or {}).items(): if dep_version.startswith("{"): - deps_lines.append(f'{dep_name} = {dep_version}') + deps_lines.append(f"{dep_name} = {dep_version}") else: deps_lines.append(f'{dep_name} = "{dep_version}"') diff --git a/pymc_rust_compiler/stan_exporter.py b/pymc_rust_compiler/stan_exporter.py index 814bb21..6da4615 100644 --- a/pymc_rust_compiler/stan_exporter.py +++ b/pymc_rust_compiler/stan_exporter.py @@ -36,9 +36,9 @@ class StanModelContext: stan_code: str params: list[StanParamInfo] - param_names: list[str] # constrained names - unc_param_names: list[str] # unconstrained names - n_params: int # unconstrained count + param_names: list[str] # constrained names + unc_param_names: list[str] # unconstrained names + n_params: int # unconstrained count n_params_constrained: int data_json: str | None data_summary: dict[str, dict] # name → {shape, dtype, min, max, mean} @@ -177,7 +177,9 @@ def _extract(self) -> StanModelContext: unc_param_names=list(unc_param_names), n_params=n_unconstrained, n_params_constrained=n_constrained, - data_json=json.dumps(self._data) if isinstance(self._data, dict) else self._data, + data_json=json.dumps(self._data) + if isinstance(self._data, dict) + else self._data, data_summary=data_summary, initial_point=initial, extra_points=extra_points, @@ -251,13 +253,17 @@ def to_prompt(self) -> str: for name, value in self._data.items(): arr = np.asarray(value) if arr.ndim == 0 and np.issubdtype(arr.dtype, np.integer): - parts.append(f"- `{name.upper()}: usize = {int(arr)}` (scalar integer)") + parts.append( + f"- `{name.upper()}: usize = {int(arr)}` (scalar integer)" + ) parts.append("") # Validation parts.append("## Validation") - parts.append("Your generated code MUST produce these exact values (within float64 precision):\n") + parts.append( + "Your generated code MUST produce these exact values (within float64 precision):\n" + ) parts.append( "NOTE: BridgeStan computes log_density with jacobian=True, propto=True.\n" "This means the log density INCLUDES Jacobian adjustments for constrained parameters\n" @@ -335,7 +341,8 @@ def save_all(self, output_dir: str | Path): def _build_param_info( - param_names: list[str], unc_param_names: list[str], + param_names: list[str], + unc_param_names: list[str], ) -> list[StanParamInfo]: """Build parameter info by grouping related constrained/unconstrained names. @@ -344,6 +351,7 @@ def _build_param_info( - Vector: "theta.1", "theta.2" (constrained), "theta.1", "theta.2" (unc) - Constrained: "sigma" (constrained), "sigma" (unconstrained, log-transformed) """ + # Group by base name (strip .N suffixes) def base_name(name: str) -> str: return re.sub(r"\.\d+$", "", name) diff --git a/pymc_rust_compiler/stan_to_pymc.py b/pymc_rust_compiler/stan_to_pymc.py index 233604c..05391b4 100644 --- a/pymc_rust_compiler/stan_to_pymc.py +++ b/pymc_rust_compiler/stan_to_pymc.py @@ -10,7 +10,6 @@ import functools import json import os -import tempfile import time from dataclasses import dataclass, field from pathlib import Path @@ -141,9 +140,13 @@ class StanToPyMCResult: timings: dict[str, float] n_tool_calls: int = 0 conversation_turns: int = 0 - token_usage: dict[str, int] = field(default_factory=lambda: { - "input_tokens": 0, "output_tokens": 0, "total_tokens": 0, - }) + token_usage: dict[str, int] = field( + default_factory=lambda: { + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + } + ) @property def success(self) -> bool: @@ -192,7 +195,7 @@ def _build_system_prompt() -> str: for skill_name in ("stan_to_pymc", "pymc_optimization"): content = _load_skill(skill_name) if content: - prompt += f"\n\n{'='*60}\n{content}" + prompt += f"\n\n{'=' * 60}\n{content}" return prompt @@ -217,9 +220,7 @@ def _build_user_prompt( if arr.ndim == 0: parts.append(f"- `{name}`: scalar = {value}") else: - parts.append( - f"- `{name}`: shape={list(arr.shape)}, dtype={arr.dtype}" - ) + parts.append(f"- `{name}`: shape={list(arr.shape)}, dtype={arr.dtype}") parts.append("") # Parameter info @@ -302,13 +303,14 @@ def transpile_stan_to_pymc( # Build reference points reference_points = [ - {"point": ctx.initial_point.point, "logp": ctx.initial_point.logp, - "dlogp": ctx.initial_point.dlogp}, + { + "point": ctx.initial_point.point, + "logp": ctx.initial_point.logp, + "dlogp": ctx.initial_point.dlogp, + }, ] for pt in ctx.extra_points: - reference_points.append( - {"point": pt.point, "logp": pt.logp, "dlogp": pt.dlogp} - ) + reference_points.append({"point": pt.point, "logp": pt.logp, "dlogp": pt.dlogp}) if verbose: print( @@ -319,7 +321,10 @@ def transpile_stan_to_pymc( # Step 2: Build prompts system_prompt = _build_system_prompt() user_prompt = _build_user_prompt( - stan_code, data, reference_points, list(ctx.unc_param_names), + stan_code, + data, + reference_points, + list(ctx.unc_param_names), ) # Step 3: Agent loop @@ -384,11 +389,13 @@ def transpile_stan_to_pymc( elif block.type == "tool_use": state.tool_calls += 1 result = _execute_tool(block.name, block.input, state, verbose) - tool_results.append({ - "type": "tool_result", - "tool_use_id": block.id, - "content": result, - }) + tool_results.append( + { + "type": "tool_result", + "tool_use_id": block.id, + "content": result, + } + ) if state.validated: break @@ -421,7 +428,10 @@ def transpile_stan_to_pymc( def _execute_tool( - name: str, input_data: dict, state: _AgentState, verbose: bool, + name: str, + input_data: dict, + state: _AgentState, + verbose: bool, ) -> str: """Execute a tool and return the result string.""" if name == "write_pymc_code": @@ -435,7 +445,9 @@ def _execute_tool( def _tool_write_pymc_code( - input_data: dict, state: _AgentState, verbose: bool, + input_data: dict, + state: _AgentState, + verbose: bool, ) -> str: code = input_data.get("code", "") if not code: @@ -481,8 +493,6 @@ def _tool_validate_model(state: _AgentState, verbose: bool) -> str: print(f" [validate_model] Model construction error: {e}") return f"Error building PyMC model: {type(e).__name__}: {e}" - import pymc as pm - # Get the logp function in unconstrained space try: logp_fn = model.compile_logp() @@ -498,19 +508,20 @@ def _tool_validate_model(state: _AgentState, verbose: bool) -> str: # sampling but causes logp comparison mismatches. _HALF_RV_OPS = {"HalfNormalRV", "HalfCauchyRV", "HalfStudentTRV", "HalfFlatRV"} n_half = sum( - 1 for rv in model.free_RVs - if type(rv.owner.op).__name__ in _HALF_RV_OPS + 1 for rv in model.free_RVs if type(rv.owner.op).__name__ in _HALF_RV_OPS ) half_logp_correction = n_half * np.log(2) if verbose and n_half > 0: - print(f" [validate_model] Found {n_half} Half* distribution(s), " - f"applying log(2) correction of {half_logp_correction:.4f}") + print( + f" [validate_model] Found {n_half} Half* distribution(s), " + f"applying log(2) correction of {half_logp_correction:.4f}" + ) # Map unconstrained point to PyMC's internal variable order # We need to understand how PyMC orders its unconstrained parameters # vs how BridgeStan does it try: - ip = model.initial_point() + model.initial_point() unc_var_names = [v.name for v in model.value_vars] except Exception as e: return f"Error getting model variable info: {type(e).__name__}: {e}" @@ -534,7 +545,9 @@ def _tool_validate_model(state: _AgentState, verbose: bool) -> str: # Build a point dict for PyMC: map unconstrained values to PyMC variables try: point_dict = _map_unc_point_to_pymc( - model, unc_point, state.unc_param_names, + model, + unc_point, + state.unc_param_names, ) except Exception as e: report_lines.append(f"{label}: ERROR mapping point: {e}") @@ -567,10 +580,7 @@ def _tool_validate_model(state: _AgentState, verbose: bool) -> str: diffs = [ref - pymc for _, ref, pymc in point_results] mean_diff = sum(diffs) / len(diffs) - correction_note = ( - f" (Half* corrected, {n_half} dists)" - if n_half > 0 else "" - ) + correction_note = f" (Half* corrected, {n_half} dists)" if n_half > 0 else "" for label, ref_logp, corrected_logp in point_results: diff = ref_logp - corrected_logp @@ -607,10 +617,7 @@ def _tool_validate_model(state: _AgentState, verbose: bool) -> str: label, ref_logp, corrected_logp = point_results[0] rel_err = abs(corrected_logp - ref_logp) / max(abs(ref_logp), 1.0) status = "OK" if rel_err <= 1e-2 else "MISMATCH" - correction_note = ( - f" (Half* corrected, {n_half} dists)" - if n_half > 0 else "" - ) + correction_note = f" (Half* corrected, {n_half} dists)" if n_half > 0 else "" report_lines.append( f"{label}: logp BridgeStan={ref_logp:.6f} PyMC={corrected_logp:.6f}" f"{correction_note} rel_err={rel_err:.2e} [{status}]" @@ -636,8 +643,7 @@ def _tool_validate_model(state: _AgentState, verbose: bool) -> str: print(f" [validate_model] FAILED ({len(errors)} errors)") return ( f"VALIDATION FAILED ({len(errors)} errors):\n\n{report}\n\n" - f"Errors:\n" + "\n".join(errors) - + "\n\nHints:\n" + f"Errors:\n" + "\n".join(errors) + "\n\nHints:\n" "- Check that all prior distributions match (including parameter names)\n" "- Ensure Stan constraints map to correct PyMC distributions " "(e.g. real with normal prior → HalfNormal)\n" @@ -655,7 +661,9 @@ def _tool_validate_model(state: _AgentState, verbose: bool) -> str: def _map_unc_point_to_pymc( - model, unc_point: np.ndarray, stan_param_names: list[str], + model, + unc_point: np.ndarray, + stan_param_names: list[str], ) -> dict: """Map a BridgeStan unconstrained point to a PyMC point dict. @@ -685,10 +693,16 @@ def _map_unc_point_to_pymc( for var in pymc_vars: var_name = var.name # Strip PyMC transform suffixes - base_name = re.sub(r"_(log|logodds|interval|circular|ordered|simplex)__$", "", var_name) + base_name = re.sub( + r"_(log|logodds|interval|circular|ordered|simplex)__$", "", var_name + ) # Determine size of this variable - var_size = int(np.prod(var.type.shape) if hasattr(var.type, 'shape') and var.type.shape else 1) + var_size = int( + np.prod(var.type.shape) + if hasattr(var.type, "shape") and var.type.shape + else 1 + ) # Find matching Stan param group matched = False diff --git a/tests/test_compiler.py b/tests/test_compiler.py index bd4712b..1fc9ee8 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -6,9 +6,6 @@ from pathlib import Path from unittest.mock import MagicMock -import numpy as np -import pymc as pm -import pytest from pymc_rust_compiler.compiler import ( CompilationResult, @@ -20,10 +17,7 @@ _tool_write_rust_code, ) from pymc_rust_compiler.exporter import ( - ModelContext, - ParamInfo, RustModelExporter, - ValidationPoint, ) @@ -57,8 +51,12 @@ def test_failure_when_not_validated(self): def test_default_token_usage(self): r = CompilationResult( - rust_code="", logp_validated=False, validation_errors=[], - n_attempts=0, build_dir=None, timings={}, + rust_code="", + logp_validated=False, + validation_errors=[], + n_attempts=0, + build_dir=None, + timings={}, ) assert r.token_usage["input_tokens"] == 0 assert r.token_usage["output_tokens"] == 0 @@ -97,28 +95,32 @@ def _make_ctx(self, observed_data, covariate_data=None): return ctx def test_basic_observed_data(self): - ctx = self._make_ctx({ - "y": { - "shape": [3], - "dtype": "float64", - "n": 3, - "values": [1.0, 2.0, 3.0], + ctx = self._make_ctx( + { + "y": { + "shape": [3], + "dtype": "float64", + "n": 3, + "values": [1.0, 2.0, 3.0], + } } - }) + ) rs = _generate_data_rs(ctx) assert "Y_N: usize = 3" in rs assert "Y_DATA: &[f64]" in rs def test_multidimensional_data_flattened(self): - ctx = self._make_ctx({ - "y": { - "shape": [2, 3], - "dtype": "float64", - "n": 6, - "values": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], + ctx = self._make_ctx( + { + "y": { + "shape": [2, 3], + "dtype": "float64", + "n": 6, + "values": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], + } } - }) + ) rs = _generate_data_rs(ctx) assert "Y_N: usize = 6" in rs @@ -156,13 +158,15 @@ def test_no_values_key(self): assert "Y_DATA" not in rs def test_full_precision(self): - ctx = self._make_ctx({ - "y": { - "shape": [1], - "n": 1, - "values": [3.141592653589793], + ctx = self._make_ctx( + { + "y": { + "shape": [1], + "n": 1, + "values": [3.141592653589793], + } } - }) + ) rs = _generate_data_rs(ctx) # Should contain full precision representation @@ -263,7 +267,9 @@ def test_write_code(self): (build_path / "src").mkdir() state = _AgentState( - build_path=build_path, ctx=None, messages=[], + build_path=build_path, + ctx=None, + messages=[], ) code = "pub fn hello() {}" @@ -278,7 +284,9 @@ def test_write_empty_code(self): (build_path / "src").mkdir() state = _AgentState( - build_path=build_path, ctx=None, messages=[], + build_path=build_path, + ctx=None, + messages=[], ) result = _tool_write_rust_code({"code": ""}, state, verbose=False) @@ -298,7 +306,9 @@ def test_read_existing_file(self): (build_path / "src" / "data.rs").write_text("// data") state = _AgentState( - build_path=build_path, ctx=None, messages=[], + build_path=build_path, + ctx=None, + messages=[], ) result = _tool_read_file({"path": "src/data.rs"}, state, verbose=False) @@ -309,7 +319,9 @@ def test_read_missing_file(self): build_path = Path(tmpdir) state = _AgentState( - build_path=build_path, ctx=None, messages=[], + build_path=build_path, + ctx=None, + messages=[], ) result = _tool_read_file({"path": "nonexistent.rs"}, state, verbose=False) @@ -318,7 +330,9 @@ def test_read_missing_file(self): def test_read_empty_path(self): with tempfile.TemporaryDirectory() as tmpdir: state = _AgentState( - build_path=Path(tmpdir), ctx=None, messages=[], + build_path=Path(tmpdir), + ctx=None, + messages=[], ) result = _tool_read_file({"path": ""}, state, verbose=False) assert "Error" in result @@ -329,7 +343,9 @@ def test_read_truncation(self): (build_path / "big.txt").write_text("x" * 10000) state = _AgentState( - build_path=build_path, ctx=None, messages=[], + build_path=build_path, + ctx=None, + messages=[], ) result = _tool_read_file({"path": "big.txt"}, state, verbose=False) @@ -345,7 +361,9 @@ def test_read_truncation(self): class TestExecuteTool: def test_unknown_tool(self): state = _AgentState( - build_path=Path("/tmp"), ctx=None, messages=[], + build_path=Path("/tmp"), + ctx=None, + messages=[], ) result = _execute_tool("nonexistent_tool", {}, state, verbose=False) assert "Unknown tool" in result @@ -356,7 +374,9 @@ def test_dispatches_write(self): (build_path / "src").mkdir() state = _AgentState( - build_path=build_path, ctx=None, messages=[], + build_path=build_path, + ctx=None, + messages=[], ) result = _execute_tool( @@ -370,7 +390,9 @@ def test_dispatches_read(self): (build_path / "test.rs").write_text("// hello") state = _AgentState( - build_path=build_path, ctx=None, messages=[], + build_path=build_path, + ctx=None, + messages=[], ) result = _execute_tool( @@ -385,14 +407,22 @@ def test_dispatches_read(self): class TestToolDefinitions: - def test_tools_list_has_four_tools(self): + def test_tools_list_has_five_tools(self): from pymc_rust_compiler.compiler import TOOLS - assert len(TOOLS) == 4 + + assert len(TOOLS) == 5 tool_names = {t["name"] for t in TOOLS} - assert tool_names == {"write_rust_code", "cargo_build", "validate_logp", "read_file"} + assert tool_names == { + "write_rust_code", + "cargo_build", + "validate_logp", + "read_file", + "add_cargo_dependency", + } def test_system_prompt_exists(self): from pymc_rust_compiler.compiler import SYSTEM_PROMPT + assert len(SYSTEM_PROMPT) > 100 assert "CpuLogpFunc" in SYSTEM_PROMPT diff --git a/tests/test_exporter.py b/tests/test_exporter.py index d9e4726..cf237b3 100644 --- a/tests/test_exporter.py +++ b/tests/test_exporter.py @@ -7,11 +7,8 @@ from pathlib import Path import numpy as np -import pymc as pm -import pytest from pymc_rust_compiler.exporter import ( - ModelContext, ParamInfo, RustModelExporter, ValidationPoint, @@ -27,31 +24,46 @@ class TestParamInfo: def test_scalar_param(self): p = ParamInfo( - name="mu", value_var="mu", transform=None, - shape=[], unc_shape=[], size=1, + name="mu", + value_var="mu", + transform=None, + shape=[], + unc_shape=[], + size=1, ) assert p.is_scalar is True def test_vector_param(self): p = ParamInfo( - name="offset", value_var="offset", transform=None, - shape=[4], unc_shape=[4], size=4, + name="offset", + value_var="offset", + transform=None, + shape=[4], + unc_shape=[4], + size=4, ) assert p.is_scalar is False def test_log_transformed_param(self): p = ParamInfo( - name="sigma", value_var="sigma_log__", transform="LogTransform", - shape=[], unc_shape=[], size=1, + name="sigma", + value_var="sigma_log__", + transform="LogTransform", + shape=[], + unc_shape=[], + size=1, ) assert p.transform == "LogTransform" assert p.is_scalar is True def test_zerosum_param(self): p = ParamInfo( - name="effect", value_var="effect_zerosum__", + name="effect", + value_var="effect_zerosum__", transform="ZeroSumTransform", - shape=[6], unc_shape=[5], size=5, + shape=[6], + unc_shape=[5], + size=5, zerosum_axes=[0], ) assert p.zerosum_axes == [0] @@ -187,8 +199,7 @@ def test_different_seeds_differ(self, normal_model): # At least one extra point should differ assert any( - p1.point != p2.point - for p1, p2 in zip(ctx1.extra_points, ctx2.extra_points) + p1.point != p2.point for p1, p2 in zip(ctx1.extra_points, ctx2.extra_points) ) @@ -373,7 +384,7 @@ def test_export_returns_exporter(self, normal_model): def test_export_with_output_dir(self, normal_model): with tempfile.TemporaryDirectory() as tmpdir: - exporter = export_model(normal_model, output_dir=tmpdir) + export_model(normal_model, output_dir=tmpdir) assert (Path(tmpdir) / "codegen_prompt.txt").exists() def test_export_with_source_code(self, normal_model): diff --git a/tests/test_jax_pytorch.py b/tests/test_jax_pytorch.py index 39ab99f..6e3a9d5 100644 --- a/tests/test_jax_pytorch.py +++ b/tests/test_jax_pytorch.py @@ -9,10 +9,21 @@ import numpy as np import pytest +_has_jax = pytest.importorskip is not None # helper below +try: + import jax # noqa: F401 + + _has_jax = True +except ImportError: + _has_jax = False + +jax_required = pytest.mark.skipif(not _has_jax, reason="JAX not installed") + # ── JAX Exporter Tests ────────────────────────────────────────────────────── +@jax_required class TestJaxExporter: """Test JaxModelExporter extracts correct model context.""" @@ -67,7 +78,6 @@ def test_extract_validation_points(self, simple_model): assert "b" in vp0.grad_params def test_forward_output_matches(self, simple_model): - import jax.numpy as jnp from pymc_rust_compiler.jax_exporter import JaxModelExporter fn, params, x = simple_model @@ -106,7 +116,9 @@ def __init__(self): super().__init__() self.fc = nn.Linear(2, 3) with torch.no_grad(): - self.fc.weight.copy_(torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])) + self.fc.weight.copy_( + torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + ) self.fc.bias.copy_(torch.tensor([0.1, 0.2, 0.3])) def forward(self, x): @@ -166,14 +178,20 @@ class TestTranspilerTools: """Test the transpiler's tool execution logic without API calls.""" def test_write_code_syntax_check(self): - from pymc_rust_compiler.jax_pytorch_transpiler import _tool_write_code, _AgentState + from pymc_rust_compiler.jax_pytorch_transpiler import ( + _tool_write_code, + _AgentState, + ) from pymc_rust_compiler.jax_exporter import ModelContext state = _AgentState( direction="jax_to_pytorch", source_context=ModelContext( - source_framework="jax", source_code=None, - params=[], inputs=[], outputs=[], + source_framework="jax", + source_code=None, + params=[], + inputs=[], + outputs=[], validation_points=[], ), generated_code="", @@ -189,14 +207,20 @@ def test_write_code_syntax_check(self): assert "Syntax error" in result def test_validate_no_code(self): - from pymc_rust_compiler.jax_pytorch_transpiler import _tool_validate, _AgentState + from pymc_rust_compiler.jax_pytorch_transpiler import ( + _tool_validate, + _AgentState, + ) from pymc_rust_compiler.jax_exporter import ModelContext state = _AgentState( direction="jax_to_pytorch", source_context=ModelContext( - source_framework="jax", source_code=None, - params=[], inputs=[], outputs=[], + source_framework="jax", + source_code=None, + params=[], + inputs=[], + outputs=[], validation_points=[], ), generated_code="", @@ -205,11 +229,16 @@ def test_validate_no_code(self): result = _tool_validate(state, verbose=False) assert "no code" in result.lower() + @jax_required def test_validate_pytorch_correct_model(self): """Test that validation passes for a correctly transpiled model.""" import jax.numpy as jnp from pymc_rust_compiler.jax_exporter import JaxModelExporter - from pymc_rust_compiler.jax_pytorch_transpiler import _tool_write_code, _tool_validate, _AgentState + from pymc_rust_compiler.jax_pytorch_transpiler import ( + _tool_write_code, + _tool_validate, + _AgentState, + ) # Create a simple JAX model params = { @@ -231,7 +260,7 @@ def forward(params, x): ) # Write correct PyTorch code - pytorch_code = ''' + pytorch_code = """ import torch import torch.nn as nn import numpy as np @@ -247,19 +276,24 @@ def forward(self, x): return x @ self.w + self.b return Model(params) -''' +""" _tool_write_code({"code": pytorch_code}, state, verbose=False) result = _tool_validate(state, verbose=False) assert "PASSED" in result assert state.validated is True + @jax_required def test_validate_jax_correct_model(self): """Test that validation passes for a correctly transpiled JAX model.""" import torch import torch.nn as nn from pymc_rust_compiler.pytorch_exporter import PytorchModelExporter - from pymc_rust_compiler.jax_pytorch_transpiler import _tool_write_code, _tool_validate, _AgentState + from pymc_rust_compiler.jax_pytorch_transpiler import ( + _tool_write_code, + _tool_validate, + _AgentState, + ) # Create a simple PyTorch model class Linear(nn.Module): @@ -284,7 +318,7 @@ def forward(self, x): ) # Write correct JAX code - jax_code = ''' + jax_code = """ import jax import jax.numpy as jnp import numpy as np @@ -294,7 +328,7 @@ def init_params(param_data): def forward(params, x): return x @ params["w"] + params["b"] -''' +""" _tool_write_code({"code": jax_code}, state, verbose=False) result = _tool_validate(state, verbose=False) diff --git a/tests/test_pytorch_rust.py b/tests/test_pytorch_rust.py index c0696ea..0bc3ed9 100644 --- a/tests/test_pytorch_rust.py +++ b/tests/test_pytorch_rust.py @@ -31,7 +31,9 @@ def __init__(self): super().__init__() self.fc = nn.Linear(2, 3) with torch.no_grad(): - self.fc.weight.copy_(torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])) + self.fc.weight.copy_( + torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + ) self.fc.bias.copy_(torch.tensor([0.1, 0.2, 0.3])) def forward(self, x): @@ -92,12 +94,17 @@ def agent_state(self): import torch import torch.nn as nn from pymc_rust_compiler.pytorch_exporter import PytorchModelExporter - from pymc_rust_compiler.pytorch_rust_transpiler import _AgentState, _setup_rust_project + from pymc_rust_compiler.pytorch_rust_transpiler import ( + _AgentState, + _setup_rust_project, + ) class Linear(nn.Module): def __init__(self): super().__init__() - self.w = nn.Parameter(torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32)) + self.w = nn.Parameter( + torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32) + ) self.b = nn.Parameter(torch.tensor([0.1, 0.2], dtype=torch.float32)) def forward(self, x): @@ -122,7 +129,9 @@ def test_write_code(self, agent_state): from pymc_rust_compiler.pytorch_rust_transpiler import _tool_write_code result = _tool_write_code( - {"code": "use crate::data::*;\npub fn forward(input: &[f32]) -> Vec { vec![] }\npub fn forward_with_grad(input: &[f32], _p: &str) -> (Vec, Vec) { (vec![], vec![]) }\n"}, + { + "code": "use crate::data::*;\npub fn forward(input: &[f32]) -> Vec { vec![] }\npub fn forward_with_grad(input: &[f32], _p: &str) -> (Vec, Vec) { (vec![], vec![]) }\n" + }, agent_state, verbose=False, ) @@ -194,7 +203,9 @@ def simple_model_context(self): class Simple(nn.Module): def __init__(self): super().__init__() - self.w = nn.Parameter(torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32)) + self.w = nn.Parameter( + torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32) + ) self.b = nn.Parameter(torch.tensor([0.1, 0.2], dtype=torch.float32)) def forward(self, x): @@ -217,7 +228,6 @@ def test_context_has_gradients(self, simple_model_context): def test_output_matches_pytorch(self, simple_model_context): """Verify that the exporter captured the correct forward pass output.""" - import torch ctx = simple_model_context vp = ctx.validation_points[0] @@ -251,12 +261,17 @@ def model_and_state(self): import torch import torch.nn as nn from pymc_rust_compiler.pytorch_exporter import PytorchModelExporter - from pymc_rust_compiler.pytorch_rust_transpiler import _AgentState, _setup_rust_project + from pymc_rust_compiler.pytorch_rust_transpiler import ( + _AgentState, + _setup_rust_project, + ) class Simple(nn.Module): def __init__(self): super().__init__() - self.w = nn.Parameter(torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32)) + self.w = nn.Parameter( + torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32) + ) self.b = nn.Parameter(torch.tensor([0.1, 0.2], dtype=torch.float32)) def forward(self, x): @@ -280,7 +295,9 @@ def forward(self, x): def test_correct_rust_validates(self, model_and_state): """Write manually correct Rust code and verify it passes validation.""" from pymc_rust_compiler.pytorch_rust_transpiler import ( - _tool_write_code, _tool_cargo_build, _tool_validate, + _tool_write_code, + _tool_cargo_build, + _tool_validate, ) try: @@ -291,7 +308,7 @@ def test_correct_rust_validates(self, model_and_state): state = model_and_state # Write correct Rust implementation of y = x @ w + b - rust_code = ''' + rust_code = """ use crate::data::*; /// Forward pass: y = x @ w + b @@ -341,7 +358,7 @@ def test_correct_rust_validates(self, model_and_state): _ => (output, vec![]) } } -''' +""" _tool_write_code({"code": rust_code}, state, verbose=False) build_result = _tool_cargo_build(state, verbose=False) @@ -377,7 +394,11 @@ def test_skill_has_backprop(self): from pymc_rust_compiler.pytorch_rust_transpiler import _load_skill skill = _load_skill("pytorch_to_rust") - assert "backward" in skill.lower() or "backprop" in skill.lower() or "gradient" in skill.lower() + assert ( + "backward" in skill.lower() + or "backprop" in skill.lower() + or "gradient" in skill.lower() + ) # ── Result Type Tests ──────────────────────────────────────────────────────── @@ -432,7 +453,6 @@ class TestPromptBuilding: """Test that user prompts are built correctly.""" def test_prompt_contains_model_info(self): - import torch import torch.nn as nn from pymc_rust_compiler.pytorch_exporter import PytorchModelExporter from pymc_rust_compiler.pytorch_rust_transpiler import _build_user_prompt