11from __future__ import annotations
22
33import logging
4+ from itertools import count
45from typing import TYPE_CHECKING , NamedTuple
56
67from pyk .kast .inner import KApply , KSort , KVariable , Subst , build_cons
1112from pyk .kast .prelude .utils import token
1213
1314from .ty import ArrayT , BoolT , EnumT , IntT , PtrT , RefT , StructT , TupleT , Ty , UintT , UnionT
14- from .value import BoolValue , IntValue
15+ from .value import (
16+ NO_SIZE ,
17+ AggregateValue ,
18+ BoolValue ,
19+ DynamicSize ,
20+ IntValue ,
21+ Local ,
22+ Metadata ,
23+ Place ,
24+ PtrLocalValue ,
25+ RangeValue ,
26+ RefValue ,
27+ StaticSize ,
28+ )
1529
1630if TYPE_CHECKING :
17- from collections .abc import Iterable , Mapping , Sequence
31+ from collections .abc import Iterable , Iterator , Mapping , Sequence
1832 from random import Random
1933 from typing import Any , Final
2034
2943_LOGGER : Final = logging .getLogger (__name__ )
3044
3145
46+ RANDOM_MAX_ARRAY_LEN : Final = 32
47+
48+
3249LOCAL_0 : Final = KApply ('newLocal' , KApply ('ty' , token (0 )), KApply ('Mutability::Not' ))
3350
3451
@@ -474,25 +491,7 @@ def to_kast(self) -> KInner:
474491
475492
476493def _random_locals (random : Random , args : Sequence [_Local ], types : Mapping [Ty , TypeMetadata ]) -> list [KInner ]:
477- res : list [KInner ] = [LOCAL_0 ]
478- pointees : list [KInner ] = []
479-
480- next_ref = len (args ) + 1
481- for arg in args :
482- rvres = _random_value (
483- random = random ,
484- local = arg ,
485- types = types ,
486- next_ref = next_ref ,
487- )
488- res .append (rvres .value .to_kast ())
489- match rvres :
490- case PointerRes (pointee = pointee ):
491- pointees .append (pointee .to_kast ())
492- next_ref += 1
493-
494- res += pointees
495- return res
494+ return _RandomArgGen (random = random , args = args , types = types ).run ()
496495
497496
498497class SimpleRes (NamedTuple ):
@@ -501,55 +500,155 @@ class SimpleRes(NamedTuple):
501500
502501class ArrayRes (NamedTuple ):
503502 value : TypedValue
504- metadata : MetadataSize
503+ metadata_size : MetadataSize
505504
506505
507- class PointerRes (NamedTuple ):
508- value : TypedValue
509- pointee : TypedValue
506+ RandomValueRes = SimpleRes | ArrayRes
510507
511508
512- RandomValueRes = SimpleRes | ArrayRes | PointerRes
509+ class _RandomArgGen :
510+ _random : Random
511+ _args : Sequence [_Local ]
512+ _types : Mapping [Ty , TypeMetadata ]
513+ _pointees : list [TypedValue ]
514+ _ref : Iterator [int ]
513515
516+ def __init__ (self , * , random : Random , args : Sequence [_Local ], types : Mapping [Ty , TypeMetadata ]):
517+ self ._random = random
518+ self ._args = args
519+ self ._types = types
520+ self ._pointees = []
521+ self ._ref = count (len (args ) + 1 )
514522
515- def _random_value (
516- * ,
517- random : Random ,
518- local : _Local ,
519- types : Mapping [Ty , TypeMetadata ],
520- next_ref : int ,
521- ) -> RandomValueRes :
522- try :
523- type_info = types [local .ty ]
524- except KeyError as err :
525- raise ValueError (f'Unknown type: { local .ty } ' ) from err
526-
527- match type_info :
528- case BoolT ():
529- return SimpleRes (
530- TypedValue .from_local (
531- value = _random_bool_value (random ),
532- local = local ,
523+ def run (self ) -> list [KInner ]:
524+ res : list [KInner ] = [LOCAL_0 ]
525+ res .extend (self ._random_value (arg ).value .to_kast () for arg in self ._args )
526+ res .extend (pointee .to_kast () for pointee in self ._pointees )
527+ return res
528+
529+ def _random_value (self , local : _Local ) -> RandomValueRes :
530+ try :
531+ type_info = self ._types [local .ty ]
532+ except KeyError as err :
533+ raise ValueError (f'Unknown type: { local .ty } ' ) from err
534+
535+ match type_info :
536+ case BoolT ():
537+ return SimpleRes (
538+ TypedValue .from_local (
539+ value = self ._random_bool_value (),
540+ local = local ,
541+ )
533542 )
534- )
535- case IntT () | UintT ():
536- return SimpleRes (
537- TypedValue .from_local (
538- value = _random_int_value (random , type_info ),
539- local = local ,
540- ),
541- )
542- case _:
543- raise ValueError (f'Type unsupported for random value generator: { type_info } ' )
543+ case IntT () | UintT ():
544+ return SimpleRes (
545+ TypedValue .from_local (
546+ value = self ._random_int_value (type_info ),
547+ local = local ,
548+ ),
549+ )
550+ case EnumT (discriminants = discriminants , fields = fields ):
551+ return SimpleRes (
552+ TypedValue .from_local (
553+ value = self ._random_enum_value (mut = local .mut , discriminants = discriminants , fields = fields ),
554+ local = local ,
555+ ),
556+ )
557+ case StructT (fields = tys ) | TupleT (components = tys ):
558+ return SimpleRes (
559+ TypedValue .from_local (
560+ value = self ._random_struct_or_tuple_value (mut = local .mut , tys = tys ),
561+ local = local ,
562+ ),
563+ )
564+ case ArrayT (element_type = elem_ty , length = length ):
565+ value , metadata_size = self ._random_array_value (mut = local .mut , elem_ty = elem_ty , length = length )
566+ return ArrayRes (
567+ value = TypedValue .from_local (
568+ value = value ,
569+ local = local ,
570+ ),
571+ metadata_size = metadata_size ,
572+ )
573+ case PtrT () | RefT ():
574+ return SimpleRes (
575+ value = TypedValue .from_local (
576+ value = self ._random_ptr_value (mut = local .mut , type_info = type_info ),
577+ local = local ,
578+ ),
579+ )
580+ case _:
581+ raise ValueError (f'Type unsupported for random value generator: { type_info } ' )
544582
583+ def _random_bool_value (self ) -> BoolValue :
584+ return BoolValue (bool (self ._random .getrandbits (1 )))
545585
546- def _random_bool_value (random : Random ) -> BoolValue :
547- return BoolValue (bool (random .getrandbits (1 )))
586+ def _random_int_value (self , type_info : IntT | UintT ) -> IntValue :
587+ return IntValue (
588+ value = self ._random .randint (type_info .min , type_info .max ),
589+ nbits = type_info .nbits ,
590+ signed = isinstance (type_info , IntT ),
591+ )
548592
593+ def _random_enum_value (
594+ self ,
595+ * ,
596+ mut : bool ,
597+ discriminants : list [int ],
598+ fields : list [list [Ty ]],
599+ ) -> AggregateValue :
600+ variant_idx = self ._random .randrange (len (discriminants ))
601+ values = self ._random_fields (tys = fields [variant_idx ], mut = mut )
602+ return AggregateValue (variant_idx , values )
603+
604+ def _random_struct_or_tuple_value (self , * , mut : bool , tys : list [Ty ]) -> AggregateValue :
605+ return AggregateValue (0 , fields = self ._random_fields (tys = tys , mut = mut ))
606+
607+ def _random_fields (self , * , tys : list [Ty ], mut : bool ) -> tuple [Value , ...]:
608+ return tuple (self ._random_value (local = _Local (ty = ty , mut = mut )).value .value for ty in tys )
609+
610+ def _random_array_value (self , * , mut : bool , elem_ty : Ty , length : int | None ) -> tuple [RangeValue , MetadataSize ]:
611+ metadata_size : MetadataSize
612+ if length is None :
613+ length = self ._random .randint (0 , RANDOM_MAX_ARRAY_LEN )
614+ metadata_size = DynamicSize (length )
615+ else :
616+ metadata_size = StaticSize (length )
617+
618+ elems = tuple (self ._random_value (local = _Local (ty = elem_ty , mut = mut )).value .value for _ in range (length ))
619+ value = RangeValue (elems )
620+ return value , metadata_size
621+
622+ def _random_ptr_value (self , mut : bool , type_info : PtrT | RefT ) -> PtrLocalValue | RefValue :
623+ pointee_local = _Local (ty = type_info .pointee_type , mut = mut )
624+ pointee_res = self ._random_value (pointee_local )
625+ self ._pointees .append (pointee_res .value )
626+
627+ metadata_size : MetadataSize
628+ match pointee_res :
629+ case ArrayRes (metadata_size = metadata_size ):
630+ pass
631+ case _:
632+ metadata_size = NO_SIZE
549633
550- def _random_int_value (random : Random , type_info : IntT | UintT ) -> IntValue :
551- return IntValue (
552- value = random .randint (type_info .min , type_info .max ),
553- nbits = type_info .nbits ,
554- signed = isinstance (type_info , IntT ),
555- )
634+ metadata = Metadata (size = metadata_size , pointer_offset = 0 , origin_size = metadata_size )
635+
636+ ref = next (self ._ref )
637+
638+ match type_info :
639+ case PtrT ():
640+ return PtrLocalValue (
641+ stack_depth = 0 ,
642+ place = Place (local = Local (ref )),
643+ mut = mut ,
644+ metadata = metadata ,
645+ )
646+ case RefT ():
647+ return RefValue (
648+ stack_depth = 0 ,
649+ place = Place (local = Local (ref )),
650+ mut = mut ,
651+ metadata = metadata ,
652+ )
653+ case _:
654+ raise AssertionError ()
0 commit comments