diff --git a/kmir/pyproject.toml b/kmir/pyproject.toml index 6601457ea..2ad0cb394 100644 --- a/kmir/pyproject.toml +++ b/kmir/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "kmir" -version = "0.3.163" +version = "0.3.164" description = "" requires-python = "~=3.10" dependencies = [ diff --git a/kmir/src/kmir/__init__.py b/kmir/src/kmir/__init__.py index 6fff8bfbd..854795654 100644 --- a/kmir/src/kmir/__init__.py +++ b/kmir/src/kmir/__init__.py @@ -1,3 +1,3 @@ from typing import Final -VERSION: Final = '0.3.163' +VERSION: Final = '0.3.164' diff --git a/kmir/src/kmir/__main__.py b/kmir/src/kmir/__main__.py index 6abf198b2..9ba731824 100644 --- a/kmir/src/kmir/__main__.py +++ b/kmir/src/kmir/__main__.py @@ -69,7 +69,11 @@ def _kmir_gen_spec(opts: GenSpecOpts) -> None: kmir_kast, _ = parse_result apr_proof = kmir.apr_proof_from_kast( - str(opts.input_file.stem.replace('_', '-')), kmir_kast, start_symbol=opts.start_symbol, sort='KmirCell' + str(opts.input_file.stem.replace('_', '-')), + kmir_kast, + SMIRInfo.from_file(opts.input_file), + start_symbol=opts.start_symbol, + sort='KmirCell', ) claim = apr_proof.as_claim() @@ -110,7 +114,7 @@ def _kmir_view(opts: ViewOpts) -> None: proof = APRProof.read_proof_data(opts.proof_dir, opts.id) smir_info = None if opts.smir_info is not None: - smir_info = SMIRInfo(opts.smir_info) + smir_info = SMIRInfo.from_file(opts.smir_info) node_printer = KMIRAPRNodePrinter(kmir, proof, smir_info=smir_info, full_printer=False) viewer = APRProofViewer(proof, kmir, node_printer=node_printer) viewer.run() @@ -121,7 +125,7 @@ def _kmir_show(opts: ShowOpts) -> None: proof = APRProof.read_proof_data(opts.proof_dir, opts.id) smir_info = None if opts.smir_info is not None: - smir_info = SMIRInfo(opts.smir_info) + smir_info = SMIRInfo.from_file(opts.smir_info) node_printer = KMIRAPRNodePrinter(kmir, proof, smir_info=smir_info, full_printer=opts.full_printer) shower = APRProofShow(kmir, node_printer=node_printer) lines = shower.show(proof) diff --git a/kmir/src/kmir/kast.py b/kmir/src/kmir/kast.py new file mode 100644 index 000000000..2b3bdc77e --- /dev/null +++ b/kmir/src/kmir/kast.py @@ -0,0 +1,215 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from pyk.kast.inner import KApply, KVariable, build_cons +from pyk.kast.prelude.collections import list_of +from pyk.kast.prelude.kint import leInt +from pyk.kast.prelude.ml import mlEqualsTrue +from pyk.kast.prelude.utils import token + +from .smir import ArrayT, Bool, EnumT, Int, RefT, StructT, TupleT, Uint, UnionT + +if TYPE_CHECKING: + from collections.abc import Iterable + + from pyk.kast.inner import KInner + + from .smir import SMIRInfo + + +def int_var(var: KVariable, num_bytes: int, signed: bool) -> tuple[KInner, Iterable[KInner]]: + bit_width = num_bytes * 8 + var_max = ((1 << (bit_width - 1)) if signed else (1 << bit_width)) - 1 + var_min = -(1 << (bit_width - 1)) if signed else 0 + constraints = (mlEqualsTrue(leInt(var, token(var_max))), mlEqualsTrue(leInt(token(var_min), var))) + term = KApply('Value::Integer', (var, token(bit_width), token(signed))) + return term, constraints + + +def bool_var(var: KVariable) -> tuple[KInner, Iterable[KInner]]: + term = KApply('Value::BoolVal', (var,)) + return term, () + + +def mk_call_terminator(target: int, arg_count: int) -> KInner: + operands = [ + KApply( + 'Operand::Copy', + (KApply('place', (KApply('local', (token(i + 1),)), KApply('ProjectionElems::empty', ()))),), + ) + for i in range(arg_count) + ] + + args = build_cons(KApply('Operands::empty', ()), 'Operands::append', operands) + + return KApply( + '#execTerminator(_)_KMIR-CONTROL-FLOW_KItem_Terminator', + ( + KApply( + 'terminator(_,_)_BODY_Terminator_TerminatorKind_Span', + ( + KApply( + 'TerminatorKind::Call', + ( + KApply( + 'Operand::Constant', + ( + KApply( + 'constOperand(_,_,_)_BODY_ConstOperand_Span_MaybeUserTypeAnnotationIndex_MirConst', + ( + KApply('span', token(0)), + KApply('noUserTypeAnnotationIndex_BODY_MaybeUserTypeAnnotationIndex', ()), + KApply( + 'mirConst(_,_,_)_TYPES_MirConst_ConstantKind_Ty_MirConstId', + ( + KApply('ConstantKind::ZeroSized', ()), + KApply('ty', (token(target),)), + KApply('mirConstId(_)_TYPES_MirConstId_Int', (token(0),)), + ), + ), + ), + ), + ), + ), + args, + KApply( + 'place', + ( + KApply('local', (token(0),)), + KApply('ProjectionElems::empty', ()), + ), + ), + KApply('noBasicBlockIdx_BODY_MaybeBasicBlockIdx', ()), + KApply('UnwindAction::Continue', ()), + ), + ), + KApply('span', token(0)), + ), + ), + ), + ) + + +def symbolic_locals(smir_info: SMIRInfo, local_types: list[dict]) -> tuple[list[KInner], list[KInner]]: + locals, constraints = ArgGenerator(smir_info).run(local_types) + local0 = KApply('newLocal', (KApply('ty', (token(0),)), KApply('Mutability::Not', ()))) + return ([local0] + locals, constraints) + + +def _typed_value(value: KInner, ty: int, mutable: bool) -> KInner: + return KApply( + 'typedValue', + (value, KApply('ty', (token(ty),)), KApply('Mutability::Mut' if mutable else 'Mutability::Not', ())), + ) + + +class ArgGenerator: + smir_info: SMIRInfo + locals: list[KInner] + pointees: list[KInner] + constraints: list[KInner] + counter: int + ref_offset: int + + if TYPE_CHECKING: + from .smir import Ty + + def __init__(self, smir_info: SMIRInfo) -> None: + self.smir_info = smir_info + self.locals = [] + self.pointees = [] + self.constraints = [] + self.counter = 1 + self.ref_offset = 0 + + def run(self, local_types: list[dict]) -> tuple[list[KInner], list[KInner]]: + self.ref_offset = len(local_types) + 1 + for ty, mut in [(t['ty'], t['mutability']) for t in local_types]: + self._add_local(ty, mut == 'Mut') + return (self.locals + self.pointees, self.constraints) + + def _add_local(self, ty: Ty, mutable: bool) -> None: + value, constraints = self._symbolic_value(ty, mutable) + + self.locals.append(_typed_value(value, ty, mutable)) + self.constraints += constraints + + def _fresh_var(self, prefix: str) -> KVariable: + name = prefix + str(self.counter) + self.counter += 1 + return KVariable(name) + + def _symbolic_value(self, ty: Ty, mutable: bool) -> tuple[KInner, Iterable[KInner]]: + match self.smir_info.types.get(ty): + case Int(info): + return int_var(self._fresh_var('ARG_INT'), info.value, True) + + case Uint(info): + return int_var(self._fresh_var('ARG_UINT'), info.value, False) + + case Bool(): + return bool_var(self._fresh_var('ARG_BOOL')) + + case EnumT(_, _, discriminants): + variantVar = self._fresh_var('ARG_VARIDX') + # constraints for variant index being in range + max_variant = max(discriminants.keys()) + idx_range = [ + mlEqualsTrue(leInt(token(0), variantVar)), + mlEqualsTrue(leInt(variantVar, token(max_variant))), + ] + args = self._fresh_var('ENUM_ARGS') + return KApply('Value::Aggregate', (KApply('variantIdx', (variantVar,)), args)), idx_range + + case StructT(): + args = self._fresh_var('STRUCT_ARGS') + return KApply('Value::Aggregate', (KApply('variantIdx', (token(0),)), args)), [] + + case UnionT(): + args = self._fresh_var('ARG_UNION') + return KApply('Value::Aggregate', (KApply('variantIdx', (token(0),)), args)), [] + + case ArrayT(_, None): + elems = self._fresh_var('ARG_ARRAY') + return KApply('Value::Range', (elems,)), [] + + case ArrayT(element_type, size) if size is not None: + elem_vars: list[KInner] = [] + elem_constraints: list[KInner] = [] + for _ in range(size): + new_var, new_constraints = self._symbolic_value(element_type, mutable) + elem_vars.append(_typed_value(new_var, element_type, mutable)) + elem_constraints += new_constraints + return KApply('Value::Range', (list_of(elem_vars),)), elem_constraints + + case TupleT(components): + elem_vars = [] + elem_constraints = [] + for _ty in components: + new_var, new_constraints = self._symbolic_value(_ty, mutable) + elem_vars.append(_typed_value(new_var, _ty, mutable)) + elem_constraints += new_constraints + return ( + KApply('Value::Aggregate', (KApply('variantIdx', (token(0),)), list_of(elem_vars))), + elem_constraints, + ) + + case RefT(pointee_ty): + pointee_var, pointee_constraints = self._symbolic_value(pointee_ty, mutable) + ref = self.ref_offset + self.ref_offset += 1 + self.pointees.append(_typed_value(pointee_var, pointee_ty, mutable)) + return ( + KApply( + 'Value::Reference', + ( + token(0), + KApply('place', (KApply('local', (token(ref),)), KApply('ProjectionElems::empty', ()))), + KApply('Mutability::Mut', ()) if mutable else KApply('Mutability::Not', ()), + ), + ), + pointee_constraints, + ) + case _: + return self._fresh_var('ARG'), [] diff --git a/kmir/src/kmir/kdist/mir-semantics/body.md b/kmir/src/kmir/kdist/mir-semantics/body.md index 9d4259339..0603e8523 100644 --- a/kmir/src/kmir/kdist/mir-semantics/body.md +++ b/kmir/src/kmir/kdist/mir-semantics/body.md @@ -46,9 +46,9 @@ syntax MaybeOperand ::= someOperand(Operand) [group(mir-option)] syntax Operands ::= List {Operand, ""} [group(mir-list), symbol(Operands::append), terminator-symbol(Operands::empty)] -syntax Local ::= local(Int) [group(mir-int)] -syntax MaybeLocal ::= someLocal(Local) [group(mir-option)] - | "noLocal" [group(mir-option)] +syntax Local ::= local(Int) [group(mir-int) , symbol(local)] +syntax MaybeLocal ::= someLocal(Local) [group(mir-option), symbol(someLocal)] + | "noLocal" [group(mir-option), symbol(noLocal)] syntax ProjectionElem ::= "projectionElemDeref" [group(mir-enum), symbol(ProjectionElem::Deref)] | projectionElemField(FieldIdx, Ty) [group(mir-enum), symbol(ProjectionElem::Field)] @@ -60,9 +60,9 @@ syntax ProjectionElem ::= "projectionElemDeref" | projectionElemSubtype(Ty) [group(mir-enum), symbol(ProjectionElem::Subtype)] syntax ProjectionElems ::= List {ProjectionElem, ""} [group(mir-list), symbol(ProjectionElems::append), terminator-symbol(ProjectionElems::empty)] -syntax Place ::= place(local: Local, projection: ProjectionElems) [group(mir---local--projection)] -syntax MaybePlace ::= somePlace(Place) [group(mir-option)] - | "noPlace" [group(mir-option)] +syntax Place ::= place(local: Local, projection: ProjectionElems) [group(mir---local--projection), symbol(place)] +syntax MaybePlace ::= somePlace(Place) [group(mir-option), symbol(somePlace)] + | "noPlace" [group(mir-option), symbol(noPlace)] syntax Branch ::= branch(MIRInt, BasicBlockIdx) [group(mir)] syntax Branches ::= List {Branch, ""} [group(mir-list), symbol(Branches::append), terminator-symbol(Branches::empty)] diff --git a/kmir/src/kmir/kdist/mir-semantics/kmir.md b/kmir/src/kmir/kdist/mir-semantics/kmir.md index 236147175..3e4f59e71 100644 --- a/kmir/src/kmir/kdist/mir-semantics/kmir.md +++ b/kmir/src/kmir/kdist/mir-semantics/kmir.md @@ -59,7 +59,7 @@ The `Map` of types is static information used for decoding constants and allocat It maps `Ty` IDs to `TypeInfo` that can be supplied to decoding and casting functions as well as operations involving `Aggregate` values (related to `struct`s and `enum`s). ```k - syntax Map ::= #mkTypeMap ( Map, TypeMappings ) [function, total] + syntax Map ::= #mkTypeMap ( Map, TypeMappings ) [function, total, symbol("mkTypeMap")] rule #mkTypeMap(ACC, .TypeMappings) => ACC @@ -80,7 +80,7 @@ It maps `Ty` IDs to `TypeInfo` that can be supplied to decoding and casting func Another type-related `Map` is required to associate an `AdtDef` ID with its corresponding `Ty` ID for `struct`s and `enum`s when creating or using `Aggregate` values. ```k - syntax Map ::= #mkAdtMap ( Map , TypeMappings ) [function, total] + syntax Map ::= #mkAdtMap ( Map , TypeMappings ) [function, total, symbol("mkAdtMap")] // -------------------------------------------------------------- rule #mkAdtMap(ACC, .TypeMappings) => ACC @@ -108,7 +108,7 @@ they are callee in a `Call` terminator within an `Item`). The function _names_ and _ids_ are not relevant for calls and therefore dropped. ```k - syntax Map ::= #mkFunctionMap ( FunctionNames, MonoItems ) [ function, total ] + syntax Map ::= #mkFunctionMap ( FunctionNames, MonoItems ) [ function, total, symbol("mkFunctionMap") ] | #accumFunctions ( Map, Map, FunctionNames ) [ function, total ] | #accumItems ( Map, MonoItems ) [ function, total ] diff --git a/kmir/src/kmir/kdist/mir-semantics/rt/value.md b/kmir/src/kmir/kdist/mir-semantics/rt/value.md index 937faa980..7d1012412 100644 --- a/kmir/src/kmir/kdist/mir-semantics/rt/value.md +++ b/kmir/src/kmir/kdist/mir-semantics/rt/value.md @@ -22,17 +22,17 @@ High-level values can be - arrays and slices (with homogenous element types) ```k - syntax Value ::= Integer( Int, Int, Bool ) + syntax Value ::= Integer( Int, Int, Bool ) [symbol(Value::Integer)] // value, bit-width, signedness for un/signed int - | BoolVal( Bool ) + | BoolVal( Bool ) [symbol(Value::BoolVal)] // boolean - | Aggregate( VariantIdx , List ) + | Aggregate( VariantIdx , List ) [symbol(Value::Aggregate)] // heterogenous value list for tuples and structs (standard, tuple, or anonymous) - | Float( Float, Int ) + | Float( Float, Int ) [symbol(Value::Float)] // value, bit-width for f16-f128 - | Reference( Int , Place , Mutability ) + | Reference( Int , Place , Mutability ) [symbol(Value::Reference)] // stack depth (initially 0), place, borrow kind - | Range( List ) + | Range( List ) [symbol(Value::Range)] // homogenous values for array/slice // | Ptr( Address, MaybeValue ) // address, metadata for ref/ptr @@ -52,11 +52,11 @@ The local variables may be actual values (`typedValue`), uninitialised (`NewLoca // local storage of the stack frame syntax TypedLocal ::= TypedValue | MovedLocal | NewLocal - syntax TypedValue ::= typedValue ( Value , MaybeTy , Mutability ) + syntax TypedValue ::= typedValue ( Value , MaybeTy , Mutability ) [symbol(typedValue)] syntax MovedLocal ::= "Moved" - syntax NewLocal ::= newLocal ( Ty , Mutability ) + syntax NewLocal ::= newLocal ( Ty , Mutability ) [symbol(newLocal)] // the type of aggregates cannot be determined from the data provided when they // occur as `RValue`, therefore we have to make the `Ty` field optional here. diff --git a/kmir/src/kmir/kdist/mir-semantics/ty.md b/kmir/src/kmir/kdist/mir-semantics/ty.md index 691bd13bf..76ed5237b 100644 --- a/kmir/src/kmir/kdist/mir-semantics/ty.md +++ b/kmir/src/kmir/kdist/mir-semantics/ty.md @@ -71,7 +71,7 @@ syntax ParamDef ::= paramDef(Int) [group(mir-int)] // impo syntax RegionDef ::= regionDef(Int) [group(mir-int)] // imported from crate syntax TraitDef ::= traitDef(Int) [group(mir-int)] // imported from crate -syntax VariantIdx ::= variantIdx(Int) [group(mir-int)] +syntax VariantIdx ::= variantIdx(Int) [group(mir-int), symbol(variantIdx)] syntax DynKind ::= "dynKindDyn" [group(mir-enum), symbol(DynKind::Dyn)] | "dynKindDynStar" [group(mir-enum), symbol(DynKind::DynStar)] diff --git a/kmir/src/kmir/kmir.py b/kmir/src/kmir/kmir.py index b3a210a7f..71cee38a0 100644 --- a/kmir/src/kmir/kmir.py +++ b/kmir/src/kmir/kmir.py @@ -6,8 +6,9 @@ from pyk.cli.utils import bug_report_arg from pyk.cterm import CTerm, cterm_symbolic -from pyk.kast.inner import KApply, KInner, KSequence, KSort, KToken, Subst -from pyk.kast.manip import split_config_from +from pyk.kast.inner import KApply, KInner, KSequence, KSort, KToken, KVariable, Subst +from pyk.kast.manip import abstract_term_safely, split_config_from +from pyk.kast.prelude.collections import list_empty, list_of, map_empty from pyk.kast.prelude.string import stringToken from pyk.kcfg import KCFG from pyk.kcfg.explore import KCFGExplore @@ -19,8 +20,10 @@ from pyk.proof.show import APRProofNodePrinter from .cargo import cargo_get_smir_json +from .kast import mk_call_terminator, symbolic_locals from .kparse import KParse from .parse.parser import Parser +from .smir import SMIRInfo if TYPE_CHECKING: from collections.abc import Iterator @@ -31,7 +34,6 @@ from pyk.utils import BugReport from .options import ProveRSOpts - from .smir import SMIRInfo _LOGGER: Final = logging.getLogger(__name__) @@ -73,6 +75,33 @@ def make_init_config( init_config = subst.apply(self.definition.init_config(KSort(sort))) return init_config + def make_call_config( + self, parsed_smir: KApply, smir_info: SMIRInfo, start_symbol: str = 'main', sort: str = 'GeneratedTopCell' + ) -> tuple[KInner, list[KInner]]: + + if not start_symbol in smir_info.function_tys: + raise KeyError(f'{start_symbol} not found in program') + + _sym, _allocs, functions, items, types, _ = parsed_smir.args + + args_info = smir_info.function_arguments[start_symbol] + + locals, constraints = symbolic_locals(smir_info, args_info) + + subst = { + 'K_CELL': mk_call_terminator(smir_info.function_tys[start_symbol], len(args_info)), + 'STARTSYMBOL_CELL': KApply('symbol(_)_LIB_Symbol_String', (stringToken(start_symbol),)), + 'STACK_CELL': list_empty(), # FIXME see #560, problems matching a symbolic stack + 'LOCALS_CELL': list_of(locals), + 'FUNCTIONS_CELL': KApply('mkFunctionMap', (functions, items)), + 'TYPES_CELL': KApply('mkTypeMap', (map_empty(), types)), + 'ADTTOTY_CELL': KApply('mkAdtMap', (map_empty(), types)), + } + + config = self.definition.empty_config(KSort(sort)) + + return (Subst(subst).apply(config), constraints) + def run_parsed(self, parsed_smir: KInner, start_symbol: KInner | str = 'main', depth: int | None = None) -> Pattern: init_config = self.make_init_config(parsed_smir, start_symbol) init_kore = self.kast_to_kore(init_config, KSort('GeneratedTopCell')) @@ -80,21 +109,34 @@ def run_parsed(self, parsed_smir: KInner, start_symbol: KInner | str = 'main', d return result + def run_call( + self, parsed_smir: KApply, smir_json: SMIRInfo, start_symbol: str = 'main', depth: int | None = None + ) -> Pattern: + init_config, _ = self.make_call_config(parsed_smir, smir_json, start_symbol) + init_kore = self.kast_to_kore(init_config, KSort('GeneratedTopCell')) + result = self.run_pattern(init_kore, depth=depth) + + return result + def apr_proof_from_kast( self, id: str, kmir_kast: KInner, + smir_info: SMIRInfo, start_symbol: str = 'main', sort: str = 'GeneratedTopCell', proof_dir: Path | None = None, ) -> APRProof: - config = self.make_init_config(kmir_kast, start_symbol, sort=sort) - config_with_cell_vars, _ = split_config_from(config) - - lhs = CTerm(config) - - rhs_subst = Subst({'K_CELL': KMIR.Symbols.END_PROGRAM}) - rhs = CTerm(rhs_subst(config_with_cell_vars)) + assert type(kmir_kast) is KApply + lhs_config, constraints = self.make_call_config(kmir_kast, smir_info, start_symbol=start_symbol, sort=sort) + lhs = CTerm(lhs_config, constraints) + + var_config, var_subst = split_config_from(lhs_config) + _rhs_subst: dict[str, KInner] = { + v_name: abstract_term_safely(KVariable('_'), base_name=v_name) for v_name in var_subst + } + _rhs_subst['K_CELL'] = KSequence([KMIR.Symbols.END_PROGRAM]) + rhs = CTerm(Subst(_rhs_subst)(var_config)) kcfg = KCFG() init_node = kcfg.create_node(lhs) target_node = kcfg.create_node(rhs) @@ -117,7 +159,7 @@ def prove_rs(self, opts: ProveRSOpts) -> APRProof: kmir_kast, _ = parse_result assert isinstance(kmir_kast, KInner) apr_proof = self.apr_proof_from_kast( - label, kmir_kast, start_symbol=opts.start_symbol, proof_dir=opts.proof_dir + label, kmir_kast, SMIRInfo(smir_json), start_symbol=opts.start_symbol, proof_dir=opts.proof_dir ) if apr_proof.passed: return apr_proof diff --git a/kmir/src/kmir/smir.py b/kmir/src/kmir/smir.py index 704587bf2..c04595859 100644 --- a/kmir/src/kmir/smir.py +++ b/kmir/src/kmir/smir.py @@ -1,42 +1,43 @@ from __future__ import annotations import json -from functools import cached_property +from dataclasses import dataclass +from enum import Enum +from functools import cached_property, reduce from typing import TYPE_CHECKING, NewType if TYPE_CHECKING: from pathlib import Path - from typing import Any Ty = NewType('Ty', int) AdtDef = NewType('AdtDef', int) -# TODO: Named tuples w/ `from_dict` and helpers to create K terms - class SMIRInfo: _smir: dict - def __init__(self, smir_json_file: Path) -> None: - self._smir = json.loads(smir_json_file.read_text()) + def __init__(self, smir_json: dict) -> None: + self._smir = smir_json + + @staticmethod + def from_file(smir_json_file: Path) -> SMIRInfo: + return SMIRInfo(json.loads(smir_json_file.read_text())) @cached_property - def types(self) -> dict[Ty, Any]: - res = {} - for id, type in self._smir['types']: - res[Ty(id)] = type - return res + def types(self) -> dict[Ty, TypeMetadata]: + return {Ty(id): metadata_from_json(type) for id, type in self._smir['types']} @cached_property def adt_defs(self) -> dict[AdtDef, Ty]: res = {} for ty, typeinfo in self.types.items(): - if 'StructType' in typeinfo: - adt_def = typeinfo['StructType']['adt_def'] - res[AdtDef(adt_def)] = ty - if 'EnumType' in typeinfo: - adt_def = typeinfo['EnumType']['adt_def'] - res[AdtDef(adt_def)] = ty + match typeinfo: + case StructT(adt_def=adt_def): + res[AdtDef(adt_def)] = ty + case EnumT(adt_def=adt_def): + res[AdtDef(adt_def)] = ty + case UnionT(adt_def=adt_def): + res[AdtDef(adt_def)] = ty return res @cached_property @@ -44,7 +45,7 @@ def items(self) -> dict[str, dict]: return {_item['symbol_name']: _item for _item in self._smir['items']} @cached_property - def function_arguments(self) -> dict[str, dict]: + def function_arguments(self) -> dict[str, list[dict]]: res = {} for item in self._smir['items']: if not SMIRInfo._is_func(item): @@ -61,9 +62,185 @@ def function_arguments(self) -> dict[str, dict]: def function_symbols(self) -> dict[int, dict]: return {ty: sym for ty, sym, *_ in self._smir['functions'] if type(ty) is int} + @cached_property + def function_symbols_reverse(self) -> dict[str, int]: + return {sym['NormalSym']: ty for ty, sym in self.function_symbols.items() if 'NormalSym' in sym} + + @cached_property + def function_tys(self) -> dict[str, int]: + fun_syms = self.function_symbols_reverse + + res = {'main': -1} + for item in self._smir['items']: + if not SMIRInfo._is_func(item): + continue + + mono_item_fn = item['mono_item_kind']['MonoItemFn'] + name = mono_item_fn['name'] + sym = item['symbol_name'] + if not sym in fun_syms: + continue + + res[name] = fun_syms[sym] + return res + @staticmethod def _is_func(item: dict[str, dict]) -> bool: return 'MonoItemFn' in item['mono_item_kind'] - # TODO (does this go here?) - # def ty_as_kast(self, ty: Ty) -> KInner + +class IntTy(Enum): + I8 = 1 + I16 = 2 + I32 = 4 + I64 = 8 + I128 = 16 + Isize = 8 + + +class UintTy(Enum): + U8 = 1 + U16 = 2 + U32 = 4 + U64 = 8 + U128 = 16 + Usize = 8 + + +class FloatTy(Enum): + F16 = 2 + F32 = 4 + F64 = 8 + F128 = 16 + + +@dataclass +class TypeMetadata: ... + + +@dataclass +class PrimitiveType(TypeMetadata): ... + + +@dataclass +class Bool(PrimitiveType): ... + + +@dataclass +class Char(PrimitiveType): ... + + +@dataclass +class Str(PrimitiveType): ... + + +@dataclass +class Float(PrimitiveType): + info: FloatTy + + +@dataclass +class Int(PrimitiveType): + info: IntTy + + +@dataclass +class Uint(PrimitiveType): + info: UintTy + + +def _primty_from_json(typeinfo: str | dict) -> PrimitiveType: + if typeinfo == 'Bool': + return Bool() + elif typeinfo == 'Char': + return Char() + elif typeinfo == 'Str': + return Str() + + assert isinstance(typeinfo, dict) + if 'Uint' in typeinfo: + return Uint(UintTy.__members__[typeinfo['Uint']]) + if 'Int' in typeinfo: + return Int(IntTy.__members__[typeinfo['Int']]) + if 'Float' in typeinfo: + return Float(FloatTy.__members__[typeinfo['Float']]) + return NotImplemented + + +@dataclass +class EnumT(TypeMetadata): + name: str + adt_def: int + discriminants: dict + + +@dataclass +class StructT(TypeMetadata): + name: str + adt_def: int + + +@dataclass +class UnionT(TypeMetadata): + name: str + adt_def: int + + +@dataclass +class ArrayT(TypeMetadata): + element_type: Ty + length: int | None + + +@dataclass +class PtrT(TypeMetadata): + pointee_type: Ty + + +@dataclass +class RefT(TypeMetadata): + pointee_type: Ty + + +@dataclass +class TupleT(TypeMetadata): + components: list[Ty] + + +@dataclass +class FunT(TypeMetadata): + type_str: str + + +def metadata_from_json(typeinfo: dict) -> TypeMetadata: + if 'PrimitiveType' in typeinfo: + return _primty_from_json(typeinfo['PrimitiveType']) + elif 'EnumType' in typeinfo: + info = typeinfo['EnumType'] + discriminants = dict(info['discriminants']) + return EnumT(name=info['name'], adt_def=info['adt_def'], discriminants=discriminants) + elif 'StructType' in typeinfo: + return StructT(typeinfo['StructType']['name'], typeinfo['StructType']['adt_def']) + elif 'UnionType' in typeinfo: + return UnionT(typeinfo['UnionType']['name'], typeinfo['UnionType']['adt_def']) + elif 'ArrayType' in typeinfo: + info = typeinfo['ArrayType'] + assert isinstance(info, list) + length = None if info[1] is None else _decode(info[1]['kind']['Value'][1]['bytes']) + return ArrayT(info[0], length) + elif 'PtrType' in typeinfo: + return PtrT(typeinfo['PtrType']) + elif 'RefType' in typeinfo: + return RefT(typeinfo['RefType']) + elif 'TupleType' in typeinfo: + return TupleT(typeinfo['TupleType']['types']) + elif 'FunType' in typeinfo: + return FunT(typeinfo['FunType']) + + return NotImplemented + + +def _decode(bytes: list[int]) -> int: + # assume little-endian: reverse the bytes + bytes.reverse() + return reduce(lambda x, y: x * 256 + y, bytes) diff --git a/kmir/src/tests/integration/data/prove-rs/show/assert_eq_exp-fail.expected b/kmir/src/tests/integration/data/prove-rs/show/assert_eq_exp-fail.expected index 7ecce8811..a7fc7dd0b 100644 --- a/kmir/src/tests/integration/data/prove-rs/show/assert_eq_exp-fail.expected +++ b/kmir/src/tests/integration/data/prove-rs/show/assert_eq_exp-fail.expected @@ -1,10 +1,9 @@ ┌─ 1 (root, init) -│ #init ( symbol ( "assert_eq_exp_fail" ) globalAllocEntry ( 0 , Memory ( allocati -│ function: main -│ span: 93 +│ #execTerminator ( terminator ( ... kind: terminatorKindCall ( ... func: operandC +│ span: 0 │ -│ (88 steps) +│ (89 steps) └─ 3 (stuck, leaf) #readProjection ( typedValue ( Any , ty ( 25 ) , mutabilityNot ) , projectionEle function: main @@ -12,6 +11,6 @@ ┌─ 2 (root, leaf, target, terminal) -│ #EndProgram +│ #EndProgram ~> .K diff --git a/kmir/src/tests/integration/data/prove-rs/show/bitwise-fail.expected b/kmir/src/tests/integration/data/prove-rs/show/bitwise-fail.expected index 2775549ca..9d22391df 100644 --- a/kmir/src/tests/integration/data/prove-rs/show/bitwise-fail.expected +++ b/kmir/src/tests/integration/data/prove-rs/show/bitwise-fail.expected @@ -1,16 +1,15 @@ ┌─ 1 (root, init) -│ #init ( symbol ( "bitwise_fail" ) globalAllocEntry ( 5 , Memory ( allocation ( . -│ function: main -│ span: 68 +│ #execTerminator ( terminator ( ... kind: terminatorKindCall ( ... func: operandC +│ span: 0 │ -│ (43 steps) +│ (44 steps) └─ 3 (stuck, leaf) #selectBlock ( switchTargets ( ... branches: branch ( 0 , basicBlockIdx ( 2 ) ) function: main ┌─ 2 (root, leaf, target, terminal) -│ #EndProgram +│ #EndProgram ~> .K diff --git a/kmir/src/tests/integration/data/prove-rs/show/bitwise-not-shift-fail.expected b/kmir/src/tests/integration/data/prove-rs/show/bitwise-not-shift-fail.expected index 6086a86ae..d5ce205ae 100644 --- a/kmir/src/tests/integration/data/prove-rs/show/bitwise-not-shift-fail.expected +++ b/kmir/src/tests/integration/data/prove-rs/show/bitwise-not-shift-fail.expected @@ -1,16 +1,15 @@ ┌─ 1 (root, init) -│ #init ( symbol ( "bitwise_not_shift_fail" ) globalAllocEntry ( 8 , Memory ( allo -│ function: main -│ span: 108 +│ #execTerminator ( terminator ( ... kind: terminatorKindCall ( ... func: operandC +│ span: 0 │ -│ (254 steps) +│ (255 steps) └─ 3 (stuck, leaf) #readProjection ( typedValue ( Any , ty ( 25 ) , mutabilityNot ) , projectionEle span: 63 ┌─ 2 (root, leaf, target, terminal) -│ #EndProgram +│ #EndProgram ~> .K diff --git a/kmir/src/tests/integration/data/prove-rs/show/interior-mut-fail.expected b/kmir/src/tests/integration/data/prove-rs/show/interior-mut-fail.expected index 870b12ac6..dde724c7a 100644 --- a/kmir/src/tests/integration/data/prove-rs/show/interior-mut-fail.expected +++ b/kmir/src/tests/integration/data/prove-rs/show/interior-mut-fail.expected @@ -1,16 +1,15 @@ ┌─ 1 (root, init) -│ #init ( symbol ( "interior_mut_fail" ) globalAllocEntry ( 2 , Memory ( allocatio -│ function: main -│ span: 288 +│ #execTerminator ( terminator ( ... kind: terminatorKindCall ( ... func: operandC +│ span: 0 │ -│ (160 steps) +│ (161 steps) └─ 3 (stuck, leaf) #readProjection ( typedValue ( thunk ( #cast ( typedValue ( thunk ( rvalueAddres span: 91 ┌─ 2 (root, leaf, target, terminal) -│ #EndProgram +│ #EndProgram ~> .K diff --git a/kmir/src/tests/integration/data/prove-rs/show/interior-mut2-fail.expected b/kmir/src/tests/integration/data/prove-rs/show/interior-mut2-fail.expected index 22f221ec0..30df85fb0 100644 --- a/kmir/src/tests/integration/data/prove-rs/show/interior-mut2-fail.expected +++ b/kmir/src/tests/integration/data/prove-rs/show/interior-mut2-fail.expected @@ -1,16 +1,15 @@ ┌─ 1 (root, init) -│ #init ( symbol ( "interior_mut2_fail" ) globalAllocEntry ( 1 , Memory ( allocati -│ function: main -│ span: 120 +│ #execTerminator ( terminator ( ... kind: terminatorKindCall ( ... func: operandC +│ span: 0 │ -│ (151 steps) +│ (152 steps) └─ 3 (stuck, leaf) #readProjection ( typedValue ( thunk ( #cast ( typedValue ( thunk ( rvalueAddres span: 53 ┌─ 2 (root, leaf, target, terminal) -│ #EndProgram +│ #EndProgram ~> .K diff --git a/kmir/src/tests/integration/data/prove-rs/show/interior-mut3-fail.expected b/kmir/src/tests/integration/data/prove-rs/show/interior-mut3-fail.expected index 149286aba..8b1619f29 100644 --- a/kmir/src/tests/integration/data/prove-rs/show/interior-mut3-fail.expected +++ b/kmir/src/tests/integration/data/prove-rs/show/interior-mut3-fail.expected @@ -1,10 +1,9 @@ ┌─ 1 (root, init) -│ #init ( symbol ( "interior_mut3_fail" ) globalAllocEntry ( 1 , Memory ( allocati -│ function: main -│ span: 67 +│ #execTerminator ( terminator ( ... kind: terminatorKindCall ( ... func: operandC +│ span: 0 │ -│ (136 steps) +│ (137 steps) ├─ 3 │ #expect ( typedValue ( thunk ( #compute ( binOpEq , typedValue ( thunk ( #comput │ function: main @@ -35,6 +34,6 @@ ┌─ 2 (root, leaf, target, terminal) -│ #EndProgram +│ #EndProgram ~> .K diff --git a/kmir/src/tests/integration/data/prove-rs/show/local-raw-fail.expected b/kmir/src/tests/integration/data/prove-rs/show/local-raw-fail.expected index 945e1c12f..c402d6441 100644 --- a/kmir/src/tests/integration/data/prove-rs/show/local-raw-fail.expected +++ b/kmir/src/tests/integration/data/prove-rs/show/local-raw-fail.expected @@ -1,10 +1,9 @@ ┌─ 1 (root, init) -│ #init ( symbol ( "local_raw_fail" ) .GlobalAllocs ListItem ( functionName ( ty ( -│ function: main -│ span: 51 +│ #execTerminator ( terminator ( ... kind: terminatorKindCall ( ... func: operandC +│ span: 0 │ -│ (92 steps) +│ (93 steps) ├─ 3 │ #expect ( typedValue ( thunk ( #compute ( binOpEq , typedValue ( thunk ( #comput │ function: main @@ -35,6 +34,6 @@ ┌─ 2 (root, leaf, target, terminal) -│ #EndProgram +│ #EndProgram ~> .K diff --git a/kmir/src/tests/integration/data/prove-rs/show/symbolic-args-fail.expected b/kmir/src/tests/integration/data/prove-rs/show/symbolic-args-fail.expected new file mode 100644 index 000000000..1610e0e51 --- /dev/null +++ b/kmir/src/tests/integration/data/prove-rs/show/symbolic-args-fail.expected @@ -0,0 +1,39 @@ + +┌─ 1 (root, init) +│ #execTerminator ( terminator ( ... kind: terminatorKindCall ( ... func: operandC +│ span: 0 +│ +│ (45 steps) +├─ 3 (split) +│ #selectBlock ( switchTargets ( ... branches: branch ( 0 , basicBlockIdx ( 3 ) ) +┃ +┃ (branch) +┣━━┓ subst: .Subst +┃ ┃ constraint: +┃ ┃ ARG_BOOL3:Bool +┃ │ +┃ ├─ 4 +┃ │ #selectBlock ( switchTargets ( ... branches: branch ( 0 , basicBlockIdx ( 3 ) ) +┃ │ +┃ │ (65 steps) +┃ └─ 6 (stuck, leaf) +┃ #applyUnOp ( unOpPtrMetadata , typedValue ( Reference ( 0 , place ( ... local: l +┃ span: 59 +┃ +┗━━┓ subst: .Subst + ┃ constraint: + ┃ notBool ARG_BOOL3:Bool + │ + ├─ 5 + │ #selectBlock ( switchTargets ( ... branches: branch ( 0 , basicBlockIdx ( 3 ) ) + │ + │ (16 steps) + └─ 7 (stuck, leaf) + #applyUnOp ( unOpPtrMetadata , typedValue ( Reference ( 0 , place ( ... local: l + span: 59 + + +┌─ 2 (root, leaf, target, terminal) +│ #EndProgram ~> .K + + diff --git a/kmir/src/tests/integration/data/prove-rs/symbolic-args-fail.rs b/kmir/src/tests/integration/data/prove-rs/symbolic-args-fail.rs new file mode 100644 index 000000000..ffece24de --- /dev/null +++ b/kmir/src/tests/integration/data/prove-rs/symbolic-args-fail.rs @@ -0,0 +1,49 @@ +// @kmir prove-rs: eats_all_args + +#[allow(dead_code)] +struct MyStruct<'a>{ field: &'a MyEnum} + +#[allow(dead_code)] +enum MyEnum { + My1, + My2(i8), +} + +#[allow(dead_code)] +#[allow(unused)] +fn eats_all_args<'a>( + x1: i32, + x2: &mut u16, + x3: bool, + mut x4: MyStruct<'a>, + x5: MyEnum, + x6: &mut [u8], + x7: &[i8; 3], + x8: &mut [MyStruct<'a>; 2] +) -> () { + *x2 = x1 as u16; + if x3 { + x8[0] = x4; + } + match x5 { + _ => { + if x6.len() > 0 + { x6[0] = x7[0] as u8; } + } + } +} + +// we need a `main` function that calls eats_all_args +fn main() { + let e1 = MyEnum::My1; + let e2 = MyEnum::My2(0); + let e3 = MyEnum::My1; + let e4 = MyEnum::My2(0); + let mut x2 = 0; + let my1 = MyStruct{field: &e1}; + let my2 = MyStruct{field: &e2}; + let my3 = MyStruct{field: &e3}; + let mut a1 = [1, 2, 3]; + let a2 = [1, 2, 3]; + eats_all_args(1, &mut x2, true, my1, e4, &mut a1, &a2, &mut [my2, my3]); +} \ No newline at end of file diff --git a/kmir/src/tests/integration/test_integration.py b/kmir/src/tests/integration/test_integration.py index 89b8cab6d..c2d295efc 100644 --- a/kmir/src/tests/integration/test_integration.py +++ b/kmir/src/tests/integration/test_integration.py @@ -111,10 +111,10 @@ def test_schema_parse(test_dir: Path, kmir: KMIR) -> None: ] LOCAL_DECL_TESTS = [ - (2, KApply('local(_)_BODY_Local_Int', (KToken('2', KSort('Int')))), KSort('Local')), + (2, KApply('local', (KToken('2', KSort('Int')))), KSort('Local')), ( {'StorageLive': 2}, - KApply('StatementKind::StorageLive', (KApply('local(_)_BODY_Local_Int', (KToken('2', KSort('Int')))))), + KApply('StatementKind::StorageLive', (KApply('local', (KToken('2', KSort('Int')))))), KSort('StatementKind'), ), ('Not', KApply('Mutability::Not', ()), KSort('Mutability')), @@ -215,7 +215,7 @@ def test_schema_parse(test_dir: Path, kmir: KMIR) -> None: KApply( 'statement(_,_)_BODY_Statement_StatementKind_Span', ( - KApply('StatementKind::StorageLive', (KApply('local(_)_BODY_Local_Int', (KToken('42', KSort('Int')))))), + KApply('StatementKind::StorageLive', (KApply('local', (KToken('42', KSort('Int')))))), KApply('span', (KToken('1', KSort('Int')))), ), ), @@ -435,6 +435,7 @@ def test_prove(spec: Path, tmp_path: Path, kmir: KMIR) -> None: 'interior-mut3-fail', 'assert_eq_exp-fail', 'bitwise-not-shift-fail', + 'symbolic-args-fail', ] @@ -449,8 +450,20 @@ def test_prove_rs(rs_file: Path, kmir: KMIR, update_expected_output: bool) -> No prove_rs_opts = ProveRSOpts(rs_file) + # read start symbol(s) from the first line (default: [main] otherwise) + start_sym_prefix = '// @kmir prove-rs:' + with open(rs_file) as f: + headline = f.readline().strip('\n') + if headline.startswith(start_sym_prefix): + start_symbols = headline.removeprefix(start_sym_prefix).split() + else: + start_symbols = ['main'] + if should_show: - # always run `main` when kmir show is tested + # only run a single start symbol when kmir show is tested + assert len(start_symbols) == 1 + prove_rs_opts.start_symbol = start_symbols[0] + apr_proof = kmir.prove_rs(prove_rs_opts) if not should_fail: @@ -464,16 +477,6 @@ def test_prove_rs(rs_file: Path, kmir: KMIR, update_expected_output: bool) -> No show_res, PROVING_DIR / f'show/{rs_file.stem}.expected', update=update_expected_output ) else: - # read start symbol(s) from the first line (default: [main] otherwise) - start_sym_prefix = '// @kmir prove-rs:' - with open(rs_file) as f: - headline = f.readline().strip('\n') - - if headline.startswith(start_sym_prefix): - start_symbols = headline.removeprefix(start_sym_prefix).split() - else: - start_symbols = ['main'] - for start_symbol in start_symbols: prove_rs_opts.start_symbol = start_symbol apr_proof = kmir.prove_rs(prove_rs_opts) diff --git a/kmir/uv.lock b/kmir/uv.lock index 159fc09c6..69273ed48 100644 --- a/kmir/uv.lock +++ b/kmir/uv.lock @@ -491,7 +491,7 @@ wheels = [ [[package]] name = "kmir" -version = "0.3.163" +version = "0.3.164" source = { editable = "." } dependencies = [ { name = "kframework" }, diff --git a/package/version b/package/version index ee7ed8106..9fdf92793 100644 --- a/package/version +++ b/package/version @@ -1 +1 @@ -0.3.163 +0.3.164