Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion crates/goth-eval/src/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ impl Evaluator {
("det", PrimFn::Det), ("inv", PrimFn::Inv),
("diag", PrimFn::Diag), ("eye", PrimFn::Eye),
("solve", PrimFn::Solve), ("solveWith", PrimFn::SolveWith),
("eig", PrimFn::Eig), ("eigvecs", PrimFn::EigVecs),
];
for (name, prim) in prims { self.globals.borrow_mut().insert(name.to_string(), Value::Primitive(*prim)); }
// Register stream constants
Expand Down Expand Up @@ -648,7 +649,7 @@ fn prim_arity(prim: PrimFn) -> usize {
PrimFn::Flush | PrimFn::RawModeEnter | PrimFn::RawModeExit => 1, // Terminal control (take unit)
PrimFn::Lines | PrimFn::Words | PrimFn::Bytes => 1, // String splitting (unary)
PrimFn::Re | PrimFn::Im | PrimFn::Conj | PrimFn::Arg => 1, // Complex decomposition
PrimFn::Trace | PrimFn::Det | PrimFn::Inv | PrimFn::Diag | PrimFn::Eye => 1, // Matrix utilities
PrimFn::Trace | PrimFn::Det | PrimFn::Inv | PrimFn::Diag | PrimFn::Eye | PrimFn::Eig | PrimFn::EigVecs => 1, // Matrix utilities
PrimFn::Solve => 2, // Linear solve (default LU)
PrimFn::SolveWith => 3, // Linear solve with method string
PrimFn::WriteFile | PrimFn::ReadBytes | PrimFn::WriteBytes => 2, // Binary I/O takes 2 args
Expand Down
261 changes: 261 additions & 0 deletions crates/goth-eval/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1836,4 +1836,265 @@ mod tests {
let expr = Expr::app(Expr::app(Expr::app(Expr::name("solveWith"), a), b), method);
assert!(e.eval(&expr).is_err());
}

// ── Eigenvalue tests ──

fn assert_eigenvalue_approx(val: &Value, idx: usize, expected_re: f64, expected_im: f64, tol: f64, label: &str) {
match val {
Value::Tensor(t) => {
let elem = t.get_flat(idx).expect(&format!("{}: index {} out of bounds", label, idx));
match elem {
Value::Float(f) => {
assert!(expected_im.abs() < tol, "{}: expected complex but got Float({})", label, f.0);
assert!((f.0 - expected_re).abs() < tol, "{}: expected re={}, got {} (diff {})", label, expected_re, f.0, (f.0 - expected_re).abs());
}
Value::Complex(re, im) => {
assert!((re - expected_re).abs() < tol, "{}: expected re={}, got {} (diff {})", label, expected_re, re, (re - expected_re).abs());
assert!((im - expected_im).abs() < tol, "{}: expected im={}, got {} (diff {})", label, expected_im, im, (im - expected_im).abs());
}
other => panic!("{}: expected Float or Complex, got {:?}", label, other),
}
}
_ => panic!("{}: expected Tensor, got {:?}", label, val),
}
}

#[test]
fn test_eig_identity() {
let mut e = Evaluator::new();
let expr = Expr::app(Expr::name("eig"), Expr::app(Expr::name("eye"), Expr::int(3)));
let result = e.eval(&expr).unwrap();
if let Value::Tensor(t) = &result {
assert_eq!(t.shape, vec![3]);
// All eigenvalues should be 1.0
for i in 0..3 {
assert_eigenvalue_approx(&result, i, 1.0, 0.0, 1e-10, "eig(eye(3))");
}
// Should be Float tensor (all real)
assert!(matches!(t.data, TensorData::Float(_)), "identity eigenvalues should be Float tensor");
} else {
panic!("expected Tensor");
}
}

#[test]
fn test_eig_1x1() {
let mut e = Evaluator::new();
let mat = Expr::array(vec![Expr::array(vec![Expr::float(7.0)])]);
let expr = Expr::app(Expr::name("eig"), mat);
let result = e.eval(&expr).unwrap();
assert_eigenvalue_approx(&result, 0, 7.0, 0.0, 1e-10, "eig([[7]])");
}

#[test]
fn test_eig_symmetric_2x2() {
let mut e = Evaluator::new();
let expr = Expr::app(Expr::name("eig"), mat2x2(2.0, 1.0, 1.0, 2.0));
let result = e.eval(&expr).unwrap();
// Eigenvalues should be 3 and 1 (sorted descending)
assert_eigenvalue_approx(&result, 0, 3.0, 0.0, 1e-10, "eig sym 2x2 [0]");
assert_eigenvalue_approx(&result, 1, 1.0, 0.0, 1e-10, "eig sym 2x2 [1]");
}

#[test]
fn test_eig_diagonal() {
let mut e = Evaluator::new();
let expr = Expr::app(Expr::name("eig"),
Expr::app(Expr::name("diag"), Expr::array(vec![Expr::float(1.0), Expr::float(2.0), Expr::float(3.0)])));
let result = e.eval(&expr).unwrap();
// Sorted descending: 3, 2, 1
assert_eigenvalue_approx(&result, 0, 3.0, 0.0, 1e-10, "eig diag [0]");
assert_eigenvalue_approx(&result, 1, 2.0, 0.0, 1e-10, "eig diag [1]");
assert_eigenvalue_approx(&result, 2, 1.0, 0.0, 1e-10, "eig diag [2]");
}

#[test]
fn test_eig_rotation_complex() {
let mut e = Evaluator::new();
// [[0, -1], [1, 0]] has eigenvalues i and -i
let expr = Expr::app(Expr::name("eig"), mat2x2(0.0, -1.0, 1.0, 0.0));
let result = e.eval(&expr).unwrap();
if let Value::Tensor(t) = &result {
assert!(matches!(t.data, TensorData::Generic(_)), "rotation eigenvalues should be Generic (complex)");
}
// Eigenvalues are ±i, sorted by real part then imaginary descending
assert_eigenvalue_approx(&result, 0, 0.0, 1.0, 1e-10, "eig rot [0]");
assert_eigenvalue_approx(&result, 1, 0.0, -1.0, 1e-10, "eig rot [1]");
}

#[test]
fn test_eig_trace_invariant() {
// Sum of eigenvalues = trace(A)
let mut e = Evaluator::new();
let a = mat3x3([2.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 4.0]);
let trace_expr = Expr::app(Expr::name("trace"), a.clone());
let trace_val = e.eval(&trace_expr).unwrap().coerce_float().unwrap();

let eig_expr = Expr::app(Expr::name("eig"), a);
let eig_result = e.eval(&eig_expr).unwrap();
let mut eig_sum = 0.0;
if let Value::Tensor(t) = &eig_result {
for i in 0..t.shape[0] {
let v = t.get_flat(i).unwrap();
eig_sum += v.coerce_float().unwrap();
}
}
assert_approx(eig_sum, trace_val, 1e-10, "sum(eig) == trace");
}

#[test]
fn test_eig_det_invariant() {
// Product of eigenvalues = det(A)
let mut e = Evaluator::new();
let a = mat2x2(4.0, 2.0, 1.0, 3.0);
let det_expr = Expr::app(Expr::name("det"), a.clone());
let det_val = e.eval(&det_expr).unwrap().coerce_float().unwrap();

let eig_expr = Expr::app(Expr::name("eig"), a);
let eig_result = e.eval(&eig_expr).unwrap();
let mut eig_prod = 1.0;
if let Value::Tensor(t) = &eig_result {
for i in 0..t.shape[0] {
let v = t.get_flat(i).unwrap();
eig_prod *= v.coerce_float().unwrap();
}
}
assert_approx(eig_prod, det_val, 1e-10, "prod(eig) == det");
}

#[test]
fn test_eig_non_square_error() {
let mut e = Evaluator::new();
let mat = Expr::array(vec![
Expr::array(vec![Expr::float(1.0), Expr::float(2.0), Expr::float(3.0)]),
Expr::array(vec![Expr::float(4.0), Expr::float(5.0), Expr::float(6.0)]),
]);
assert!(e.eval(&Expr::app(Expr::name("eig"), mat)).is_err());
}

// ── Eigenvector tests ──

#[test]
fn test_eigvecs_returns_tuple() {
let mut e = Evaluator::new();
let expr = Expr::app(Expr::name("eigvecs"), Expr::app(Expr::name("eye"), Expr::int(2)));
let result = e.eval(&expr).unwrap();
if let Value::Tuple(vs) = &result {
assert_eq!(vs.len(), 2, "eigvecs should return a 2-tuple");
} else {
panic!("expected Tuple, got {:?}", result);
}
}

#[test]
fn test_eigvecs_identity() {
let mut e = Evaluator::new();
let expr = Expr::app(Expr::name("eigvecs"), Expr::app(Expr::name("eye"), Expr::int(3)));
let result = e.eval(&expr).unwrap();
if let Value::Tuple(vs) = &result {
// All eigenvalues should be 1.0
for i in 0..3 {
assert_eigenvalue_approx(&vs[0], i, 1.0, 0.0, 1e-10, "eigvecs(eye(3)) eval");
}
} else {
panic!("expected Tuple");
}
}

#[test]
fn test_eigvecs_av_equals_lambda_v() {
// Fundamental property: A*v = λ*v for each eigenvalue/eigenvector pair
let mut e = Evaluator::new();
let a_data = [4.0, 1.0, 2.0, 3.0];
let a_expr = mat2x2(a_data[0], a_data[1], a_data[2], a_data[3]);
let expr = Expr::app(Expr::name("eigvecs"), a_expr);
let result = e.eval(&expr).unwrap();
if let Value::Tuple(vs) = &result {
let evals = &vs[0];
let evecs = &vs[1];
if let (Value::Tensor(eval_t), Value::Tensor(evec_t)) = (evals, evecs) {
let n = 2;
for col in 0..n {
let lambda = eval_t.get_flat(col).unwrap().coerce_float().unwrap();
// Extract eigenvector column
let mut v = vec![0.0; n];
for row in 0..n { v[row] = evec_t.get(&[row, col]).unwrap().coerce_float().unwrap(); }
// Compute A*v
let mut av = vec![0.0; n];
for i in 0..n {
for j in 0..n {
av[i] += a_data[i * n + j] * v[j];
}
}
// Check A*v = λ*v
for i in 0..n {
assert_approx(av[i], lambda * v[i], 1e-8, &format!("A*v = λ*v, col={}, row={}", col, i));
}
}
}
} else {
panic!("expected Tuple");
}
}

#[test]
fn test_eigvecs_symmetric_orthogonal() {
// Eigenvectors of symmetric matrix should be orthogonal
let mut e = Evaluator::new();
let a_expr = mat2x2(2.0, 1.0, 1.0, 2.0);
let expr = Expr::app(Expr::name("eigvecs"), a_expr);
let result = e.eval(&expr).unwrap();
if let Value::Tuple(vs) = &result {
if let Value::Tensor(evec_t) = &vs[1] {
let n = 2;
// Get columns
let mut v0 = vec![0.0; n];
let mut v1 = vec![0.0; n];
for i in 0..n {
v0[i] = evec_t.get(&[i, 0]).unwrap().coerce_float().unwrap();
v1[i] = evec_t.get(&[i, 1]).unwrap().coerce_float().unwrap();
}
let dot: f64 = v0.iter().zip(v1.iter()).map(|(a, b)| a * b).sum();
assert!(dot.abs() < 1e-8, "eigenvectors should be orthogonal, dot = {}", dot);
}
} else {
panic!("expected Tuple");
}
}

#[test]
fn test_eigvecs_diagonal() {
let mut e = Evaluator::new();
let a_expr = Expr::app(Expr::name("diag"), Expr::array(vec![Expr::float(5.0), Expr::float(3.0)]));
let expr = Expr::app(Expr::name("eigvecs"), a_expr);
let result = e.eval(&expr).unwrap();
if let Value::Tuple(vs) = &result {
// Eigenvalues should be 5 and 3 (sorted descending)
assert_eigenvalue_approx(&vs[0], 0, 5.0, 0.0, 1e-10, "eigvecs diag eval[0]");
assert_eigenvalue_approx(&vs[0], 1, 3.0, 0.0, 1e-10, "eigvecs diag eval[1]");
// Eigenvectors should be unit vectors
if let Value::Tensor(evec_t) = &vs[1] {
for col in 0..2 {
let mut norm_sq = 0.0;
for row in 0..2 {
let v = evec_t.get(&[row, col]).unwrap().coerce_float().unwrap();
norm_sq += v * v;
}
assert_approx(norm_sq, 1.0, 1e-10, &format!("eigvec col {} should be unit", col));
}
}
} else {
panic!("expected Tuple");
}
}

#[test]
fn test_eigvecs_non_square_error() {
let mut e = Evaluator::new();
let mat = Expr::array(vec![
Expr::array(vec![Expr::float(1.0), Expr::float(2.0), Expr::float(3.0)]),
Expr::array(vec![Expr::float(4.0), Expr::float(5.0), Expr::float(6.0)]),
]);
assert!(e.eval(&Expr::app(Expr::name("eigvecs"), mat)).is_err());
}
}
Loading
Loading