Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
cb84cac
add gpu-related types (Tensor + GpuConfig)
SousaTrashBin Mar 12, 2026
95ed8aa
add some gpu functions (need to still generate proper gpu implementat…
SousaTrashBin Mar 12, 2026
90fbc3e
add gpu decorator (@gpu)
SousaTrashBin Mar 12, 2026
e36e1f4
add some gpu helper functions such as Tensor_length
SousaTrashBin Mar 12, 2026
2905cc7
add some gpu tests (for now uses CPU but should work nonetheless)
SousaTrashBin Mar 12, 2026
3f66d0f
Merge branch 'master' into add_gpu_support
SousaTrashBin Mar 12, 2026
8a7530b
remove gpu tests, still need to verify if syntax will stay the same o…
SousaTrashBin Mar 13, 2026
a0330f8
add gpu subset code validation (including recursion checking)
SousaTrashBin Mar 13, 2026
23d0acb
add verification to types and ops (for now only supports builtins)
SousaTrashBin Mar 13, 2026
5a21d8b
add gpu subset ast representation as well as a way to convert between…
SousaTrashBin Mar 13, 2026
203a59f
not sure why this is needed (reminder to ask)
SousaTrashBin Mar 13, 2026
14f5673
add some gpu subset testing (still need to study a bit of llvm syntax…
SousaTrashBin Mar 13, 2026
374c975
eventually this will be the file that hosts the conversion between th…
SousaTrashBin Mar 13, 2026
5d7cf26
still need to add the generation of the kernel itself
SousaTrashBin Mar 13, 2026
9cd9713
fix, was using a raw str but should instead use the Name dataclass
SousaTrashBin Mar 13, 2026
5de3a25
forgot to setup auto ruff
SousaTrashBin Mar 13, 2026
0416512
add llvm decorator
SousaTrashBin Mar 18, 2026
49a8fcb
add llvm/gpu decorator injection (I feel like this would be cleaner/l…
SousaTrashBin Mar 18, 2026
4714a62
add tests for the llvm decorators (cpu + gpu)
SousaTrashBin Mar 18, 2026
afd9d98
fix type signature of decorator args
SousaTrashBin Mar 18, 2026
1de461b
add LLVM ast representation (not sure if the address space is importa…
SousaTrashBin Mar 18, 2026
0b08911
add some abstract classes to encapsulate the pipeline steps (validate…
SousaTrashBin Mar 18, 2026
96feffd
forgot to add the arg type for the abstraction
SousaTrashBin Mar 18, 2026
193d1a9
fix validate method signature; now receives a list of valid function …
SousaTrashBin Mar 18, 2026
ce27d67
add an initial CPULLVMLowerer validate implementation
SousaTrashBin Mar 18, 2026
69f9b2e
add some important util functions; in the future refined types should…
SousaTrashBin Mar 18, 2026
553860f
add some validation tests
SousaTrashBin Mar 18, 2026
ec90911
update, validation now is a collection of small independent steps (Ty…
SousaTrashBin Mar 18, 2026
3a5e6bc
remove unused files
SousaTrashBin Mar 18, 2026
e5f35fd
remove unused files
SousaTrashBin Mar 18, 2026
1c950c5
Merge remote-tracking branch 'origin/add_gpu_support' into add_gpu_su…
SousaTrashBin Mar 18, 2026
fa7a4eb
update some LLVM AST types and terms data, also add some str represen…
SousaTrashBin Mar 18, 2026
844ceef
update last test to represent another step of the verification pipeli…
SousaTrashBin Mar 18, 2026
5e4dd00
add another verification step (Full Application);
SousaTrashBin Mar 18, 2026
6029542
update from_type_to_llvm_type to support the new FunctionType (which …
SousaTrashBin Mar 18, 2026
66b28da
add simple llvm lower tests
SousaTrashBin Mar 18, 2026
13e3b45
fix some type signature bugs
SousaTrashBin Mar 18, 2026
b48d779
done some changes to ensure lowering works even with anf partial appl…
SousaTrashBin Mar 27, 2026
ebbfab6
fix method signature
SousaTrashBin Mar 27, 2026
7c57429
fix, builtin-op type can be either int or float
SousaTrashBin Mar 27, 2026
832f0ef
add llvmlite as a dependency (for the code gen)
SousaTrashBin Mar 27, 2026
73d1c2e
add llvm code gen implementation
SousaTrashBin Mar 27, 2026
5984180
add some generator/e2e generator tests for the llvm generation
SousaTrashBin Mar 27, 2026
0893266
comment for now, even though it's not valid llvmir code and it will n…
SousaTrashBin Mar 27, 2026
198762c
fix according to mypy
SousaTrashBin Mar 27, 2026
9c5818a
fix according to mypy
SousaTrashBin Mar 27, 2026
6c106f9
ruff linting
SousaTrashBin Mar 27, 2026
2d2a6a1
add Vector library, technically we could use List, but in this way, t…
SousaTrashBin Mar 31, 2026
6309b54
refactor vector, make Vector library parametric and revise function s…
SousaTrashBin Apr 1, 2026
36f39f6
fix small mistake (pow doesn't receive a type)
SousaTrashBin Apr 1, 2026
48d6389
add util function to sanitize a name for the llvm (doesn't support no…
SousaTrashBin Apr 1, 2026
378b712
extend `EvaluationContext` with `metadata` and `pipeline`, add LLVM-s…
SousaTrashBin Apr 1, 2026
b85fbf1
organize imports
SousaTrashBin Apr 1, 2026
92a5785
refactor `CPULLVMIRGenerator`, improve type handling, add vector oper…
SousaTrashBin Apr 1, 2026
28d55e9
refactor `CPULLVMIRGenerator`, rename helper methods for clarity, add…
SousaTrashBin Apr 1, 2026
d8e5347
add `LLVMPipeline` interface and extend `LLVMIRGenerator` with `gener…
SousaTrashBin Apr 1, 2026
ebaa926
integrate LLVM pipeline into `EvaluationContext`, extend execution fl…
SousaTrashBin Apr 1, 2026
5b9bdf3
extend `llvm ast` with vector operations, and add new LLVM term repre…
SousaTrashBin Apr 1, 2026
d662197
replace `LLVMAbstraction` with `LLVMFunction` and rename `generate_ke…
SousaTrashBin Apr 1, 2026
b478bf7
replace `LLVMAbstraction` with `LLVMFunction` and rename `generate_ke…
SousaTrashBin Apr 1, 2026
a59d041
simplify type and operation validation, improve naming clarity
SousaTrashBin Apr 1, 2026
60451d2
add `CPULLVMPipeline` implementation, integrate function compilation …
SousaTrashBin Apr 1, 2026
abf9951
add `CPULLVMExecutionEngine` implementation to handle LLVM IR executi…
SousaTrashBin Apr 1, 2026
6f08a19
add `LLVMCast` term, improve vector operation handling in IR generati…
SousaTrashBin Apr 1, 2026
69147f3
extend LLVM lowering with vector operation handling and built-in func…
SousaTrashBin Apr 1, 2026
93540a5
some fixes according to mypy and ruff
SousaTrashBin Apr 2, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions aeon/backend/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,35 @@
from aeon.core.terms import Term
from aeon.core.terms import Var
from aeon.utils.name import Name
from aeon.decorators.api import Metadata
from aeon.llvm.core import LLVMPipeline

real_eval = eval


class EvaluationContext:
variables: dict[Name, Any]

def __init__(self, prev: dict[Name, Any] | None = None):
metadata: Metadata | None
pipeline: LLVMPipeline | None

def __init__(
self,
prev: dict[Name, Any] | None = None,
metadata: Metadata | None = None,
pipeline: LLVMPipeline | None = None,
):
if prev:
self.variables = {k: v for (k, v) in prev.items()}
else:
self.variables = {}
self.metadata = metadata
self.pipeline = pipeline

def with_var(self, name: Name, value: Any):
assert isinstance(name, Name)
v = self.variables.copy()
v.update({name: value})
return EvaluationContext(v)
return EvaluationContext(v, metadata=self.metadata, pipeline=self.pipeline)

def get(self, name: Name):
return self.variables[name]
Expand Down Expand Up @@ -76,6 +87,23 @@ def eval(t: Term, ctx: EvaluationContext = EvaluationContext()) -> Any:
case Let(var_name, var_value, body):
return eval(body, ctx.with_var(var_name, eval(var_value, ctx)))
case Rec(var_name, _, var_value, body):
found_llvm = False
if ctx.pipeline and ctx.metadata:
name_str = var_name.name
for k, v in ctx.metadata.items():
k_name = k.name if isinstance(k, Name) else str(k)
if k_name == name_str and v.get("llvm"):
found_llvm = True
break

if found_llvm:
try:
v = ctx.pipeline.get_curried_function(var_name)
if v is not None:
return eval(body, ctx.with_var(var_name, v))
except Exception:
pass

if isinstance(var_value, Abstraction):
fun = var_value

Expand All @@ -87,7 +115,7 @@ def v(x):

else:
v = eval(var_value, ctx)
return eval(t.body, ctx.with_var(t.var_name, v))
return eval(body, ctx.with_var(var_name, v))
case If(cond, then, otherwise):
c = eval(cond, ctx)
if c:
Expand Down
4 changes: 3 additions & 1 deletion aeon/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,10 @@ def __hash__(self) -> int:
t_int = TypeConstructor(Name("Int", 0), [])
t_float = TypeConstructor(Name("Float", 0), [])
t_string = TypeConstructor(Name("String", 0), [])
t_tensor = TypeConstructor(Name("Tensor", 0), [])
t_gpu_config = TypeConstructor(Name("GpuConfig", 0), [])

builtin_core_types = [t_unit, t_bool, t_int, t_float, t_string]
builtin_core_types = [t_unit, t_bool, t_int, t_float, t_string, t_tensor, t_gpu_config]

top = Top()

Expand Down
9 changes: 6 additions & 3 deletions aeon/decorators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ def fun(...) { ... }
eventual complementary definitions.
"""

from aeon.decorators.api import DecoratorType
from aeon.decorators.api import Metadata
from aeon.decorators.api import Metadata, DecoratorType
from aeon.llvm.decorators.gpu import gpu
from aeon.llvm.decorators.llvm import llvm
from aeon.sugar.program import Definition
from aeon.synthesis.decorators import (
minimize_int,
Expand Down Expand Up @@ -40,6 +41,8 @@ def fun(...) { ... }
"error_fitness": error_fitness,
"disable_control_flow": disable_control_flow,
"prompt": prompt,
"llvm": llvm,
"gpu": gpu,
"csv_data": csv_data,
"csv_file": csv_file,
}
Expand All @@ -53,7 +56,7 @@ def apply_decorators(fun: Definition, metadata: Metadata) -> tuple[Definition, l
for decorator in fun.decorators:
dname = decorator.name.name
if dname not in decorators_environment:
raise Exception(f"Unknown decorator named {dname}, in function {fun.name}.")
raise Exception(f"Unknown decorator named {dname}, in function {fun.name.pretty()}.")
decorator_processor = decorators_environment[dname]
(fun, extra, metadata) = decorator_processor(decorator.macro_args, fun, metadata)
total_extra.extend(extra)
Expand Down
18 changes: 15 additions & 3 deletions aeon/facade/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
from aeon.utils.name import Name
from aeon.utils.pprint import pretty_print_node
from aeon.utils.time_utils import RecordTime
from aeon.llvm.cpu.pipeline import CPULLVMPipeline
from aeon.llvm.cpu.executor import CPULLVMExecutionEngine
from aeon.llvm.cpu.converter import CPULLVMIRGenerator
from aeon.llvm.cpu.lowerer import CPULLVMLowerer


def read_file(filename: str) -> str:
Expand Down Expand Up @@ -107,17 +111,25 @@ def parse(self, filename: str = None, aeon_code: str = None) -> Iterable[AeonErr
return type_errors

with RecordTime("Preparing execution env"):
evaluation_ctx = EvaluationContext(evaluation_vars)
executor = CPULLVMExecutionEngine()
generator = CPULLVMIRGenerator()
lowerer = CPULLVMLowerer()
pipeline = CPULLVMPipeline(executor, generator, lowerer, metadata=metadata)
evaluation_ctx = EvaluationContext(evaluation_vars, metadata=metadata, pipeline=pipeline)

self.metadata = metadata
self.core = core_ast_anf
self.typing_ctx = typing_ctx
self.evaluation_ctx = evaluation_ctx

with RecordTime("LLVM compilation"):
pipeline.compile(self.core)

return []

def run(self) -> None:
def run(self) -> Any:
with RecordTime("Evaluation"):
eval(self.core, self.evaluation_ctx)
return eval(self.core, self.evaluation_ctx)

def has_synth(self) -> bool:
with RecordTime("DetectSynthesis"):
Expand Down
85 changes: 85 additions & 0 deletions aeon/llvm/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from __future__ import annotations

from dataclasses import dataclass
from abc import ABC, abstractmethod
from typing import Dict, Any, List
from aeon.core.terms import Term
from aeon.utils.name import Name
from aeon.llvm.llvm_ast import LLVMTerm, LLVMType


class LLVMBackendError(Exception):
pass


class LLVMValidationError(LLVMBackendError):
pass


@dataclass(frozen=True)
class ValidationContext:
pass


class ValidationStep(ABC):
@abstractmethod
def validate(self, t: Term, ctx: ValidationContext) -> None:
pass


class LLVMLowerer(ABC):
def validate(self, t: Term, ctx: ValidationContext) -> None:
for step in self.get_validation_steps():
step.validate(t, ctx)

@abstractmethod
def get_validation_steps(self) -> List[ValidationStep]:
pass

@abstractmethod
def lower(
self,
t: Term,
expected_type: LLVMType = None,
type_env: Dict[Name, LLVMType] = None,
env: Dict[Name, LLVMTerm] = None,
) -> LLVMTerm:
pass

@abstractmethod
def get_signature(self, ty: LLVMType) -> tuple[List[LLVMType], LLVMType]:
pass


class LLVMIRGenerator(ABC):
@abstractmethod
def generate_ir(self, definitions: List[LLVMTerm]) -> str:
pass


class LLVMOptimizer(ABC):
@abstractmethod
def optimize(self, llvm_ir: str, opt_level: int = 3) -> str:
pass


class LLVMExecutionEngine(ABC):
@abstractmethod
def execute(
self, llvm_ir: str, func_name: str, args: List[Any], arg_types: List[LLVMType], ret_type: LLVMType
) -> Any:
pass


class LLVMPipeline(ABC):
@abstractmethod
def compile(self, program: Term) -> None:
pass

@abstractmethod
def get_curried_function(self, name: Name) -> Any:
pass

@abstractmethod
def invoke(self, name_id: Name, arguments: List[Any]) -> Any:
pass
Loading
Loading