Skip to content

Commit

Permalink
update 2025
Browse files Browse the repository at this point in the history
  • Loading branch information
tw-ilson committed Jan 9, 2025
1 parent cef1fe1 commit f9d1060
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 51 deletions.
2 changes: 2 additions & 0 deletions libludi/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ pub struct FuncSignature {
pub ret: Vec<Type>,
}

// --------------------------------------------------------
// AST productions for the lang of format:
// meta-symbol {
// production { attributes... }
Expand All @@ -138,6 +139,7 @@ ast! {
// | Condition
}
}
// --------------------------------------------------------

impl Display for FuncSignature {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Expand Down
3 changes: 2 additions & 1 deletion libludi/src/shape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ impl Shape {
}
}

// checks if the other shape is a valid subshape of this
// checks if the other shape is a valid subshape of this (shape agreement)
// returns the shape of the extra dimensions (the shape of the application)
// ex: [2 1] <= [2 2 1], but [2 1] !<= [2 1 1]
pub fn subshape_fit(&self, other: &Self) -> Option<&[usize]> {
if let Some(rank_diff) = self.rank().checked_sub(other.rank()) {
Expand Down
208 changes: 179 additions & 29 deletions libludi/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ use std::{fmt::Display, rc::Rc, str::FromStr};

pub type TypeEnv = Env<Type>;

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Ref<T> {
name: Name,
ty: Box<T>,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Sort {
Dim,
Expand Down Expand Up @@ -47,13 +53,13 @@ pub struct Arr {

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Array {
ArrayRef(Name),
ArrayRef(Ref<Array>),
Arr(Arr),
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Atom {
AtomRef(Name),
AtomRef(Ref<Atom>),
Func(Rc<Func>),
Forall(Rc<Forall>),
Pi(Rc<Pi>),
Expand Down Expand Up @@ -93,8 +99,48 @@ pub enum AtomicDataType {
Boolean,
}

// common patterns for atomic data type matching:

macro_rules! numeric {
() => {
UInt8
| Int8
| UInt16
| Int16
| UInt32
| Int32
| UInt64
| Int64
| BFloat16
| Float16
| Float32
| Float64
| Complex
};
}
macro_rules! integer {
() => {
UInt8 | Int8 | UInt16 | Int16 | UInt32 | Int32 | UInt64 | Int64
};
}
macro_rules! integer_signed {
() => {
Int8 | Int16 | Int32 | Int64
};
}
macro_rules! integer_unsigned {
() => {
UInt8 | UInt16 | UInt32 | UInt64
};
}
macro_rules! float {
() => {
BFloat16 | Float16 | Float32 | Float64
};
}

// define primitives
#[derive(Clone, Eq, PartialEq, Debug)]
#[derive(derive_more::Display, Clone, Eq, PartialEq, Debug)]
pub enum PrimitiveFuncType {
// arithmetic
Add,
Expand Down Expand Up @@ -164,10 +210,16 @@ impl FromStr for AtomicDataType {
"complex" => Ok(AtomicDataType::Complex),
"char" => Ok(AtomicDataType::Character),
"bool" => Ok(AtomicDataType::Boolean),
"()" => Err(Error::type_err(TypeErrorKind::Unsupported,
"()" => Err(Error::type_err(
TypeErrorKind::Unsupported,
"Unit type not supported in function signature",
).into()),
_ => Err(Error::type_err(crate::err::TypeErrorKind::Unknown,"not a known builtin type").into()),
)
.into()),
_ => Err(Error::type_err(
crate::err::TypeErrorKind::Unknown,
"not a known builtin type",
)
.into()),
}
}
}
Expand Down Expand Up @@ -217,7 +269,11 @@ impl TryFrom<&str> for PrimitiveFuncType {
"iota" => Ok(Self::Iota),
"slice" => Ok(Self::Slice),
"scatter" => Ok(Self::Scatter),
_ => Err(Error::parse_err(ParseErrorKind::FnCall, "tried to parse a primitive function").into()),
_ => Err(Error::parse_err(
ParseErrorKind::FnCall,
"tried to parse a primitive function",
)
.into()),
}
}
}
Expand All @@ -239,8 +295,17 @@ impl Display for Type {
}
}

impl Atom {
fn upgrade(self) -> Array {
Array::Arr(Arr {
element: self,
shape: Shape::new(&[]),
})
}
}

pub trait GetType {
fn get_type(&self) -> Type;
fn get_type(&self) -> &Type;
}

pub mod typed_ast {
Expand Down Expand Up @@ -272,11 +337,12 @@ pub mod typed_ast {
$($variant {
$($childname: $childtype),*
}),+});

impl GetType for $base_name {
fn get_type(&self) -> Type {
fn get_type(&self) -> &Type {
match self {
$(
Self::$variant {ty, ..} => ty.clone(),
Self::$variant {ty, ..} => ty,
)+
}
}
Expand Down Expand Up @@ -361,17 +427,24 @@ pub mod typed_ast {
}
}
pub mod typecheck {
use crate::{ast, err::ParseErrorKind};

use super::*;
use super::{
typed_ast, Arr, Array, Atom, AtomicDataType, Func, GetType, PrimitiveFuncType, Type,
TypeEnv,
};
use crate::{
ast::{self, FuncSignature},
atomic::Literal,
err::{Error, ParseErrorKind, Result, TypeErrorKind},
shape::Shape,
};
use itertools::Itertools;
use typed_ast::*;
use std::rc::Rc;

pub trait TypeCheck<T> {
fn type_check(self, table: &mut TypeEnv) -> Result<T>;
}

impl TypeCheck<typed_ast::TypedExpr> for Expr {
impl TypeCheck<typed_ast::TypedExpr> for ast::Expr {
fn type_check(self, table: &mut TypeEnv) -> Result<typed_ast::TypedExpr> {
match self {
Self::AtomLiteral(node) => {
Expand All @@ -398,24 +471,28 @@ pub mod typecheck {
match t1 {
Type::Atom(atom) => {
let ty = Type::Array(Array::Arr(Arr {
element: atom,
element: atom.clone(),
shape: Shape::new(&[frame_n]),
}));
Ok(typed_ast::frame_node(ty, checked_arr))
}
Type::Array(Array::Arr(arr)) => {
// this is the case for n-d array literal
let ty = Type::Array(Array::Arr(Arr {
element: arr.element,
shape: Shape::new(&[frame_n]).concat(arr.shape),
element: arr.element.clone(),
shape: Shape::new(&[frame_n]).concat(arr.shape.clone()),
}));
Ok(typed_ast::frame_node(ty, checked_arr)) // is it necessary to retain the
// inner information?
}
Type::Array(Array::ArrayRef(_array_ref)) => todo!(),
}
} else {
Err(Error::type_err(TypeErrorKind::TypeMismatch, &format!("Found non-conforming type in frame; expected {}", t1)).into())
Err(Error::type_err(
TypeErrorKind::TypeMismatch,
&format!("Found non-conforming type in frame; expected {}", t1),
)
.into())
}
} else {
Err(Error::type_err(TypeErrorKind::Unsupported, "empty frame").into())
Expand All @@ -427,24 +504,27 @@ pub mod typecheck {
}
Self::Let(node) => {
let initializer = node.initializer.type_check(table)?;
table.put(node.name.clone(), initializer.get_type());
table.put(node.name.clone(), initializer.get_type().clone());
let region = match node.region {
Some(body) => body.type_check(table)?,
None => {
return Err(Error::parse_err(ParseErrorKind::LetExpr, "dangling let body not allowed here").into())
return Err(Error::parse_err(
ParseErrorKind::LetExpr,
"dangling let body not allowed here",
)
.into()); // alternatively, return Unit
}
}; // No restrictions on region type
Ok(typed_ast::let_node(
region.get_type(),
region.get_type().clone(),
node.name,
initializer,
region,
))
}
Self::FnDef(node) => {
// TODO: IMPORTANT: Some kind of pi or sigma?
let body = node.body.type_check(table)?;
if body.get_type() != node.signature.ret[0] {
if *body.get_type() != node.signature.ret[0] {
return Err(anyhow::anyhow!(
"type error - expected function to return X"
));
Expand All @@ -454,6 +534,7 @@ pub mod typecheck {
}
Self::FnCall(node) => {
// look up type of callee
use AtomicDataType::*;
match node.callee {
ast::Callee::Expression(call_expr) => {
let typed_call_expr = call_expr.type_check(table)?;
Expand All @@ -464,17 +545,49 @@ pub mod typecheck {
.args
.into_iter()
.map(|arg_expr| arg_expr.type_check(table))
.collect::<Result<Vec<TypedExpr>>>()?;
.collect::<Result<Vec<typed_ast::TypedExpr>>>()?;
Ok(typed_ast::fn_call_node(ty, callee, args))
} else {
Err(anyhow::anyhow!("type error - not a function!"))
}
}
ast::Callee::Primitive(prim) => match prim {
PrimitiveFuncType::Add => todo!(),
PrimitiveFuncType::Sub => todo!(),
PrimitiveFuncType::Mul => todo!(),
PrimitiveFuncType::Div => todo!(),
PrimitiveFuncType::Add
| PrimitiveFuncType::Sub
| PrimitiveFuncType::Mul
| PrimitiveFuncType::Div => {
let mut arg_iter = node.args.into_iter();
let (arg1, arg2) = (
arg_iter.next().expect("missing arg 1?").type_check(table)?,
arg_iter.next().expect("missing arg 2?").type_check(table)?,
);
let (ty1, ty2) = (arg1.get_type(), arg2.get_type());
let frame = ty1.agreement(ty2, table);
if !frame.is_some()
|| !matches!(
ty1,
Type::Atom(Atom::Literal(numeric!()))
| Type::Array(Array::Arr(Arr {
element: Atom::Literal(numeric!()),
..
}))
)
{
return Err(Error::type_err(
TypeErrorKind::TypeMismatch,
&format!(
"types for {} must match: got {:?}, {:?}",
prim, ty1, ty2
),
)
.into());
}
Ok(typed_ast::fn_call_node(
ty1.clone(),
typed_ast::Callee::Primitive(prim),
vec![arg1, arg2],
))
}
PrimitiveFuncType::Mod => todo!(),
PrimitiveFuncType::Neg => todo!(),
PrimitiveFuncType::Inv => todo!(),
Expand Down Expand Up @@ -510,6 +623,43 @@ pub mod typecheck {
}
}
}

pub trait FrameAgreement {
// returns the prinicipal frame if there is agreement
fn agreement(&self, other: &Self, table: &mut TypeEnv) -> Option<Shape>;
}
impl FrameAgreement for Type {
fn agreement(&self, other: &Self, table: &mut TypeEnv) -> Option<Shape> {
match (self, other) {
(Type::Array(a), Type::Array(b)) => {
let arr_a: &Arr = match a {
Array::ArrayRef(_name) => todo!(),
Array::Arr(arr) => arr,
};
let arr_b: &Arr = match b {
Array::ArrayRef(_name) => todo!(),
Array::Arr(arr) => arr,
};
arr_a
.shape
.subshape_fit(&arr_b.shape)
.or(arr_b.shape.subshape_fit(&arr_a.shape))
.map(Shape::new)
.filter(|_| arr_a.element == arr_b.element)
}
(Type::Atom(a), Type::Atom(b)) => (a == b).then_some(Shape::new(&[])),
(Type::Atom(atom), array @ Type::Array(..))
| (array @ Type::Array(..), Type::Atom(atom)) => {
FrameAgreement::agreement(&Type::Array(atom.clone().upgrade()), &array, table)
}
}
}
}
impl std::cmp::PartialOrd for Array {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
todo!()
}
}
impl TryInto<Func> for FuncSignature {
type Error = anyhow::Error;
fn try_into(self) -> Result<Func> {
Expand Down
Loading

0 comments on commit f9d1060

Please sign in to comment.