diff --git a/Cargo.toml b/Cargo.toml index 946ebc3..47700be 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,5 +2,6 @@ members=[ "libludi", "interpreter", + "ludic" ] resolver="2" diff --git a/examples/basic_arithmetic.ludi b/examples/basic_arithmetic.ludi new file mode 100644 index 0000000..c341379 --- /dev/null +++ b/examples/basic_arithmetic.ludi @@ -0,0 +1,5 @@ +fn main() { + let a = 12; + let b = 4; + a+b +} diff --git a/interpreter/src/array.rs b/interpreter/src/array.rs index 982b62b..9ca2a58 100644 --- a/interpreter/src/array.rs +++ b/interpreter/src/array.rs @@ -1,6 +1,6 @@ use crate::datatypes::{ArrayType, AtomicType}; use libludi::ast::FrameNode; -use libludi::err::{Error, LudiError, Result}; +use libludi::err::{Error, LudiError, RuntimeErrorKind, Result}; use libludi::shape::ambassador_impl_ArrayProps; use libludi::shape::ambassador_impl_ShapeOps; use libludi::shape::{ArrayProps, Shape, ShapeOps}; @@ -131,7 +131,7 @@ impl IntoIterator for Array { } } impl TryInto for Array { - type Error = Error; + type Error = anyhow::Error; fn try_into(self) -> Result { if self.rank() != 1 { Err(anyhow::anyhow!( @@ -149,7 +149,7 @@ impl TryInto for Array { } impl TryInto for Array { - type Error = Error; + type Error = anyhow::Error; fn try_into(self) -> Result { if self.rank() != 1 { Err(anyhow::anyhow!( @@ -208,7 +208,7 @@ impl FromIterator for Result { newaxis +=1; Ok(array_base.data_raw().into_iter()) } else { - Err(Error::runtime_err("Frame error: mismatched types or shape in frame")) + Err(Error::runtime_err(RuntimeErrorKind::InterpretError, "Frame error: mismatched types or shape in frame").into()) }).flatten_ok().collect::>>()?, shape: Shape::new(&[newaxis]).concat(rest_of_shape), } diff --git a/interpreter/src/datatypes.rs b/interpreter/src/datatypes.rs index 4eb89d5..2a20235 100644 --- a/interpreter/src/datatypes.rs +++ b/interpreter/src/datatypes.rs @@ -4,7 +4,7 @@ use crate::interpret::DynamicEnv; use libludi::ast::{Arg, FnDefNode, FuncSignature}; use libludi::atomic::Literal; -use libludi::err::{Error, LudiError, Result}; +use libludi::err::{Error, LudiError, Result, RuntimeErrorKind}; use libludi::shape::ambassador_impl_ArrayProps; use libludi::shape::ambassador_impl_ShapeOps; use libludi::shape::{ArrayProps, Shape, ShapeOps}; @@ -35,7 +35,7 @@ pub trait Data // BinaryOp + pub enum DataType { Array(ArrayType), Atomic(AtomicType), - Unit + Unit, } impl Data for DataType {} @@ -80,7 +80,6 @@ pub enum ArrayType { Fn(Array), } - #[repr(u8)] #[derive(derive_more::Display, Eq, Debug, Copy, Clone, PartialEq)] pub enum DataTypeTag { @@ -123,8 +122,9 @@ impl FromIterator for Result { } else { //TODO: add better error information Err(Error::runtime_err( + RuntimeErrorKind::InterpretError, "Frame error: found non conforming value in frame", - )) + ).into()) } }) .collect::>>()??, @@ -134,13 +134,17 @@ impl FromIterator for Result { Ok(a.upgrade()) } else { Err(Error::runtime_err( + RuntimeErrorKind::InterpretError, "Frame error: found non conforming value in frame", - )) + ).into()) } }) .collect::>>()??, Some(DataType::Unit) => return Ok(DataType::Unit), - None => Err(Error::runtime_err("Frame error: empty frame"))?, + None => Err(Error::runtime_err( + RuntimeErrorKind::InterpretError, + "Frame error: empty frame" + ))?, } })) } @@ -153,7 +157,7 @@ impl FromIterator for Result { // } impl FromStr for DataTypeTag { - type Err = Error; + type Err = anyhow::Error; fn from_str(s: &str) -> Result { match s { "u8" => Ok(DataTypeTag::UInt8), @@ -172,15 +176,21 @@ impl FromStr for DataTypeTag { "char" => Ok(DataTypeTag::Character), "bool" => Ok(DataTypeTag::Boolean), "box" => Err(Error::runtime_err( + RuntimeErrorKind::InterpretError, "Box type not supported in function signature", - )), + ).into()), "fn" => Err(Error::runtime_err( + RuntimeErrorKind::InterpretError, "Fn type not suppored in function signature", - )), + ).into()), "()" => Err(Error::runtime_err( + RuntimeErrorKind::InterpretError, "Unit type not supported in function signature", - )), - _ => Err(Error::runtime_err("not a known builtin type")), + ).into()), + _ => Err(Error::runtime_err( + RuntimeErrorKind::InterpretError, + "not a known builtin type", + ).into()), } } } @@ -192,12 +202,13 @@ pub struct TypeSignature(pub DataTypeTag, pub Shape); pub struct OptionalTypeSignature(pub Option, pub Shape); impl TryFrom for TypeSignature { - type Error = Error; + type Error = anyhow::Error; fn try_from(value: OptionalTypeSignature) -> Result { Ok(TypeSignature( - value - .0 - .ok_or(Error::runtime_err("expected explicit type annotations"))?, + value.0.ok_or(Error::runtime_err( + RuntimeErrorKind::InterpretError, + "expected explicit type annotations", + ))?, value.1, )) } diff --git a/interpreter/src/function.rs b/interpreter/src/function.rs index 0f60ff3..324f69c 100644 --- a/interpreter/src/function.rs +++ b/interpreter/src/function.rs @@ -41,7 +41,7 @@ impl Callable for FunctionData { )); } e.push_with( - zip(self.params().into_iter(), arguments.into_iter()).map(|(Arg(name, ty), value)| { + zip(self.params().into_iter(), arguments.into_iter()).map(|(Arg(name, _ty), value)| { //TODO: check type agreement (name.clone(), value.into()) }), diff --git a/interpreter/src/interpret.rs b/interpreter/src/interpret.rs index e7655bf..3e1ee94 100644 --- a/interpreter/src/interpret.rs +++ b/interpreter/src/interpret.rs @@ -10,7 +10,7 @@ use itertools::Itertools; use libludi::{ ast::*, env::Env, - err::{Error, LudiError, Result}, + err::{Error, LudiError, Result, RuntimeErrorKind}, shape::{ArrayProps, Shape, ShapeOps}, token::Token, types::{self, Atom, AtomicDataType, PrimitiveFuncType, Type}, @@ -177,9 +177,9 @@ impl Interpret for FnCallNode { // interpreter should prevent branches diverge... } } - _ => Err(Error::runtime_err( + _ => Err(Error::runtime_err(RuntimeErrorKind::InterpretError, "if expression expected boolean condition", - )), + ).into()), } } PrimitiveFuncType::Gt => { @@ -195,9 +195,9 @@ impl Interpret for FnCallNode { DataType::Atomic(AtomicType::Float(a)), DataType::Atomic(AtomicType::Float(b)), ) => Ok(DataType::Atomic(AtomicType::Boolean(a > b))), - _ => Err(Error::runtime_err( + _ => Err(Error::runtime_err(RuntimeErrorKind::InterpretError, "type error: '>' op is not defined for between these types", - )), + ).into()), } } PrimitiveFuncType::GtEq => { @@ -213,9 +213,9 @@ impl Interpret for FnCallNode { DataType::Atomic(AtomicType::Float(a)), DataType::Atomic(AtomicType::Float(b)), ) => Ok(DataType::Atomic(AtomicType::Boolean(a >= b))), - _ => Err(Error::runtime_err( + _ => Err(Error::runtime_err(RuntimeErrorKind::InterpretError, "type error: '>=' op is not defined for between these types", - )), + ).into()), } } PrimitiveFuncType::Lt => { @@ -231,9 +231,9 @@ impl Interpret for FnCallNode { DataType::Atomic(AtomicType::Float(a)), DataType::Atomic(AtomicType::Float(b)), ) => Ok(DataType::Atomic(AtomicType::Boolean(a < b))), - _ => Err(Error::runtime_err( + _ => Err(Error::runtime_err(RuntimeErrorKind::InterpretError, "type error: '<' op is not defined for between these types", - )), + ).into()), } } PrimitiveFuncType::LtEq => { @@ -249,9 +249,9 @@ impl Interpret for FnCallNode { DataType::Atomic(AtomicType::Float(a)), DataType::Atomic(AtomicType::Float(b)), ) => Ok(DataType::Atomic(AtomicType::Boolean(a <= b))), - _ => Err(Error::runtime_err( + _ => Err(Error::runtime_err(RuntimeErrorKind::InterpretError, "type error: '<=' op is not defined for between these types", - )), + ).into()), } } PrimitiveFuncType::Ne => { @@ -271,9 +271,9 @@ impl Interpret for FnCallNode { DataType::Atomic(AtomicType::Character(a)), DataType::Atomic(AtomicType::Character(b)), ) => Ok(DataType::Atomic(AtomicType::Boolean(a != b))), - _ => Err(Error::runtime_err( + _ => Err(Error::runtime_err(RuntimeErrorKind::InterpretError, "type error: '!=' op is not defined for between these types", - )), + ).into()), } } PrimitiveFuncType::Eq => { @@ -293,9 +293,9 @@ impl Interpret for FnCallNode { DataType::Atomic(AtomicType::Character(a)), DataType::Atomic(AtomicType::Character(b)), ) => Ok(DataType::Atomic(AtomicType::Boolean(a == b))), - _ => Err(Error::runtime_err( + _ => Err(Error::runtime_err(RuntimeErrorKind::InterpretError, "type error: '==' op is not defined for between these types", - )), + ).into()), } } PrimitiveFuncType::And => { @@ -308,9 +308,9 @@ impl Interpret for FnCallNode { DataType::Atomic(AtomicType::Boolean(a)), DataType::Atomic(AtomicType::Boolean(b)), ) => Ok(DataType::Atomic(AtomicType::Boolean(a && b))), - _ => Err(Error::runtime_err( + _ => Err(Error::runtime_err(RuntimeErrorKind::InterpretError, "type error: 'and' op is not defined for between these types", - )), + ).into()), } } PrimitiveFuncType::Or => { @@ -323,9 +323,9 @@ impl Interpret for FnCallNode { DataType::Atomic(AtomicType::Boolean(a)), DataType::Atomic(AtomicType::Boolean(b)), ) => Ok(DataType::Atomic(AtomicType::Boolean(a || b))), - _ => Err(Error::runtime_err( + _ => Err(Error::runtime_err(RuntimeErrorKind::InterpretError, "type error: 'or' op is not defined for between these types", - )), + ).into()), } } PrimitiveFuncType::Not => { @@ -335,9 +335,9 @@ impl Interpret for FnCallNode { DataType::Atomic(AtomicType::Boolean(a)) => { Ok(DataType::Atomic(AtomicType::Boolean(!a))) } - _ => Err(Error::runtime_err( + _ => Err(Error::runtime_err(RuntimeErrorKind::InterpretError, "type error: 'not' op is not defined for between these types", - )), + ).into()), } } PrimitiveFuncType::Neg => { @@ -353,7 +353,7 @@ impl Interpret for FnCallNode { crate::datatypes::ArrayType::Int(Iota::iota(i.try_into()?)), )), DataType::Array(ArrayType::Int(_a_i)) => todo!(), - _ => Err(Error::runtime_err("error: Iota expects integer")), + _ => Err(Error::runtime_err(RuntimeErrorKind::InterpretError, "error: Iota expects integer").into()), } } PrimitiveFuncType::Reshape => { diff --git a/interpreter/src/main.rs b/interpreter/src/main.rs index 5d1b47c..050c916 100644 --- a/interpreter/src/main.rs +++ b/interpreter/src/main.rs @@ -25,7 +25,7 @@ use array::{Array, Iota}; use datatypes::{ArrayType, AtomicType, DataType}; use libludi::err::Result; use libludi::shape::{ArrayProps, Shape, ShapeOps}; -use libludi::lex::lex; +use libludi::lex::Lex; use libludi::parser::expression; use ops::*; @@ -73,7 +73,7 @@ fn test_frame() -> Result<()>{ fn basic_unary() -> Result<()> { // assert_eq!(8, std::mem::size_of::()); let mut e = DynamicEnv::new(); - let expr = expression(&mut lex("let a = 2 in -a"))?; + let expr = expression(&mut Lex::lex("let a = 2 in -a"))?; // dbg!(&expr); let r = expr.interpret(&mut e)?; assert_eq!(format!("{}",r), "-2"); @@ -83,7 +83,7 @@ fn basic_unary() -> Result<()> { #[test] fn basic_arithmetic1() -> Result<()> { let mut e = DynamicEnv::new(); - let r = expression(&mut lex("2+2"))?.interpret(&mut e)?; + let r = expression(&mut Lex::lex("2+2"))?.interpret(&mut e)?; assert_eq!(format!("{}",r), "4"); Ok(()) } @@ -91,59 +91,59 @@ fn basic_arithmetic1() -> Result<()> { #[test] fn basic_arithmetic2() -> Result<()> { let mut e = DynamicEnv::new(); - let r = expression(&mut lex("1./2."))?.interpret(&mut e)?; + let r = expression(&mut Lex::lex("1./2."))?.interpret(&mut e)?; assert_eq!(format!("{}",r), "0.5"); Ok(()) } #[test] fn basic_arithmetic3() -> Result<()> { let mut e = DynamicEnv::new(); - let r = expression(&mut lex("1.0/(2.0+3.0)"))?.interpret(&mut e)?; + let r = expression(&mut Lex::lex("1.0/(2.0+3.0)"))?.interpret(&mut e)?; assert_eq!(format!("{}",r), "0.2"); Ok(()) } #[test] fn basic_arithmetic4() ->Result<()> { let mut e = DynamicEnv::new(); - let r = expression(&mut lex("5.0 * 1.0/(2.0+3.0)"))?.interpret(&mut e)?; + let r = expression(&mut Lex::lex("5.0 * 1.0/(2.0+3.0)"))?.interpret(&mut e)?; assert_eq!(format!("{}",r), "1"); Ok(()) } #[test] fn basic_arithmetic5() -> Result<()> { let mut e = DynamicEnv::new(); - let r = expression(&mut lex("2.0 + 0.3"))?.interpret(&mut e)?; + let r = expression(&mut Lex::lex("2.0 + 0.3"))?.interpret(&mut e)?; assert_eq!(format!("{}",r), "2.3"); Ok(()) } #[test] fn basic_arithmetic6() -> Result<()> { let mut e = DynamicEnv::new(); - let r = expression(&mut lex("10.0 + 17.0 - 12.2/4.0"))?.interpret(&mut e)?; + let r = expression(&mut Lex::lex("10.0 + 17.0 - 12.2/4.0"))?.interpret(&mut e)?; assert_eq!(format!("{}",r), "23.95"); Ok(()) } #[test] fn assignment1() -> Result<()> { let mut e = DynamicEnv::new(); - let _ = expression(&mut lex("let a = 2.0+0.3;"))?.interpret(&mut e)?; - let r = expression(&mut lex("a"))?.interpret(&mut e)?; + let _ = expression(&mut Lex::lex("let a = 2.0+0.3;"))?.interpret(&mut e)?; + let r = expression(&mut Lex::lex("a"))?.interpret(&mut e)?; assert_eq!(format!("{}",r), "2.3"); - let r = expression(&mut lex("let a = a + a; a"))?.interpret(&mut e)?; + let r = expression(&mut Lex::lex("let a = a + a; a"))?.interpret(&mut e)?; assert_eq!(format!("{}",r), "4.6"); Ok(()) } #[test] fn assignment_err1() -> Result<()> { let mut e = DynamicEnv::new(); - let _ = expression(&mut lex("notathing"))?.interpret(&mut e).expect_err("failed to detect unbound symbol!"); + let _ = expression(&mut Lex::lex("notathing"))?.interpret(&mut e).expect_err("failed to detect unbound symbol!"); Ok(()) } #[test] fn automap() -> Result<()> { let mut e = DynamicEnv::new(); - let r = expression(&mut lex("reshape(iota(8), [4 2]) * [0 2]"))?.interpret(&mut e)?; + let r = expression(&mut Lex::lex("reshape(iota(8), [4 2]) * [0 2]"))?.interpret(&mut e)?; match r { DataType::Array(ArrayType::Int(int_array)) => { assert_eq!(int_array.shape_slice(), &[4, 2]); diff --git a/interpreter/src/ops.rs b/interpreter/src/ops.rs index 9fb1014..13b7d89 100644 --- a/interpreter/src/ops.rs +++ b/interpreter/src/ops.rs @@ -5,9 +5,9 @@ use crate::array::Array; use libludi::shape::ArrayProps; // use crate::atomic::{AtomicType}; +use libludi::err::{Error, LudiError, RuntimeErrorKind, Result}; use crate::datatypes::{ArrayType, AtomicType, Data, DataType}; use itertools::izip; -use libludi::err::{Error, LudiError, Result}; use num::complex::ComplexFloat; use num::Complex; @@ -63,7 +63,7 @@ macro_rules! delegate_binops_data { match self { DataType::Array(a) => Ok(DataType::Array(a.$fname()?)), DataType::Atomic(a) => Ok(DataType::Atomic(a.$fname()?)), - DataType::Unit => Err(Error::runtime_err("operating on unit type is not allowed")) + DataType::Unit => Err(Error::runtime_err(RuntimeErrorKind::InterpretError, "operating on unit type is not allowed").into()) } } } @@ -147,7 +147,7 @@ macro_rules! delegate_binops_numbertype { (AtomicType::Index(a), AtomicType::Index(b)) => Ok(AtomicType::Index(a.$fname(b)?)), (AtomicType::Float(a), AtomicType::Float(b)) => Ok(AtomicType::Float(a.$fname(b)?)), (AtomicType::Complex(a), AtomicType::Complex(b)) => Ok(AtomicType::Complex(a.$fname(b)?)), - _ => {Err(Error::msg(format!("incompatible types for {}", stringify!($fname))))} + _ => {Err(Error::runtime_err(RuntimeErrorKind::InterpretError, &format!("incompatible types for {}", stringify!($fname))).into())} } } } @@ -205,7 +205,7 @@ macro_rules! delegate_binops_std_array { } ) ), - None => return Err(Error::msg("shape error"))}, + None => return Err(Error::runtime_err(RuntimeErrorKind::InterpretError, "shape error").into())}, } } } @@ -272,7 +272,7 @@ impl Neg for AtomicType { Int(a) => Int(-a), Float(a) => Float(-a), // Complex(a) => Complex(num::Complex::new(-a.re(), a.im())), //subtracts the real - _ => return Err(Error::msg("unsupported op")), + _ => return Err(Error::runtime_err(RuntimeErrorKind::InterpretError, "unsupported type for negate").into()), }) } } diff --git a/interpreter/src/run.rs b/interpreter/src/run.rs index 3ad1364..abf300e 100644 --- a/interpreter/src/run.rs +++ b/interpreter/src/run.rs @@ -10,7 +10,7 @@ use libludi::ast::ParseTree; use libludi::ast::Stmt; use libludi::env::Env; use libludi::err::{Error, Result}; -use libludi::lex::lex; +use libludi::lex::Lex; use libludi::parser::Parser; use libludi::parser::{expression, statement}; use rustyline::error::ReadlineError; @@ -52,7 +52,7 @@ pub fn repl() -> Result<()> { pub fn run(source: &str, e: &mut DynamicEnv) -> Result { let dump_ast: bool = env::var("DUMP_AST").is_ok(); let dump_tokens: bool = env::var("DUMP_TOKENS").is_ok(); - let mut tokens = lex(source); + let mut tokens = Lex::lex(source); if dump_tokens { tokens.clone().for_each(|t| { println!("{:?}", t.token); diff --git a/libludi/Cargo.toml b/libludi/Cargo.toml index ec2d9f4..5393f46 100644 --- a/libludi/Cargo.toml +++ b/libludi/Cargo.toml @@ -13,8 +13,9 @@ num = "0.4.3" # Complex numbers & big integers smallvec = "1.13.2" # stack-allocated Vec thiserror = "1.0.61" # derive macro for std::Error unicode-segmentation = "1.11.0" # unicode character for lexer -melior = { version = "0.18.5", features = ["ods-dialects"], path="../../melior/melior" } # MLIR +melior = { version="0.19.1", features = ["ods-dialects"], path="../../melior/melior" } # MLIR paste = "1.0.15" # identifier concat in macros +mlir-sys = "0.3.0" # serde = { version = "1.0.210", features=["derive"] } # bytemuck = "1.18.0" diff --git a/libludi/src/ast.rs b/libludi/src/ast.rs index 2e983a1..e7a3ce2 100644 --- a/libludi/src/ast.rs +++ b/libludi/src/ast.rs @@ -4,7 +4,7 @@ use std::ops::Deref; use crate::atomic::Literal; use crate::env::Name; -use crate::err::{Error, ErrorKind, LudiError, Result}; +use crate::err::{Error, LudiError, Result}; use crate::shape::{ArrayProps, Shape, ShapeOps}; use crate::token::TokenData; use crate::types::{PrimitiveFuncType, Type}; diff --git a/libludi/src/atomic.rs b/libludi/src/atomic.rs index 63f306f..55d25e1 100644 --- a/libludi/src/atomic.rs +++ b/libludi/src/atomic.rs @@ -1,4 +1,4 @@ -use crate::err::{Error, ErrorKind, LudiError, Result}; +use crate::err::{Error, LudiError, ParseErrorKind, Result}; use crate::token::{Location, Token, TokenData}; use std::fmt::{write, Debug, Display}; use std::hash::Hash; @@ -13,7 +13,7 @@ pub enum Literal { } impl TryFrom for Literal { - type Error = Error; + type Error = anyhow::Error; fn try_from(value: TokenData) -> Result { let loc = value.loc; Ok(match value.token { @@ -21,7 +21,13 @@ impl TryFrom for Literal { Token::FLOAT_LITERAL(atom) => Literal::Float { loc, atom }, Token::TRUE => Literal::Bool { loc, atom: true }, Token::FALSE => Literal::Bool { loc, atom: false }, - _ => Err(Error::at_token(value, "Expected atomic value"))?, + _ => { + return Err( + Error::parse_err(ParseErrorKind::Literal, "expected literal value") + .at_token(value) + .into(), + ) + } }) } } diff --git a/libludi/src/codegen/mod.rs b/libludi/src/codegen/mod.rs index 839cede..fbbcc11 100644 --- a/libludi/src/codegen/mod.rs +++ b/libludi/src/codegen/mod.rs @@ -1,4 +1,5 @@ -pub mod writer; +mod writer; mod types; +pub use writer::*; // pub use dialect; diff --git a/libludi/src/codegen/types.rs b/libludi/src/codegen/types.rs index a7496b2..d3885eb 100644 --- a/libludi/src/codegen/types.rs +++ b/libludi/src/codegen/types.rs @@ -1,3 +1,5 @@ +use melior::ir::TypeLike; + use super::writer::{self, MLIRGen}; use crate::err::Result; use crate::shape::ArrayProps; @@ -83,8 +85,8 @@ impl<'c> writer::MLIRGen<'c, melior::ir::Type<'c>> for types::AtomicDataType { types::AtomicDataType::BFloat16 => melior::ir::r#type::Type::bfloat16(context), types::AtomicDataType::Float32 => melior::ir::r#type::Type::float32(context), types::AtomicDataType::Float64 => melior::ir::r#type::Type::float64(context), - types::AtomicDataType::Complex => { - melior::ir::r#type::Type::complex(melior::ir::r#type::Type::float32(context)) + types::AtomicDataType::Complex => unsafe { + melior::ir::r#type::Type::from_raw(mlir_sys::mlirComplexTypeGet(melior::ir::r#type::Type::float32(context).to_raw())) } types::AtomicDataType::UInt8 => { melior::ir::r#type::IntegerType::unsigned(context, 8).into() diff --git a/libludi/src/codegen/writer.rs b/libludi/src/codegen/writer.rs index f9ab0bb..4594d3e 100644 --- a/libludi/src/codegen/writer.rs +++ b/libludi/src/codegen/writer.rs @@ -1,4 +1,5 @@ use itertools::Itertools; +use melior::ir::operation::OperationPrintingFlags; use melior::ir::r#type::FunctionType; use melior::ir::Location; @@ -9,20 +10,35 @@ use crate::types; use crate::types::typed_ast; use crate::types::GetType; +pub trait WriteMLIR { + fn write(&self) -> String; +} + +impl WriteMLIR for typed_ast::TypedExpr { + fn write(&self) -> String { + let writer = CodeWriter::new(); + let module = writer.write_ast(self) + .expect("error: failed to convert AST to MLIR"); + module.as_operation().to_string_with_flags(OperationPrintingFlags::default()) + .expect("error: failed to convert MLIR Module to String") + } +} + // an object which contains an MLIR context and produces a module pub struct CodeWriter { // manages a single thread of MLIR core context pub context: melior::Context, } impl CodeWriter { - pub fn new(// what are the arguments + pub fn new( + // ne argz? ) -> Self { let context = load_builtin_dialects(); // debug locations in code CodeWriter { context } } - pub fn write_ast(&self, ast: typed_ast::TypedExpr) -> Result { + pub fn write_ast(&self, ast: &typed_ast::TypedExpr) -> Result { let builder = melior::dialect::ods::builtin::ModuleOperationBuilder::new( &self.context, melior::ir::Location::unknown(&self.context), diff --git a/libludi/src/env.rs b/libludi/src/env.rs index 70bf33e..1efd269 100644 --- a/libludi/src/env.rs +++ b/libludi/src/env.rs @@ -1,8 +1,8 @@ -use crate::err::{Error, ErrorKind, LudiError, Result}; +use crate::err::{Error, LudiError, ParseErrorKind, Result}; use crate::token::{Location, Token}; use crate::token::{Token::IDENTIFIER, TokenData}; use std::cell::{OnceCell, RefCell}; -use std::collections::HashMap; +use std::collections::{BTreeMap, HashMap}; use std::fmt::Debug; use std::rc::Rc; use std::str::FromStr; @@ -24,10 +24,13 @@ pub struct Name { pub loc: Location, } +// type Map = HashMap; +type Map = BTreeMap; + pub type EnvLink = Option>>; pub struct EnvMap { // A scoped symbol table - table: HashMap, + table: Map, // The outer scope prev: EnvLink, } @@ -35,8 +38,7 @@ pub struct EnvMap { pub struct Env { head: EnvLink, } -impl Env -{ +impl Env { pub fn new() -> Self { Self { head: Some(EnvMap::new(None).into()), @@ -72,7 +74,7 @@ impl Env pub fn get(&self, ident: &Name) -> Result<&S> { match &self.head { Some(table) => table.get(ident), - None => Err(Error::msg(format!("Unknown symbol name: {}", ident.name))), + None => Err(anyhow::anyhow!("Unknown symbol name: {}", ident.name)), } } pub fn put(&mut self, ident: Name, val: S) -> Option { @@ -92,11 +94,10 @@ impl Drop for Env { } } -impl EnvMap -{ +impl EnvMap { pub fn new(prev: EnvLink) -> Self { Self { - table: HashMap::new(), + table: Map::new(), prev, } } @@ -105,7 +106,7 @@ impl EnvMap I: Iterator, { Self { - table: HashMap::from_iter(list), + table: Map::from_iter(list), prev, } } @@ -118,7 +119,7 @@ impl EnvMap } else if let Some(p) = &self.prev { p.get(ident) } else { - Err(Error::msg(format!("Unknown symbol name: {}", ident.name))) + Err(anyhow::anyhow!("Unknown symbol name: {}", ident.name)) } } } @@ -141,8 +142,18 @@ impl std::hash::Hash for Name { self.name.hash(state) } } +impl std::cmp::Ord for Name { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.name.cmp(&other.name) + } +} +impl std::cmp::PartialOrd for Name { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.name.cmp(&other.name)) + } +} impl TryFrom for Name { - type Error = Error; + type Error = anyhow::Error; fn try_from(value: TokenData) -> Result { if let Token::IDENTIFIER(name) = value.token { Ok(Name { @@ -150,15 +161,13 @@ impl TryFrom for Name { loc: value.loc, }) } else { - Err(Error::at_token( - value, - "Trying to use non-identifier token as name", - )) + Err(Error::parse_err(ParseErrorKind::Ident, "Trying to use non-identifier token as name") + .at_token(value).into()) } } } impl FromStr for Name { - type Err = crate::err::Error; + type Err = anyhow::Error; fn from_str(s: &str) -> Result { Ok(Self { name: s.into(), diff --git a/libludi/src/err.rs b/libludi/src/err.rs index da5754f..db13a38 100644 --- a/libludi/src/err.rs +++ b/libludi/src/err.rs @@ -1,3 +1,5 @@ +use std::fmt::write; + use crate::{ // allocator::BlockError, env::Name, @@ -6,72 +8,124 @@ use crate::{ use anyhow; use thiserror; -// For errors related to the parsing of grammer -#[derive(thiserror::Error, Debug, Clone)] -pub enum ParseError {} -// For errors related to semantics -#[derive(thiserror::Error, Debug, Clone)] -pub enum CompileError {} -// For errors related to code generation & backends -#[derive(thiserror::Error, Debug, Clone)] -pub enum CodeGenError {} -// For errors related to runtime & interpreter -#[derive(thiserror::Error, Debug, Clone)] -pub enum RuntimeError {} +// re-export of anyhow result type +pub type Result = anyhow::Result; + +#[derive(thiserror::Error, derive_more::Display, Debug, Clone)] +pub enum LexErrorKind { + TokenError, +} +#[derive(thiserror::Error, derive_more::Display, Debug, Clone)] +pub enum ParseErrorKind { + Expr, + Literal, + Ident, + Frame, + LetExpr, + FnDef, + FnCall, +} +#[derive(thiserror::Error, derive_more::Display, Debug, Clone)] +pub enum CompileErrorKind { + TailCallOptError, +} +#[derive(thiserror::Error, derive_more::Display, Debug, Clone)] +pub enum TypeErrorKind { + TypeMismatch, + Unknown, + Unsupported, +} #[derive(thiserror::Error, derive_more::Display, Debug, Clone)] -pub enum ErrorKind { - LexErr, - ParseErr, // additional information - CompileErr, //additional information - CodeGenErr, // additional information - RuntimeErr, // additional information +pub enum CodeGenErrorKind { + MLIRError, +} +#[derive(thiserror::Error, derive_more::Display, Debug, Clone)] +pub enum RuntimeErrorKind { + InterpretError, } -// impl std::fmt::Display for LangError { -// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { -// match self { -// Self::LexErr(msg) => write!(f, "Lexical Error: {}", msg), -// Self::ParseErr(msg) => write!(f, "Parsing Error: {}", msg), -// Self::CompileErr(msg) => write!(f, "Compile Error: {}", msg), -// Self::RuntimeErr(msg) => write!(f, "Runtime Error: {}", msg), -// // Self::AllocErr(msg) => write!(f, "Allocation Error"), -// } -// } -// } +#[derive(thiserror::Error, Debug, Clone)] +pub enum LudiError { + LexError(LexErrorKind), + CompileError(CompileErrorKind), + ParseError(ParseErrorKind), + TypeError(TypeErrorKind), + CodeGenError(CodeGenErrorKind), + RuntimeError(RuntimeErrorKind), + // Add additional information +} -// re-export of anyhow result type -pub type Result = anyhow::Result; -pub type Error = anyhow::Error; -pub trait LudiError { - fn with_name(name: Name, msg: &str) -> Self; - fn at_token(tok: TokenData, msg: &str) -> Self; - fn runtime_err(msg: &'static str) -> Self; - fn parse_err(msg: &'static str) -> Self; - fn compile_err(msg: &'static str) -> Self; - fn codegen_err(msg: &'static str) -> Self; +impl std::fmt::Display for LudiError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::LexError(kind) => write!(f, "Lexical Error: {}", kind), + Self::ParseError(kind) => write!(f, "Parsing Error: {}", kind), + Self::CompileError(kind) => write!(f, "Compile Error: {}", kind), + Self::TypeError(kind) => write!(f, "Type Error: {}", kind), + Self::CodeGenError(kind) => write!(f, "Codegen Error: {}", kind), + Self::RuntimeError(kind) => write!(f, "Runtime Error: {}", kind), + // Self::AllocErr(msg) => write!(f, "Allocation Error"), + } + } +} + + +#[derive(thiserror::Error, Debug, Clone)] +pub struct Error { + error: LudiError, + msg: String, +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{} -- {}", self.error, self.msg) + } } -impl LudiError for Error { - fn with_name(name: Name, msg: &str) -> Self { - Error::msg(format!("Error at {}: {}", name, msg)) + +impl Error { + // pub fn with_name(name: Name, msg: &str) -> Self { + // Error::msg(format!("Error at {}: {}", name, msg)) + // } + pub fn at_token(mut self, tok: TokenData) -> Self { + self.msg = format!("Error line {} at \"{}\" -- {}", tok.loc, tok.token, self.msg); + return self; + } + pub fn lex_err(kind: LexErrorKind, msg: &str) -> Self { + Self { + error: LudiError::LexError(kind), + msg: msg.to_string(), + } } - fn at_token(tok: TokenData, msg: &str) -> Self { - Error::msg(format!( - "Error line {}: unexpected {:?}; {}", - tok.loc, tok.token, msg - )) + pub fn runtime_err(kind: RuntimeErrorKind, msg: &str) -> Self { + Self { + error: LudiError::RuntimeError(kind), + msg: msg.to_string(), + } } - fn runtime_err(msg: &'static str) -> Self { - Error::new(ErrorKind::RuntimeErr).context(msg) + pub fn parse_err(kind: ParseErrorKind, msg: &str) -> Self { + Self { + error: LudiError::ParseError(kind), + msg: msg.to_string(), + } } - fn parse_err(msg: &'static str) -> Self { - Error::new(ErrorKind::ParseErr).context(msg) + pub fn compile_err(kind: CompileErrorKind, msg: &str) -> Self { + Self { + error: LudiError::CompileError(kind), + msg: msg.to_string(), + } } - fn compile_err(msg: &'static str) -> Self { - Error::new(ErrorKind::CompileErr).context(msg) + pub fn type_err(kind: TypeErrorKind, msg: &str) -> Self { + Self { + error: LudiError::TypeError(kind), + msg: msg.to_string(), + } } - fn codegen_err(msg: &'static str) -> Self { - Error::new(ErrorKind::CodeGenErr).context(msg) + pub fn codegen_err(kind: CodeGenErrorKind, msg: &str) -> Self { + Self { + error: LudiError::CodeGenError(kind), + msg: msg.to_string(), + } } } diff --git a/libludi/src/lex.rs b/libludi/src/lex.rs index e501352..d90876b 100644 --- a/libludi/src/lex.rs +++ b/libludi/src/lex.rs @@ -25,8 +25,16 @@ use unicode_segmentation::{Graphemes, UnicodeSegmentation}; use Token::*; pub type Lexer<'s> = Peekable>; -pub fn lex<'s>(s: &'s str) -> Lexer { - TokenStream::new(s).peekable() + +pub trait Lex { + fn lex(&self) -> Lexer; +} + + +impl Lex for str { + fn lex(&self) -> Lexer { + TokenStream::new(self).peekable() + } } #[derive(Clone)] diff --git a/libludi/src/lib.rs b/libludi/src/lib.rs index 4f0b8fb..e93df8b 100644 --- a/libludi/src/lib.rs +++ b/libludi/src/lib.rs @@ -17,6 +17,5 @@ pub mod ast; pub mod atomic; pub mod env; pub mod types; -pub mod pipeline; pub mod codegen; pub mod optimize; diff --git a/libludi/src/optimize/normalize.rs b/libludi/src/optimize/normalize.rs index cc0559a..c412af1 100644 --- a/libludi/src/optimize/normalize.rs +++ b/libludi/src/optimize/normalize.rs @@ -1,6 +1,5 @@ -use crate::ast::{self}; -use crate::ast::{define_ast, define_constructors, define_enum, define_nodes}; -use crate::data::DataType; +use crate::ast::{ast, define_constructors, define_enum, define_nodes}; +use crate::type::DataType; use crate::env::Name; use crate::err::{LangError, Result}; use crate::parse_err; diff --git a/libludi/src/parser.rs b/libludi/src/parser.rs index f79b70a..278c8c7 100644 --- a/libludi/src/parser.rs +++ b/libludi/src/parser.rs @@ -4,8 +4,8 @@ use crate::{ ast::*, atomic::Literal, env::Name, - err::{Error, LudiError, Result}, - lex::{lex, Lexer}, + err::{Error, LudiError, ParseErrorKind, Result}, + lex::Lexer, shape::{ArrayProps, Shape, ShapeVec}, token::{Location, Token, TokenData}, types::{Arr, Array, Atom, PrimitiveFuncType, Type}, @@ -28,15 +28,12 @@ impl Parser for Lexer<'_> { macro_rules! parse_failure { ($tokens:ident, $msg:expr) => { if let Some(bad_tok) = $tokens.peek() { - Error::at_token(bad_tok.clone(), $msg) + Error::parse_err(ParseErrorKind::Expr, $msg).at_token(bad_tok.clone()) } else { - Error::at_token( - TokenData { - token: EOF, - loc: Location { line: 0 }, - }, - $msg, - ) + Error::parse_err(ParseErrorKind::Expr, $msg).at_token(TokenData { + token: EOF, + loc: Location { line: 0 }, + }) } }; } @@ -59,12 +56,17 @@ macro_rules! expect_next { }; } +pub fn program(tokens: &mut Lexer) -> Result> { + (0..) + .map_while(|_| tokens.peek().is_some().then(|| statement(tokens))) + .collect() +} + pub fn statement(tokens: &mut Lexer) -> Result { if tokens.peek().is_some() { if match_next!(tokens, PRINT).is_some() { Ok(Stmt::Print({ let expression = expression(tokens)?; - expect_next!(tokens, SEMICOLON)?; PrintNode { expression }.into() })) // } else if match_next!(tokens, OPEN_BRACE).is_some() { @@ -194,9 +196,11 @@ fn typesignature(tokens: &mut Lexer) -> Result { token: INTEGER_LITERAL(n_str), .. }) => Some( - n_str - .parse::() - .or(Err(Error::parse_err("shape expects an unsigned int"))), + n_str.parse::().or(Err(Error::parse_err( + ParseErrorKind::FnDef, + "shape expects an unsigned int", + ) + .into())), ), _ => None, }) @@ -377,7 +381,12 @@ fn fncall(tokens: &mut Lexer) -> Result { if match_next!(tokens, CLOSE_PAREN).is_none() { loop { if args.len() > 255 { - return Err(Error::at_token(tokens.next().unwrap(), "Functions with >255 arguments are not allowed")) + return Err(Error::parse_err( + ParseErrorKind::FnCall, + "Functions with >255 arguments are not allowed", + ) + .at_token(tokens.next().unwrap()) + .into()); } let a = expression(tokens)?; args.push(a); diff --git a/libludi/src/token.rs b/libludi/src/token.rs index df05945..7f76973 100644 --- a/libludi/src/token.rs +++ b/libludi/src/token.rs @@ -1,6 +1,6 @@ use std::hash::Hash; -use crate::err::{ErrorKind, Result}; +use crate::err::{LudiError, Result}; #[derive(derive_more::Display, Debug, Clone, PartialEq, Eq)] pub enum Token { diff --git a/libludi/src/types.rs b/libludi/src/types.rs index 79eb7f3..a64d9fd 100644 --- a/libludi/src/types.rs +++ b/libludi/src/types.rs @@ -1,11 +1,10 @@ -use anyhow::Error; - +use crate::err::{Error, ParseErrorKind, TypeErrorKind}; use crate::token::TokenData; use crate::{ ast::{Expr, FuncSignature}, atomic::Literal, env::{Env, Name}, - err::{ErrorKind, LudiError, Result}, + err::{LudiError, Result}, shape::Shape, }; use std::{fmt::Display, rc::Rc, str::FromStr}; @@ -147,7 +146,7 @@ impl Type { } impl FromStr for AtomicDataType { - type Err = crate::err::Error; + type Err = anyhow::Error; fn from_str(s: &str) -> Result { match s { "u8" => Ok(AtomicDataType::UInt8), @@ -165,13 +164,19 @@ impl FromStr for AtomicDataType { "complex" => Ok(AtomicDataType::Complex), "char" => Ok(AtomicDataType::Character), "bool" => Ok(AtomicDataType::Boolean), - "()" => Err(Error::parse_err( + "()" => Err(Error::type_err(TypeErrorKind::Unsupported, "Unit type not supported in function signature", - )), - _ => Err(Error::parse_err("not a known builtin type")), + ).into()), + _ => Err(Error::type_err(crate::err::TypeErrorKind::Unknown,"not a known builtin type").into()), } } } +impl TryFrom<&str> for AtomicDataType { + type Error = anyhow::Error; + fn try_from(value: &str) -> Result { + FromStr::from_str(value) + } +} impl Into for Atom { fn into(self) -> Array { @@ -183,7 +188,7 @@ impl Into for Atom { } impl TryFrom<&str> for Atom { - type Error = Error; + type Error = anyhow::Error; fn try_from(value: &str) -> Result { use std::str::FromStr; if let Ok(atomic_type) = AtomicDataType::from_str(value) { @@ -196,7 +201,7 @@ impl TryFrom<&str> for Atom { } impl TryFrom<&str> for PrimitiveFuncType { - type Error = crate::err::Error; + type Error = anyhow::Error; fn try_from(value: &str) -> crate::err::Result { use crate::token::Token; match value { @@ -212,7 +217,7 @@ impl TryFrom<&str> for PrimitiveFuncType { "iota" => Ok(Self::Iota), "slice" => Ok(Self::Slice), "scatter" => Ok(Self::Scatter), - _ => Err(Error::parse_err("tried to parse a primitive function")), + _ => Err(Error::parse_err(ParseErrorKind::FnCall, "tried to parse a primitive function").into()), } } } @@ -356,7 +361,7 @@ pub mod typed_ast { } } pub mod typecheck { - use crate::ast; + use crate::{ast, err::ParseErrorKind}; use super::*; use itertools::Itertools; @@ -410,12 +415,10 @@ pub mod typecheck { Type::Array(Array::ArrayRef(_array_ref)) => todo!(), } } else { - // return type_err!() - panic!() + Err(Error::type_err(TypeErrorKind::TypeMismatch, &format!("Found non-conforming type in frame; expected {}", t1)).into()) } } else { - // return type_err!() - panic!() + Err(Error::type_err(TypeErrorKind::Unsupported, "empty frame").into()) } } Self::Term(node) => { @@ -428,7 +431,7 @@ pub mod typecheck { let region = match node.region { Some(body) => body.type_check(table)?, None => { - return Err(Error::compile_err("dangling let body not allowed here")) + return Err(Error::parse_err(ParseErrorKind::LetExpr, "dangling let body not allowed here").into()) } }; // No restrictions on region type Ok(typed_ast::let_node( @@ -439,7 +442,7 @@ pub mod typecheck { )) } Self::FnDef(node) => { - // TODO: IMPORTANT: Some kind of pi or sigma expression? + // TODO: IMPORTANT: Some kind of pi or sigma? let body = node.body.type_check(table)?; if body.get_type() != node.signature.ret[0] { return Err(anyhow::anyhow!( @@ -467,30 +470,48 @@ pub mod typecheck { Err(anyhow::anyhow!("type error - not a function!")) } } - ast::Callee::Primitive(prim) => todo!(), + ast::Callee::Primitive(prim) => match prim { + PrimitiveFuncType::Add => todo!(), + PrimitiveFuncType::Sub => todo!(), + PrimitiveFuncType::Mul => todo!(), + PrimitiveFuncType::Div => todo!(), + PrimitiveFuncType::Mod => todo!(), + PrimitiveFuncType::Neg => todo!(), + PrimitiveFuncType::Inv => todo!(), + PrimitiveFuncType::Eq => todo!(), + PrimitiveFuncType::Ne => todo!(), + PrimitiveFuncType::Gt => todo!(), + PrimitiveFuncType::GtEq => todo!(), + PrimitiveFuncType::Lt => todo!(), + PrimitiveFuncType::LtEq => todo!(), + PrimitiveFuncType::Or => todo!(), + PrimitiveFuncType::And => todo!(), + PrimitiveFuncType::Not => todo!(), + PrimitiveFuncType::If => todo!(), + PrimitiveFuncType::Reshape => todo!(), + PrimitiveFuncType::Reduce => todo!(), + PrimitiveFuncType::Scan => todo!(), + PrimitiveFuncType::Fold => todo!(), + PrimitiveFuncType::Trace => todo!(), + PrimitiveFuncType::Reverse => todo!(), + PrimitiveFuncType::Filter => todo!(), + PrimitiveFuncType::Append => todo!(), + PrimitiveFuncType::Rotate => todo!(), + PrimitiveFuncType::Iota => todo!(), + PrimitiveFuncType::Slice => todo!(), + PrimitiveFuncType::Scatter => todo!(), + PrimitiveFuncType::IntToFloat => todo!(), + PrimitiveFuncType::IntToBool => todo!(), + PrimitiveFuncType::BoolToInt => todo!(), + PrimitiveFuncType::FloatToInt => todo!(), + }, } } } } } - - // impl TypeCheck for PrimitiveFuncType { - // fn type_check(self, _table: &mut TypeEnv) -> Result { - // let ty = Type::Atom(Atom::Func(match self { - // Self::Add => Func { - // parameters: - // }.into(), - // // Self::Sub => (), - // // Self::Mul => (), - // // Self::Div => (), - // _ => todo!() - // })); - // Ok(ty) - // } - // } - impl TryInto for FuncSignature { - type Error = Error; + type Error = anyhow::Error; fn try_into(self) -> Result { let parameters: Vec = self .args diff --git a/libludi/tests/codegen_test.rs b/libludi/tests/codegen_test.rs index 9f2847c..800976c 100644 --- a/libludi/tests/codegen_test.rs +++ b/libludi/tests/codegen_test.rs @@ -1,8 +1,9 @@ use libludi::{ - codegen::writer::{load_builtin_dialects, CodeWriter}, - lex::lex, + codegen::{load_builtin_dialects, CodeWriter}, + lex::Lex, parser::expression, types::{typecheck::TypeCheck, TypeEnv}, + err::Result, }; use melior::{ dialect::{arith, func}, @@ -14,32 +15,40 @@ use melior::{ }, }; -fn verify_codegen(program: &str) -> anyhow::Result<()> { - let expr = expression(&mut lex(program))?.type_check(&mut TypeEnv::new())?; +fn verify_codegen(program: &str) -> Result { + let expr = expression(&mut program.lex())?.type_check(&mut TypeEnv::new())?; let writer = CodeWriter::new(); - let module = writer.write_ast(expr)?; + let module = writer.write_ast(&expr)?; println!( "{}", module .as_operation() .to_string_with_flags(OperationPrintingFlags::new())? ); - assert!(module.as_operation().verify()); + Ok(module.as_operation().verify()) +} + +#[test] +fn return_constant() -> anyhow::Result<()> { + assert!(verify_codegen("fn main() -> i64 { 1 }")?); Ok(()) } #[test] -fn atom_literal() -> anyhow::Result<()> { - verify_codegen("1") +fn identity_func() -> anyhow::Result<()> { + assert!(verify_codegen("fn id(a) { a }")?); + Ok(()) } + #[test] fn fn_simple() -> anyhow::Result<()> { - verify_codegen( + assert!(verify_codegen( "fn add(x, y) { x + y }" - ) + )?); + Ok(()) } #[test] @@ -104,7 +113,6 @@ fn melior_example_simple() -> anyhow::Result<()> { .as_operation() .to_string_with_flags(OperationPrintingFlags::new())? ); - // assert!(module.as_operation().verify()); - panic!(); + assert!(module.as_operation().verify()); Ok(()) } diff --git a/libludi/tests/parser_test.rs b/libludi/tests/parser_test.rs index 01cbde0..d189b4c 100644 --- a/libludi/tests/parser_test.rs +++ b/libludi/tests/parser_test.rs @@ -1,7 +1,7 @@ use libludi::ast::*; use libludi::atomic::Literal; use libludi::env::Name; -use libludi::lex::lex; +use libludi::lex::Lex; use libludi::parser::*; use libludi::shape::Shape; use libludi::token::*; @@ -28,13 +28,13 @@ fn atom_float(lf: &str) -> Expr { fn scan_arithmetic_tokens() { let src = "+ - * /"; let src2 = "+-*/"; - let s = lex(src); + let s = src.lex(); use libludi::token::Token::*; assert_eq!( s.map(|tok| tok.token).collect::>(), vec![PLUS, MINUS, STAR, SLASH] ); - let s = lex(src2); + let s = src2.lex(); assert_eq!( s.map(|tok| tok.token).collect::>(), vec![PLUS, MINUS, STAR, SLASH] @@ -48,21 +48,21 @@ fn scan_number_literal() { let src3 = "2.0 + 0.3"; use libludi::token::Token::*; assert_eq!( - lex(&src).next().unwrap(), + src.lex().next().unwrap(), TokenData { token: INTEGER_LITERAL("12".into()), loc: Location { line: 1 } } ); assert_eq!( - lex(&src2).next().unwrap(), + src2.lex().next().unwrap(), TokenData { token: FLOAT_LITERAL("4712.08".into()), loc: Location { line: 1 } } ); assert_eq!( - lex(&src3).map(|tok| tok.token).collect::>(), + src3.lex().map(|tok| tok.token).collect::>(), vec![ FLOAT_LITERAL("2.0".into()), PLUS, @@ -75,7 +75,7 @@ fn scan_number_literal() { fn scan_string_literal() { let src = "\"Hello,\" + \" world!\""; use libludi::token::Token::*; - let toks: Vec = lex(src).map(|d| d.token).collect::>(); + let toks: Vec = src.lex().map(|d| d.token).collect::>(); assert_eq!( toks.as_slice(), &[ @@ -88,7 +88,7 @@ fn scan_string_literal() { #[test] fn let_basic() -> anyhow::Result<()> { - let expr = expression(&mut lex("let a = 10;"))?; + let expr = expression(&mut "let a = 10;".lex())?; assert_eq!( expr, let_node( @@ -105,7 +105,7 @@ fn let_basic() -> anyhow::Result<()> { #[test] fn let_with_body() -> anyhow::Result<()> { - let expr = expression(&mut lex("let a = 2 in a+2"))?; + let expr = expression(&mut "let a = 2 in a+2".lex())?; assert_eq!( expr, Expr::Let( @@ -141,9 +141,7 @@ fn let_with_body() -> anyhow::Result<()> { } #[test] fn let_with_body_complex() -> anyhow::Result<()> { - let expr = expression(&mut lex( - "let a = 2 in let b = 4 in let c = a + b in foo(a, b, c)", - ))?; + let expr = expression(&mut "let a = 2 in let b = 4 in let c = a + b in foo(a, b, c)".lex())?; assert_eq!( expr, Expr::Let( @@ -252,7 +250,7 @@ fn let_with_body_complex() -> anyhow::Result<()> { fn scan_stmts() { let src = "print 2+2;"; use libludi::token::Token::*; - let toks: Vec = lex(src).map(|d| d.token).collect::>(); + let toks: Vec = src.lex().map(|d| d.token).collect::>(); assert_eq!( toks.as_slice(), &[ @@ -269,8 +267,8 @@ fn scan_stmts() { fn binary_expr1() -> anyhow::Result<()> { use itertools::Itertools; let src = "5 < 3"; - dbg!(lex(src).into_iter().collect_vec()); - let s: Expr = expression(&mut lex(src))?; + dbg!(src.lex().into_iter().collect_vec()); + let s: Expr = expression(&mut src.lex())?; let s_test = Expr::FnCall( FnCallNode { callee: Callee::Primitive(PrimitiveFuncType::Lt), @@ -286,7 +284,7 @@ fn binary_expr1() -> anyhow::Result<()> { #[test] fn binary_expr2() -> anyhow::Result<()> { use Expr::*; - let s = expression(&mut lex("75.4 + 1.006"))?; + let s = expression(&mut "75.4 + 1.006".lex())?; let s_test = FnCall( FnCallNode { @@ -302,7 +300,7 @@ fn binary_expr2() -> anyhow::Result<()> { #[test] fn test_binary_operation() -> anyhow::Result<()> { use std::str::FromStr; - let prg = expression(&mut lex("a + b * c - d"))?; + let prg = expression(&mut "a + b * c - d".lex())?; assert_eq!( prg, Expr::FnCall( @@ -360,7 +358,7 @@ fn test_binary_operation() -> anyhow::Result<()> { #[test] fn lambda_expr() -> anyhow::Result<()> { use std::str::FromStr; - let prg = expression(&mut lex("|a[3], b[3]| -> [3] { a + b }"))?; + let prg = expression(&mut "|a[3], b[3]| -> [3] { a + b }".lex())?; assert_eq!( prg, Expr::FnDef( @@ -416,7 +414,7 @@ fn lambda_expr() -> anyhow::Result<()> { #[test] fn fndef_complex_body() -> anyhow::Result<()> { use std::str::FromStr; - let prg = expression(&mut lex("fn foo(x[5], y[5]) -> [5] { (x + y) * 2 }"))?; + let prg = expression(&mut "fn foo(x[5], y[5]) -> [5] { (x + y) * 2 }".lex())?; assert_eq!( prg, Expr::Let( @@ -497,7 +495,7 @@ fn fndef_complex_body() -> anyhow::Result<()> { #[test] fn fncall() -> anyhow::Result<()> { use std::str::FromStr; - let prg = expression(&mut lex("foo(1, 2)"))?; + let prg = expression(&mut "foo(1, 2)".lex())?; assert_eq!( prg, Expr::FnCall( @@ -537,7 +535,7 @@ fn fncall() -> anyhow::Result<()> { #[test] fn nested_fncall() -> anyhow::Result<()> { use std::str::FromStr; - let prg = expression(&mut lex("foo(bar(1), baz(2))"))?; + let prg = expression(&mut "foo(bar(1), baz(2))".lex())?; assert_eq!( prg, Expr::FnCall( @@ -598,9 +596,7 @@ fn nested_fncall() -> anyhow::Result<()> { } #[test] fn curried_fncalls() -> anyhow::Result<()> { - - - let prg = expression(&mut lex("foo()()()"))?; + let prg = expression(&mut "foo()()()".lex())?; assert_eq!( prg, @@ -616,15 +612,15 @@ fn curried_fncalls() -> anyhow::Result<()> { } .into() )), - args: vec![] + args: vec![] } .into() )), - args: vec![] + args: vec![] } .into() )), - args: vec![] + args: vec![] } .into() ) @@ -632,11 +628,9 @@ fn curried_fncalls() -> anyhow::Result<()> { Ok(()) } - #[test] fn parse_array1() -> anyhow::Result<()> { - let mut tokens = lex("[ 1 2 3 2 1 ]"); - let expr = expression(&mut tokens)?; + let expr = expression(&mut "[ 1 2 3 2 1 ]".lex())?; assert_eq!( expr, @@ -653,7 +647,7 @@ fn parse_array1() -> anyhow::Result<()> { #[test] fn frame1() -> anyhow::Result<()> { - let expr = expression(&mut lex("let a = 1; let b = 2; [ a b ]"))?; + let expr = expression(&mut "let a = 1; let b = 2; [ a b ]".lex())?; assert_eq!( expr, let_node( @@ -696,11 +690,14 @@ fn frame1() -> anyhow::Result<()> { // ), #[test] fn diff_square() -> anyhow::Result<()> { - let expr = expression(&mut lex(" + let expr = expression( + &mut " fn diff_square(x, y) { x*x - y*y } - "))?; + " + .lex(), + )?; assert_eq!( expr, Expr::Let( @@ -803,11 +800,11 @@ fn diff_square() -> anyhow::Result<()> { #[test] fn diff_square_typesignatures() -> anyhow::Result<()> { - let expr = expression(&mut lex(" + let expr = expression(&mut " fn diff_square(x[u32], y[u32]) -> [u32] { x*x - y*y } - "))?; + ".lex())?; assert_eq!( expr, Expr::Let( @@ -916,7 +913,7 @@ fn diff_square_typesignatures() -> anyhow::Result<()> { #[test] fn condition1() -> anyhow::Result<()> { - let expr = expression(&mut lex("if a>b { a } else { b }"))?; + let expr = expression(&mut "if a>b { a } else { b }".lex())?; assert_eq!( expr, fn_call_node( @@ -936,7 +933,7 @@ fn condition1() -> anyhow::Result<()> { #[test] fn condition2() -> anyhow::Result<()> { - let expr = expression(&mut lex("if a>b {a} else { if b>c {b} else {c} }"))?; + let expr = expression(&mut "if a>b {a} else { if b>c {b} else {c} }".lex())?; assert_eq!( expr, fn_call_node( diff --git a/libludi/tests/types_test.rs b/libludi/tests/types_test.rs index 48709f1..d21c43b 100644 --- a/libludi/tests/types_test.rs +++ b/libludi/tests/types_test.rs @@ -1,4 +1,4 @@ -use libludi::{lex::lex, parser::expression, shape::Shape, types::*}; +use libludi::{lex::Lex, parser::expression, shape::Shape, types::*}; use pretty_assertions::assert_eq; use typecheck::TypeCheck; @@ -8,13 +8,13 @@ fn basic_types() -> anyhow::Result<()> { let prg2 = "true"; let prg3 = "1.0"; let mut table = TypeEnv::new(); - let t_expr1 = expression(&mut lex(prg1))? + let t_expr1 = expression(&mut prg1.lex())? .type_check(&mut table)? .get_type(); - let t_expr2 = expression(&mut lex(prg2))? + let t_expr2 = expression(&mut prg2.lex())? .type_check(&mut table)? .get_type(); - let t_expr3 = expression(&mut lex(prg3))? + let t_expr3 = expression(&mut prg3.lex())? .type_check(&mut table)? .get_type(); assert_eq!(Type::Atom(Atom::Literal(AtomicDataType::Int64)), t_expr1); @@ -27,8 +27,8 @@ fn basic_types() -> anyhow::Result<()> { fn letexpr_types() -> anyhow::Result<()> { let prg = "let a = true in a"; let mut table = TypeEnv::new(); - // let expr = expression(&mut lex(prg))?; - let t_expr = expression(&mut lex(prg))? + // let expr = expression(&mut prg))?; + let t_expr = expression(&mut prg.lex())? .type_check(&mut table)? .get_type(); assert_eq!(Type::Atom(Atom::Literal(AtomicDataType::Boolean)), t_expr); @@ -38,13 +38,13 @@ fn letexpr_types() -> anyhow::Result<()> { #[test] fn frame_type() -> anyhow::Result<()> { - let expr1 = expression(&mut lex("[[1 2] [3 4]]"))?; - let expr2 = expression(&mut lex(" + let expr1 = expression(&mut "[[1 2] [3 4]]".lex())?; + let expr2 = expression(&mut " let a = true in let b = false in [[a b] [b a]] - "))?; + ".lex())?; let ty1 = expr1.type_check(&mut TypeEnv::new())?.get_type(); let ty2 = expr2.type_check(&mut TypeEnv::new())?.get_type(); assert_eq!( @@ -66,11 +66,11 @@ fn frame_type() -> anyhow::Result<()> { #[test] fn fn_def_type() -> anyhow::Result<()> { - let expr = expression(&mut lex(" + let expr = expression(&mut " fn diff_square(x[u32], y[u32]) -> [u32] { x*x - y*y } - "))?; + ".lex())?; let ty = expr.type_check(&mut TypeEnv::new())?.get_type(); // assert_eq!( // ty, diff --git a/ludic/Cargo.toml b/ludic/Cargo.toml new file mode 100644 index 0000000..2008794 --- /dev/null +++ b/ludic/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "ludic" +version = "0.1.0" +edition = "2021" + +[dependencies] +clap = "4.5.21" +libludi = { version = "0.1.0", path = "../libludi" } diff --git a/ludic/src/cli.rs b/ludic/src/cli.rs new file mode 100644 index 0000000..ab76832 --- /dev/null +++ b/ludic/src/cli.rs @@ -0,0 +1,9 @@ +use crate::pipeline::{BasicCompiler, Pipeline}; + +pub fn cli() { + let args: Vec = std::env::args().collect(); + assert!(args.len() == 2); + let file = std::fs::read_to_string(args[1].clone()).expect("no such file"); + let output = BasicCompiler::apply(file).expect("failed"); + println!("{}", output); +} diff --git a/ludic/src/main.rs b/ludic/src/main.rs new file mode 100644 index 0000000..a817b03 --- /dev/null +++ b/ludic/src/main.rs @@ -0,0 +1,6 @@ +mod pipeline; +mod cli; + +fn main() { + cli::cli() +} diff --git a/ludic/src/pipeline.rs b/ludic/src/pipeline.rs new file mode 100644 index 0000000..5333716 --- /dev/null +++ b/ludic/src/pipeline.rs @@ -0,0 +1,23 @@ +use libludi::lex::Lex; +use libludi::parser::Parser; +use libludi::codegen::WriteMLIR; +use libludi::types::typecheck::TypeCheck; +use libludi::err::Result; +use libludi::types::TypeEnv; + +pub trait Pipeline { + fn apply(source: Source) -> Result; +} + +pub struct BasicCompiler; + +impl Pipeline for BasicCompiler { + fn apply(source: String) -> Result { + Ok(source + .lex() + .parse()? + // .optimize() + .type_check(&mut TypeEnv::new())? + .write()) + } +}