diff --git a/aeon/backend/evaluator.py b/aeon/backend/evaluator.py index 47f1e02e..c23beeee 100644 --- a/aeon/backend/evaluator.py +++ b/aeon/backend/evaluator.py @@ -13,24 +13,35 @@ from aeon.core.terms import Term from aeon.core.terms import Var from aeon.utils.name import Name +from aeon.decorators.api import Metadata +from aeon.llvm.core import LLVMPipeline real_eval = eval class EvaluationContext: variables: dict[Name, Any] - - def __init__(self, prev: dict[Name, Any] | None = None): + metadata: Metadata | None + pipeline: LLVMPipeline | None + + def __init__( + self, + prev: dict[Name, Any] | None = None, + metadata: Metadata | None = None, + pipeline: LLVMPipeline | None = None, + ): if prev: self.variables = {k: v for (k, v) in prev.items()} else: self.variables = {} + self.metadata = metadata + self.pipeline = pipeline def with_var(self, name: Name, value: Any): assert isinstance(name, Name) v = self.variables.copy() v.update({name: value}) - return EvaluationContext(v) + return EvaluationContext(v, metadata=self.metadata, pipeline=self.pipeline) def get(self, name: Name): return self.variables[name] @@ -76,6 +87,23 @@ def eval(t: Term, ctx: EvaluationContext = EvaluationContext()) -> Any: case Let(var_name, var_value, body): return eval(body, ctx.with_var(var_name, eval(var_value, ctx))) case Rec(var_name, _, var_value, body): + found_llvm = False + if ctx.pipeline and ctx.metadata: + name_str = var_name.name + for k, v in ctx.metadata.items(): + k_name = k.name if isinstance(k, Name) else str(k) + if k_name == name_str and v.get("llvm"): + found_llvm = True + break + + if found_llvm: + try: + v = ctx.pipeline.get_curried_function(var_name) + if v is not None: + return eval(body, ctx.with_var(var_name, v)) + except Exception: + pass + if isinstance(var_value, Abstraction): fun = var_value @@ -87,7 +115,7 @@ def v(x): else: v = eval(var_value, ctx) - return eval(t.body, ctx.with_var(t.var_name, v)) + return eval(body, ctx.with_var(var_name, v)) case If(cond, then, otherwise): c = eval(cond, ctx) if c: diff --git a/aeon/core/types.py b/aeon/core/types.py index 22edf846..3b519832 100644 --- a/aeon/core/types.py +++ b/aeon/core/types.py @@ -159,8 +159,10 @@ def __hash__(self) -> int: t_int = TypeConstructor(Name("Int", 0), []) t_float = TypeConstructor(Name("Float", 0), []) t_string = TypeConstructor(Name("String", 0), []) +t_tensor = TypeConstructor(Name("Tensor", 0), []) +t_gpu_config = TypeConstructor(Name("GpuConfig", 0), []) -builtin_core_types = [t_unit, t_bool, t_int, t_float, t_string] +builtin_core_types = [t_unit, t_bool, t_int, t_float, t_string, t_tensor, t_gpu_config] top = Top() diff --git a/aeon/decorators/__init__.py b/aeon/decorators/__init__.py index 529315dd..fafdde00 100644 --- a/aeon/decorators/__init__.py +++ b/aeon/decorators/__init__.py @@ -9,8 +9,9 @@ def fun(...) { ... } eventual complementary definitions. """ -from aeon.decorators.api import DecoratorType -from aeon.decorators.api import Metadata +from aeon.decorators.api import Metadata, DecoratorType +from aeon.llvm.decorators.gpu import gpu +from aeon.llvm.decorators.llvm import llvm from aeon.sugar.program import Definition from aeon.synthesis.decorators import ( minimize_int, @@ -40,6 +41,8 @@ def fun(...) { ... } "error_fitness": error_fitness, "disable_control_flow": disable_control_flow, "prompt": prompt, + "llvm": llvm, + "gpu": gpu, "csv_data": csv_data, "csv_file": csv_file, } @@ -53,7 +56,7 @@ def apply_decorators(fun: Definition, metadata: Metadata) -> tuple[Definition, l for decorator in fun.decorators: dname = decorator.name.name if dname not in decorators_environment: - raise Exception(f"Unknown decorator named {dname}, in function {fun.name}.") + raise Exception(f"Unknown decorator named {dname}, in function {fun.name.pretty()}.") decorator_processor = decorators_environment[dname] (fun, extra, metadata) = decorator_processor(decorator.macro_args, fun, metadata) total_extra.extend(extra) diff --git a/aeon/facade/driver.py b/aeon/facade/driver.py index a1c25eb9..0e902861 100644 --- a/aeon/facade/driver.py +++ b/aeon/facade/driver.py @@ -32,6 +32,10 @@ from aeon.utils.name import Name from aeon.utils.pprint import pretty_print_node from aeon.utils.time_utils import RecordTime +from aeon.llvm.cpu.pipeline import CPULLVMPipeline +from aeon.llvm.cpu.executor import CPULLVMExecutionEngine +from aeon.llvm.cpu.converter import CPULLVMIRGenerator +from aeon.llvm.cpu.lowerer import CPULLVMLowerer def read_file(filename: str) -> str: @@ -107,17 +111,25 @@ def parse(self, filename: str = None, aeon_code: str = None) -> Iterable[AeonErr return type_errors with RecordTime("Preparing execution env"): - evaluation_ctx = EvaluationContext(evaluation_vars) + executor = CPULLVMExecutionEngine() + generator = CPULLVMIRGenerator() + lowerer = CPULLVMLowerer() + pipeline = CPULLVMPipeline(executor, generator, lowerer, metadata=metadata) + evaluation_ctx = EvaluationContext(evaluation_vars, metadata=metadata, pipeline=pipeline) self.metadata = metadata self.core = core_ast_anf self.typing_ctx = typing_ctx self.evaluation_ctx = evaluation_ctx + + with RecordTime("LLVM compilation"): + pipeline.compile(self.core) + return [] - def run(self) -> None: + def run(self) -> Any: with RecordTime("Evaluation"): - eval(self.core, self.evaluation_ctx) + return eval(self.core, self.evaluation_ctx) def has_synth(self) -> bool: with RecordTime("DetectSynthesis"): diff --git a/aeon/llvm/core.py b/aeon/llvm/core.py new file mode 100644 index 00000000..571f1ce6 --- /dev/null +++ b/aeon/llvm/core.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +from dataclasses import dataclass +from abc import ABC, abstractmethod +from typing import Dict, Any, List +from aeon.core.terms import Term +from aeon.utils.name import Name +from aeon.llvm.llvm_ast import LLVMTerm, LLVMType + + +class LLVMBackendError(Exception): + pass + + +class LLVMValidationError(LLVMBackendError): + pass + + +@dataclass(frozen=True) +class ValidationContext: + pass + + +class ValidationStep(ABC): + @abstractmethod + def validate(self, t: Term, ctx: ValidationContext) -> None: + pass + + +class LLVMLowerer(ABC): + def validate(self, t: Term, ctx: ValidationContext) -> None: + for step in self.get_validation_steps(): + step.validate(t, ctx) + + @abstractmethod + def get_validation_steps(self) -> List[ValidationStep]: + pass + + @abstractmethod + def lower( + self, + t: Term, + expected_type: LLVMType = None, + type_env: Dict[Name, LLVMType] = None, + env: Dict[Name, LLVMTerm] = None, + ) -> LLVMTerm: + pass + + @abstractmethod + def get_signature(self, ty: LLVMType) -> tuple[List[LLVMType], LLVMType]: + pass + + +class LLVMIRGenerator(ABC): + @abstractmethod + def generate_ir(self, definitions: List[LLVMTerm]) -> str: + pass + + +class LLVMOptimizer(ABC): + @abstractmethod + def optimize(self, llvm_ir: str, opt_level: int = 3) -> str: + pass + + +class LLVMExecutionEngine(ABC): + @abstractmethod + def execute( + self, llvm_ir: str, func_name: str, args: List[Any], arg_types: List[LLVMType], ret_type: LLVMType + ) -> Any: + pass + + +class LLVMPipeline(ABC): + @abstractmethod + def compile(self, program: Term) -> None: + pass + + @abstractmethod + def get_curried_function(self, name: Name) -> Any: + pass + + @abstractmethod + def invoke(self, name_id: Name, arguments: List[Any]) -> Any: + pass diff --git a/aeon/llvm/cpu/converter.py b/aeon/llvm/cpu/converter.py new file mode 100644 index 00000000..a22b5e1a --- /dev/null +++ b/aeon/llvm/cpu/converter.py @@ -0,0 +1,566 @@ +from __future__ import annotations + +import llvmlite.binding as llvm +import llvmlite.ir as ir + +from aeon.llvm.core import LLVMIRGenerator, LLVMBackendError +from aeon.llvm.llvm_ast import ( + LLVMTerm, + LLVMType, + LLVMIntType, + LLVMFloatType, + LLVMDoubleType, + LLVMBoolType, + LLVMCharType, + LLVMVoidType, + LLVMPointerType, + LLVMArrayType, + LLVMFunctionType, + LLVMLiteral, + LLVMVar, + LLVMIf, + LLVMLet, + LLVMFunction, + LLVMCall, + LLVMGetElementPtr, + LLVMLoad, + LLVMStore, + LLVMAlloc, + LLVMVectorMap, + LLVMVectorReduce, + LLVMVectorIMap, + LLVMVectorFilter, + LLVMVectorZipWith, + LLVMVectorCount, + VECTOR_OPERATIONS, + LLVMCast, +) +from aeon.llvm.utils import BINARY_OPS, UNARY_OPS, sanitize_name +from aeon.utils.name import Name +from typing import Dict, Any, Callable + + +class LLVMIRGenerationError(LLVMBackendError): + pass + + +class CPULLVMIRGenerator(LLVMIRGenerator): + def __init__(self): + llvm.initialize_native_target() + llvm.initialize_native_asmprinter() + + self.module = ir.Module(name="aeon_cpu_module") + self.module.triple = llvm.get_process_triple() + target = llvm.Target.from_triple(self.module.triple) + target_machine = target.create_target_machine() + self.module.data_layout = str(target_machine.target_data) + self.target_data = target_machine.target_data + + self.builder = None + self.env = {} + self.fn_count = 0 + + def to_ir_type(self, ty: LLVMType) -> ir.Type: + match ty: + case LLVMIntType(bits): + return ir.IntType(bits) + case LLVMFloatType(): + return ir.FloatType() + case LLVMDoubleType(): + return ir.DoubleType() + case LLVMBoolType(): + return ir.IntType(1) + case LLVMCharType(): + return ir.IntType(8) + case LLVMVoidType(): + return ir.IntType(32) + case LLVMFunctionType(arg_types, return_type): + ir_return_type = ( + ir.VoidType() if isinstance(return_type, LLVMVoidType) else self.to_ir_type(return_type) + ) + ir_arg_types = [self.to_ir_type(arg) for arg in arg_types] + return ir.FunctionType(ir_return_type, ir_arg_types) + case LLVMPointerType(element_type, address_space): + ir_base = self.to_ir_type(element_type) + if isinstance(ir_base, ir.VoidType): + ir_base = ir.IntType(8) + return ir.PointerType(ir_base, address_space.value) + case LLVMArrayType(element_type, size): + ir_base = self.to_ir_type(element_type) + return ir.ArrayType(ir_base, size if size is not None else 0) + case _: + raise LLVMIRGenerationError(f"unsupported LLVM type {ty}") + + def _heap_alloc(self, element_ty: ir.Type, count: ir.Value) -> ir.Value: + element_size = self.target_data.get_abi_size(element_ty) + + count_i64 = self.builder.sext(count, ir.IntType(64)) if count.type.width < 64 else count + total_size = self.builder.mul(count_i64, ir.Constant(ir.IntType(64), element_size)) + + malloc_ty = ir.FunctionType(ir.PointerType(ir.IntType(8)), [ir.IntType(64)]) + malloc_func = self.module.globals.get("malloc") + if not malloc_func: + malloc_func = ir.Function(self.module, malloc_ty, name="malloc") + + raw_ptr = self.builder.call(malloc_func, [total_size]) + return self.builder.bitcast(raw_ptr, ir.PointerType(element_ty)) + + def generate_ir(self, definitions: list[LLVMTerm], initial_env: Dict[str, Any] = None) -> str: + if initial_env: + self.env.update(initial_env) + + for kernel_ast in definitions: + if isinstance(kernel_ast, LLVMFunction) and kernel_ast.name: + func_name = sanitize_name(kernel_ast.name) + if func_name not in self.module.globals: + func_type = self.to_ir_type(kernel_ast.type) + func = ir.Function(self.module, func_type, name=func_name) + self.env[func_name] = func + + for kernel_ast in definitions: + self.to_ir(kernel_ast, is_top_level=True) + return str(self.module) + + def declare_external(self, name: Name, ty: LLVMType): + str_name = sanitize_name(name) + if str_name in self.module.globals: + return + ir_type = self.to_ir_type(ty) + ir.Function(self.module, ir_type, name=str_name) + + def to_ir(self, llvm_ast: LLVMTerm, is_top_level: bool) -> ir.Value | None: + if llvm_ast is None: + return None + + match llvm_ast: + case LLVMLiteral(type=ty, value=val): + return self.to_ir_literal(ty, val) + + case LLVMVar(type=ty, name=name): + return self.to_ir_variable(ty, name) + + case LLVMIf(type=ty, cond=cond, then_t=then_t, else_t=else_t): + return self.to_ir_if(ty, cond, then_t, else_t) + + case LLVMLet(type=ty, var_name=name, var_value=val, body=body): + return self.to_ir_let(name, val, body, is_top_level) + + case LLVMFunction(type=ty, arg_names=names, body=body, name=opt_name): + return self.to_ir_function(ty, names, body, opt_name) + + case LLVMCall(type=_, target=target, args=args): + return self.to_ir_call(target, args) + + case LLVMCast(type=ty, val=val): + v_val = self.to_ir(val, False) + target_ty = self.to_ir_type(ty) + if v_val.type == target_ty: + return v_val + if isinstance(v_val.type, ir.IntType) and isinstance(target_ty, (ir.FloatType, ir.DoubleType)): + return self.builder.sitofp(v_val, target_ty) + if isinstance(v_val.type, (ir.FloatType, ir.DoubleType)) and isinstance(target_ty, ir.IntType): + return self.builder.fptosi(v_val, target_ty) + if isinstance(v_val.type, ir.FloatType) and isinstance(target_ty, ir.DoubleType): + return self.builder.fpext(v_val, target_ty) + if isinstance(v_val.type, ir.DoubleType) and isinstance(target_ty, ir.FloatType): + return self.builder.fptrunc(v_val, target_ty) + return self.builder.bitcast(v_val, target_ty) + + case LLVMGetElementPtr(ptr=ptr, indices=indices): + return self.builder.gep(self.to_ir(ptr, False), [self.to_ir(i, False) for i in indices]) + + case LLVMLoad(ptr=ptr): + return self.builder.load(self.to_ir(ptr, False)) + + case LLVMStore(value=value, ptr=ptr): + v_val = self.to_ir(value, False) + p_val = self.to_ir(ptr, False) + return p_val if isinstance(v_val.type, ir.VoidType) else self.builder.store(v_val, p_val) + + case LLVMAlloc(type=ty): + alloc_ty = self.to_ir_type(ty.element_type if isinstance(ty, LLVMPointerType) else ty) + return self.builder.alloca(alloc_ty) + + case LLVMVectorMap(type=res_ty, f=f, v=v, size=size): + return self.to_ir_vector_map(res_ty, f, v, size) + + case LLVMVectorReduce(type=ty, f=f, initial=initial, v=v, size=size): + return self.to_ir_vector_reduce(ty, f, initial, v, size) + + case LLVMVectorIMap(type=res_ty, f=f, v=v, size=size): + return self.to_ir_vector_imap(res_ty, f, v, size) + + case LLVMVectorFilter(type=res_ty, f=f, v=v, size=size): + return self.to_ir_vector_filter(res_ty, f, v, size) + + case LLVMVectorZipWith(type=res_ty, f=f, v1=v1, v2=v2, size=size): + return self.to_ir_vector_zipwith(res_ty, f, v1, v2, size) + + case LLVMVectorCount(f=f, v=v, size=size): + return self.to_ir_vector_count(f, v, size) + + case _: + raise LLVMIRGenerationError(f"unsupported LLVM node {type(llvm_ast)}") + + def to_ir_literal(self, result_type: LLVMType, value: Any) -> ir.Value: + ir_type = self.to_ir_type(result_type) + match result_type: + case LLVMBoolType(): + return ir.Constant(ir.IntType(1), 1 if value else 0) + case LLVMIntType(bits): + return ir.Constant(ir.IntType(bits), int(value)) + case LLVMFloatType() | LLVMDoubleType(): + return ir.Constant(ir_type, float(value)) + case LLVMCharType(): + return ir.Constant(ir.IntType(8), ord(value)) + case LLVMPointerType(element_type=LLVMCharType()): + if isinstance(value, str): + text = value + "\0" + c_str = ir.Constant(ir.ArrayType(ir.IntType(8), len(text)), bytearray(text, "utf-8")) + gv = ir.GlobalVariable(self.module, c_str.type, name=f"str_const_{self.fn_count}") + self.fn_count += 1 + gv.global_constant = True + gv.initializer = c_str + zero = ir.Constant(ir.IntType(32), 0) + return self.builder.gep(gv, [zero, zero]) if self.builder else gv + raise LLVMIRGenerationError(f"unsupported pointer literal {value}") + case _: + raise LLVMIRGenerationError(f"unsupported literal type {result_type}") + + def to_ir_variable(self, result_type: LLVMType, var_name: Name) -> ir.Value: + str_name = sanitize_name(var_name) + if str_name in self.env: + return self.env[str_name] + if str_name in self.module.globals: + return self.module.globals[str_name] + + base_name = var_name.name + if base_name == "Math_PI": + return ir.Constant(ir.DoubleType(), 3.141592653589793) + + builtin_map = { + "Math_pow": "pow", + "Math_sqrt": "sqrt", + "Math_sqrtf": "sqrt", + "Math_sin": "sin", + "Math_cos": "cos", + "Math_exp": "exp", + "Math_log": "log", + } + + name_parts = str_name.rsplit("_", 1) + lookup_name = name_parts[0] if len(name_parts) > 1 and name_parts[1].isdigit() else str_name + + actual_name = builtin_map.get(lookup_name, lookup_name) + + if ( + actual_name in {"pow", "sqrt", "sin", "cos", "exp", "log", "malloc", "free", "printf", "native"} + or lookup_name in VECTOR_OPERATIONS + ): + if actual_name in self.module.globals: + return self.module.globals[actual_name] + + actual_ty = result_type + if actual_name == "native" and not isinstance(actual_ty, LLVMFunctionType): + actual_ty = LLVMFunctionType( + [LLVMPointerType(element_type=LLVMCharType())], LLVMPointerType(element_type=LLVMCharType()) + ) + + return ir.Function(self.module, self.to_ir_type(actual_ty), name=actual_name) + + raise LLVMIRGenerationError(f"undefined variable {str_name}") + + def to_ir_if(self, result_type: LLVMType, cond: LLVMTerm, then_t: LLVMTerm, else_t: LLVMTerm) -> ir.Value | None: + if self.builder is None: + return None + cond_val = self.to_ir(cond, False) + + with self.builder.if_else(cond_val) as (then_block, else_block): + with then_block: + then_val = self.to_ir(then_t, False) + then_exit = self.builder.basic_block + with else_block: + else_val = self.to_ir(else_t, False) + else_exit = self.builder.basic_block + + if isinstance(result_type, LLVMVoidType): + return None + + phi = self.builder.phi(self.to_ir_type(result_type), name="if_res") + phi.add_incoming(then_val if then_val is not None else ir.Constant(phi.type, 0), then_exit) + phi.add_incoming(else_val if else_val is not None else ir.Constant(phi.type, 0), else_exit) + return phi + + def to_ir_let(self, var_name: Name, var_value: LLVMTerm, body: LLVMTerm, is_top_level: bool) -> ir.Value | None: + str_name = sanitize_name(var_name) + if isinstance(var_value, LLVMFunction): + var_value.name = var_name + func = self.to_ir(var_value, False) + self.env[str_name] = func + return self.to_ir(body, is_top_level) + + val_gen = self.to_ir(var_value, False) + old_val = self.env.get(str_name) + self.env[str_name] = val_gen + res = self.to_ir(body, is_top_level) + + if old_val is not None: + self.env[str_name] = old_val + else: + del self.env[str_name] + return res + + def to_ir_function( + self, function_type: LLVMType, arg_names: list[Name], body: LLVMTerm, function_name: Name | None + ) -> ir.Function: + func_name = sanitize_name(function_name) if function_name else f"anon_func_{self.fn_count}" + if not function_name: + self.fn_count += 1 + + func = self.module.globals.get(func_name) or ir.Function( + self.module, self.to_ir_type(function_type), name=func_name + ) + + old_builder, old_env = self.builder, self.env.copy() + self.env[func_name] = func + + self.builder = ir.IRBuilder(func.append_basic_block(name="entry")) + for i, arg_name in enumerate(arg_names): + str_arg_name = sanitize_name(arg_name) + func.args[i].name = str_arg_name + self.env[str_arg_name] = func.args[i] + + ret_val = self.to_ir(body, False) + if isinstance(function_type, LLVMFunctionType) and isinstance(function_type.return_type, LLVMVoidType): + self.builder.ret_void() + else: + self.builder.ret(ret_val) + + self.builder, self.env = old_builder, old_env + return func + + def to_ir_call(self, target: LLVMTerm, args: list[LLVMTerm]) -> ir.Value | None: + if self.builder is None: + return None + + if isinstance(target, LLVMVar) and (target.name.name in BINARY_OPS or target.name.name in UNARY_OPS): + return self.to_ir_operator(target.name.name, args) + + target_func = self.to_ir(target, False) + arg_vals = [self.to_ir(arg, False) for arg in args] + return self.builder.call(target_func, arg_vals) if target_func else None + + def to_ir_operator(self, op: str, args: list[LLVMTerm]) -> ir.Value | None: + vals = [self.to_ir(arg, False) for arg in args] + is_f = isinstance(vals[0].type, (ir.FloatType, ir.DoubleType)) + + match op: + case "+" if is_f: + return self.builder.fadd(vals[0], vals[1]) + case "+": + return self.builder.add(vals[0], vals[1]) + case "-" if is_f: + return self.builder.fsub(vals[0], vals[1]) if len(vals) == 2 else self.builder.fneg(vals[0]) + case "-": + return ( + self.builder.sub(vals[0], vals[1]) + if len(vals) == 2 + else self.builder.sub(ir.Constant(vals[0].type, 0), vals[0]) + ) + case "*" if is_f: + return self.builder.fmul(vals[0], vals[1]) + case "*": + return self.builder.mul(vals[0], vals[1]) + case "/" if is_f: + return self.builder.fdiv(vals[0], vals[1]) + case "/": + return self.builder.sdiv(vals[0], vals[1]) + case "%" if is_f: + return self.builder.frem(vals[0], vals[1]) + case "%": + return self.builder.srem(vals[0], vals[1]) + case "==": + return ( + self.builder.fcmp_ordered("==", vals[0], vals[1]) + if is_f + else self.builder.icmp_signed("==", vals[0], vals[1]) + ) + case "!=": + return ( + self.builder.fcmp_ordered("!=", vals[0], vals[1]) + if is_f + else self.builder.icmp_signed("!=", vals[0], vals[1]) + ) + case "<": + return ( + self.builder.fcmp_ordered("<", vals[0], vals[1]) + if is_f + else self.builder.icmp_signed("<", vals[0], vals[1]) + ) + case "<=": + return ( + self.builder.fcmp_ordered("<=", vals[0], vals[1]) + if is_f + else self.builder.icmp_signed("<=", vals[0], vals[1]) + ) + case ">": + return ( + self.builder.fcmp_ordered(">", vals[0], vals[1]) + if is_f + else self.builder.icmp_signed(">", vals[0], vals[1]) + ) + case ">=": + return ( + self.builder.fcmp_ordered(">=", vals[0], vals[1]) + if is_f + else self.builder.icmp_signed(">=", vals[0], vals[1]) + ) + case "&&": + return self.builder.and_(vals[0], vals[1]) + case "||": + return self.builder.or_(vals[0], vals[1]) + case "!": + return self.builder.not_(vals[0]) + return None + + def to_ir_loop(self, size: ir.Value, name: str, body_fn: Callable[[ir.Value], None]): + idx_ptr = self.builder.alloca(ir.IntType(32), name=f"{name}_idx") + self.builder.store(ir.Constant(ir.IntType(32), 0), idx_ptr) + + cond_bb = self.builder.append_basic_block(f"{name}_cond") + body_bb = self.builder.append_basic_block(f"{name}_body") + end_bb = self.builder.append_basic_block(f"{name}_end") + + self.builder.branch(cond_bb) + self.builder.position_at_end(cond_bb) + + curr_idx = self.builder.load(idx_ptr) + is_less = self.builder.icmp_signed("<", curr_idx, size) + self.builder.cbranch(is_less, body_bb, end_bb) + + self.builder.position_at_end(body_bb) + body_fn(curr_idx) + + self.builder.store(self.builder.add(curr_idx, ir.Constant(ir.IntType(32), 1)), idx_ptr) + self.builder.branch(cond_bb) + self.builder.position_at_end(end_bb) + + def to_ir_vector_map(self, res_ty: LLVMType, f: LLVMTerm, v: LLVMTerm, size: LLVMTerm) -> ir.Value: + f_val, v_val, size_val = self.to_ir(f, False), self.to_ir(v, False), self.to_ir(size, False) + res_base_ty = self.to_ir_type(res_ty.element_type if isinstance(res_ty, LLVMPointerType) else res_ty) + if isinstance(res_base_ty, ir.VoidType): + res_base_ty = ir.IntType(32) + + new_v = self._heap_alloc(res_base_ty, size_val) + + def body(idx): + mapped_val = self.builder.call(f_val, [self.builder.load(self.builder.gep(v_val, [idx]))]) + if not isinstance(mapped_val.type, ir.VoidType): + self.builder.store(mapped_val, self.builder.gep(new_v, [idx])) + + self.to_ir_loop(size_val, "map", body) + return new_v + + def to_ir_vector_reduce( + self, ty: LLVMType, f: LLVMTerm, initial: LLVMTerm, v: LLVMTerm, size: LLVMTerm + ) -> ir.Value: + f_val, init_val, v_val, size_val = ( + self.to_ir(f, False), + self.to_ir(initial, False), + self.to_ir(v, False), + self.to_ir(size, False), + ) + acc_ty = self.to_ir_type(ty) + if isinstance(acc_ty, ir.VoidType): + acc_ty = ir.IntType(32) + + acc_ptr = self.builder.alloca(acc_ty, name="reduce_acc") + if init_val and not isinstance(init_val.type, ir.VoidType): + self.builder.store(init_val, acc_ptr) + + def body(idx): + new_acc = self.builder.call( + f_val, [self.builder.load(acc_ptr), self.builder.load(self.builder.gep(v_val, [idx]))] + ) + if not isinstance(new_acc.type, ir.VoidType): + self.builder.store(new_acc, acc_ptr) + + self.to_ir_loop(size_val, "reduce", body) + return self.builder.load(acc_ptr) + + def to_ir_vector_imap(self, res_ty: LLVMType, f: LLVMTerm, v: LLVMTerm, size: LLVMTerm) -> ir.Value: + f_val, v_val, size_val = self.to_ir(f, False), self.to_ir(v, False), self.to_ir(size, False) + res_base_ty = self.to_ir_type(res_ty.element_type if isinstance(res_ty, LLVMPointerType) else res_ty) + if isinstance(res_base_ty, ir.VoidType): + res_base_ty = ir.IntType(32) + + new_v = self._heap_alloc(res_base_ty, size_val) + + def body(idx): + mapped_val = self.builder.call(f_val, [idx, self.builder.load(self.builder.gep(v_val, [idx]))]) + if not isinstance(mapped_val.type, ir.VoidType): + self.builder.store(mapped_val, self.builder.gep(new_v, [idx])) + + self.to_ir_loop(size_val, "imap", body) + return new_v + + def to_ir_vector_filter(self, res_ty: LLVMType, f: LLVMTerm, v: LLVMTerm, size: LLVMTerm) -> ir.Value: + f_val, v_val, size_val = self.to_ir(f, False), self.to_ir(v, False), self.to_ir(size, False) + res_base_ty = self.to_ir_type(res_ty.element_type if isinstance(res_ty, LLVMPointerType) else res_ty) + if isinstance(res_base_ty, ir.VoidType): + res_base_ty = ir.IntType(32) + + new_v = self._heap_alloc(res_base_ty, size_val) + new_idx_ptr = self.builder.alloca(ir.IntType(32), name="filter_new_idx") + self.builder.store(ir.Constant(ir.IntType(32), 0), new_idx_ptr) + + def body(idx): + val = self.builder.load(self.builder.gep(v_val, [idx])) + keep = self.builder.call(f_val, [val]) + with self.builder.if_then(keep): + new_idx = self.builder.load(new_idx_ptr) + self.builder.store(val, self.builder.gep(new_v, [new_idx])) + self.builder.store(self.builder.add(new_idx, ir.Constant(ir.IntType(32), 1)), new_idx_ptr) + + self.to_ir_loop(size_val, "filter", body) + return new_v + + def to_ir_vector_zipwith( + self, res_ty: LLVMType, f: LLVMTerm, v1: LLVMTerm, v2: LLVMTerm, size: LLVMTerm + ) -> ir.Value: + f_val, v1_val, v2_val, size_val = ( + self.to_ir(f, False), + self.to_ir(v1, False), + self.to_ir(v2, False), + self.to_ir(size, False), + ) + res_base_ty = self.to_ir_type(res_ty.element_type if isinstance(res_ty, LLVMPointerType) else res_ty) + if isinstance(res_base_ty, ir.VoidType): + res_base_ty = ir.IntType(32) + + new_v = self._heap_alloc(res_base_ty, size_val) + + def body(idx): + val1 = self.builder.load(self.builder.gep(v1_val, [idx])) + val2 = self.builder.load(self.builder.gep(v2_val, [idx])) + res = self.builder.call(f_val, [val1, val2]) + self.builder.store(res, self.builder.gep(new_v, [idx])) + + self.to_ir_loop(size_val, "zip", body) + return new_v + + def to_ir_vector_count(self, f: LLVMTerm, v: LLVMTerm, size: LLVMTerm) -> ir.Value: + f_val, v_val, size_val = self.to_ir(f, False), self.to_ir(v, False), self.to_ir(size, False) + count_ptr = self.builder.alloca(ir.IntType(32), name="count_res") + self.builder.store(ir.Constant(ir.IntType(32), 0), count_ptr) + + def body(idx): + val = self.builder.load(self.builder.gep(v_val, [idx])) + is_match = self.builder.call(f_val, [val]) + with self.builder.if_then(is_match): + self.builder.store( + self.builder.add(self.builder.load(count_ptr), ir.Constant(ir.IntType(32), 1)), count_ptr + ) + + self.to_ir_loop(size_val, "count", body) + return self.builder.load(count_ptr) diff --git a/aeon/llvm/cpu/executor.py b/aeon/llvm/cpu/executor.py new file mode 100644 index 00000000..aba83e17 --- /dev/null +++ b/aeon/llvm/cpu/executor.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +import ctypes +from typing import Any, List, Dict + +import llvmlite.binding as llvm + +from aeon.llvm.core import LLVMExecutionEngine, LLVMBackendError +from aeon.llvm.llvm_ast import ( + LLVMType, + LLVMIntType, + LLVMFloatType, + LLVMDoubleType, + LLVMBoolType, + LLVMCharType, + LLVMVoidType, + LLVMPointerType, +) + + +class LLVMExecutionError(LLVMBackendError): + pass + + +class CPULLVMExecutionEngine(LLVMExecutionEngine): + def __init__(self): + self._init_llvm() + self.target_machine = self._create_target_machine() + self._keep_alive = [] + + def _init_llvm(self): + llvm.initialize_native_target() + llvm.initialize_native_asmprinter() + + def _create_target_machine(self): + target = llvm.Target.from_triple(llvm.get_process_triple()) + return target.create_target_machine() + + def _get_ctypes_type(self, ty: LLVMType) -> Any: + match ty: + case LLVMIntType(bits): + types_map = { + 1: ctypes.c_bool, + 8: ctypes.c_int8, + 16: ctypes.c_int16, + 32: ctypes.c_int32, + 64: ctypes.c_int64, + } + if bits in types_map: + return types_map[bits] + raise LLVMExecutionError(f"unsupported integer width: {bits} bits") + case LLVMBoolType(): + return ctypes.c_bool + case LLVMFloatType(): + return ctypes.c_float + case LLVMDoubleType(): + return ctypes.c_double + case LLVMCharType(): + return ctypes.c_char + case LLVMVoidType(): + return None + case LLVMPointerType(): + return ctypes.c_void_p + case _: + raise LLVMExecutionError(f"unsupported LLVM type for execution: {ty}") + + def _flatten_list(self, val: Any) -> List[Any]: + if not isinstance(val, (list, tuple)): + return [val] + res = [] + for item in val: + if isinstance(item, (list, tuple)): + res.extend(self._flatten_list(item)) + else: + res.append(item) + return res + + def _convert_to_ctypes(self, val: Any, ty: LLVMType) -> Any: + if isinstance(ty, LLVMPointerType) and isinstance(val, list): + flat_val = self._flatten_list(val) + base_ty = ty.element_type + element_cty = self._get_ctypes_type(base_ty) + processed_flat_val = [self._convert_to_ctypes(item, base_ty) for item in flat_val] + array_type = element_cty * len(processed_flat_val) + c_array = array_type(*processed_flat_val) + self._keep_alive.append(c_array) + return ctypes.cast(c_array, ctypes.c_void_p) + + if isinstance(ty, LLVMCharType) and isinstance(val, str): + return ord(val) + + return val + + def _get_vector_impl(self, arg_types: List[LLVMType], ret_type: LLVMType) -> Dict[str, Any]: + def vector_get(ptr: ctypes.c_void_p, idx: int) -> Any: + el_ty = self._get_ctypes_type(ret_type) + return ctypes.cast(ptr, ctypes.POINTER(el_ty))[idx] + + def vector_set(ptr: ctypes.c_void_p, idx: int, val: Any) -> ctypes.c_void_p: + # val is already converted to the correct type by ctypes + el_ty = self._get_ctypes_type(arg_types[2]) if len(arg_types) > 2 else ctypes.c_int32 + ctypes.cast(ptr, ctypes.POINTER(el_ty))[idx] = val + return ptr + + def native_dummy(code: ctypes.c_char_p) -> ctypes.c_void_p: + return ctypes.c_void_p(None) + + return { + "Vector_get": vector_get, + "Vector_set": vector_set, + "native": native_dummy, + } + + def execute( + self, + llvm_ir: str, + func_name: str, + args: List[Any], + arg_types: List[LLVMType], + ret_type: LLVMType, + debug: bool = False, + ) -> Any: + self._keep_alive = [] + + # We need the actual function type to register the correct callback types + # But we can also register them with a generic signature if needed. + # However, for Vector_get/set, we need the element type. + + vector_impls = self._get_vector_impl(arg_types, ret_type) + # We don't register them as global symbols if they are specialized per call? + # Actually, the JIT needs to find them. If we have multiple calls with different types, + # we might need different names or a generic implementation that uses the type info. + # But here 'execute' is for a specific function. + + # For now, let's register them. If there are multiple Vector_gets, they will conflict. + # A better way would be to let the lowerer emit the implementation if it's not provided by a library. + + # For 'native', it's always the same. + llvm.add_symbol( + "native", + ctypes.cast( + ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.c_char_p)(vector_impls["native"]), ctypes.c_void_p + ).value, + ) + + backing_mod = llvm.parse_assembly(llvm_ir) + + backing_mod.verify() + with llvm.create_mcjit_compiler(backing_mod, self.target_machine) as engine: + engine.finalize_object() + func_ptr = engine.get_function_address(func_name) + if not func_ptr: + raise LLVMExecutionError(f"failed to find function address for {func_name}") + + ctypes_args = [self._get_ctypes_type(t) for t in arg_types] + ctypes_ret = self._get_ctypes_type(ret_type) if not isinstance(ret_type, LLVMVoidType) else None + + cfunc = ctypes.CFUNCTYPE(ctypes_ret, *ctypes_args)(func_ptr) + processed_args = [self._convert_to_ctypes(val, ty) for val, ty in zip(args, arg_types)] + result = cfunc(*processed_args) + + if isinstance(ret_type, LLVMCharType): + return chr(result) + + return result diff --git a/aeon/llvm/cpu/lowerer.py b/aeon/llvm/cpu/lowerer.py new file mode 100644 index 00000000..0f36729b --- /dev/null +++ b/aeon/llvm/cpu/lowerer.py @@ -0,0 +1,696 @@ +from __future__ import annotations + +from dataclasses import dataclass, field, replace +from typing import Dict, List + +from aeon.core.terms import ( + Abstraction, + Application, + If, + Let, + Literal, + Rec, + Term, + Var, + TypeApplication, + TypeAbstraction, + Annotation, + Hole, +) +from aeon.core.types import Type +from aeon.llvm.core import LLVMLowerer, LLVMValidationError, ValidationStep, ValidationContext, LLVMBackendError +from aeon.llvm.llvm_ast import ( + LLVMType, + LLVMTerm, + LLVMLiteral, + LLVMInt, + LLVMDouble, + LLVMFloatType, + LLVMDoubleType, + LLVMBool, + LLVMVar, + LLVMIf, + LLVMCall, + LLVMFunctionType, + LLVMLet, + LLVMFunction, + LLVMGetElementPtr, + LLVMLoad, + LLVMStore, + LLVMPointerType, + LLVMVoid, + LLVMVoidType, + LLVMCharType, + LLVMVectorMap, + LLVMVectorReduce, + LLVMVectorIMap, + LLVMVectorFilter, + LLVMVectorZipWith, + LLVMVectorCount, + VECTOR_OPERATIONS, +) +from aeon.llvm.utils import ( + validate_type, + to_llvm_type, + UNARY_OPS, + BINARY_OPS, + get_builtin_op_type, + sanitize_name, +) +from aeon.utils.name import Name + + +class LLVMLoweringError(LLVMBackendError): + pass + + +_generic_ptr = LLVMPointerType(LLVMCharType()) +_func_i_i = LLVMFunctionType([LLVMInt], LLVMInt) +_func_ii_i = LLVMFunctionType([LLVMInt, LLVMInt], LLVMInt) +_func_i_b = LLVMFunctionType([LLVMInt], LLVMBool) + +BUILTIN_FUNCTION_TYPES: Dict[str, LLVMFunctionType] = { + "malloc": LLVMFunctionType([LLVMInt], _generic_ptr), + "free": LLVMFunctionType([_generic_ptr], LLVMVoid), + "printf": LLVMFunctionType([_generic_ptr], LLVMInt), + "Math_pow": LLVMFunctionType([LLVMDouble, LLVMDouble], LLVMDouble), + "Math_sqrt": LLVMFunctionType([LLVMDouble], LLVMDouble), + "Math_sqrtf": LLVMFunctionType([LLVMDouble], LLVMDouble), + "Math_sin": LLVMFunctionType([LLVMDouble, LLVMDouble], LLVMDouble), + "Math_cos": LLVMFunctionType([LLVMDouble, LLVMDouble], LLVMDouble), + "Math_exp": LLVMFunctionType([LLVMDouble], LLVMDouble), + "Math_log": LLVMFunctionType([LLVMDouble], LLVMDouble), + "Vector_new": LLVMFunctionType([], _generic_ptr), + "Vector_append": LLVMFunctionType([_generic_ptr, LLVMInt], _generic_ptr), + "Vector_get": LLVMFunctionType([_generic_ptr, LLVMInt], LLVMInt), + "Vector_set": LLVMFunctionType([_generic_ptr, LLVMInt, LLVMInt], _generic_ptr), + "Vector_map": LLVMFunctionType([LLVMPointerType(_func_i_i), _generic_ptr, LLVMInt], _generic_ptr), + "Vector_reduce": LLVMFunctionType([LLVMPointerType(_func_ii_i), LLVMInt, _generic_ptr, LLVMInt], LLVMInt), + "Vector_imap": LLVMFunctionType([LLVMPointerType(_func_ii_i), _generic_ptr, LLVMInt], _generic_ptr), + "Vector_filter": LLVMFunctionType([LLVMPointerType(_func_i_b), _generic_ptr, LLVMInt], _generic_ptr), + "Vector_zipWith": LLVMFunctionType( + [LLVMPointerType(_func_ii_i), _generic_ptr, _generic_ptr, LLVMInt], _generic_ptr + ), + "Vector_count": LLVMFunctionType([LLVMPointerType(_func_i_b), _generic_ptr, LLVMInt], LLVMInt), +} + +POLYMORPHIC_FUNCTIONS: set[str] = { + "Math_pow", + "Math_exp", + "Math_sqrt", + "Math_sqrtf", + "Math_sin", + "Math_cos", + "Math_log", + "Vector_get", + "Vector_set", + "Vector_new", + "Vector_map", + "Vector_reduce", + "Vector_imap", + "Vector_filter", + "Vector_zipWith", + "Vector_count", +} + + +@dataclass(frozen=True) +class CPUValidationContext(ValidationContext): + allowed_func_calls: set[Name] = field(default_factory=set) + type_env: Dict[Name, LLVMType] = field(default_factory=dict) + env_names: set[str] = field(default_factory=set) + is_top_level: bool = True + strict: bool = False + in_vector_op: bool = False + + +class CPUTypeValidationStep(ValidationStep): + def validate(self, t: Term, ctx: ValidationContext) -> None: + match t: + case Literal(_, ty): + validate_type(ty) + case Rec(_, var_type, var_value, body): + validate_type(var_type) + self.validate(var_value, ctx) + self.validate(body, ctx) + case Annotation(expr, ty) | TypeApplication(expr, ty): + validate_type(ty) + self.validate(expr, ctx) + case Abstraction(_, body) | TypeAbstraction(_, _, body): + self.validate(body, ctx) + case Let(_, var_value, body): + self.validate(var_value, ctx) + self.validate(body, ctx) + case Application(f, arg): + self.validate(f, ctx) + self.validate(arg, ctx) + case If(cond, then_t, else_t): + for sub in (cond, then_t, else_t): + self.validate(sub, ctx) + case Hole(name): + raise LLVMValidationError(f"unresolved hole {name}") + case _: + pass + + +class CPUFunctionCallValidationStep(ValidationStep): + def validate(self, t: Term, ctx: ValidationContext) -> None: + assert isinstance(ctx, CPUValidationContext) + match t: + case Var(name): + self._validate_var(name, ctx) + case Rec(var_name, _, var_value, body): + self._validate_rec(var_name, var_value, body, ctx) + case Let(var_name, var_value, body): + self._validate_let(var_name, var_value, body, ctx) + case Abstraction(var_name, body): + self.validate( + body, replace(ctx, env_names=ctx.env_names | {sanitize_name(var_name)}, is_top_level=False) + ) + case Application(f, arg): + self.validate(f, replace(ctx, is_top_level=False)) + self.validate(arg, replace(ctx, is_top_level=False)) + case If(cond, then_t, else_t): + for sub in (cond, then_t, else_t): + self.validate(sub, replace(ctx, is_top_level=False)) + case Annotation(expr, _) | TypeApplication(expr, _) | TypeAbstraction(_, _, expr): + self.validate(expr, ctx) + case _: + pass + + def _validate_var(self, name: Name, ctx: CPUValidationContext) -> None: + if not ctx.strict: + return + is_local = name in ctx.type_env or sanitize_name(name) in ctx.env_names + is_op = name.name in BINARY_OPS or name.name in UNARY_OPS + is_anf = name.name.startswith("anf") + is_allowed = any(name.name == allowed.name for allowed in ctx.allowed_func_calls) + is_builtin = name.name in BUILTIN_FUNCTION_TYPES or name.name in VECTOR_OPERATIONS + if not (is_local or is_op or is_anf or is_allowed or is_builtin): + raise LLVMValidationError(f"function or variable {name.name} is not allowed in CPU LLVM functions.") + + def _validate_rec(self, var_name: Name, var_value: Term, body: Term, ctx: CPUValidationContext) -> None: + if ctx.is_top_level: + self.validate(var_value, replace(ctx, is_top_level=False)) + self.validate(body, replace(ctx, is_top_level=True)) + else: + new_ctx = replace(ctx, env_names=ctx.env_names | {sanitize_name(var_name)}, is_top_level=False) + self.validate(var_value, new_ctx) + self.validate(body, new_ctx) + + def _validate_let(self, var_name: Name, var_value: Term, body: Term, ctx: CPUValidationContext) -> None: + if ctx.is_top_level: + self.validate(var_value, replace(ctx, is_top_level=False)) + self.validate(body, replace(ctx, is_top_level=True)) + else: + self.validate(var_value, replace(ctx, is_top_level=False)) + self.validate(body, replace(ctx, env_names=ctx.env_names | {sanitize_name(var_name)}, is_top_level=False)) + + +class CPUFullApplicationValidationStep(ValidationStep): + def validate(self, t: Term, ctx: ValidationContext) -> None: + assert isinstance(ctx, CPUValidationContext) + no_top = replace(ctx, is_top_level=False) + match t: + case Application(fun, arg): + arguments = [arg] + base = fun + while isinstance(base, Application): + arguments.append(base.arg) + base = base.fun + for a in arguments: + self.validate(a, no_top) + self.validate(base, no_top) + case Let(var_name, var_value, body): + llvm_var_type = self._infer_let_type(var_value) + self.validate(var_value, no_top) + self.validate(body, replace(ctx, type_env=ctx.type_env | {var_name: llvm_var_type}, is_top_level=False)) + case Rec(var_name, var_ty, var_value, body): + llvm_ty = to_llvm_type(var_ty) + new_ctx = replace(ctx, type_env=ctx.type_env | {var_name: llvm_ty}, is_top_level=False) + self.validate(var_value, new_ctx) + self.validate(body, new_ctx) + case Abstraction(_, body) | Annotation(body, _) | TypeApplication(body, _) | TypeAbstraction(_, _, body): + self.validate(body, no_top) + case If(cond, then_t, else_t): + for sub in (cond, then_t, else_t): + self.validate(sub, no_top) + case _: + pass + + @staticmethod + def _infer_let_type(var_value: Term) -> LLVMType: + if isinstance(var_value, Annotation): + return to_llvm_type(var_value.type) + if isinstance(var_value, Rec): + return to_llvm_type(var_value.var_type) + if isinstance(var_value, Var) and var_value.name.name in BINARY_OPS: + return get_builtin_op_type(var_value.name.name) + return LLVMInt + + +class CPULLVMLowerer(LLVMLowerer): + def get_validation_steps(self) -> List[ValidationStep]: + return [CPUTypeValidationStep(), CPUFunctionCallValidationStep(), CPUFullApplicationValidationStep()] + + def lower( + self, + term: Term, + expected_type: LLVMType | None = None, + type_env: Dict[Name, LLVMType] | None = None, + env: Dict[Name, LLVMTerm] | None = None, + allowed_func_calls: set[Name] | None = None, + strict: bool = False, + in_vector_op: bool = False, + ) -> LLVMTerm: + type_env = type_env or {} + env = env or {} + allowed_func_calls = allowed_func_calls or set() + + validation_ctx = CPUValidationContext( + allowed_func_calls=allowed_func_calls, + type_env=type_env, + env_names={sanitize_name(n) for n in env.keys()}, + is_top_level=True, + strict=strict, + in_vector_op=in_vector_op, + ) + for step in self.get_validation_steps(): + step.validate(term, validation_ctx) + + return self._lower_term(term, expected_type, type_env, env, allowed_func_calls, in_vector_op=in_vector_op) + + def get_signature(self, llvm_type: LLVMType) -> tuple[List[LLVMType], LLVMType]: + arg_types: list[LLVMType] = [] + curr = llvm_type + while True: + if isinstance(curr, LLVMFunctionType): + arg_types.extend(curr.arg_types) + curr = curr.return_type + elif isinstance(curr, LLVMPointerType) and isinstance(curr.element_type, LLVMFunctionType): + curr = curr.element_type + else: + break + return arg_types, curr + + def _get_vector_base_type(self, vector_type: LLVMType) -> LLVMType: + if isinstance(vector_type, LLVMPointerType): + element = vector_type.element_type + if not isinstance(element, (LLVMCharType, LLVMPointerType)): + return element + if isinstance(element, LLVMCharType): + return LLVMInt + return LLVMInt + + def _get_operator_type(self, op: str, expected: LLVMType | None) -> LLVMFunctionType: + is_float = False + if expected: + if isinstance(expected, LLVMFunctionType): + is_float = any(isinstance(ty, (LLVMFloatType, LLVMDoubleType)) for ty in expected.arg_types) + elif isinstance(expected, (LLVMFloatType, LLVMDoubleType)): + is_float = True + return get_builtin_op_type(op, is_float) + + def _cast_if_needed(self, val: LLVMTerm, target_ty: LLVMType) -> LLVMTerm: + if val.type == target_ty: + return val + from aeon.llvm.llvm_ast import LLVMCast + + return LLVMCast(target_ty, val) + + def _get_target_name(self, target: LLVMTerm) -> str: + if isinstance(target, LLVMVar): + return target.name.name + if isinstance(target, LLVMCall): + return self._get_target_name(target.target) + return "" + + def _is_inlinable_anf(self, name: Name, val: LLVMTerm) -> bool: + if not name.name.startswith("anf"): + return False + is_partial = isinstance(val, LLVMCall) and isinstance(val.type, LLVMFunctionType) + target = self._get_target_name(val) if isinstance(val, LLVMVar) else "" + is_op = isinstance(val, LLVMVar) and (target in BINARY_OPS or target in UNARY_OPS) + is_vec = isinstance(val, LLVMVar) and target in (VECTOR_OPERATIONS | {"Vector_set", "Vector_get"}) + return is_partial or is_op or is_vec + + def _lower_as_standalone( + self, + term: Term | LLVMTerm, + expected: LLVMType | None, + type_env: Dict[Name, LLVMType], + env: Dict[Name, LLVMTerm], + allowed: set[Name], + in_vec: bool = False, + ) -> LLVMTerm: + if isinstance(term, LLVMTerm): + return term + if isinstance(term, Abstraction): + return self._lower_function(term, expected, type_env, env, allowed, in_vec) + lowered = self._lower_term(term, expected, type_env, env, allowed, in_vector_op=in_vec) + if isinstance(lowered, LLVMCall) and isinstance(lowered.type, LLVMFunctionType): + return self._create_wrapper_function(lowered) + return lowered + + def _lower_function( + self, + abs_term: Abstraction, + expected: LLVMType | None, + type_env: Dict[Name, LLVMType], + env: Dict[Name, LLVMTerm], + allowed: set[Name], + in_vec: bool, + ) -> LLVMFunction: + arg_names: list[Name] = [] + curr: Term = abs_term + while isinstance(curr, Abstraction): + arg_names.append(curr.var_name) + curr = curr.body + + param_tys, ret_ty = self.get_signature( + expected.element_type + if isinstance(expected, LLVMPointerType) and isinstance(expected.element_type, LLVMFunctionType) + else expected or LLVMInt + ) + param_tys = (param_tys + [LLVMInt] * len(arg_names))[: len(arg_names)] + + new_type_env = type_env.copy() + resolved_tys = [] + for name, p_ty in zip(arg_names, param_tys): + actual_ty = _generic_ptr if isinstance(p_ty, LLVMVoidType) else p_ty + new_type_env[name] = actual_ty + resolved_tys.append(actual_ty) + + body = self._lower_term(curr, ret_ty, new_type_env, env, allowed, in_vector_op=in_vec) + return LLVMFunction(LLVMFunctionType(resolved_tys, ret_ty), arg_names, resolved_tys, body) + + def _create_wrapper_function(self, call: LLVMCall) -> LLVMFunction: + params, ret = self.get_signature(call.type) + names = [Name(f"wrapper_arg_{i}") for i in range(len(params))] + args = call.args + [LLVMVar(ty, n) for n, ty in zip(names, params)] + return LLVMFunction(call.type, names, params, LLVMCall(ret, call.target, args)) + + def _uncurry(self, app: Application) -> tuple[Term, List[Term]]: + args: list[Term] = [] + curr: Term = app + while isinstance(curr, Application): + args.append(curr.arg) + curr = curr.fun + args.reverse() + return curr, args + + def _lower_args( + self, + args: list[Term], + expected_params: list[LLVMType], + offset: int, + type_env: Dict[Name, LLVMType], + env: Dict[Name, LLVMTerm], + allowed: set[Name], + in_vec: bool, + ) -> list[LLVMTerm]: + lowered_args = [] + for i, arg in enumerate(args): + idx = offset + i + exp = expected_params[idx] if idx < len(expected_params) else None + if isinstance(arg, Annotation): + exp = to_llvm_type(arg.type) + arg = arg.expr + lowered_args.append(self._lower_term(arg, exp, type_env, env, allowed, in_vector_op=in_vec)) + return lowered_args + + def _lower_vector_op( + self, + op: str, + args: list[Term | LLVMTerm], + expected: LLVMType | None, + type_env: Dict[Name, LLVMType], + env: Dict[Name, LLVMTerm], + allowed: set[Name], + ) -> LLVMTerm: + def low_term(term, exp=None): + return self._lower_term(term, exp, type_env, env, allowed, in_vector_op=True) + + if op == "Vector_reduce": + kernel_term, init_term, vector_term, size_term = args + low_vec, low_init, low_size = low_term(vector_term), low_term(init_term), low_term(size_term, LLVMInt) + element_type = self._get_vector_base_type(low_vec.type) + vec_cast = self._cast_if_needed(low_vec, LLVMPointerType(element_type)) + kernel = self._lower_as_standalone( + kernel_term, + LLVMFunctionType([low_init.type, element_type], low_init.type), + type_env, + env, + allowed, + True, + ) + return LLVMVectorReduce(low_init.type, kernel, low_init, vec_cast, low_size) + + if op == "Vector_zipWith": + kernel_term, v1_term, v2_term, size_term = args + v1_low, v2_low, sz_low = low_term(v1_term), low_term(v2_term), low_term(size_term, LLVMInt) + res_el = expected.element_type if expected and isinstance(expected, LLVMPointerType) else LLVMInt + el1, el2 = self._get_vector_base_type(v1_low.type), self._get_vector_base_type(v2_low.type) + v1_cast, v2_cast = ( + self._cast_if_needed(v1_low, LLVMPointerType(el1)), + self._cast_if_needed(v2_low, LLVMPointerType(el2)), + ) + kernel = self._lower_as_standalone( + kernel_term, LLVMFunctionType([el1, el2], res_el), type_env, env, allowed, True + ) + assert isinstance(kernel.type, LLVMFunctionType) + return LLVMVectorZipWith(LLVMPointerType(kernel.type.return_type), kernel, v1_cast, v2_cast, sz_low) + + kernel_term, vector_term, size_term = args + v_low, sz_low = low_term(vector_term), low_term(size_term, LLVMInt) + element_type = self._get_vector_base_type(v_low.type) + vec_cast = self._cast_if_needed(v_low, LLVMPointerType(element_type)) + + if op in ("Vector_filter", "Vector_count"): + res_el = LLVMBool + elif expected and isinstance(expected, LLVMPointerType): + res_el = expected.element_type + else: + k_lowered = self._lower_as_standalone(kernel_term, None, type_env, env, allowed, True) + res_el = k_lowered.type.return_type if isinstance(k_lowered.type, LLVMFunctionType) else LLVMInt + + k_params = [LLVMInt, element_type] if op == "Vector_imap" else [element_type] + kernel = self._lower_as_standalone( + kernel_term, LLVMFunctionType(k_params, res_el), type_env, env, allowed, True + ) + + if op == "Vector_filter": + return LLVMVectorFilter(vec_cast.type, kernel, vec_cast, sz_low) + if op == "Vector_count": + return LLVMVectorCount(LLVMInt, kernel, vec_cast, sz_low) + + assert isinstance(kernel.type, LLVMFunctionType) + res_vec_ty = LLVMPointerType(kernel.type.return_type) + return ( + LLVMVectorMap(res_vec_ty, kernel, vec_cast, sz_low) + if op == "Vector_map" + else LLVMVectorIMap(res_vec_ty, kernel, vec_cast, sz_low) + ) + + def _lower_term( + self, + term: Term | LLVMTerm, + expected: LLVMType | None = None, + type_env: Dict[Name, LLVMType] = None, + env: Dict[Name, LLVMTerm] = None, + allowed: set[Name] = None, + in_vector_op: bool = False, + ) -> LLVMTerm: + if isinstance(term, LLVMTerm): + return term + type_env, env, allowed = type_env or {}, env or {}, allowed or set() + + def recurse(t, exp=None, vec=in_vector_op): + return self._lower_term(t, exp, type_env, env, allowed, in_vector_op=vec) + + match term: + case Literal(val, ty): + return LLVMLiteral(to_llvm_type(ty), val) + case Var(name): + return self._lower_var(name, expected, type_env, env) + case Annotation(e, ty) | TypeApplication(e, ty): + return recurse(e, to_llvm_type(ty)) + case TypeAbstraction(_, _, body): + return recurse(body, expected) + case Abstraction(_, _): + return self._lower_as_standalone(term, expected, type_env, env, allowed, in_vector_op) + case Application(_, _): + return self._lower_app(term, expected, type_env, env, allowed, in_vector_op) + case Let(name, val, body): + return self._lower_let(name, val, body, expected, type_env, env, allowed, in_vector_op) + case Rec(name, ty, val, body): + return self._lower_rec(name, ty, val, body, expected, type_env, env, allowed, in_vector_op) + case If(cond, then_t, else_t): + low_cond, low_then, low_else = recurse(cond), recurse(then_t, expected), recurse(else_t, expected) + return LLVMIf(low_then.type, low_cond, low_then, low_else) + case _: + raise LLVMLoweringError(f"could not lower term {term}") + + def _lower_var( + self, name: Name, expected: LLVMType | None, type_env: Dict[Name, LLVMType], env: Dict[Name, LLVMTerm] + ) -> LLVMTerm: + if name.name in BINARY_OPS or name.name in UNARY_OPS: + return LLVMVar(self._get_operator_type(name.name, expected), name) + + if name.name in BUILTIN_FUNCTION_TYPES: + ty = BUILTIN_FUNCTION_TYPES[name.name] + if expected and isinstance(expected, LLVMFunctionType) and len(expected.arg_types) == len(ty.arg_types): + ty = expected + elif expected and name.name == "Vector_new" and isinstance(expected, LLVMPointerType): + ty = LLVMFunctionType([], expected) + return LLVMVar(ty, name) + + if name in env: + return env[name] + for en, term in env.items(): + if en.name == name.name: + return term + + var_ty = type_env.get(name) or expected or LLVMInt + return LLVMVar(var_ty, name) + + def _lower_app( + self, + t: Application, + expected: LLVMType | None, + type_env: Dict[Name, LLVMType], + env: Dict[Name, LLVMTerm], + allowed: set[Name], + in_vec: bool, + ) -> LLVMTerm: + base, args = self._uncurry(t) + lowered_base = self._lower_term(base, None, type_env, env, allowed, in_vector_op=in_vec) + if not lowered_base: + raise LLVMBackendError(f"could not lower base {base}") + + target, prev_args, eff_ty = self._extract_call_info(lowered_base) + params, ret = self.get_signature(eff_ty) + lookup = self._get_lookup_name(target) + + if lookup in BUILTIN_FUNCTION_TYPES or lookup in VECTOR_OPERATIONS: + return self._lower_builtin_call(lookup, target, prev_args, args, expected, type_env, env, allowed, in_vec) + + all_args = prev_args + self._lower_args(args, params, len(prev_args), type_env, env, allowed, in_vec) + return self._create_call_or_partial(target, all_args, params, ret) + + def _lower_builtin_call( + self, + name: str, + target: LLVMTerm, + prev_args: list[LLVMTerm], + args: list[Term], + expected: LLVMType | None, + type_env: Dict[Name, LLVMType], + env: Dict[Name, LLVMTerm], + allowed: set[Name], + in_vec: bool, + ) -> LLVMTerm: + params, ret = self.get_signature(target.type) + + is_full_v = self._is_full_vector_op(name, len(prev_args) + len(args)) + if (in_vec or name in VECTOR_OPERATIONS) and is_full_v: + return self._lower_vector_op(name, list(prev_args) + list(args), expected, type_env, env, allowed) + + all_args = prev_args + self._lower_args(args, params, len(prev_args), type_env, env, allowed, in_vec) + + if name in POLYMORPHIC_FUNCTIONS: + if name.startswith("Math"): + all_args = [self._cast_if_needed(a, p) for a, p in zip(all_args, params)] + return LLVMCall(ret, target, all_args) + + if name == "Vector_get" and len(all_args) == 2: + return self._lower_vector_get(all_args[0], all_args[1]) + + if name == "Vector_set" and len(all_args) == 3: + return self._lower_vector_set(all_args[0], all_args[1], all_args[2]) + + return self._create_call_or_partial(target, all_args, params, ret) + + def _extract_call_info(self, lowered: LLVMTerm) -> tuple[LLVMTerm, list[LLVMTerm], LLVMType]: + if isinstance(lowered, LLVMCall) and isinstance(lowered.type, LLVMFunctionType): + return lowered.target, lowered.args, lowered.type + return lowered, [], lowered.type + + def _get_lookup_name(self, target: LLVMTerm) -> str: + name = self._get_target_name(target) + return name.rsplit("_", 1)[0] if name.rsplit("_", 1)[-1].isdigit() else name + + def _is_full_vector_op(self, op: str, total_args: int) -> bool: + if op not in VECTOR_OPERATIONS: + return False + threshold = 4 if op in ("Vector_reduce", "Vector_zipWith") else 3 + return total_args >= threshold + + def _lower_vector_get(self, vec: LLVMTerm, idx: LLVMTerm) -> LLVMLoad: + el = self._get_vector_base_type(vec.type) + ptr_ty = LLVMPointerType(el) + vec_ptr = self._cast_if_needed(vec, ptr_ty) + return LLVMLoad(el, LLVMGetElementPtr(ptr_ty, vec_ptr, [idx])) + + def _lower_vector_set(self, vec: LLVMTerm, idx: LLVMTerm, val: LLVMTerm) -> LLVMStore: + el = self._get_vector_base_type(vec.type) + ptr_ty = LLVMPointerType(el) + vec_ptr = self._cast_if_needed(vec, ptr_ty) + val_cast = self._cast_if_needed(val, el) + return LLVMStore(vec.type, val_cast, LLVMGetElementPtr(ptr_ty, vec_ptr, [idx])) + + def _create_call_or_partial( + self, target: LLVMTerm, args: list[LLVMTerm], params: list[LLVMType], ret: LLVMType + ) -> LLVMCall: + if len(args) < len(params): + return LLVMCall(LLVMFunctionType(params[len(args) :], ret), target, args) + return LLVMCall(ret, target, args) + + def _lower_let( + self, + name: Name, + val: Term, + body: Term, + expected: LLVMType | None, + type_env: Dict[Name, LLVMType], + env: Dict[Name, LLVMTerm], + allowed: set[Name], + in_vec: bool, + ) -> LLVMTerm: + lowered_val = self._lower_term(val, None, type_env, env, allowed, in_vector_op=in_vec) + if isinstance(lowered_val, LLVMFunction): + lowered_val.name = name + new_ty_env = {**type_env, name: lowered_val.type if lowered_val else LLVMInt} + new_env = env.copy() + if lowered_val and self._is_inlinable_anf(name, lowered_val): + new_env[name] = lowered_val + else: + new_env[name] = LLVMVar(new_ty_env[name], name) + lowered_body = self._lower_term(body, expected, new_ty_env, new_env, allowed, in_vector_op=in_vec) + return ( + lowered_body + if self._is_inlinable_anf(name, lowered_val) + else LLVMLet(lowered_body.type, name, lowered_val, lowered_body) + ) + + def _lower_rec( + self, + name: Name, + ty: Type, + val: Term, + body: Term, + expected: LLVMType | None, + type_env: Dict[Name, LLVMType], + env: Dict[Name, LLVMTerm], + allowed: set[Name], + in_vec: bool, + ) -> LLVMLet: + func_type = to_llvm_type(ty) + params, ret_ty = self.get_signature(func_type) + flat_func_type = LLVMFunctionType(params, ret_ty) + new_ty_env, new_env = {**type_env, name: flat_func_type}, {**env, name: LLVMVar(flat_func_type, name)} + lowered_val = self._lower_term(val, flat_func_type, new_ty_env, new_env, allowed, in_vector_op=in_vec) + if isinstance(lowered_val, LLVMFunction): + lowered_val.name = name + lowered_body = self._lower_term(body, expected, new_ty_env, new_env, allowed, in_vector_op=in_vec) + return LLVMLet(lowered_body.type, name, lowered_val, lowered_body) diff --git a/aeon/llvm/cpu/pipeline.py b/aeon/llvm/cpu/pipeline.py new file mode 100644 index 00000000..7dfe8e13 --- /dev/null +++ b/aeon/llvm/cpu/pipeline.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +from typing import Any, Dict + +from loguru import logger + +from aeon.core.terms import Term, Let, Rec +from aeon.core.types import Type +from aeon.llvm.core import ( + LLVMExecutionEngine, + LLVMIRGenerator, + LLVMLowerer, + LLVMBackendError, + LLVMPipeline, +) +from aeon.llvm.llvm_ast import LLVMFunction, LLVMTerm +from aeon.llvm.utils import sanitize_name, to_llvm_type +from aeon.utils.name import Name + + +class CPULLVMPipeline(LLVMPipeline): + def __init__( + self, + executor: LLVMExecutionEngine, + generator: LLVMIRGenerator, + lowerer: LLVMLowerer, + metadata: Dict[Name, Any] | None = None, + debug: bool = False, + ): + self.executor = executor + self.generator = generator + self.lowerer = lowerer + self.metadata = metadata or {} + self.debug = debug + self.compiled_functions: Dict[Name, LLVMTerm] = {} + self.name_to_id_cache: Dict[str, Name] = {} + self.type_environment: Dict[Name, Type] = {} + self.llvm_ir: str = "" + + def compile(self, program: Term): + discovered_targets = self._find_compilation_targets(program) + + for target_id, target_body in discovered_targets.items(): + target_type = self.type_environment[target_id] + target_llvm_type = to_llvm_type(target_type) + + llvm_ast = self.lowerer.lower( + target_body, + expected_type=target_llvm_type, + type_env={fid: to_llvm_type(ty) for fid, ty in self.type_environment.items()}, + env={fid: value for fid, value in self.compiled_functions.items()}, + ) + + if isinstance(llvm_ast, LLVMFunction): + llvm_ast.name = target_id + + self.compiled_functions[target_id] = llvm_ast + + self.llvm_ir = self.generator.generate_ir(list(self.compiled_functions.values())) + if self.debug: + logger.debug(self.llvm_ir) + + def _find_compilation_targets(self, term: Term) -> Dict[Name, Term]: + discovery_targets = {} + current_term = term + + while isinstance(current_term, (Let, Rec)): + if isinstance(current_term, Rec): + target_id = current_term.var_name + target_name = target_id.name + + self.type_environment[target_id] = current_term.var_type + self.name_to_id_cache[target_name] = target_id + + should_compile = False + if not self.metadata: + should_compile = True + else: + for meta_key, meta_value in self.metadata.items(): + key_string = meta_key.name if isinstance(meta_key, Name) else str(meta_key) + if key_string == target_name and meta_value.get("llvm"): + should_compile = True + break + + if should_compile: + discovery_targets[target_id] = current_term.var_value + + current_term = current_term.body + + return discovery_targets + + def get_curried_function(self, function_id: Name): + if function_id not in self.compiled_functions: + function_id = self.name_to_id_cache.get(function_id.name) + + if function_id is None or function_id not in self.compiled_functions: + return None + + target_type = self.type_environment.get(function_id) + if not target_type: + return None + + target_llvm_type = to_llvm_type(target_type) + param_types, return_type = self.lowerer.get_signature(target_llvm_type) + + def invoke_wrapper(accumulated_args: list[Any]): + if len(accumulated_args) == len(param_types): + assert function_id is not None + return self.invoke(function_id, accumulated_args) + else: + return lambda next_arg: invoke_wrapper(accumulated_args + [next_arg]) + + return invoke_wrapper([]) + + def invoke(self, name_id: Name, arguments: list[Any]): + target_type = self.type_environment.get(name_id) + if not target_type: + raise LLVMBackendError(f"type for {name_id} not found") + + target_llvm_type = to_llvm_type(target_type) + param_types, return_type = self.lowerer.get_signature(target_llvm_type) + + return self.executor.execute( + self.llvm_ir, + sanitize_name(name_id), + arguments, + param_types, + return_type, + ) diff --git a/aeon/llvm/decorators/gpu.py b/aeon/llvm/decorators/gpu.py new file mode 100644 index 00000000..af8f0a50 --- /dev/null +++ b/aeon/llvm/decorators/gpu.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from typing import Any + +from aeon.decorators.api import Metadata, metadata_update +from aeon.sugar.program import Definition, STerm, SLiteral + + +def gpu( + args: list[STerm], + fun: Definition, + metadata: Metadata, +) -> tuple[Definition, list[Definition], Metadata]: + gpu_info: dict[str, Any] = { + "gpu": True, + "gpu_device": "cuda", + "gpu_debug": False, + "gpu_cache": False, + "gpu_block_size": 1, + "gpu_thread_count": 1, + } + + arg_keys = ["gpu_device", "gpu_debug", "gpu_cache", "gpu_block_size", "gpu_thread_count"] + + for key, arg in zip(arg_keys, args): + if isinstance(arg, SLiteral): + gpu_info[key] = arg.value + + metadata = metadata_update(metadata, fun, gpu_info) + return fun, [], metadata diff --git a/aeon/llvm/decorators/llvm.py b/aeon/llvm/decorators/llvm.py new file mode 100644 index 00000000..5a7efb81 --- /dev/null +++ b/aeon/llvm/decorators/llvm.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from typing import Any + +from aeon.decorators.api import Metadata, metadata_update +from aeon.sugar.program import Definition, STerm, SLiteral + + +def llvm( + args: list[STerm], + fun: Definition, + metadata: Metadata, +) -> tuple[Definition, list[Definition], Metadata]: + llvm_args: dict[str, Any] = {"llvm": True, "llvm_debug": False, "llvm_cache": False} + + arg_keys = ["llvm_debug", "llvm_cache"] + + for key, arg in zip(arg_keys, args): + if isinstance(arg, SLiteral): + llvm_args[key] = arg.value + + metadata = metadata_update(metadata, fun, llvm_args) + return fun, [], metadata diff --git a/aeon/llvm/llvm_ast.py b/aeon/llvm/llvm_ast.py new file mode 100644 index 00000000..d320fd75 --- /dev/null +++ b/aeon/llvm/llvm_ast.py @@ -0,0 +1,269 @@ +from __future__ import annotations + +from dataclasses import dataclass +from enum import IntEnum +from typing import Any +from aeon.utils.name import Name + + +@dataclass(frozen=True) +class LLVMType: + def is_pointer(self) -> bool: + return isinstance(self, LLVMPointerType) + + def is_function(self) -> bool: + return isinstance(self, LLVMFunctionType) + + +@dataclass(frozen=True) +class LLVMIntType(LLVMType): + bits: int = 32 + + def __str__(self): + return f"i{self.bits}" + + +@dataclass(frozen=True) +class LLVMFloatType(LLVMType): + def __str__(self): + return "float" + + +@dataclass(frozen=True) +class LLVMDoubleType(LLVMType): + def __str__(self): + return "double" + + +@dataclass(frozen=True) +class LLVMBoolType(LLVMType): + def __str__(self): + return "i1" + + +@dataclass(frozen=True) +class LLVMCharType(LLVMType): + def __str__(self): + return "i8" + + +@dataclass(frozen=True) +class LLVMVoidType(LLVMType): + def __str__(self): + return "void" + + +class LLVMAddressSpace(IntEnum): + GENERIC = 0 + GLOBAL = 1 + SHARED = 3 + CONSTANT = 4 + LOCAL = 5 + + +@dataclass(frozen=True) +class LLVMPointerType(LLVMType): + element_type: LLVMType + address_space: LLVMAddressSpace = LLVMAddressSpace.GENERIC + + def __str__(self): + return f"{self.element_type}*" + + +@dataclass(frozen=True) +class LLVMArrayType(LLVMType): + element_type: LLVMType + size: int | None = None + + def __str__(self): + return f"[{self.size if self.size is not None else 0} x {self.element_type}]" + + +@dataclass(frozen=True) +class LLVMFunctionType(LLVMType): + arg_types: list[LLVMType] + return_type: LLVMType + + def __str__(self): + args = ", ".join(str(t) for t in self.arg_types) + return f"{self.return_type} ({args})" + + +LLVMInt = LLVMIntType(32) +LLVMLong = LLVMIntType(64) +LLVMFloat = LLVMFloatType() +LLVMDouble = LLVMDoubleType() +LLVMBool = LLVMBoolType() +LLVMChar = LLVMCharType() +LLVMVoid = LLVMVoidType() +LLVMVectorInt = LLVMPointerType(LLVMInt) +LLVMVectorFloat = LLVMPointerType(LLVMFloat) +LLVMVectorDouble = LLVMPointerType(LLVMDouble) + +VECTOR_OPERATIONS: frozenset[str] = frozenset( + [ + "Vector_map", + "Vector_reduce", + "Vector_imap", + "Vector_filter", + "Vector_zipWith", + "Vector_count", + ] +) + + +@dataclass +class LLVMTerm: + type: LLVMType + + +@dataclass +class LLVMLiteral(LLVMTerm): + value: Any + + def __str__(self): + return str(self.value) + + +@dataclass +class LLVMVar(LLVMTerm): + name: Name + + def __str__(self): + return self.name.name + + +@dataclass +class LLVMIf(LLVMTerm): + cond: LLVMTerm + then_t: LLVMTerm + else_t: LLVMTerm + + def __str__(self): + return f"if {self.cond} then {self.then_t} else {self.else_t}" + + +@dataclass +class LLVMLet(LLVMTerm): + var_name: Name + var_value: LLVMTerm + body: LLVMTerm + + def __str__(self): + return f"let {self.var_name.name} = {self.var_value} in {self.body}" + + +@dataclass +class LLVMFunction(LLVMTerm): + arg_names: list[Name] + arg_types: list[LLVMType] + body: LLVMTerm + name: Name | None = None + + def __str__(self): + args = ", ".join(f"{n.name}:{t}" for n, t in zip(self.arg_names, self.arg_types)) + return f"\\{args} -> {self.body}" + + +@dataclass +class LLVMCall(LLVMTerm): + target: LLVMTerm + args: list[LLVMTerm] + + def __str__(self): + args = ", ".join(str(a) for a in self.args) + return f"{self.target}({args})" + + +@dataclass +class LLVMCast(LLVMTerm): + val: LLVMTerm + + def __str__(self): + return f"cast {self.val} to {self.type}" + + +@dataclass +class LLVMGetElementPtr(LLVMTerm): + ptr: LLVMTerm + indices: list[LLVMTerm] + + def __str__(self): + indices = ", ".join(str(i) for i in self.indices) + return f"gep {self.ptr}, {indices}" + + +@dataclass +class LLVMLoad(LLVMTerm): + ptr: LLVMTerm + + def __str__(self): + return f"load {self.ptr}" + + +@dataclass +class LLVMStore(LLVMTerm): + value: LLVMTerm + ptr: LLVMTerm + + def __str__(self): + return f"store {self.value}, {self.ptr}" + + +@dataclass +class LLVMAlloc(LLVMTerm): + def __str__(self): + return f"alloca {self.type}" + + +@dataclass +class LLVMVectorOp(LLVMTerm): + f: LLVMTerm + v: LLVMTerm + size: LLVMTerm + + +@dataclass +class LLVMVectorMap(LLVMVectorOp): + def __str__(self): + return f"vector_map {self.f}, {self.v}, {self.size}" + + +@dataclass +class LLVMVectorReduce(LLVMTerm): + f: LLVMTerm + initial: LLVMTerm + v: LLVMTerm + size: LLVMTerm + + def __str__(self): + return f"vector_reduce {self.f}, {self.initial}, {self.v}, {self.size}" + + +@dataclass +class LLVMVectorIMap(LLVMVectorOp): + def __str__(self): + return f"vector_imap {self.f}, {self.v}, {self.size}" + + +@dataclass +class LLVMVectorFilter(LLVMVectorOp): + def __str__(self): + return f"vector_filter {self.f}, {self.v}, {self.size}" + + +@dataclass +class LLVMVectorZipWith(LLVMTerm): + f: LLVMTerm + v1: LLVMTerm + v2: LLVMTerm + size: LLVMTerm + + def __str__(self): + return f"vector_zipWith {self.f}, {self.v1}, {self.v2}, {self.size}" + + +@dataclass +class LLVMVectorCount(LLVMVectorOp): + def __str__(self): + return f"vector_count {self.f}, {self.v}, {self.size}" diff --git a/aeon/llvm/utils.py b/aeon/llvm/utils.py new file mode 100644 index 00000000..b3d3f516 --- /dev/null +++ b/aeon/llvm/utils.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +from aeon.core.types import TypeConstructor, RefinedType, AbstractionType, Type +from aeon.utils.name import Name +from aeon.llvm.core import LLVMValidationError +from aeon.llvm.llvm_ast import ( + LLVMType, + LLVMInt, + LLVMFloat, + LLVMBool, + LLVMChar, + LLVMDouble, + LLVMVoid, + LLVMLong, + LLVMFunctionType, + LLVMVectorInt, + LLVMPointerType, +) + +SUPPORTED_TYPES = {"Int", "Float", "Bool", "Char", "Double", "Long", "Unit", "Vector", "String"} +BINARY_OPS = {"+", "-", "*", "/", "%", "==", "!=", "<", "<=", ">", ">=", "&&", "||"} +UNARY_OPS = {"!", "-"} + + +def validate_ops(op: str): + if not op.startswith("anf") and op not in BINARY_OPS and op not in UNARY_OPS: + raise LLVMValidationError(f"LLVM Backend does not support operation {op}") + + +def sanitize_name(name: Name) -> str: + res = name.name if name.id in (-1, 0) else f"{name.name}_{name.id}" + return res.translate(str.maketrans(".- ", "___")) + + +def validate_type(ty: Type): + match ty: + case RefinedType(_, it, _): + validate_type(it) + case AbstractionType(_, vt, rt): + validate_type(vt) + validate_type(rt) + case TypeConstructor(n, _) if n.name not in SUPPORTED_TYPES: + raise LLVMValidationError(f"LLVM Backend does not support type {n.name}") + case _: + pass + + +def get_builtin_op_type(op: str, is_f: bool = False) -> LLVMFunctionType: + t = LLVMFloat if is_f else LLVMInt + if op in {"==", "!=", "<", "<=", ">", ">="}: + return LLVMFunctionType([t, t], LLVMBool) + if op in {"&&", "||"}: + return LLVMFunctionType([LLVMBool, LLVMBool], LLVMBool) + if op == "!": + return LLVMFunctionType([LLVMBool], LLVMBool) + if op in BINARY_OPS: + return LLVMFunctionType([t, t], t) + if op in UNARY_OPS: + return LLVMFunctionType([t], t) + raise LLVMValidationError(f"Unknown operator {op}") + + +def to_llvm_type(ty: Type) -> LLVMType: + match ty: + case RefinedType(_, it, _): + return to_llvm_type(it) + case AbstractionType(_, vt, rt): + args, curr = [to_llvm_type(vt)], rt + while isinstance(curr, AbstractionType): + args.append(to_llvm_type(curr.var_type)) + curr = curr.type + return LLVMFunctionType(args, to_llvm_type(curr)) + case TypeConstructor(n, args): + match n.name: + case "Int": + return LLVMInt + case "Float": + return LLVMFloat + case "Double": + return LLVMDouble + case "Long": + return LLVMLong + case "Bool": + return LLVMBool + case "Char": + return LLVMChar + case "Unit": + return LLVMVoid + case "Vector": + return LLVMPointerType(to_llvm_type(args[0])) if args else LLVMVectorInt + case _: + return LLVMInt + case _: + return LLVMInt diff --git a/aeon/prelude/prelude.py b/aeon/prelude/prelude.py index 073a10ab..c3ad43b8 100644 --- a/aeon/prelude/prelude.py +++ b/aeon/prelude/prelude.py @@ -24,7 +24,72 @@ def native_import(name): return importlib.import_module(name) -native_types: list[Name] = [Name("Unit", 0), Name("Bool", 0), Name("Int", 0), Name("Float", 0), Name("String", 0)] +native_types: list[Name] = [ + Name("Unit", 0), + Name("Bool", 0), + Name("Int", 0), + Name("Float", 0), + Name("String", 0), + Name("Tensor", 0), + Name("GpuConfig", 0), +] + + +def gpu_map(f): + def run(t, conf=None): + # TODO add gpu support + return [f(x) for x in t] + + return run + + +def gpu_imap(f): + def run(t, conf=None): + # TODO add gpu support + return [f(i) for i in range(len(t))] + + return run + + +def gpu_reduce(f): + def with_initial(i): + def run(t, conf=None): + import functools + + # TODO add gpu support + return functools.reduce(lambda x, y: f(x)(y), t, i) + + return run + + return with_initial + + +def gpu_filter(f): + def run(t, conf=None): + # TODO add gpu support + return [x for x in t if f(x)] + + return run + + +def gpu_dot(a): + def with_b(b): + # TODO add gpu support + return sum(x * y for x, y in zip(a, b)) + + return with_b + + +def gpu_run(k): + def with_config(c): + def with_input(t): + print(f"Executing kernel on CPU (for now) with config {c}...") + return k(t) + + return with_input + + return with_config + # TODO: polymorphic signatures prelude = [ @@ -45,6 +110,33 @@ def native_import(name): ("&&", "(x:Bool) -> (y:Bool) -> Bool", lambda x: lambda y: x and y), ("||", "(x:Bool) -> (y:Bool) -> Bool", lambda x: lambda y: x or y), ("!", "(x:Bool) -> Bool", lambda x: not x), + ("gpu_map", "forall a:B, forall b:B, (f:(x:a) -> b) -> (t:Tensor) -> Tensor", gpu_map), + ( + "gpu_imap", + "forall b:B, (f:(i:Int) -> b) -> (t:Tensor) -> Tensor", + gpu_imap, + ), + ("__index__", "forall a:B, (t:Tensor) -> (i:Int) -> a", lambda t: lambda i: t[i]), + ( + "gpu_reduce", + "forall a:B, (f:(acc:a) -> (x:a) -> a) -> (initial:a) -> (t:Tensor) -> a", + gpu_reduce, + ), + ( + "gpu_filter", + "forall a:B, (f:(x:a) -> Bool) -> (t:Tensor) -> Tensor", + gpu_filter, + ), + ( + "gpu_dot", + "(a:Tensor) -> (b:Tensor) -> Float", + gpu_dot, + ), + ( + "run_gpu", + "forall a:B, (kernel:(x:Tensor) -> a) -> (config:GpuConfig) -> (input:Tensor) -> a", + gpu_run, + ), ] typing_vars: dict[Name, SType] = {} diff --git a/aeon/sugar/ast_helpers.py b/aeon/sugar/ast_helpers.py index 35e74f4f..8630141f 100644 --- a/aeon/sugar/ast_helpers.py +++ b/aeon/sugar/ast_helpers.py @@ -15,6 +15,8 @@ def mk_binop(fresh: Callable[[], str], op: Name, a1: STerm, a2: STerm) -> STerm: st_int = STypeConstructor(Name("Int", 0)) st_float = STypeConstructor(Name("Float", 0)) st_string = STypeConstructor(Name("String", 0)) +st_tensor = STypeConstructor(Name("Tensor", 0)) +st_gpu_config = STypeConstructor(Name("GpuConfig", 0)) true = SLiteral(True, st_bool) false = SLiteral(False, st_bool) diff --git a/aeon/sugar/stypes.py b/aeon/sugar/stypes.py index 74330382..8c9da445 100644 --- a/aeon/sugar/stypes.py +++ b/aeon/sugar/stypes.py @@ -76,7 +76,7 @@ def __hash__(self): return hash(self.name) + sum(hash(c) for c in self.args) -builtin_types = ["Top", "Bool", "Int", "Float", "String", "Unit"] +builtin_types = ["Top", "Bool", "Int", "Float", "String", "Unit", "Tensor", "GpuConfig"] def get_type_vars(ty: SType) -> set[STypeVar]: diff --git a/libraries/Vector.ae b/libraries/Vector.ae new file mode 100644 index 00000000..00442a35 --- /dev/null +++ b/libraries/Vector.ae @@ -0,0 +1,27 @@ +type Vector a; + +def Vector_new : forall a:B, (Vector a) = native "[]"; + +def Vector_append : forall a:B, (v:(Vector a)) -> (x:a) -> (Vector a) = \v -> \x -> native "v + [x]"; + +def Vector_get : forall a:B, (v:(Vector a)) -> (i:Int) -> a = \v -> \i -> native "v[i]"; + +def Vector_set : forall a:B, (v:(Vector a)) -> (i:Int) -> (x:a) -> (Vector a) = \v -> \i -> \x -> native "v[:i] + [x] + v[i+1:]"; + +def Vector_map : forall a:B, forall b:B, (f:(x:a) -> b) -> (v:(Vector a)) -> (size:Int) -> (Vector b) = \f -> \v -> \size -> + native "[f(x) for x in v]"; + +def Vector_reduce : forall a:B, forall b:B, (f:(acc:a) -> (curr:b) -> a) -> (initial:a) -> (v:(Vector b)) -> (size:Int) -> a = \f -> \initial -> \v -> \size -> + native "__import__('functools').reduce(lambda acc, curr: f(acc)(curr), v, initial)"; + +def Vector_imap : forall a:B, forall b:B, (f:(i:Int) -> (x:a) -> b) -> (v:(Vector a)) -> (size:Int) -> (Vector b) = \f -> \v -> \size -> + native "[f(i)(x) for i, x in enumerate(v)]"; + +def Vector_filter : forall a:B, (f:(x:a) -> Bool) -> (v:(Vector a)) -> (size:Int) -> (Vector a) = \f -> \v -> \size -> + native "list(filter(f, v))"; + +def Vector_zipWith : forall a:B, forall b:B, forall c:B, (f:(x:a) -> (y:b) -> c) -> (v1:(Vector a)) -> (v2:(Vector b)) -> (size:Int) -> (Vector c) = \f -> \v1 -> \v2 -> \size -> + native "[f(x)(y) for x, y in zip(v1, v2)]"; + +def Vector_count : forall a:B, (f:(x:a) -> Bool) -> (v:(Vector a)) -> (size:Int) -> Int = \f -> \v -> \size -> + native "len(list(filter(f, v)))"; diff --git a/pyproject.toml b/pyproject.toml index 4c7d2842..0f6e40ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,8 @@ dependencies = [ 'textdistance', 'z3-solver >= 4', 'zstandard==0.25.0', - 'zss' + 'zss', + "llvmlite", ] [project.optional-dependencies] diff --git a/tests/llvm_decorators_test.py b/tests/llvm_decorators_test.py new file mode 100644 index 00000000..a7d2a7a0 --- /dev/null +++ b/tests/llvm_decorators_test.py @@ -0,0 +1,72 @@ +from aeon.sugar.desugar import desugar +from aeon.sugar.parser import parse_program + + +def test_gpu_decorator_defaults(): + source = """ + @gpu + def max(a:Int) (b:Int) : Int { if a > b then a else b } + """ + prog = parse_program(source) + desugared = desugar(prog) + + gpu_funcs = [k for k, v in desugared.metadata.items() if v.get("gpu")] + assert len(gpu_funcs) == 1 + + max_meta = desugared.metadata[gpu_funcs[0]] + assert max_meta["gpu_device"] == "cuda" + assert max_meta["gpu_debug"] is False + assert max_meta["gpu_cache"] is False + assert max_meta["gpu_block_size"] == 1 + assert max_meta["gpu_thread_count"] == 1 + + +def test_gpu_decorator_with_args(): + source = """ + @gpu("cuda", true, true, 256, 128) + def max(a:Int) (b:Int) : Int { if a > b then a else b } + """ + prog = parse_program(source) + desugared = desugar(prog) + + gpu_funcs = [k for k, v in desugared.metadata.items() if v.get("gpu")] + assert len(gpu_funcs) == 1 + + max_meta = desugared.metadata[gpu_funcs[0]] + assert max_meta["gpu_device"] == "cuda" + assert max_meta["gpu_debug"] is True + assert max_meta["gpu_cache"] is True + assert max_meta["gpu_block_size"] == 256 + assert max_meta["gpu_thread_count"] == 128 + + +def test_llvm_decorator_defaults(): + source = """ + @llvm + def max(a:Int) (b:Int) : Int { if a > b then a else b } + """ + prog = parse_program(source) + desugared = desugar(prog) + + llvm_funcs = [k for k, v in desugared.metadata.items() if v.get("llvm")] + assert len(llvm_funcs) == 1 + + max_meta = desugared.metadata[llvm_funcs[0]] + assert max_meta["llvm_debug"] is False + assert max_meta["llvm_cache"] is False + + +def test_llvm_decorator_with_args(): + source = """ + @llvm(true, true) + def max(a:Int) (b:Int) : Int { if a > b then a else b } + """ + prog = parse_program(source) + desugared = desugar(prog) + + llvm_funcs = [k for k, v in desugared.metadata.items() if v.get("llvm")] + assert len(llvm_funcs) == 1 + + max_meta = desugared.metadata[llvm_funcs[0]] + assert max_meta["llvm_debug"] is True + assert max_meta["llvm_cache"] is True diff --git a/tests/llvm_e2e_test.py b/tests/llvm_e2e_test.py new file mode 100644 index 00000000..1a670339 --- /dev/null +++ b/tests/llvm_e2e_test.py @@ -0,0 +1,176 @@ +import sys +from loguru import logger +from aeon.facade.driver import AeonDriver, AeonConfig +from aeon.logger.logger import setup_logger +from aeon.synthesis.uis.api import SilentSynthesisUI + + +def compile_and_run(source: str): + setup_logger() + logger.add(sys.stderr, level="DEBUG") + + cfg = AeonConfig(synthesizer="random_search", synthesis_ui=SilentSynthesisUI(), synthesis_budget=0, no_main=True) + driver = AeonDriver(cfg) + errors = driver.parse(aeon_code=source) + assert not errors + return driver.run() + + +def test_e2e_sum_floats(): + source = r""" + @llvm + def special_sum (x:Float) (y:Float) : Float { + let w : Float = 5.0; + let z : Float = 10.0; + x + y - z * w + } + + def main (i:Int) : Float { special_sum 5.0 7.0 } + """ + res = compile_and_run(source) + assert res == -38.0 + + +def test_e2e_recursive(): + source = r""" + @llvm + def count_divisors (target:Int) (candidate:Int) : Int { + if candidate <= 0 then + 0 + else + let remainder : Int = target % candidate; + if remainder == 0 then + 1 + count_divisors target (candidate - 1) + else + count_divisors target (candidate - 1) + } + + def main (i:Int) : Int { count_divisors 100 50 } + """ + res = compile_and_run(source) + assert res == 8 + + +def test_e2e_llvm_fibonacci(): + source = r""" + @llvm + def fib(n:Int) : Int { + if n <= 1 then n else fib (n-1) + fib (n-2) + } + + def main (i:Int) : Int { fib 10 } + """ + res = compile_and_run(source) + assert res == 55 + + +def test_e2e_llvm_matrix_sum(): + source = r""" + import "Vector.ae"; + + @llvm + def add(acc:Int) (curr:Int) : Int { acc + curr } + + @llvm + def sum_matrix(m:(Vector Int)) (s:Int) : Int { + Vector_reduce[Int][Int] add 0 m s + } + + def main (i:Int) : Int { sum_matrix (native "[1, 2, 3, 4]") 4 } + """ + res = compile_and_run(source) + assert res == 10 + + +def test_e2e_llvm_matrix_filter(): + source = r""" + import "Vector.ae"; + + @llvm + def filter_even(m:(Vector Int)) (s:Int) : (Vector Int) { + Vector_filter[Int] (\x:Int -> x % 2 == 0) m s + } + + + def main (i:Int) : Int { + let m : (Vector Int) = native "[1, 2, 3, 4, 5, 6]" in + let filtered : (Vector Int) = filter_even m 6 in + Vector_get[Int] filtered 0 + } + """ + res = compile_and_run(source) + assert res == 2 + + +def test_e2e_llvm_matrix_zip_with(): + source = r""" + import "Vector.ae"; + + @llvm + def vec_add(v1:(Vector Int)) (v2:(Vector Int)) (s:Int) : (Vector Int) { + Vector_zipWith[Int][Int][Int] (\x:Int -> \y:Int -> x + y) v1 v2 s + } + + + def main (i:Int) : Int { + let v1 : (Vector Int) = native "[1, 2, 3]" in + let v2 : (Vector Int) = native "[10, 20, 30]" in + let v3 : (Vector Int) = vec_add v1 v2 3 in + Vector_get[Int] v3 1 + } + """ + res = compile_and_run(source) + assert res == 22 + + +def test_e2e_llvm_matrix_count(): + source = r""" + import "Vector.ae"; + + @llvm + def count_gt_10(v:(Vector Int)) (s:Int) : Int { + Vector_count[Int] (\x:Int -> x > 10) v s + } + + def main (i:Int) : Int { + let v : (Vector Int) = native "[5, 15, 8, 25, 3]" in + count_gt_10 v 5 + } + """ + res = compile_and_run(source) + assert res == 2 + + +def test_e2e_llvm_matrix_map(): + source = r""" + import "Vector.ae"; + + @llvm + def vec_inc(v:(Vector Int)) (s:Int) : (Vector Int) { + Vector_map[Int][Int] (\x:Int -> x + 1) v s + } + + + def main (i:Int) : Int { + let v : (Vector Int) = native "[1, 2, 3, 4, 5]" in + let v2 : (Vector Int) = vec_inc v 5 in + Vector_get[Int] v2 2 + } + """ + res = compile_and_run(source) + assert res == 4 + + +def test_e2e_llvm_math_integration(): + source = r""" + import "Math.ae"; + + @llvm + def compute_circle_area(radius:Float) : Float { + Math_PI * Math_pow radius 2.0 + } + + def main (i:Int) : Float { compute_circle_area 5.0 } + """ + res = compile_and_run(source) + assert abs(res - 78.53981633974483) < 1e-5 diff --git a/tests/llvm_generator_test.py b/tests/llvm_generator_test.py new file mode 100644 index 00000000..eeb4b034 --- /dev/null +++ b/tests/llvm_generator_test.py @@ -0,0 +1,227 @@ +import pytest + +from aeon.llvm.cpu.converter import CPULLVMIRGenerator, LLVMIRGenerationError +from aeon.llvm.llvm_ast import ( + LLVMInt, + LLVMBool, + LLVMLiteral, + LLVMVar, + LLVMIf, + LLVMLet, + LLVMFunction, + LLVMCall, + LLVMFunctionType, +) +from aeon.utils.name import Name + + +def test_generate_literal(): + generator = CPULLVMIRGenerator() + + func_type = LLVMFunctionType(arg_types=[], return_type=LLVMInt) + func_ast = LLVMFunction(type=func_type, arg_names=[], arg_types=[], body=LLVMLiteral(type=LLVMInt, value=42)) + + kernel_ast = LLVMLet( + type=LLVMInt, var_name=Name("my_const"), var_value=func_ast, body=LLVMLiteral(type=LLVMInt, value=0) + ) + + ir_code = generator.generate_ir([kernel_ast]) + print(ir_code) + + assert 'define i32 @"my_const' in ir_code + assert "ret i32 42" in ir_code + + +def test_generate_if_else(): + generator = CPULLVMIRGenerator() + + cond = LLVMLiteral(type=LLVMBool, value=True) + then_t = LLVMLiteral(type=LLVMInt, value=10) + else_t = LLVMLiteral(type=LLVMInt, value=20) + if_ast = LLVMIf(type=LLVMInt, cond=cond, then_t=then_t, else_t=else_t) + + func_type = LLVMFunctionType(arg_types=[], return_type=LLVMInt) + func_ast = LLVMFunction(type=func_type, arg_names=[], arg_types=[], body=if_ast) + kernel_ast = LLVMLet( + type=LLVMInt, var_name=Name("my_branch"), var_value=func_ast, body=LLVMLiteral(type=LLVMInt, value=0) + ) + + ir_code = generator.generate_ir([kernel_ast]) + print(ir_code) + + assert 'define i32 @"my_branch' in ir_code + assert "br i1 1" in ir_code + assert "phi i32" in ir_code + assert "[10," in ir_code + assert "[20," in ir_code + + +def test_generate_local_let_shadowing(): + generator = CPULLVMIRGenerator() + + inner_let = LLVMLet( + type=LLVMInt, + var_name=Name("x"), + var_value=LLVMLiteral(type=LLVMInt, value=10), + body=LLVMVar(type=LLVMInt, name=Name("x")), + ) + outer_let = LLVMLet(type=LLVMInt, var_name=Name("x"), var_value=LLVMLiteral(type=LLVMInt, value=5), body=inner_let) + + func_type = LLVMFunctionType(arg_types=[], return_type=LLVMInt) + func_ast = LLVMFunction(type=func_type, arg_names=[], arg_types=[], body=outer_let) + kernel_ast = LLVMLet( + type=LLVMInt, var_name=Name("my_shadow"), var_value=func_ast, body=LLVMLiteral(type=LLVMInt, value=0) + ) + + ir_code = generator.generate_ir([kernel_ast]) + print(ir_code) + assert "ret i32 10" in ir_code + + +def test_generate_abstraction_and_call(): + generator = CPULLVMIRGenerator() + + func_type = LLVMFunctionType(arg_types=[LLVMInt, LLVMInt], return_type=LLVMInt) + func_ast = LLVMFunction( + type=func_type, + arg_names=[Name("x"), Name("y")], + arg_types=[LLVMInt, LLVMInt], + body=LLVMVar(type=LLVMInt, name=Name("x")), + ) + + caller_type = LLVMFunctionType(arg_types=[], return_type=LLVMInt) + call_ast = LLVMCall( + type=LLVMInt, + target=LLVMVar(type=func_type, name=Name("my_func")), + args=[LLVMLiteral(type=LLVMInt, value=42), LLVMLiteral(type=LLVMInt, value=10)], + ) + caller_ast = LLVMFunction(type=caller_type, arg_names=[], arg_types=[], body=call_ast) + + program_ast = LLVMLet( + type=LLVMInt, + var_name=Name("my_func"), + var_value=func_ast, + body=LLVMLet( + type=LLVMInt, var_name=Name("my_caller"), var_value=caller_ast, body=LLVMLiteral(type=LLVMInt, value=0) + ), + ) + + ir_code = generator.generate_ir([program_ast]) + print(ir_code) + + assert 'define i32 @"my_func"(i32 %"x", i32 %"y")' in ir_code + assert 'define i32 @"my_caller' in ir_code + assert 'call i32 @"my_func"(i32 42, i32 10)' in ir_code + + +def test_generate_sum_with_if(): + generator = CPULLVMIRGenerator() + + sum_type = LLVMFunctionType(arg_types=[LLVMInt, LLVMInt], return_type=LLVMInt) + sum_func = LLVMFunction( + type=sum_type, + arg_names=[Name("x"), Name("y")], + arg_types=[LLVMInt, LLVMInt], + body=LLVMCall( + type=LLVMInt, + target=LLVMVar(type=sum_type, name=Name("+")), + args=[LLVMVar(type=LLVMInt, name=Name("x")), LLVMVar(type=LLVMInt, name=Name("y"))], + ), + ) + + main_type = LLVMFunctionType(arg_types=[], return_type=LLVMInt) + + x_var = LLVMVar(type=LLVMInt, name=Name("x")) + cond = LLVMCall( + type=LLVMBool, + target=LLVMVar(type=LLVMFunctionType([LLVMInt, LLVMInt], LLVMBool), name=Name("<")), + args=[x_var, LLVMLiteral(type=LLVMInt, value=3)], + ) + y_val = LLVMIf( + type=LLVMInt, + cond=cond, + then_t=LLVMLiteral(type=LLVMInt, value=7), + else_t=LLVMLiteral(type=LLVMInt, value=10), + ) + + body = LLVMLet( + type=LLVMInt, + var_name=Name("x"), + var_value=LLVMLiteral(type=LLVMInt, value=5), + body=LLVMLet( + type=LLVMInt, + var_name=Name("y"), + var_value=y_val, + body=LLVMCall( + type=LLVMInt, + target=LLVMVar(type=sum_type, name=Name("sum")), + args=[LLVMVar(type=LLVMInt, name=Name("x")), LLVMVar(type=LLVMInt, name=Name("y"))], + ), + ), + ) + + main_func = LLVMFunction(type=main_type, arg_names=[], arg_types=[], body=body) + + program_ast = LLVMLet( + type=LLVMInt, + var_name=Name("sum"), + var_value=sum_func, + body=LLVMLet( + type=LLVMInt, + var_name=Name("main"), + var_value=main_func, + body=LLVMLiteral(type=LLVMInt, value=0), + ), + ) + + ir_code = generator.generate_ir([program_ast]) + print(ir_code) + + assert 'define i32 @"sum' in ir_code + assert 'define i32 @"main' in ir_code + assert "add i32" in ir_code + assert "icmp slt i32" in ir_code + + +def test_generate_unary_op(): + generator = CPULLVMIRGenerator() + + func_type = LLVMFunctionType(arg_types=[LLVMInt], return_type=LLVMInt) + func_ast = LLVMFunction( + type=func_type, + arg_names=[Name("x")], + arg_types=[LLVMInt], + body=LLVMCall( + type=LLVMInt, + target=LLVMVar(type=LLVMFunctionType([LLVMInt], LLVMInt), name=Name("-")), + args=[LLVMVar(type=LLVMInt, name=Name("x"))], + ), + ) + + kernel_ast = LLVMLet( + type=LLVMInt, + var_name=Name("my_neg"), + var_value=func_ast, + body=LLVMLiteral(type=LLVMInt, value=0), + ) + + ir_code = generator.generate_ir([kernel_ast]) + print(ir_code) + + assert 'define i32 @"my_neg' in ir_code + assert "sub i32 0," in ir_code + + +def test_undefined_variable_raises_error(): + generator = CPULLVMIRGenerator() + + bad_var = LLVMVar(type=LLVMInt, name=Name("not_found")) + + func_type = LLVMFunctionType(arg_types=[], return_type=LLVMInt) + func_ast = LLVMFunction(type=func_type, arg_names=[], arg_types=[], body=bad_var) + kernel_ast = LLVMLet( + type=LLVMInt, var_name=Name("bad_kernel"), var_value=func_ast, body=LLVMLiteral(type=LLVMInt, value=0) + ) + + with pytest.raises(LLVMIRGenerationError): + generator.generate_ir([kernel_ast]) diff --git a/tests/llvm_lowerer_test.py b/tests/llvm_lowerer_test.py new file mode 100644 index 00000000..69aeb870 --- /dev/null +++ b/tests/llvm_lowerer_test.py @@ -0,0 +1,92 @@ +from aeon.core.terms import Literal, Application, Var, Abstraction +from aeon.core.types import t_int +from aeon.llvm.cpu.lowerer import CPULLVMLowerer +from aeon.llvm.llvm_ast import LLVMFunctionType, LLVMInt, LLVMCall, LLVMFunction, LLVMLiteral +from aeon.utils.name import Name + + +def test_lower_literal(): + lowerer = CPULLVMLowerer() + lit = Literal(42, t_int) + llvm_lit = lowerer.lower(lit) + assert isinstance(llvm_lit, LLVMLiteral) + assert llvm_lit.value == 42 + assert llvm_lit.type == LLVMInt + + +def test_lower_var_op(): + lowerer = CPULLVMLowerer() + var_plus = Var(Name("+")) + llvm_plus = lowerer.lower(var_plus) + + assert isinstance(llvm_plus.type, LLVMFunctionType) + assert llvm_plus.type.arg_types == [LLVMInt, LLVMInt] + assert llvm_plus.type.return_type == LLVMInt + + +def test_lower_full_application(): + lowerer = CPULLVMLowerer() + plus = Var(Name("+")) + app = Application(Application(plus, Var(Name("x"))), Var(Name("y"))) + + type_env = {Name("x"): LLVMInt, Name("y"): LLVMInt} + llvm_call = lowerer.lower(app, type_env=type_env) + + assert isinstance(llvm_call, LLVMCall) + assert len(llvm_call.args) == 2 + assert llvm_call.type == LLVMInt + + +def test_lower_abstraction_full(): + lowerer = CPULLVMLowerer() + body = Application(Application(Var(Name("+")), Var(Name("x"))), Var(Name("y"))) + func = Abstraction(Name("x"), Abstraction(Name("y"), body)) + + expected_type = LLVMFunctionType([LLVMInt, LLVMInt], LLVMInt) + llvm_abs = lowerer.lower(func, expected_type=expected_type) + + assert isinstance(llvm_abs, LLVMFunction) + assert llvm_abs.arg_names == [Name("x"), Name("y")] + assert llvm_abs.arg_types == [LLVMInt, LLVMInt] + assert llvm_abs.type == expected_type + + +def test_lower_vector_get(): + lowerer = CPULLVMLowerer() + # Vector_get vec 0 + app = Application(Application(Var(Name("Vector_get")), Var(Name("vec"))), Literal(0, t_int)) + from aeon.llvm.llvm_ast import LLVMPointerType + + type_env = {Name("vec"): LLVMPointerType(LLVMInt)} + llvm_get = lowerer.lower(app, type_env=type_env) + assert llvm_get.type == LLVMInt + + +def test_lower_vector_map(): + lowerer = CPULLVMLowerer() + # Vector_map (\x -> x + 1) vec sz + kernel = Abstraction(Name("x"), Application(Application(Var(Name("+")), Var(Name("x"))), Literal(1, t_int))) + app = Application(Application(Application(Var(Name("Vector_map")), kernel), Var(Name("vec"))), Var(Name("sz"))) + from aeon.llvm.llvm_ast import LLVMPointerType + + type_env = {Name("vec"): LLVMPointerType(LLVMInt), Name("sz"): LLVMInt} + llvm_map = lowerer.lower(app, type_env=type_env) + assert isinstance(llvm_map.type, LLVMPointerType) + assert llvm_map.type.element_type == LLVMInt + + +def test_lower_math_pow(): + lowerer = CPULLVMLowerer() + # Math_pow 2 3.0 + from aeon.core.types import t_float + + app = Application(Application(Var(Name("Math_pow")), Literal(2, t_int)), Literal(3.0, t_float)) + llvm_pow = lowerer.lower(app) + from aeon.llvm.llvm_ast import LLVMDouble + + assert llvm_pow.type == LLVMDouble + # The first argument should be cast to double + from aeon.llvm.llvm_ast import LLVMCast + + assert isinstance(llvm_pow.args[0], LLVMCast) + assert llvm_pow.args[0].type == LLVMDouble diff --git a/tests/llvm_validate_test.py b/tests/llvm_validate_test.py new file mode 100644 index 00000000..a36367ea --- /dev/null +++ b/tests/llvm_validate_test.py @@ -0,0 +1,40 @@ +import pytest + +from aeon.core.terms import Literal, Application, Var, Let, Abstraction +from aeon.core.types import TypeConstructor +from aeon.llvm.core import LLVMValidationError +from aeon.llvm.cpu.lowerer import CPULLVMLowerer, CPUValidationContext +from aeon.utils.name import Name + + +def test_validate_valid_cpu(): + # let f = \x -> \y -> x + y in f + body = Application(Application(Var(Name("+")), Var(Name("x"))), Var(Name("y"))) + func = Abstraction(Name("x"), Abstraction(Name("y"), body)) + term = Let(Name("f"), func, Var(Name("f"))) + + lowerer = CPULLVMLowerer() + ctx = CPUValidationContext(allowed_func_calls={Name("f"), Name("+")}) + lowerer.validate(term, ctx) + + +def test_validate_invalid_type(): + t_unsupported = TypeConstructor(Name("UnsupportedType")) + term = Literal(1, t_unsupported) + + lowerer = CPULLVMLowerer() + with pytest.raises(LLVMValidationError): + lowerer.validate(term, CPUValidationContext()) + + +def test_validate_invalid_call_non_llvm(): + # let f = \x -> external_func x in f + body = Application(Var(Name("external_func")), Var(Name("x"))) + func = Abstraction(Name("x"), body) + term = Let(Name("f"), func, Var(Name("f"))) + + lowerer = CPULLVMLowerer() + with pytest.raises(LLVMValidationError): + # 'f' is allowed, but 'external_func' (used in body) is not + ctx = CPUValidationContext(allowed_func_calls={Name("f")}, strict=True) + lowerer.validate(term, ctx)