diff --git a/aeon/__main__.py b/aeon/__main__.py index a704fd4a..2d34342a 100644 --- a/aeon/__main__.py +++ b/aeon/__main__.py @@ -7,8 +7,7 @@ from aeon.backend.evaluator import EvaluationContext from aeon.backend.evaluator import eval from aeon.core.types import top -from aeon.decorators import Metadata -from aeon.frontend.anf_converter import ensure_anf +from aeon.decorators.api import Metadata from aeon.frontend.parser import parse_term from aeon.logger.logger import export_log from aeon.logger.logger import setup_logger @@ -83,7 +82,7 @@ def read_file(filename: str) -> str: return file.read() -def log_type_errors(errors: list[Exception | str]): +def log_type_errors(errors: list[Exception]): print("TYPECHECKER", "-------------------------------") print("TYPECHECKER", "+ Type Checking Error +") for error in errors: @@ -92,7 +91,7 @@ def log_type_errors(errors: list[Exception | str]): print("TYPECHECKER", "-------------------------------") -def main(): +def main() -> None: args = parse_arguments() logger = setup_logger() export_log(args.log, args.logfile, args.filename) @@ -123,20 +122,16 @@ def main(): logger.debug(core_ast) - with RecordTime("ANF conversion"): - core_ast_anf = ensure_anf(core_ast) - logger.debug(core_ast) - with RecordTime("TypeChecking"): - type_errors = check_type_errors(typing_ctx, core_ast_anf, top) + type_errors = check_type_errors(typing_ctx, core_ast, top) if type_errors: log_type_errors(type_errors) sys.exit(1) with RecordTime("DetectSynthesis"): incomplete_functions: list[tuple[ - str, list[str]]] = incomplete_functions_and_holes( - typing_ctx, core_ast_anf) + str, + list[str]]] = incomplete_functions_and_holes(typing_ctx, core_ast) if incomplete_functions: filename = args.filename if args.csv_synth else None @@ -149,7 +144,7 @@ def main(): synthesis_result = synthesize( typing_ctx, evaluation_ctx, - core_ast_anf, + core_ast, incomplete_functions, metadata, filename, diff --git a/aeon/core/instantiation.py b/aeon/core/instantiation.py index e32314ff..df279ec1 100644 --- a/aeon/core/instantiation.py +++ b/aeon/core/instantiation.py @@ -3,7 +3,7 @@ from aeon.core.liquid import LiquidVar from aeon.core.liquid_ops import mk_liquid_and from aeon.core.substitutions import substitution_in_liquid -from aeon.core.types import AbstractionType +from aeon.core.types import AbstractionType, ExistentialType from aeon.core.types import BaseType from aeon.core.types import RefinedType from aeon.core.types import Type @@ -42,6 +42,15 @@ def rec(x): target.kind, type_substitution(target.body, alpha, beta), ) + elif isinstance(t, ExistentialType): + new_name = t.var_name + new_type = t.type + while new_name == alpha: + old_name = new_name + new_name = new_name + "_fresh" + new_type = type_substitution(new_type, old_name, TypeVar(new_name)) + + return ExistentialType(new_name, t.var_type, new_type) else: assert False diff --git a/aeon/core/substitutions.py b/aeon/core/substitutions.py index e259dc36..94e1a122 100644 --- a/aeon/core/substitutions.py +++ b/aeon/core/substitutions.py @@ -18,7 +18,7 @@ from aeon.core.terms import Rec from aeon.core.terms import Term from aeon.core.terms import Var -from aeon.core.types import AbstractionType +from aeon.core.types import AbstractionType, ExistentialType from aeon.core.types import BaseType from aeon.core.types import Bottom from aeon.core.types import RefinedType @@ -89,12 +89,10 @@ def rec(x: Term): assert False -def substitution_in_liquid(t: LiquidTerm, rep: LiquidTerm, - name: str) -> LiquidTerm: +def substitution_in_liquid(t: LiquidTerm, rep: LiquidTerm, name: str) -> LiquidTerm: """substitutes name in the term t with the new replacement term rep.""" assert isinstance(rep, LiquidTerm) - if isinstance(t, (LiquidLiteralInt, LiquidLiteralBool, LiquidLiteralString, - LiquidLiteralFloat)): + if isinstance(t, (LiquidLiteralInt, LiquidLiteralBool, LiquidLiteralString, LiquidLiteralFloat)): return t elif isinstance(t, LiquidVar): if t.name == name: @@ -102,16 +100,14 @@ def substitution_in_liquid(t: LiquidTerm, rep: LiquidTerm, else: return t elif isinstance(t, LiquidApp): - return LiquidApp( - t.fun, [substitution_in_liquid(a, rep, name) for a in t.args]) + return LiquidApp(t.fun, [substitution_in_liquid(a, rep, name) for a in t.args]) elif isinstance(t, LiquidHole): if t.name == name: return rep else: return LiquidHole( t.name, - [(substitution_in_liquid(a, rep, name), t) - for (a, t) in t.argtypes], + [(substitution_in_liquid(a, rep, name), t) for (a, t) in t.argtypes], ) else: print(t, type(t)) @@ -128,45 +124,62 @@ def rec(t: Type) -> Type: renamed: Type - if isinstance(t, Top): - return t - elif isinstance(t, Bottom): - return t - elif isinstance(t, BaseType): - return t - elif isinstance(t, TypeVar): - return t - elif isinstance(t, AbstractionType): - if isinstance(rep, Var) and rep.name == t.var_name: - nname = t.var_name + "1" - renamed = AbstractionType( - nname, - t.var_type, - substitution_in_type(t.type, Var(nname), t.var_name), - ) - return substitution_in_type(renamed, rep, name) - elif name == t.var_name: + match t: + case Top(): return t - else: - return AbstractionType(t.var_name, rec(t.var_type), rec(t.type)) - elif isinstance(t, RefinedType): - if isinstance(rep, Var) and rep.name == t.name: - nname = t.name + "1" - renamed = RefinedType( - nname, - t.type, - substitution_in_liquid(t.refinement, LiquidVar(nname), t.name), - ) - return substitution_in_type(renamed, rep, name) - elif t.name == name: + case Bottom(): return t - else: - return RefinedType( - t.name, - t.type, - substitution_in_liquid(t.refinement, replacement, name), - ) - assert False + case BaseType(name=_): + return t + case TypeVar(name=_): + return t + case AbstractionType(var_name=var_name, var_type=var_type, type=ity): + if isinstance(rep, Var) and rep.name == var_name: + nname = var_name + "1" + renamed = AbstractionType( + nname, + var_type, + substitution_in_type(ity, Var(nname), var_name), + ) + return substitution_in_type(renamed, rep, name) + elif name == var_name: + return t + else: + return AbstractionType(var_name, rec(var_type), rec(ity)) + case RefinedType(name=ref_name, type=type, refinement=refinement): + # alpha renaming to avoid clashes + if isinstance(rep, Var) and rep.name == ref_name: + nname = ref_name + "1" + renamed = RefinedType( + nname, + type, + substitution_in_liquid(refinement, LiquidVar(nname), ref_name), + ) + return substitution_in_type(renamed, rep, name) + elif name == ref_name: + return t + else: + return RefinedType( + ref_name, + type, + substitution_in_liquid(refinement, replacement, name), + ) + case ExistentialType(var_name=var_name, var_type=var_type, type=ity): + # alpha renaming to avoid clashes + if isinstance(rep, Var) and rep.name == var_name: + nname = name + "1" + renamed = ExistentialType( + nname, + var_type, + substitution_in_type(ity, Var(nname), var_name), + ) + return substitution_in_type(renamed, rep, name) + if name == t.var_name: + return t + else: + return ExistentialType(var_name, var_type, substitution_in_type(ity, rep, name)) + case _: + assert False def substitution(t: Term, rep: Term, name: str) -> Term: @@ -223,16 +236,15 @@ def liquefy_app(app: Application) -> LiquidApp | None: elif isinstance(app.fun, Application): liquid_pseudo_fun = liquefy_app(app.fun) if liquid_pseudo_fun: - return LiquidApp(liquid_pseudo_fun.fun, - liquid_pseudo_fun.args + [arg]) + return LiquidApp(liquid_pseudo_fun.fun, liquid_pseudo_fun.args + [arg]) return None elif isinstance(app.fun, Let): return liquefy_app( Application( - substitution(app.fun.body, app.fun.var_value, - app.fun.var_name), + substitution(app.fun.body, app.fun.var_value, app.fun.var_name), app.arg, - ), ) + ), + ) assert False diff --git a/aeon/core/types.py b/aeon/core/types.py index 19ff2a39..34896dac 100644 --- a/aeon/core/types.py +++ b/aeon/core/types.py @@ -129,16 +129,12 @@ def __hash__(self) -> int: return hash(self.var_name) + hash(self.var_type) + hash(self.type) +@dataclass class RefinedType(Type): name: str - type: BaseType | TypeVar + type: BaseType | TypeVar | Bottom | Top refinement: LiquidTerm - def __init__(self, name: str, ty: BaseType | TypeVar, refinement: LiquidTerm): - self.name = name - self.type = ty - self.refinement = refinement - def __repr__(self): return f"{{ {self.name}:{self.type} | {self.refinement} }}" @@ -154,6 +150,16 @@ def __hash__(self) -> int: return hash(self.name) + hash(self.type) + hash(self.refinement) +@dataclass +class ExistentialType(Type): + var_name: str + var_type: Type + type: Type + + def __str__(self) -> str: + return f"∃{self.var_name}:{self.var_type}, {self.type}" + + @dataclass class TypePolymorphism(Type): name: str # alpha @@ -163,7 +169,8 @@ class TypePolymorphism(Type): def extract_parts( t: Type, -) -> tuple[str, BaseType | TypeVar, LiquidTerm]: +) -> tuple[str, BaseType | TypeVar | Top | Bottom, LiquidTerm]: + print(t) assert isinstance(t, BaseType) or isinstance(t, RefinedType) or isinstance(t, TypeVar) if isinstance(t, RefinedType): return (t.name, t.type, t.refinement) diff --git a/aeon/frontend/anf_converter.py b/aeon/frontend/anf_converter.py deleted file mode 100644 index ecab854f..00000000 --- a/aeon/frontend/anf_converter.py +++ /dev/null @@ -1,109 +0,0 @@ -from aeon.core.terms import ( - Abstraction, - Annotation, - Application, - If, - Let, - Literal, - Rec, - Term, - TypeAbstraction, - TypeApplication, - Var, -) - - -class ANFConverter: - """Recursive visitor that applies ANF transformation.""" - - def __init__(self, starting_counter: int = 0): - self.counter = starting_counter - - def fresh(self): - self.counter += 1 - return f"_anf_{self.counter}" - - def convert(self, t: Term): - """Converts term to ANF form.""" - - match t: - case If(cond=cond, then=then, otherwise=otherwise): - cond = self.convert(cond) - then = self.convert(then) - otherwise = self.convert(otherwise) - if isinstance(cond, Var) or isinstance(cond, Literal): - return If(cond, then, otherwise) - else: - v = self.fresh() - return self.convert(Let(v, cond, If(Var(v), then, otherwise))) - case Application(fun=fun, arg=arg): - fun = self.convert(fun) - - if isinstance(fun, Var) or isinstance(fun, Literal): - pass - elif isinstance(fun, Let): - return Let( - fun.var_name, - fun.var_value, - self.convert(Application(fun.body, arg)), - ) - else: - v = self.fresh() - return self.convert(Let(v, fun, Application(Var(v), arg))) - - arg = self.convert(arg) - if isinstance(arg, Var) or isinstance(arg, Literal): - return Application(fun, arg) - else: - v = self.fresh() - return self.convert(Let(v, arg, Application(fun, Var(v)))) - - case Let(var_name=name, var_value=value, body=body): - value = self.convert(value) - body = self.convert(body) - match value: - case Let(var_name=vname, var_value=vvalue, body=vbody): - assert name != vname - vvalue = self.convert(vvalue) - vbody = self.convert(vbody) - return Let( - vname, - vvalue, - self.convert(Let(name, vbody, body)), - ) - case Rec(var_name=vname, var_type=vtype, var_value=vvalue, body=vbody): - assert name != vname - vvalue = self.convert(vvalue) - vbody = self.convert(vbody) - return Rec( - vname, - vtype, - vvalue, - self.convert(Let(name, vbody, body)), - ) - case _: - return Let(name, value, body) - case Rec(var_name=name, var_type=type, var_value=value, body=body): - value = self.convert(value) - body = self.convert(body) - return Rec(name, type, value, body) - case Abstraction(var_name=name, body=body): - body = self.convert(body) - return Abstraction(var_name=name, body=body) - case Annotation(expr=expr, type=ty): - expr = self.convert(expr) - return Annotation(expr=expr, type=ty) - case TypeAbstraction(name=name, kind=kind, body=body): - body = self.convert(body) - return TypeAbstraction(name, kind, body) - case TypeApplication(body=body, type=type): - body = self.convert(body) - return TypeApplication(body, type) - case _: - return t - - -def ensure_anf(t: Term, starting_counter: int = 0) -> Term: - """Converts a term to ANF form.""" - - return ANFConverter(starting_counter=starting_counter).convert(t) diff --git a/aeon/synthesis_grammar/synthesizer.py b/aeon/synthesis_grammar/synthesizer.py index 8000a0bc..11186db9 100644 --- a/aeon/synthesis_grammar/synthesizer.py +++ b/aeon/synthesis_grammar/synthesizer.py @@ -47,7 +47,6 @@ from aeon.core.types import Type from aeon.core.types import top from aeon.decorators import Metadata -from aeon.frontend.anf_converter import ensure_anf from aeon.sugar.program import Definition from aeon.synthesis_grammar.grammar import ( gen_grammar_nodes, @@ -195,7 +194,6 @@ def evaluate_individual(individual: classType, start = time.time() first_hole_name = holes[0] individual_term = individual.get_core() # type: ignore - individual_term = ensure_anf(individual_term, 10000000) individual_type_check(ctx, program, first_hole_name, individual_term) results = [ diff --git a/aeon/typechecking/entailment.py b/aeon/typechecking/entailment.py index 4ae132cf..2fb05dfe 100644 --- a/aeon/typechecking/entailment.py +++ b/aeon/typechecking/entailment.py @@ -3,50 +3,50 @@ from aeon.core.liquid import LiquidVar from aeon.core.substitutions import substitution_in_liquid -from aeon.core.types import AbstractionType +from aeon.core.types import AbstractionType, ExistentialType from aeon.core.types import BaseType from aeon.core.types import extract_parts -from aeon.core.types import Type from aeon.core.types import TypePolymorphism from aeon.typechecking.context import EmptyContext from aeon.typechecking.context import TypeBinder from aeon.typechecking.context import TypingContext from aeon.typechecking.context import UninterpretedBinder from aeon.typechecking.context import VariableBinder -from aeon.verification.helpers import show_constraint from aeon.verification.horn import solve from aeon.verification.vcs import Constraint from aeon.verification.vcs import Implication from aeon.verification.vcs import UninterpretedFunctionDeclaration -# from aeon.verification.smt import smt_valid - -def entailment(ctx: TypingContext, c: Constraint): - if isinstance(ctx, EmptyContext): - r = solve(c) - if not r: - show_constraint(c) # DEMO1 - # print(c) - return r - elif isinstance(ctx, VariableBinder): - if isinstance(ctx.type, AbstractionType): +def entailment(ctx: TypingContext, c: Constraint) -> bool: + match ctx: + case EmptyContext(): + r = solve(c) + return r + case VariableBinder(prev=prev, name=name, type=ty): + match ty: + case AbstractionType(var_name=_, var_type=_, type=_): + # Functions are not passed into SMT + return entailment(prev, c) + case TypePolymorphism(name=_, kind=_, body=_): + # TODO: TypePolymorphism is not passed to SMT. + # TODO: Consider using a custom Sort. + return entailment(prev, c) + case ExistentialType(var_name=vname, var_type=vtype, type=ity): + return entailment(VariableBinder(VariableBinder(prev, name, ity), vname, vtype), c) + case _: + (name, base, cond) = extract_parts(ty) + assert isinstance(base, BaseType) + ncond = substitution_in_liquid(cond, LiquidVar(ctx.name), name) + return entailment(ctx.prev, Implication(ctx.name, base, ncond, c)) + case TypeBinder(type_name=_, type_kind=_): + # TODO: Handle TypeBinder in entailment. + # TODO: Solution is to create a custom sort. return entailment(ctx.prev, c) - if isinstance(ctx.type, TypePolymorphism): - return entailment(ctx.prev, c) # TODO: check that this is not relevant - else: - ty: Type = ctx.type - (name, base, cond) = extract_parts(ty) - assert isinstance(base, BaseType) - ncond = substitution_in_liquid(cond, LiquidVar(ctx.name), name) - return entailment(ctx.prev, Implication(ctx.name, base, ncond, c)) - elif isinstance(ctx, TypeBinder): - print("TODO: Handle TypeBinder in entailment. The current solution is to ignore.") - return entailment(ctx.prev, c) # TODO - elif isinstance(ctx, UninterpretedBinder): - return entailment( - ctx.prev, - UninterpretedFunctionDeclaration(ctx.name, ctx.type, c), - ) - else: - assert False + case UninterpretedBinder(prev=prev, name=name, type=ty): + return entailment( + ctx.prev, + UninterpretedFunctionDeclaration(ctx.name, ctx.type, c), + ) + case _: + assert False diff --git a/aeon/typechecking/typeinfer.py b/aeon/typechecking/typeinfer.py index 91accf01..42efd831 100644 --- a/aeon/typechecking/typeinfer.py +++ b/aeon/typechecking/typeinfer.py @@ -1,9 +1,10 @@ from __future__ import annotations +from typing import Tuple from loguru import logger from aeon.core.instantiation import type_substitution -from aeon.core.liquid import LiquidApp, LiquidHole +from aeon.core.liquid import LiquidApp, LiquidHole, LiquidTerm from aeon.core.liquid import LiquidLiteralBool from aeon.core.liquid import LiquidLiteralFloat from aeon.core.liquid import LiquidLiteralInt @@ -24,7 +25,7 @@ from aeon.core.terms import TypeAbstraction from aeon.core.terms import TypeApplication from aeon.core.terms import Var -from aeon.core.types import AbstractionType +from aeon.core.types import AbstractionType, Bottom, ExistentialType, Top from aeon.core.types import BaseKind from aeon.core.types import BaseType from aeon.core.types import RefinedType @@ -50,7 +51,6 @@ from aeon.typechecking.entailment import entailment from aeon.verification.helpers import simplify_constraint from aeon.verification.horn import fresh -from aeon.verification.sub import ensure_refined from aeon.verification.sub import implication_constraint from aeon.verification.sub import sub from aeon.verification.vcs import Conjunction @@ -76,6 +76,42 @@ def __str__(self): return f"Constraint violated when checking if {self.t} : {self.ty}: \n {self.ks}" +def eq_ref(var_name: str, type_name: str) -> LiquidTerm: + return LiquidApp( + "==", + [ + LiquidVar(var_name), + LiquidVar(type_name), + ], + ) + + +def and_ref(cond1: LiquidTerm, cond2: LiquidTerm) -> LiquidTerm: + return LiquidApp("&&", [cond1, cond2]) + + +def refine_type(ctx: TypingContext, ty: Type, vname: str): + """The refine function is the selfication with support for existentials""" + match ty: + case BaseType(name=_) | Top() | Bottom(): + name = ctx.fresh_var() + return RefinedType(name, ty, eq_ref(name, vname)) + case RefinedType(name=name, type=ty, refinement=cond): + if name != vname: + return RefinedType(name, ty, and_ref(cond, eq_ref(name, vname))) + else: + return ty + case ExistentialType(var_name=var_name, var_type=var_type, type=ity): + return ExistentialType(var_name, var_type, refine_type(ctx, ity, vname)) + case AbstractionType(var_name=var_name, var_type=var_type, type=_): + return ty + case TypePolymorphism(name=name, kind=kind, body=body): + return TypePolymorphism(name, kind, refine_type(ctx, body, vname)) + case other_ty: + print(f"Failed to handle refine of {other_ty}") + assert False + + def argument_is_typevar(ty: Type): return ( isinstance(ty, TypeVar) @@ -87,6 +123,15 @@ def argument_is_typevar(ty: Type): ) +def extract_existential_binders(ty: Type) -> Tuple[Type, list[Tuple[str, Type]]]: + match ty: + case ExistentialType(var_name=var_name, var_type=var_type, type=type): + it, binders = extract_existential_binders(type) + return it, [(var_name, var_type)] + binders + case _: + return ty, [] + + def prim_litbool(t: bool) -> RefinedType: if t: return RefinedType("v", t_bool, LiquidVar("v")) @@ -219,73 +264,71 @@ def synth(ctx: TypingContext, t: Term) -> tuple[Constraint, Type]: elif isinstance(t, Var): if t.name in ops: return (ctrue, prim_op(t.name)) - ty = ctx.type_of(t.name) - if isinstance(ty, BaseType) or isinstance(ty, RefinedType): - ty = ensure_refined(ty) - # assert ty.name != t.name - if ty.name == t.name: - ty = renamed_refined_type(ty) - # Self - ty = RefinedType( - ty.name, - ty.type, - LiquidApp( - "&&", - [ - ty.refinement, - LiquidApp( - "==", - [ - LiquidVar(ty.name), - LiquidVar(t.name), - ], - ), - ], - ), - ) - if not ty: - raise CouldNotGenerateConstraintException( - f"Variable {t.name} not in context", - ) - return (ctrue, ty) + match ctx.type_of(t.name): + case None: + raise CouldNotGenerateConstraintException( + f"Variable {t.name} not in context", + ) + case ty: + return (ctrue, refine_type(ctx, ty, t.name)) + elif isinstance(t, Application): - (c, ty) = synth(ctx, t.fun) - if isinstance(ty, AbstractionType): - # This is the solution to handle polymorphic "==" in refinements. - if argument_is_typevar(ty.var_type): - (_, b, _) = extract_parts(ty.var_type) - assert isinstance(b, TypeVar) - (cp, at) = synth(ctx, t.arg) - if isinstance(at, RefinedType): - at = at.type # This is a hack before inference - return_type = substitute_vartype(ty.type, at, b.name) - else: - cp = check(ctx, t.arg, ty.var_type) - return_type = ty.type - t_subs = substitution_in_type(return_type, t.arg, ty.var_name) - c0 = Conjunction(c, cp) - # vs: list[str] = list(variables_free_in(c0)) - return (c0, t_subs) + (c1, ty1) = synth(ctx, t.fun) + (c2, ty2) = synth(ctx, t.arg) + + abstraction_type, binders1 = extract_existential_binders(ty1) + argument_type, binders2 = extract_existential_binders(ty2) + # TODO: assert that binders are non-overlapping + + match abstraction_type: + case AbstractionType(var_name=parameter_name, var_type=parameter_type, type=return_type): + pass + case _: + raise CouldNotGenerateConstraintException( + f"Application {t} is not a function.", + ) + c3: Constraint + if argument_is_typevar(parameter_type): + (_, b, _) = extract_parts(parameter_type) + assert isinstance(b, TypeVar) + parameter_type = substitute_vartype(parameter_type, ty2, b.name) + # This is an hack to handle ad-hoc polymorphism, so == works + if isinstance(ty2, RefinedType): + ty2 = ty2.type # This is a hack before inference + return_type = substitute_vartype(return_type, ty2, b.name) + c3 = ctrue else: - raise CouldNotGenerateConstraintException( - f"Application {t} is not a function.", - ) + c3 = sub(ty2, parameter_type) + new_name = ctx.fresh_var() + + return_type = substitution_in_type(return_type, Var(new_name), parameter_name) + nt = ExistentialType(var_name=new_name, var_type=argument_type, type=return_type) + + conj: Constraint = Conjunction(Conjunction(c1, c2), c3) + for aname, aty in binders1 + binders2: + nt = ExistentialType(aname, aty, nt) + conj = implication_constraint(aname, aty, conj) + return conj, nt + elif isinstance(t, Let): (c1, t1) = synth(ctx, t.var_value) nctx: TypingContext = ctx.with_var(t.var_name, t1) (c2, t2) = synth(nctx, t.body) term_vars = type_free_term_vars(t1) assert t.var_name not in term_vars - r = (Conjunction(c1, implication_constraint(t.var_name, t1, c2)), t2) + r = ( + Conjunction(c1, implication_constraint(t.var_name, t1, c2)), + ExistentialType(var_name=t.var_name, var_type=t1, type=t2), + ) return r elif isinstance(t, Rec): nrctx: TypingContext = ctx.with_var(t.var_name, t.var_type) c1 = check(nrctx, t.var_value, t.var_type) (c2, t2) = synth(nrctx, t.body) - c1 = implication_constraint(t.var_name, t.var_type, c1) c2 = implication_constraint(t.var_name, t.var_type, c2) - return Conjunction(c1, c2), t2 + + return Conjunction(c1, c2), ExistentialType(var_name=t.var_name, var_type=t.var_type, type=t2) elif isinstance(t, Annotation): ty = fresh(ctx, t.type) c = check(ctx, t.expr, ty) @@ -342,7 +385,6 @@ def check_(ctx: TypingContext, t: Term, ty: Type) -> Constraint: # patterm matching term -@wrap_checks # DEMO1 def check(ctx: TypingContext, t: Term, ty: Type) -> Constraint: if isinstance(t, Abstraction) and isinstance( ty, @@ -367,25 +409,35 @@ def check(ctx: TypingContext, t: Term, ty: Type) -> Constraint: c2 = implication_constraint(t.var_name, t1, c2) return Conjunction(c1, c2) elif isinstance(t, If): - y = ctx.fresh_var() + # TODO: ANF to Existentials broke here on liquefy. This should replace applications, and it's just translating the application! liq_cond = liquefy(t.cond) assert liq_cond is not None if not check_type(ctx, t.cond, t_bool): raise CouldNotGenerateConstraintException( "If condition not boolean", ) + + cond_name = ctx.fresh_var() + cond = LiquidVar(cond_name) + + y = ctx.fresh_var() + c0 = check(ctx, t.cond, t_bool) c1 = implication_constraint( y, - RefinedType("branch_", t_int, liq_cond), + RefinedType("branch_", t_int, cond), check(ctx, t.then, ty), ) c2 = implication_constraint( y, - RefinedType("branch_", t_int, LiquidApp("!", [liq_cond])), + RefinedType("branch_", t_int, LiquidApp("!", [cond])), check(ctx, t.otherwise, ty), ) - return Conjunction(c0, Conjunction(c1, c2)) + + constraint = Conjunction(c0, Conjunction(c1, c2)) + eq = LiquidApp("==", [LiquidVar(cond_name), liq_cond]) + return implication_constraint(cond_name, RefinedType(cond_name, t_bool, eq), constraint) + elif isinstance(t, TypeAbstraction) and isinstance(ty, TypePolymorphism): ty_right = type_substitution(ty, ty.name, TypeVar(t.name)) assert isinstance(ty_right, TypePolymorphism) @@ -403,30 +455,38 @@ def check_type(ctx: TypingContext, t: Term, ty: Type) -> bool: try: constraint = check(ctx, t, ty) return entailment(ctx, constraint) - except CouldNotGenerateConstraintException: + except CouldNotGenerateConstraintException as e: + logger.info(f"Could not generate constraint: f{e}") return False - except FailedConstraintException: + except FailedConstraintException as e: + logger.info(f"Could not prove constraint: f{e}") return False +class CouldNotProveTypingRelation(Exception): + def __init__(self, context: TypingContext, term: Term, type: Type): + self.context = context + self.term = term + self.type = type + + def __str__(self): + return f"Could not prove typing relation (Context: {self.context}) (Term: {self.term}) (Type: {self.type})." + + def check_type_errors( ctx: TypingContext, t: Term, ty: Type, -) -> list[Exception | str]: +) -> list[Exception]: """Checks whether t as type ty in ctx, but returns a list of errors.""" try: constraint = check(ctx, t, ty) + print(f"Constraint: {constraint}") r = entailment(ctx, constraint) if r: return [] else: - return [ - "Could not prove typing relation.", - f"Context: {ctx}", - f"Term: {t}", - f"Type: {ty}", - ] + return [CouldNotProveTypingRelation(ctx, t, ty)] except CouldNotGenerateConstraintException as e: return [e] except FailedConstraintException as e: @@ -434,6 +494,7 @@ def check_type_errors( def is_subtype(ctx: TypingContext, subt: Type, supt: Type): + assert not isinstance(supt, ExistentialType) if args_size_of_type(subt) != args_size_of_type(supt): return False if subt == supt: diff --git a/aeon/typechecking/well_formed.py b/aeon/typechecking/well_formed.py index da4f210a..ce3fd176 100644 --- a/aeon/typechecking/well_formed.py +++ b/aeon/typechecking/well_formed.py @@ -3,7 +3,7 @@ from aeon.core.liquid import LiquidLiteralBool from aeon.core.liquid import LiquidVar from aeon.core.substitutions import substitution_in_liquid -from aeon.core.types import AbstractionType +from aeon.core.types import AbstractionType, ExistentialType from aeon.core.types import BaseType from aeon.core.types import extract_parts from aeon.core.types import Kind @@ -45,7 +45,8 @@ def wellformed(ctx: TypingContext, t: Type, k: Kind = StarKind()) -> bool: wf_all = ( isinstance(t, TypePolymorphism) and k == StarKind() and wellformed(ctx.with_typevar(t.name, t.kind), t.body) ) - return wf_norefinement or wf_var or wf_base or wf_fun or wf_all + wf_existential = isinstance(t, ExistentialType) and wellformed(ctx.with_var(t.var_name, t.var_type), t.type, k) + return wf_norefinement or wf_var or wf_base or wf_fun or wf_all or wf_existential def inhabited(ctx: TypingContext, ty: Type) -> bool: diff --git a/aeon/verification/smt.py b/aeon/verification/smt.py index ca4e6fe7..d0806bf5 100644 --- a/aeon/verification/smt.py +++ b/aeon/verification/smt.py @@ -35,7 +35,7 @@ from aeon.core.liquid import LiquidTerm from aeon.core.liquid import LiquidVar from aeon.core.liquid_ops import mk_liquid_and -from aeon.core.types import AbstractionType +from aeon.core.types import AbstractionType, Bottom, ExistentialType, Top, TypeVar from aeon.core.types import BaseType from aeon.core.types import Type from aeon.core.types import t_bool @@ -74,7 +74,7 @@ @dataclass class CanonicConstraint: - binders: list[tuple[str, BaseType | AbstractionType]] + binders: list[tuple[str, BaseType | AbstractionType | TypeVar | Bottom | Top | ExistentialType]] pre: LiquidTerm pos: LiquidTerm @@ -94,9 +94,7 @@ def flatten(c: Constraint) -> Generator[CanonicConstraint, None, None]: pos=sub.pos, ) elif isinstance(c, LiquidConstraint): - yield CanonicConstraint(binders=[], - pre=LiquidLiteralBool(True), - pos=c.expr) + yield CanonicConstraint(binders=[], pre=LiquidLiteralBool(True), pos=c.expr) elif isinstance(c, UninterpretedFunctionDeclaration): for sub in flatten(c.seq): yield CanonicConstraint( @@ -109,16 +107,14 @@ def flatten(c: Constraint) -> Generator[CanonicConstraint, None, None]: s = Solver() -(s.set(timeout=200), ) +(s.set(timeout=200),) -def smt_valid(constraint: Constraint, - foralls: list[tuple[str, Any]] = []) -> bool: +def smt_valid(constraint: Constraint, foralls: list[tuple[str, Any]] = []) -> bool: """Verifies if a constraint is true using Z3.""" cons: list[CanonicConstraint] = list(flatten(constraint)) - forall_vars = [(f[0], make_variable(f[0], f[1])) for f in foralls - if isinstance(f[1], BaseType)] + forall_vars = [(f[0], make_variable(f[0], f[1])) for f in foralls if isinstance(f[1], BaseType)] for c in cons: s.push() smt_c = translate(c, extra=forall_vars) @@ -232,7 +228,8 @@ def translate( extra=list[tuple[str, Any]], ) -> BoolRef | bool: variables = [ - (name, make_variable(name, base)) for (name, base) in c.binders[::-1] + (name, make_variable(name, base)) + for (name, base) in c.binders[::-1] if isinstance(base, BaseType) or isinstance(base, AbstractionType) ] + extra e1 = translate_liq(c.pre, variables) diff --git a/aeon/verification/sub.py b/aeon/verification/sub.py index 796ab8b9..fe258129 100644 --- a/aeon/verification/sub.py +++ b/aeon/verification/sub.py @@ -1,13 +1,12 @@ from __future__ import annotations -from loguru import logger from aeon.core.liquid import LiquidLiteralBool from aeon.core.liquid import LiquidVar from aeon.core.substitutions import substitution_in_liquid from aeon.core.substitutions import substitution_in_type from aeon.core.terms import Var -from aeon.core.types import AbstractionType, TypeVar +from aeon.core.types import AbstractionType, ExistentialType, TypeVar from aeon.core.types import BaseType from aeon.core.types import Bottom from aeon.core.types import RefinedType @@ -25,33 +24,44 @@ def ensure_refined(t: Type) -> RefinedType: if isinstance(t, RefinedType): return t - elif isinstance(t, BaseType): + elif isinstance(t, BaseType) or isinstance(t, Top) or isinstance(t, Bottom) or isinstance(t, TypeVar): return RefinedType(f"singleton_{t}", t, LiquidLiteralBool(True)) assert False def implication_constraint(name: str, t: Type, c: Constraint) -> Constraint: - if isinstance(t, RefinedType): - ref_subs = substitution_in_liquid(t.refinement, LiquidVar(name), t.name) - # print(t.type, BaseType) - assert isinstance(t.type, BaseType) - return Implication(name, t.type, ref_subs, c) + match t: + case Bottom(): + return c + case Top(): + return c + case BaseType(name=_): + return Implication(name, t, LiquidLiteralBool(True), c) + case RefinedType(name=rname, type=ty, refinement=refinement): + ref_subs = substitution_in_liquid(refinement, LiquidVar(name), rname) + return Implication(name, ty, ref_subs, c) + case AbstractionType(var_name=var_name, var_type=var_type, type=ty): + return implication_constraint( + var_name, # TODO: Double-check if this needs to be a base type! + var_type, + implication_constraint(name, ty, c), + ) + case TypeVar(name=_): + # TODO: We should create a custom sort instead of Int. + return Implication(name, BaseType("Int"), LiquidLiteralBool(True), c) + case ExistentialType(var_name=var_name, var_type=var_type, type=ty): + return implication_constraint(var_name, var_type, implication_constraint(name, ty, c)) + case _: + print(f"{name} : {t} => {c} ({type(t)})") + assert False + + +def ensure_safe_type(t: Type) -> BaseType: + if isinstance(t, Top) or isinstance(t, Bottom): + return BaseType("Bool") elif isinstance(t, BaseType): - return Implication(name, t, LiquidLiteralBool(True), c) - elif isinstance(t, AbstractionType): - return implication_constraint( - t.var_name, - t.var_type, - implication_constraint(name, t.type, c), - ) # TODO: email Rahjit - elif isinstance(t, TypeVar): - # TODO: We are using Int here, but it could have been a singleton. - return Implication(name, BaseType("Int"), LiquidLiteralBool(True), c) - elif isinstance(t, Bottom): - return c - elif isinstance(t, Top): - return c - logger.debug(f"{name} : {t} => {c} ({type(t)})") + return t + print(f"Unsafe: {t}") assert False @@ -62,7 +72,12 @@ def sub(t1: Type, t2: Type) -> Constraint: t1 = ensure_refined(t1) if isinstance(t2, BaseType): t2 = ensure_refined(t2) - if isinstance(t1, RefinedType) and isinstance(t2, RefinedType): + if isinstance(t2, ExistentialType): + assert False + if isinstance(t1, ExistentialType): + c = sub(t1.type, t2) + return implication_constraint(t1.var_name, t1.var_type, c) + elif isinstance(t1, RefinedType) and isinstance(t2, RefinedType): if isinstance(t1.type, Bottom) or isinstance(t2.type, Top): return ctrue elif t1.type == t2.type: diff --git a/aeon/verification/vcs.py b/aeon/verification/vcs.py index 7e4867e1..4f36f1d6 100644 --- a/aeon/verification/vcs.py +++ b/aeon/verification/vcs.py @@ -9,7 +9,7 @@ from aeon.core.liquid import LiquidLiteralString from aeon.core.liquid import LiquidTerm from aeon.core.liquid import LiquidVar -from aeon.core.types import AbstractionType +from aeon.core.types import AbstractionType, Bottom, ExistentialType, Top, TypeVar from aeon.core.types import BaseType @@ -47,7 +47,7 @@ def __repr__(self): @dataclass class Implication(Constraint): name: str - base: BaseType + base: BaseType | ExistentialType | Bottom | Top | TypeVar pred: LiquidTerm seq: Constraint diff --git a/tests/end_to_end_test.py b/tests/end_to_end_test.py index b472f7f3..6bd0ba8c 100644 --- a/tests/end_to_end_test.py +++ b/tests/end_to_end_test.py @@ -3,7 +3,7 @@ from aeon.backend.evaluator import EvaluationContext from aeon.backend.evaluator import eval from aeon.core.types import top -from aeon.frontend.anf_converter import ensure_anf + from aeon.frontend.parser import parse_term from aeon.frontend.parser import parse_type from aeon.prelude.prelude import evaluation_vars @@ -17,12 +17,11 @@ def check_compile(source, ty, res): p = parse_term(source) - p = ensure_anf(p) assert check_type(ctx, p, ty) assert eval(p, ectx) == res -def test_anf(): +def test_multiple_applications(): source = r"""let f : (x:Int) -> (y:Int) -> Int = (\x -> (\y -> x)) in let r = f (f 1 2) (f 2 3) in r""" diff --git a/tests/frontend_test.py b/tests/frontend_test.py index 80e6d4ac..4819598d 100644 --- a/tests/frontend_test.py +++ b/tests/frontend_test.py @@ -19,14 +19,13 @@ from aeon.core.types import t_int from aeon.core.types import TypePolymorphism from aeon.core.types import TypeVar -from aeon.frontend.anf_converter import ensure_anf + from aeon.frontend.parser import parse_term from aeon.frontend.parser import parse_type from aeon.utils.ast_helpers import false from aeon.utils.ast_helpers import i0 from aeon.utils.ast_helpers import i1 from aeon.utils.ast_helpers import i2 -from aeon.utils.ast_helpers import is_anf from aeon.utils.ast_helpers import mk_binop from aeon.utils.ast_helpers import true @@ -126,12 +125,6 @@ def test_operators(): assert parse_term("1 % 1") == mk_binop(lambda: "t", "%", i1, i1) -def test_precedence(): - t1 = parse_term("1 + 2 * 0") - at1 = ensure_anf(t1) - assert is_anf(at1) - - def test_let(): assert parse_term("let x = 1 in x") == Let("x", i1, Var("x")) diff --git a/tests/hole_test.py b/tests/hole_test.py index 85971591..37b901c5 100644 --- a/tests/hole_test.py +++ b/tests/hole_test.py @@ -1,5 +1,5 @@ from aeon.core.types import top -from aeon.frontend.anf_converter import ensure_anf + from aeon.sugar.desugar import desugar, apply_decorators_in_program from aeon.sugar.parser import parse_program from aeon.synthesis_grammar.identification import incomplete_functions_and_holes @@ -10,9 +10,8 @@ def extract_target_functions(source): prog = parse_program(source) prog = apply_decorators_in_program(prog) core, ctx, _, _ = desugar(prog) - core_anf = ensure_anf(core) - check_type_errors(ctx, core_anf, top) - return incomplete_functions_and_holes(ctx, core_anf) + check_type_errors(ctx, core, top) + return incomplete_functions_and_holes(ctx, core) def test_hole_identification(): diff --git a/tests/infer_test.py b/tests/infer_test.py index 115a3f05..f04980c5 100644 --- a/tests/infer_test.py +++ b/tests/infer_test.py @@ -1,7 +1,7 @@ from __future__ import annotations from aeon.core.types import t_int -from aeon.frontend.anf_converter import ensure_anf + from aeon.frontend.parser import parse_term from aeon.frontend.parser import parse_type from aeon.typechecking.context import EmptyContext @@ -14,7 +14,7 @@ def tt(e: str, t: str, vars: dict[str, str] = {}): ctx = build_context({k: parse_type(v) for (k, v) in vars.items()}) - term = ensure_anf(parse_term(e)) + term = parse_term(e) return check_type(ctx, term, parse_type(t)) @@ -89,6 +89,14 @@ def test_fifteen(): # Branches +def test_branch_eq_var(): + assert tt("x == 0", "Bool", {"x": "Int"}) + + +def test_branch_eq(): + assert tt("1 == 0", "{v:Bool | v == false}", {"x": "Int"}) + + def test_if(): assert tt("if x == 1 then 1 else 0", "Int", {"x": "Int"}) assert tt( diff --git a/tests/optimization_decorators_test.py b/tests/optimization_decorators_test.py index bb389712..e759be75 100644 --- a/tests/optimization_decorators_test.py +++ b/tests/optimization_decorators_test.py @@ -1,6 +1,6 @@ from aeon.core.terms import Term from aeon.core.types import top -from aeon.frontend.anf_converter import ensure_anf + from aeon.sugar.desugar import desugar from aeon.sugar.parser import parse_program from aeon.sugar.program import Program @@ -11,9 +11,8 @@ def extract_core(source: str) -> Term: prog = parse_program(source) core, ctx, _, _ = desugar(prog) - core_anf = ensure_anf(core) - check_type_errors(ctx, core_anf, top) - return core_anf + check_type_errors(ctx, core, top) + return core def test_hole_minimize_int(): @@ -44,6 +43,7 @@ def main(args:Int) : Unit { metadata, ) = desugar(prog) - core_ast_anf = ensure_anf(core_ast) - type_errors = check_type_errors(typing_ctx, core_ast_anf, top) + type_errors = check_type_errors(typing_ctx, core_ast, top) + for te in type_errors: + print(te) assert len(type_errors) == 0 diff --git a/tests/smt_test.py b/tests/smt_test.py index 05006ea7..a1948efe 100644 --- a/tests/smt_test.py +++ b/tests/smt_test.py @@ -8,7 +8,7 @@ from aeon.core.types import BaseType from aeon.core.types import t_int from aeon.core.types import top -from aeon.frontend.anf_converter import ensure_anf + from aeon.sugar.desugar import desugar from aeon.sugar.parser import parse_program from aeon.sugar.program import Program @@ -21,9 +21,8 @@ def extract_core(source: str) -> Term: prog = parse_program(source) core, ctx, _, _ = desugar(prog) - core_anf = ensure_anf(core) - check_type_errors(ctx, core_anf, top) - return core_anf + check_type_errors(ctx, core, top) + return core example = Implication( @@ -81,8 +80,7 @@ def main (x:Int) : Unit { metadata, ) = desugar(prog) - core_ast_anf = ensure_anf(core_ast) - type_errors = check_type_errors(typing_ctx, core_ast_anf, top) + type_errors = check_type_errors(typing_ctx, core_ast, top) assert len(type_errors) == 0 @@ -106,6 +104,5 @@ def main (x:Int) : Unit { metadata, ) = desugar(prog) - core_ast_anf = ensure_anf(core_ast) - type_errors = check_type_errors(typing_ctx, core_ast_anf, top) + type_errors = check_type_errors(typing_ctx, core_ast, top) assert len(type_errors) == 0 diff --git a/tests/substitutions_test.py b/tests/substitutions_test.py index 96ed4161..dbd70666 100644 --- a/tests/substitutions_test.py +++ b/tests/substitutions_test.py @@ -2,6 +2,8 @@ from aeon.core.substitutions import substitution from aeon.core.substitutions import substitution_in_type +from aeon.core.terms import Var +from aeon.core.types import ExistentialType from aeon.frontend.parser import parse_term from aeon.frontend.parser import parse_type @@ -66,3 +68,11 @@ def test_substitution_autorename_ref(): assert substitution_in_type(ty, parse_term("y"), "z") == parse_type( r"(y1:Int) -> {x : Int | y1 > y}", ) + + +def test_substitution_type_exist(): + ty = ExistentialType(var_name="z", var_type=parse_type("Int"), type=parse_type(r"(y:Int) -> {x : Int | x > z}")) + subs = substitution_in_type(ty, Var("z"), "y") + + assert subs.var_name != "z" # alpha renaming + assert "3" not in str(subs) diff --git a/tests/synth_fitness_test.py b/tests/synth_fitness_test.py index 9fe6ad71..bc7010ee 100644 --- a/tests/synth_fitness_test.py +++ b/tests/synth_fitness_test.py @@ -2,9 +2,11 @@ from abc import ABC +import pytest + from aeon.core.terms import Term, Application, Literal, Var from aeon.core.types import top, BaseType -from aeon.frontend.anf_converter import ensure_anf + from aeon.logger.logger import setup_logger from aeon.sugar.desugar import desugar from aeon.sugar.parser import parse_program @@ -17,6 +19,7 @@ def mock_literal_individual(value: int): + class t_Int(ABC): pass @@ -31,25 +34,31 @@ def __init__(self, value: int): return literal_int_instance(value) # type: ignore +@pytest.mark.skip def test_fitness(): code = """def year : Int = 2023; def synth (i: Int): Int { (?hole: Int) * i} """ prog = parse_program(code) - p, ctx, ectx, _ = desugar(prog) - p = ensure_anf(p) + p, ctx, ectx, metadata = desugar(prog) check_type_errors(ctx, p, top) internal_minimize = Definition( name="__internal__minimize_int_synth_0", args=[], type=BaseType("Int"), - body=Application(Application(Var("synth"), Literal(7, BaseType("Int"))), Application(Var("-"), Var("synth"))), + body=Application( + Application(Var("synth"), Literal(7, BaseType("Int"))), + Application(Var("-"), Var("synth"))), ) - term = synthesize(ctx, ectx, p, [("synth", ["hole"])], {"synth": {"minimize_int": [internal_minimize]}}) + term = synthesize(ctx, ectx, p, [("synth", ["hole"])], + {"synth": { + "minimize_int": [internal_minimize] + }}) assert isinstance(term, Term) +@pytest.mark.skip def test_fitness2(): code = """def year : Int = 2023; @minimize_int( year - synth(7) ) @@ -57,7 +66,6 @@ def synth (i:Int) : Int {(?hole: Int) * i} """ prog = parse_program(code) p, ctx, ectx, metadata = desugar(prog) - p = ensure_anf(p) check_type_errors(ctx, p, top) term = synthesize(ctx, ectx, p, [("synth", ["hole"])], metadata) diff --git a/tests/wellformed_test.py b/tests/wellformed_test.py index a3fafcb3..54842391 100644 --- a/tests/wellformed_test.py +++ b/tests/wellformed_test.py @@ -1,6 +1,6 @@ from __future__ import annotations -from aeon.core.types import BaseKind +from aeon.core.types import BaseKind, ExistentialType from aeon.core.types import StarKind from aeon.core.types import t_bool from aeon.core.types import t_int @@ -61,3 +61,10 @@ def test_poly(): TypePolymorphism("a", StarKind(), TypeVar("a")), BaseKind(), ) + + +def test_wf_existential(): + assert wellformed( + empty, + TypePolymorphism("a", BaseKind(), ExistentialType(var_name="x", var_type=parse_type("Int"), type=TypeVar("a"))), + )