From 743cd5bedbd406c9e36583913db18e511db55cbc Mon Sep 17 00:00:00 2001 From: Zesen Qian Date: Mon, 28 Apr 2025 18:10:30 +0100 Subject: [PATCH 1/3] refactor mode system flipping --- typing/allowance.ml | 4 + typing/allowance.mli | 8 + typing/jkind_axis.ml | 2 +- typing/jkind_axis.mli | 2 +- typing/mode.ml | 323 +++++++++++++++++++++++------------------ typing/mode_intf.mli | 85 +++++------ typing/solver.ml | 171 ++-------------------- typing/solver_intf.mli | 121 ++++----------- typing/typemode.ml | 6 +- 9 files changed, 275 insertions(+), 447 deletions(-) diff --git a/typing/allowance.ml b/typing/allowance.ml index f2855fc4976..1c30a2a8b28 100644 --- a/typing/allowance.ml +++ b/typing/allowance.ml @@ -22,6 +22,10 @@ type right_only = disallowed * allowed type both = allowed * allowed +type 'a pos = 'b * 'c constraint 'a = 'b * 'c + +type 'a neg = 'c * 'b constraint 'a = 'b * 'c + module type Allow_disallow = sig type ('a, 'b, 'd) sided constraint 'd = 'l * 'r diff --git a/typing/allowance.mli b/typing/allowance.mli index 2f20b54bc42..e8f7b573c08 100644 --- a/typing/allowance.mli +++ b/typing/allowance.mli @@ -47,6 +47,14 @@ type right_only = disallowed * allowed type both = allowed * allowed +(** Arrange the permissions appropriately for a positive lattice, by +doing nothing. *) +type 'a pos = 'b * 'c constraint 'a = 'b * 'c + +(** Arrange the permissions appropriately for a negative lattice, by + swapping left and right. *) +type 'a neg = 'c * 'b constraint 'a = 'b * 'c + module type Allow_disallow = sig type ('a, 'b, 'd) sided constraint 'd = 'l * 'r diff --git a/typing/jkind_axis.ml b/typing/jkind_axis.ml index 4153a484a85..ae2a43f8b3a 100644 --- a/typing/jkind_axis.ml +++ b/typing/jkind_axis.ml @@ -122,7 +122,7 @@ module Axis = struct end type 'a t = - | Modal : ('m, 'a, 'd) Mode.Alloc.axis -> 'a t + | Modal : ('a, _, _) Mode.Alloc.axis -> 'a t | Nonmodal : 'a Nonmodal.t -> 'a t type packed = Pack : 'a t -> packed [@@unboxed] diff --git a/typing/jkind_axis.mli b/typing/jkind_axis.mli index 1dfd540d388..6d1ba92b8fa 100644 --- a/typing/jkind_axis.mli +++ b/typing/jkind_axis.mli @@ -49,7 +49,7 @@ module Axis : sig (** Represents an axis of a jkind *) type 'a t = - | Modal : ('m, 'a, 'd) Mode.Alloc.axis -> 'a t + | Modal : ('a, _, _) Mode.Alloc.axis -> 'a t | Nonmodal : 'a Nonmodal.t -> 'a t type packed = Pack : 'a t -> packed [@@unboxed] diff --git a/typing/mode.ml b/typing/mode.ml index fad02023ccc..4b51864b9fd 100644 --- a/typing/mode.ml +++ b/typing/mode.ml @@ -974,14 +974,14 @@ module Lattices_mono = struct end type ('a, 'b, 'd) morph = - | Id : ('a, 'a, 'd) morph (** identity morphism *) + | Id : ('a, 'a, 'l * 'r) morph (** identity morphism *) | Meet_with : 'a -> ('a, 'a, 'l * 'r) morph (** Meet the input with the parameter *) - | Imply : 'a -> ('a, 'a, disallowed * 'd) morph + | Imply : 'a -> ('a, 'a, disallowed * 'r) morph (** The right adjoint of [Meet_with] *) | Join_with : 'a -> ('a, 'a, 'l * 'r) morph (** Join the input with the parameter *) - | Subtract : 'a -> ('a, 'a, 'd * disallowed) morph + | Subtract : 'a -> ('a, 'a, 'l * disallowed) morph (** The left adjoint of [Join_with] *) | Proj : 't obj * ('t, 'r_) Axis.t -> ('t, 'r_, 'l * 'r) morph (** Project from a product to an axis *) @@ -990,8 +990,8 @@ module Lattices_mono = struct | Min_with : ('t, 'r_) Axis.t -> ('r_, 't, 'l * disallowed) morph (** Combine an axis with minima along other axes *) | Map_comonadic : - ('a0, 'a1, 'd) morph - -> ('a0 comonadic_with, 'a1 comonadic_with, 'd) morph + ('a0, 'a1, 'l * 'r) morph + -> ('a0 comonadic_with, 'a1 comonadic_with, 'l * 'r) morph (** Lift an morphism on areality to a morphism on the comonadic fragment *) | Monadic_to_comonadic_min : (Monadic_op.t, 'a comonadic_with, 'l * disallowed) morph @@ -1015,8 +1015,11 @@ module Lattices_mono = struct (** Maps regional to global, identity otherwise *) | Global_to_regional : (Locality.t, Regionality.t, disallowed * 'r) morph (** Maps global to regional, local to local *) - | Compose : ('b, 'c, 'd) morph * ('a, 'b, 'd) morph -> ('a, 'c, 'd) morph - (** Compoistion of two morphisms *) + | Compose : + ('b, 'c, 'l * 'r) morph * ('a, 'b, 'l * 'r) morph + -> ('a, 'c, 'l * 'r) morph (** Compoistion of two morphisms *) + constraint 'd = _ * _ + [@@ocaml.warning "-62"] include Magic_allow_disallow (struct type ('a, 'b, 'd) sided = ('a, 'b, 'd) morph constraint 'd = 'l * 'r @@ -1152,7 +1155,7 @@ module Lattices_mono = struct | Yielding | Statefulness -> assert false - let rec src : type a b d. b obj -> (a, b, d) morph -> a obj = + let rec src : type a b l r. b obj -> (a, b, l * r) morph -> a obj = fun dst f -> match f with | Id -> dst @@ -1240,7 +1243,7 @@ module Lattices_mono = struct let eq_morph = Equal_morph.equal let rec print_morph : - type a b d. b obj -> Format.formatter -> (a, b, d) morph -> unit = + type a b l r. b obj -> Format.formatter -> (a, b, l * r) morph -> unit = fun dst ppf -> function | Id -> Format.fprintf ppf "id" | Join_with c -> Format.fprintf ppf "join(%a)" (print dst) c @@ -1356,7 +1359,7 @@ module Lattices_mono = struct let statefulness = visibility_to_statefulness m.visibility in { areality; linearity; portability; yielding; statefulness } - let rec apply : type a b d. b obj -> (a, b, d) morph -> a -> b = + let rec apply : type a b l r. b obj -> (a, b, l * r) morph -> a -> b = fun dst f a -> match f with | Compose (f, g) -> @@ -1388,8 +1391,11 @@ module Lattices_mono = struct (** Compose m0 after m1. Returns [Some f] if the composition can be represented by [f] instead of [Compose m0 m1]. [None] otherwise. *) let rec maybe_compose : - type a b c d. - c obj -> (b, c, d) morph -> (a, b, d) morph -> (a, c, d) morph option = + type a b c l r. + c obj -> + (b, c, l * r) morph -> + (a, b, l * r) morph -> + (a, c, l * r) morph option = fun dst m0 m1 -> let is_max c = le dst (max dst) c in let is_min c = le dst c (min dst) in @@ -1547,8 +1553,9 @@ module Lattices_mono = struct . and compose : - type a b c d. - c obj -> (b, c, d) morph -> (a, b, d) morph -> (a, c, d) morph = + type a b c l r. + c obj -> (b, c, l * r) morph -> (a, b, l * r) morph -> (a, c, l * r) morph + = fun dst f g -> match maybe_compose dst f g with Some m -> m | None -> Compose (f, g) @@ -1614,7 +1621,8 @@ module Lattices_mono = struct end module C = Lattices_mono -module S = Solvers_polarized (C) +module Solver = Solver_mono (C) +module S = Solver type monadic = C.monadic = { uniqueness : C.Uniqueness.t; @@ -1641,16 +1649,12 @@ let append_changes : (changes ref -> unit) ref = ref (fun _ -> assert false) let set_append_changes f = append_changes := f -type ('a, 'd) mode_monadic = ('a, 'd) S.Negative.mode - -type ('a, 'd) mode_comonadic = ('a, 'd) S.Positive.mode +type ('a, 'd) mode = ('a, 'd) S.mode (** Representing a single object *) module type Obj = sig type const - module Solver : S.Solver_polarized - val obj : const C.obj end @@ -1692,10 +1696,10 @@ let equate_from_submode' submode m0 m1 = | Ok () -> Ok ()) [@@inline] -module Common (Obj : Obj) = struct +module Comonadic_gen (Obj : Obj) = struct open Obj - type 'd t = (const, 'd) Solver.mode + type 'd t = (const, 'l * 'r) Solver.mode constraint 'd = 'l * 'r type l = (allowed * disallowed) t @@ -1761,18 +1765,90 @@ module Common (Obj : Obj) = struct end [@@inline] +module Monadic_gen (Obj : Obj) = struct + (* Monadic fragment is flipped *) + open Obj + + type 'd t = (const, 'r * 'l) Solver.mode constraint 'd = 'l * 'r + + type l = (allowed * disallowed) t + + type r = (disallowed * allowed) t + + type lr = (allowed * allowed) t + + type nonrec error = const error + + type equate_error = equate_step * error + + type (_, _, 'd) sided = 'd t + + let flip_error = function + | Ok _ as r -> r + | Error { left; right } -> Error { left = right; right = left } + + let disallow_right m = Solver.disallow_left m + + let disallow_left m = Solver.disallow_right m + + let allow_left m = Solver.allow_right m + + let allow_right m = Solver.allow_left m + + let newvar () = Solver.newvar obj + + let min = Solver.max obj + + let max = Solver.min obj + + let newvar_above m = Solver.newvar_below obj m + + let newvar_below m = Solver.newvar_above obj m + + let submode_log a b ~log = Solver.submode obj b a ~log |> flip_error + + let submode a b = try_with_log (submode_log a b) + + let join l = Solver.meet obj l + + let meet l = Solver.join obj l + + let submode_exn m0 m1 = assert (submode m0 m1 |> Result.is_ok) + + let equate a b = try_with_log (equate_from_submode submode_log a b) + + let equate_exn m0 m1 = assert (equate m0 m1 |> Result.is_ok) + + let print ?verbose () ppf m = Solver.print ?verbose obj ppf m + + let zap_to_ceil m = with_log (Solver.zap_to_floor obj m) + + let zap_to_floor m = with_log (Solver.zap_to_ceil obj m) + + let of_const : type l r. const -> (l * r) t = fun a -> Solver.of_const obj a + + module Guts = struct + let _get_floor m = Solver.get_ceil obj m + + let get_ceil m = Solver.get_floor obj m + + let _get_loose_floor m = Solver.get_loose_ceil obj m + + let _get_loose_ceil m = Solver.get_loose_floor obj m + end +end +[@@inline] + module Locality = struct module Const = C.Locality module Obj = struct type const = Const.t - module Solver = S.Positive - let obj = C.Locality end - include Common (Obj) + include Comonadic_gen (Obj) let global = of_const Global @@ -1801,12 +1877,10 @@ module Regionality = struct module Obj = struct type const = Const.t - module Solver = S.Positive - let obj = C.Regionality end - include Common (Obj) + include Comonadic_gen (Obj) let local = of_const Const.Local @@ -1825,12 +1899,10 @@ module Linearity = struct module Obj = struct type const = Const.t - module Solver = S.Positive - let obj : _ C.obj = C.Linearity end - include Common (Obj) + include Comonadic_gen (Obj) let many = of_const Many @@ -1847,12 +1919,10 @@ module Statefulness = struct module Obj = struct type const = Const.t - module Solver = S.Positive - let obj = C.Statefulness end - include Common (Obj) + include Comonadic_gen (Obj) let stateless = of_const Stateless @@ -1872,13 +1942,10 @@ module Visibility = struct module Obj = struct type const = Const.t - (* the negation of Visibility_op gives us the proper visibility *) - module Solver = S.Negative - let obj = C.Visibility_op end - include Common (Obj) + include Monadic_gen (Obj) let immutable = of_const Immutable @@ -1897,12 +1964,10 @@ module Portability = struct module Obj = struct type const = Const.t - module Solver = S.Positive - let obj : _ C.obj = C.Portability end - include Common (Obj) + include Comonadic_gen (Obj) let legacy = of_const Const.legacy @@ -1919,13 +1984,10 @@ module Uniqueness = struct module Obj = struct type const = Const.t - (* the negation of Uniqueness_op gives us the proper uniqueness *) - module Solver = S.Negative - let obj = C.Uniqueness_op end - include Common (Obj) + include Monadic_gen (Obj) let aliased = of_const Aliased @@ -1943,13 +2005,10 @@ module Contention = struct module Obj = struct type const = Const.t - (* the negation of Contention_op gives us the proper contention *) - module Solver = S.Negative - let obj = C.Contention_op end - include Common (Obj) + include Monadic_gen (Obj) let legacy = of_const Const.legacy @@ -1966,12 +2025,10 @@ module Yielding = struct module Obj = struct type const = Const.t - module Solver = S.Positive - let obj = C.Yielding end - include Common (Obj) + include Comonadic_gen (Obj) let yielding = of_const Yielding @@ -1985,33 +2042,29 @@ module Yielding = struct match global with true -> zap_to_floor | false -> zap_to_ceil end -let regional_to_local m = - S.Positive.via_monotone Locality.Obj.obj C.Regional_to_local m +let regional_to_local m = S.apply Locality.Obj.obj C.Regional_to_local m let locality_as_regionality m = - S.Positive.via_monotone Regionality.Obj.obj C.Locality_as_regionality m + S.apply Regionality.Obj.obj C.Locality_as_regionality m -let regional_to_global m = - S.Positive.via_monotone Locality.Obj.obj C.Regional_to_global m +let regional_to_global m = S.apply Locality.Obj.obj C.Regional_to_global m module type Areality = sig module Const : C.Areality - module Obj : Obj with type const = Const.t and module Solver = S.Positive + module Obj : Obj with type const = Const.t - val zap_to_legacy : (Const.t, allowed * 'r) Obj.Solver.mode -> Const.t + val zap_to_legacy : (Const.t, allowed * 'r) Solver.mode -> Const.t end module Comonadic_with (Areality : Areality) = struct module Obj = struct - type const = Areality.Obj.const C.comonadic_with - - module Solver = S.Positive + type const = Areality.Const.t C.comonadic_with let obj = C.comonadic_with_obj Areality.Obj.obj end - include Common (Obj) + include Comonadic_gen (Obj) type error = Error : (Obj.const, 'a) C.Axis.t * 'a Solver.error -> error @@ -2052,17 +2105,17 @@ module Comonadic_with (Areality : Areality) = struct | Statefulness -> (module Statefulness.Const) end - let proj ax m = Obj.Solver.via_monotone (proj_obj ax) (Proj (Obj.obj, ax)) m + let proj ax m = Solver.apply (proj_obj ax) (Proj (Obj.obj, ax)) m - let meet_const c m = Obj.Solver.via_monotone Obj.obj (Meet_with c) m + let meet_const c m = Solver.apply Obj.obj (Meet_with c) m - let join_const c m = Obj.Solver.via_monotone Obj.obj (Join_with c) m + let join_const c m = Solver.apply Obj.obj (Join_with c) m let min_with ax m = - Obj.Solver.via_monotone Obj.obj (Min_with ax) (Obj.Solver.disallow_right m) + Solver.apply Obj.obj (Min_with ax) (Solver.disallow_right m) let max_with ax m = - Obj.Solver.via_monotone Obj.obj (Max_with ax) (Obj.Solver.disallow_left m) + Solver.apply Obj.obj (Max_with ax) (Solver.disallow_left m) let join_with ax c m = join_const (C.min_with Obj.obj ax c) m @@ -2079,11 +2132,9 @@ module Comonadic_with (Areality : Areality) = struct let yielding = proj Yielding m |> Yielding.zap_to_legacy ~global in { areality; linearity; portability; yielding; statefulness } - let imply c m = - Obj.Solver.via_monotone Obj.obj (Imply c) (Obj.Solver.disallow_left m) + let imply c m = Solver.apply Obj.obj (Imply c) (Solver.disallow_left m) - let subtract c m = - Obj.Solver.via_monotone Obj.obj (Subtract c) (Obj.Solver.disallow_right m) + let subtract c m = Solver.apply Obj.obj (Subtract c) (Solver.disallow_right m) let legacy = of_const Const.legacy @@ -2152,14 +2203,10 @@ module Monadic = struct module Obj = struct type const = C.Monadic_op.t - (* Negative solver on the opposite of monadic should give the monadic - fragment with original ordering *) - module Solver = S.Negative - let obj = C.Monadic_op end - include Common (Obj) + include Monadic_gen (Obj) type error = Error : (Obj.const, 'a) C.Axis.t * 'a Solver.error -> error @@ -2197,31 +2244,27 @@ module Monadic = struct module Const_op = C.Monadic_op - let proj ax m = Obj.Solver.via_monotone (proj_obj ax) (Proj (Obj.obj, ax)) m + let proj ax m = Solver.apply (proj_obj ax) (Proj (Obj.obj, ax)) m - (* The monadic fragment is inverted. Most of the inversion logic is taken care - by [Solver_polarized], but some remain, such as the [Min_with] below which - is inverted from [Max_with]. *) + (* The monadic fragment is inverted. *) - let meet_const c m = Obj.Solver.via_monotone Obj.obj (Join_with c) m + let meet_const c m = Solver.apply Obj.obj (Join_with c) m - let join_const c m = Obj.Solver.via_monotone Obj.obj (Meet_with c) m + let join_const c m = Solver.apply Obj.obj (Meet_with c) m let max_with ax m = - Obj.Solver.via_monotone Obj.obj (Min_with ax) (Obj.Solver.disallow_left m) + Solver.apply Obj.obj (Min_with ax) (Solver.disallow_right m) let min_with ax m = - Obj.Solver.via_monotone Obj.obj (Max_with ax) (Obj.Solver.disallow_right m) + Solver.apply Obj.obj (Max_with ax) (Solver.disallow_left m) let join_with ax c m = join_const (C.max_with Obj.obj ax c) m let meet_with ax c m = meet_const (C.min_with Obj.obj ax c) m - let imply c m = - Obj.Solver.via_monotone Obj.obj (Subtract c) (Obj.Solver.disallow_left m) + let imply c m = Solver.apply Obj.obj (Subtract c) (Solver.disallow_right m) - let subtract c m = - Obj.Solver.via_monotone Obj.obj (Imply c) (Obj.Solver.disallow_right m) + let subtract c m = Solver.apply Obj.obj (Imply c) (Solver.disallow_left m) let zap_to_legacy m : Const.t = let uniqueness = proj Uniqueness m |> Uniqueness.zap_to_legacy in @@ -2295,22 +2338,18 @@ module Value_with (Areality : Areality) = struct type lr = (allowed * allowed) t - type ('m, 'a, 'd) axis = - | Monadic : - (Monadic.Const.t, 'a) Axis.t - -> (('a, 'd) mode_monadic, 'a, 'd) axis - | Comonadic : - (Comonadic.Const.t, 'a) Axis.t - -> (('a, 'd) mode_comonadic, 'a, 'd) axis + type ('a, 'd0, 'd1) axis = + | Monadic : (Monadic.Const.t, 'a) Axis.t -> ('a, 'd, 'd neg) axis + | Comonadic : (Comonadic.Const.t, 'a) Axis.t -> ('a, 'd, 'd pos) axis - type 'd axis_packed = P : ('m, 'a, 'd) axis -> 'd axis_packed + type axis_packed = P : (_, _, _) axis -> axis_packed - let print_axis (type m a d) ppf (axis : (m, a, d) axis) = + let print_axis (type a d0 d1) ppf (axis : (a, d0, d1) axis) = match axis with | Monadic ax -> Axis.print ppf ax | Comonadic ax -> Axis.print ppf ax - let lattice_of_axis (type m a d) (axis : (m, a, d) axis) : + let lattice_of_axis (type a d0 d1) (axis : (a, d0, d1) axis) : (module Lattice with type t = a) = match axis with | Comonadic ax -> Comonadic.Const.lattice_of_axis ax @@ -2325,7 +2364,7 @@ module Value_with (Areality : Areality) = struct P (Comonadic Statefulness); P (Monadic Visibility) ] - let proj_obj : type m a d. (m, a, d) axis -> a C.obj = function + let proj_obj : type a d0 d1. (a, d0, d1) axis -> a C.obj = function | Monadic ax -> Monadic.proj_obj ax | Comonadic ax -> Comonadic.proj_obj ax @@ -2565,29 +2604,29 @@ module Value_with (Areality : Areality) = struct let monadic = Monadic.min in merge { comonadic; monadic } - let print_axis : type m a d. (m, a, d) axis -> _ -> a -> unit = + let print_axis : type a. (a, _, _) axis -> _ -> a -> unit = fun ax ppf a -> let obj = proj_obj ax in C.print obj ppf a - let le_axis : type m a d. (m, a, d) axis -> a -> a -> bool = + let le_axis : type a d0 d1. (a, d0, d1) axis -> a -> a -> bool = fun ax m0 m1 -> match ax with | Comonadic ax -> Comonadic.le_axis ax m0 m1 | Monadic ax -> Monadic.le_axis ax m0 m1 - let min_axis : type m a d. (m, a, d) axis -> a = function + let min_axis : type a d0 d1. (a, d0, d1) axis -> a = function | Comonadic ax -> Comonadic.min_axis ax | Monadic ax -> Monadic.min_axis ax - let max_axis : type m a d. (m, a, d) axis -> a = function + let max_axis : type a d0 d1. (a, d0, d1) axis -> a = function | Comonadic ax -> Comonadic.max_axis ax | Monadic ax -> Monadic.max_axis ax - let is_max : type m a d. (m, a, d) axis -> a -> bool = + let is_max : type a d0 d1. (a, d0, d1) axis -> a -> bool = fun ax m -> le_axis ax (max_axis ax) m - let is_min : type m a d. (m, a, d) axis -> a -> bool = + let is_min : type a d0 d1. (a, d0, d1) axis -> a -> bool = fun ax m -> le_axis ax m (min_axis ax) let split = split @@ -2638,7 +2677,7 @@ module Value_with (Areality : Areality) = struct let monadic, b1 = Monadic.newvar_below monadic in { monadic; comonadic }, b0 || b1 - type error = Error : ('m, 'a, 'd) axis * 'a Solver.error -> error + type error = Error : ('a, _, _) axis * 'a Solver.error -> error type equate_error = equate_step * error @@ -2674,7 +2713,9 @@ module Value_with (Areality : Areality) = struct let proj_comonadic ax { comonadic; _ } = Comonadic.proj ax comonadic - let proj : type m a l r. (m, a, l * r) axis -> (l * r) t -> m = + let proj : + type a l0 l1 r0 r1. + (a, l0 * r0, l1 * r1) axis -> (l0 * r0) t -> (a, l1 * r1) mode = fun ax m -> match ax with | Monadic ax -> proj_monadic ax m @@ -2692,7 +2733,9 @@ module Value_with (Areality : Areality) = struct let monadic = Monadic.max |> Monadic.disallow_left |> Monadic.allow_right in { comonadic; monadic } - let max_with : type m a l r. (m, a, l * r) axis -> m -> (disallowed * r) t = + let max_with : + type a l0 l1 r0 r1. + (a, l0 * r0, l1 * r1) axis -> (a, l1 * r1) mode -> (disallowed * r0) t = fun ax m -> match ax with | Monadic ax -> max_with_monadic ax m @@ -2710,7 +2753,9 @@ module Value_with (Areality : Areality) = struct let monadic = Monadic.min |> Monadic.disallow_right |> Monadic.allow_left in { comonadic; monadic } - let min_with : type m a l r. (m, a, l * r) axis -> m -> (l * disallowed) t = + let min_with : + type a l0 l1 r0 r1. + (a, l0 * r0, l1 * r1) axis -> (a, l1 * r1) mode -> (l0 * disallowed) t = fun ax m -> match ax with | Monadic ax -> min_with_monadic ax m @@ -2724,8 +2769,9 @@ module Value_with (Areality : Areality) = struct let comonadic = Comonadic.join_with ax c comonadic in { comonadic; monadic } - let join_with : type m a d l r. (m, a, d) axis -> a -> (l * r) t -> (l * r) t - = + let join_with : + type a l0 l1 r0 r1. + (a, l0 * r0, l1 * r1) axis -> a -> (l0 * r0) t -> (l0 * r0) t = fun ax c m -> match ax with | Monadic ax -> join_with_monadic ax c m @@ -2739,8 +2785,9 @@ module Value_with (Areality : Areality) = struct let comonadic = Comonadic.meet_with ax c comonadic in { comonadic; monadic } - let meet_with : type m a d l r. (m, a, d) axis -> a -> (l * r) t -> (l * r) t - = + let meet_with : + type a l0 r0 l1 r1. + (a, l0 * r0, l1 * r1) axis -> a -> (l0 * r0) t -> (l0 * r0) t = fun ax c m -> match ax with | Monadic ax -> meet_with_monadic ax c m @@ -2769,12 +2816,10 @@ module Value_with (Areality : Areality) = struct { comonadic; monadic } let comonadic_to_monadic m = - S.Negative.via_antitone Monadic.Obj.obj - (Comonadic_to_monadic Comonadic.Obj.obj) m + S.apply Monadic.Obj.obj (Comonadic_to_monadic Comonadic.Obj.obj) m let monadic_to_comonadic_min m = - S.Positive.via_antitone Comonadic.Obj.obj Monadic_to_comonadic_min - (Monadic.disallow_left m) + S.apply Comonadic.Obj.obj Monadic_to_comonadic_min (Monadic.disallow_left m) let meet_const c { comonadic; monadic } = let c = split c in @@ -2877,8 +2922,7 @@ module Const = struct } module Axis = struct - let alloc_as_value : type d. d Alloc.axis_packed -> d Value.axis_packed = - function + let alloc_as_value : Alloc.axis_packed -> Value.axis_packed = function | P (Comonadic Areality) -> P (Comonadic Areality) | P (Comonadic Linearity) -> P (Comonadic Linearity) | P (Comonadic Portability) -> P (Comonadic Portability) @@ -2893,12 +2937,11 @@ module Const = struct end let comonadic_locality_as_regionality comonadic = - S.Positive.via_monotone Value.Comonadic.Obj.obj - (Map_comonadic Locality_as_regionality) comonadic + S.apply Value.Comonadic.Obj.obj (Map_comonadic Locality_as_regionality) + comonadic let comonadic_regional_to_local comonadic = - S.Positive.via_monotone Alloc.Comonadic.Obj.obj - (Map_comonadic Regional_to_local) comonadic + S.apply Alloc.Comonadic.Obj.obj (Map_comonadic Regional_to_local) comonadic let alloc_as_value m = let { comonadic; monadic } = m in @@ -2908,8 +2951,7 @@ let alloc_as_value m = let alloc_to_value_l2r m = let { comonadic; monadic } = Alloc.disallow_right m in let comonadic = - S.Positive.via_monotone Value.Comonadic.Obj.obj - (Map_comonadic Local_to_regional) comonadic + S.apply Value.Comonadic.Obj.obj (Map_comonadic Local_to_regional) comonadic in { comonadic; monadic } @@ -2917,8 +2959,7 @@ let value_to_alloc_r2g : type l r. (l * r) Value.t -> (l * r) Alloc.t = fun m -> let { comonadic; monadic } = m in let comonadic = - S.Positive.via_monotone Alloc.Comonadic.Obj.obj - (Map_comonadic Regional_to_global) comonadic + S.apply Alloc.Comonadic.Obj.obj (Map_comonadic Regional_to_global) comonadic in { comonadic; monadic } @@ -2928,11 +2969,11 @@ let value_to_alloc_r2l m = { comonadic; monadic } module Modality = struct - type ('m, 'a) raw = - | Meet_with : 'a -> (('a, 'l * 'r) mode_comonadic, 'a) raw - | Join_with : 'a -> (('a, 'l * 'r) mode_monadic, 'a) raw + type 'a raw = + | Meet_with : 'a -> 'a raw + | Join_with : 'a -> 'a raw - type t = Atom : ('m, 'a, _) Value.axis * ('m, 'a) raw -> t + type t = Atom : ('a, _, _) Value.axis * 'a raw -> t let is_id (Atom (ax, a)) = match a with @@ -2955,8 +2996,7 @@ module Modality = struct type 'a axis = (Mode.Const.t, 'a) Axis.t - type error = - | Error : 'a axis * (('a, _) mode_monadic, 'a) raw Solver.error -> error + type error = Error : 'a axis * 'a raw Solver.error -> error module Const = struct type t = Join_const of Mode.Const.t @@ -2979,8 +3019,7 @@ module Modality = struct Error (Error (ax, { left = Join_with left; right = Join_with right })) - let compose : - type a l r. a axis -> ((a, l * r) mode_monadic, a) raw -> t -> t = + let compose : type a. a axis -> a raw -> t -> t = fun ax a t -> match a, t with | Join_with c0, Join_const c -> @@ -3112,8 +3151,7 @@ module Modality = struct type 'a axis = (Mode.Const.t, 'a) Axis.t - type error = - | Error : 'a axis * (('a, _) mode_comonadic, 'a) raw Solver.error -> error + type error = Error : 'a axis * 'a raw Solver.error -> error module Const = struct type t = Meet_const of Mode.Const.t @@ -3136,8 +3174,7 @@ module Modality = struct Error (Error (ax, { left = Meet_with left; right = Meet_with right })) - let compose : - type a l r. a axis -> ((a, l * r) mode_comonadic, a) raw -> t -> t = + let compose : type a. a axis -> a raw -> t -> t = fun ax a t -> match a, t with | Meet_with c0, Meet_const c -> @@ -3270,8 +3307,7 @@ module Modality = struct end module Value = struct - type error = - | Error : ('m, 'a, _) Value.axis * ('m, 'a) raw Solver.error -> error + type error = Error : ('a, _, _) Value.axis * 'a raw Solver.error -> error type equate_error = equate_step * error @@ -3322,7 +3358,8 @@ module Modality = struct let to_list { monadic; comonadic } = Comonadic.to_list comonadic @ Monadic.to_list monadic - let proj (type m a d) (ax : (m, a, d) Value.axis) { monadic; comonadic } = + let proj (type a d0 d1) (ax : (a, d0, d1) Value.axis) + { monadic; comonadic } = match ax with | Monadic ax -> Monadic.proj ax monadic | Comonadic ax -> Comonadic.proj ax comonadic diff --git a/typing/mode_intf.mli b/typing/mode_intf.mli index 8bc5506822f..a425fb3e4e0 100644 --- a/typing/mode_intf.mli +++ b/typing/mode_intf.mli @@ -108,9 +108,7 @@ module type S = sig type nonrec equate_step = equate_step - type ('a, 'd) mode_monadic constraint 'd = 'l * 'r - - type ('a, 'd) mode_comonadic constraint 'd = 'l * 'r + type ('a, 'd) mode constraint 'd = 'l * 'r type ('a, 'b) monadic_comonadic = { monadic : 'a; @@ -132,7 +130,7 @@ module type S = sig Common with module Const := Const and type error := error - and type 'd t = (Const.t, 'd) mode_comonadic + and type 'd t = (Const.t, 'd pos) mode val global : lr @@ -175,7 +173,7 @@ module type S = sig Common with module Const := Const and type error := error - and type 'd t = (Const.t, 'd) mode_comonadic + and type 'd t = (Const.t, 'd pos) mode val global : lr @@ -199,7 +197,7 @@ module type S = sig Common with module Const := Const and type error := error - and type 'd t = (Const.t, 'd) mode_comonadic + and type 'd t = (Const.t, 'd pos) mode val many : lr @@ -221,7 +219,7 @@ module type S = sig Common with module Const := Const and type error := error - and type 'd t = (Const.t, 'd) mode_comonadic + and type 'd t = (Const.t, 'd pos) mode end module Uniqueness : sig @@ -241,7 +239,7 @@ module type S = sig Common with module Const := Const and type error := error - and type 'd t = (Const.t, 'd) mode_monadic + and type 'd t = (Const.t, 'd neg) mode val aliased : lr @@ -268,7 +266,7 @@ module type S = sig Common with module Const := Const and type error := error - and type 'd t = (Const.t, 'd) mode_monadic + and type 'd t = (Const.t, 'd neg) mode end module Yielding : sig @@ -286,7 +284,7 @@ module type S = sig Common with module Const := Const and type error := error - and type 'd t = (Const.t, 'd) mode_comonadic + and type 'd t = (Const.t, 'd pos) mode val yielding : lr @@ -309,7 +307,7 @@ module type S = sig Common with module Const := Const and type error := error - and type 'd t = (Const.t, 'd) mode_comonadic + and type 'd t = (Const.t, 'd pos) mode val stateless : lr @@ -336,7 +334,7 @@ module type S = sig Common with module Const := Const and type error := error - and type 'd t = (Const.t, 'd) mode_monadic + and type 'd t = (Const.t, 'd neg) mode val immutable : lr @@ -416,25 +414,21 @@ module type S = sig val meet_const : Const.t -> ('l * 'r) t -> ('l * 'r) t end - (** Represents a mode axis in this product whose constant is ['a], and - whose variable is ['m] given the allowness ['d]. *) - type ('m, 'a, 'd) axis = - | Monadic : - (Monadic.Const.t, 'a) Axis.t - -> (('a, 'd) mode_monadic, 'a, 'd) axis - | Comonadic : - (Comonadic.Const.t, 'a) Axis.t - -> (('a, 'd) mode_comonadic, 'a, 'd) axis + (** Represents a mode axis in this product whose constant is ['a], and whose + allowance is ['d1] given the product's allowance ['d0]. *) + type ('a, 'd0, 'd1) axis = + | Monadic : (Monadic.Const.t, 'a) Axis.t -> ('a, 'd, 'd neg) axis + | Comonadic : (Comonadic.Const.t, 'a) Axis.t -> ('a, 'd, 'd pos) axis - type 'd axis_packed = P : ('m, 'a, 'd) axis -> 'd axis_packed + type axis_packed = P : (_, _, _) axis -> axis_packed - val print_axis : Format.formatter -> ('m, 'a, 'd) axis -> unit + val print_axis : Format.formatter -> ('a, _, _) axis -> unit (** Gets the normal lattice for comonadic axes and the "op"ped lattice for monadic ones. *) - val lattice_of_axis : ('m, 'a, 'd) axis -> (module Lattice with type t = 'a) + val lattice_of_axis : ('a, _, _) axis -> (module Lattice with type t = 'a) - val all_axes : ('l * 'r) axis_packed list + val all_axes : axis_packed list type ('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h) modes = { areality : 'a; @@ -482,9 +476,9 @@ module type S = sig val print : Format.formatter -> t -> unit end - val is_max : ('m, 'a, 'd) axis -> 'a -> bool + val is_max : ('a, _, _) axis -> 'a -> bool - val is_min : ('m, 'a, 'd) axis -> 'a -> bool + val is_min : ('a, _, _) axis -> 'a -> bool val split : t -> (Monadic.Const.t, Comonadic.Const.t) monadic_comonadic @@ -501,10 +495,10 @@ module type S = sig val partial_apply : t -> t (** Prints a constant on any axis. *) - val print_axis : ('m, 'a, 'd) axis -> Format.formatter -> 'a -> unit + val print_axis : ('a, _, _) axis -> Format.formatter -> 'a -> unit end - type error = Error : ('m, 'a, 'd) axis * 'a Solver.error -> error + type error = Error : ('a, _, _) axis * 'a Solver.error -> error type 'd t = ('d Monadic.t, 'd Comonadic.t) monadic_comonadic @@ -519,15 +513,24 @@ module type S = sig include Allow_disallow with type (_, _, 'd) sided = 'd t list end - val proj : ('m, 'a, 'l * 'r) axis -> ('l * 'r) t -> 'm + val proj : + ('a, 'l0 * 'r0, 'l1 * 'r1) axis -> ('l0 * 'r0) t -> ('a, 'l1 * 'r1) mode - val max_with : ('m, 'a, 'l * 'r) axis -> 'm -> (disallowed * 'r) t + val max_with : + ('a, 'l0 * 'r0, 'l1 * 'r1) axis -> + ('a, 'l1 * 'r1) mode -> + (disallowed * 'r0) t - val min_with : ('m, 'a, 'l * 'r) axis -> 'm -> ('l * disallowed) t + val min_with : + ('a, 'l0 * 'r0, 'l1 * 'r1) axis -> + ('a, 'l1 * 'r1) mode -> + ('l0 * disallowed) t - val meet_with : (_, 'a, _) axis -> 'a -> ('l * 'r) t -> ('l * 'r) t + val meet_with : + ('a, 'l0 * 'r0, 'l1 * 'r1) axis -> 'a -> ('l0 * 'r0) t -> ('l0 * 'r0) t - val join_with : (_, 'a, _) axis -> 'a -> ('l * 'r) t -> ('l * 'r) t + val join_with : + ('a, 'l0 * 'r0, 'l1 * 'r1) axis -> 'a -> ('l0 * 'r0) t -> ('l0 * 'r0) t val zap_to_legacy : lr -> Const.t @@ -569,7 +572,7 @@ module type S = sig val alloc_as_value : Alloc.Const.t -> Value.Const.t module Axis : sig - val alloc_as_value : 'd Alloc.axis_packed -> 'd Value.axis_packed + val alloc_as_value : Alloc.axis_packed -> Value.axis_packed end val locality_as_regionality : Locality.Const.t -> Regionality.Const.t @@ -597,16 +600,16 @@ module type S = sig val value_to_alloc_r2g : ('l * 'r) Value.t -> ('l * 'r) Alloc.t module Modality : sig - type ('m, 'a) raw = - | Meet_with : 'a -> (('a, 'd) mode_comonadic, 'a) raw + type 'a raw = + | Meet_with : 'a -> 'a raw (** [Meet_with c] takes [x] and returns [meet c x]. [c] can be [max] in which case it's the identity modality. *) - | Join_with : 'a -> (('a, 'd) mode_monadic, 'a) raw + | Join_with : 'a -> 'a raw (** [Join_with c] takes [x] and returns [join c x]. [c] can be [min] in which case it's the identity modality. *) (** An atom modality is a [raw] accompanied by the axis it acts on. *) - type t = Atom : ('m, 'a, _) Value.axis * ('m, 'a) raw -> t + type t = Atom : ('a, _, _) Value.axis * 'a raw -> t (** Test if the given modality is the identity modality. *) val is_id : t -> bool @@ -621,7 +624,7 @@ module type S = sig type atom := t type error = - | Error : ('m, 'a, _) Value.axis * ('m, 'a) raw Solver.error -> error + | Error : ('a, _, _) Value.axis * 'a raw Solver.error -> error type nonrec equate_error = equate_step * error @@ -671,7 +674,7 @@ module type S = sig val of_list : atom list -> t (** Project out the [atom] for the given axis in the given modality. *) - val proj : ('m, 'a, 'd) Value.axis -> t -> atom + val proj : ('a, _, _) Value.axis -> t -> atom (** [equate t0 t1] checks that [t0 = t1]. Definition: [t0 = t1] iff [t0 <= t1] and [t1 <= t0]. *) diff --git a/typing/solver.ml b/typing/solver.ml index ff972766eed..b4691fadf4d 100644 --- a/typing/solver.ml +++ b/typing/solver.ml @@ -34,6 +34,8 @@ type 'a error = } module Solver_mono (C : Lattices_mono) = struct + type nonrec 'a error = 'a error + type any_morph = Any_morph : ('a, 'b, 'd) C.morph -> any_morph module VarMap = Map.Make (struct @@ -98,6 +100,8 @@ module Solver_mono (C : Lattices_mono) = struct and ('b, 'd) morphvar = | Amorphvar : 'a var * ('a, 'b, 'd) C.morph -> ('b, 'd) morphvar + constraint 'd = _ * _ + [@@ocaml.warning "-62"] let get_key (Amorphvar (v, m)) = v.id, Any_morph m @@ -134,6 +138,8 @@ module Solver_mono (C : Lattices_mono) = struct 'a * ('a, disallowed * 'r) morphvar VarMap.t -> ('a, disallowed * 'r) mode (** [Amodemeet a [mv0, mv1, ..]] represents [a meet mv0 meet mv1 meet ..]. *) + constraint 'd = _ * _ + [@@ocaml.warning "-62"] (** Prints a mode variable, including the set of variables below it (recursively). To handle cycles, [traversed] is the set of variables that @@ -169,7 +175,8 @@ module Solver_mono (C : Lattices_mono) = struct (var_map_to_list v.vlower) and print_morphvar : - type a d. ?traversed:VarSet.t -> a C.obj -> _ -> (a, d) morphvar -> _ = + type a l r. + ?traversed:VarSet.t -> a C.obj -> _ -> (a, l * r) morphvar -> _ = fun ?traversed dst ppf (Amorphvar (v, f)) -> let src = C.src dst f in Format.fprintf ppf "%a(%a)" (C.print_morph dst) f (print_var ?traversed src) @@ -260,7 +267,7 @@ module Solver_mono (C : Lattices_mono) = struct let max (type a) (obj : a C.obj) = Amode (C.max obj) - let of_const a = Amode a + let of_const _obj a = Amode a let apply_morphvar dst morph (Amorphvar (var, morph')) = Amorphvar (var, C.compose dst morph morph') @@ -690,163 +697,3 @@ module Solver_mono (C : Lattices_mono) = struct allow_left (Amodevar mu), true end [@@inline always] - -module Solvers_polarized (C : Lattices_mono) = struct - module S = Solver_mono (C) - - type changes = S.changes - - let empty_changes = S.empty_changes - - let undo_changes = S.undo_changes - - module type Solver_polarized = - Solver_polarized - with type ('a, 'b, 'd) morph := ('a, 'b, 'd) C.morph - and type 'a obj := 'a C.obj - and type 'a error := 'a error - and type changes := changes - - module rec Positive : - (Solver_polarized - with type 'd polarized = 'd pos - and type ('a, 'd) mode_op = ('a, 'd) Negative.mode) = struct - type 'd polarized = 'd pos - - type ('a, 'd) mode_op = ('a, 'd) Negative.mode - - type ('a, 'd) mode = ('a, 'd) S.mode constraint 'd = 'l * 'r - - include Magic_allow_disallow (S) - - let newvar = S.newvar - - let submode = S.submode - - let join = S.join - - let meet = S.meet - - let of_const _ = S.of_const - - let min = S.min - - let max = S.max - - let zap_to_floor = S.zap_to_floor - - let zap_to_ceil = S.zap_to_ceil - - let newvar_above = S.newvar_above - - let newvar_below = S.newvar_below - - let get_ceil = S.get_ceil - - let get_floor = S.get_floor - - let get_loose_ceil = S.get_loose_ceil - - let get_loose_floor = S.get_loose_floor - - let print ?(verbose = false) = S.print ~verbose - - let via_monotone = S.apply - - let via_antitone = S.apply - end - - and Negative : - (Solver_polarized - with type 'd polarized = 'd neg - and type ('a, 'd) mode_op = ('a, 'd) Positive.mode) = struct - type 'd polarized = 'd neg - - type ('a, 'd) mode_op = ('a, 'd) Positive.mode - - type ('a, 'd) mode = ('a, 'r * 'l) S.mode constraint 'd = 'l * 'r - - include Magic_allow_disallow (struct - type ('a, _, 'd) sided = ('a, 'd) mode - - let disallow_right = S.disallow_left - - let disallow_left = S.disallow_right - - let allow_right = S.allow_left - - let allow_left = S.allow_right - end) - - let newvar = S.newvar - - let submode obj m0 m1 ~log = - Result.map_error - (fun { left; right } -> { left = right; right = left }) - (S.submode obj m1 m0 ~log) - - let join = S.meet - - let meet = S.join - - let of_const _ = S.of_const - - let min = S.max - - let max = S.min - - let zap_to_floor = S.zap_to_ceil - - let zap_to_ceil = S.zap_to_floor - - let newvar_above = S.newvar_below - - let newvar_below = S.newvar_above - - let get_ceil = S.get_floor - - let get_floor = S.get_ceil - - let get_loose_ceil = S.get_loose_floor - - let get_loose_floor = S.get_loose_ceil - - let print ?(verbose = false) = S.print ~verbose - - let via_monotone = S.apply - - let via_antitone = S.apply - end - - (* Definitions to show that this solver works over a category. *) - module Category = struct - type 'a obj = 'a C.obj - - type ('a, 'b, 'd) morph = ('a, 'b, 'd) C.morph - - type ('a, 'd) mode = - | Positive of ('a, 'd pos) Positive.mode - | Negative of ('a, 'd neg) Negative.mode - - let apply_into_positive : - type a b l r. - b obj -> - (a, b, l * r) morph -> - (a, l * r) mode -> - (b, l * r) Positive.mode = - fun obj morph -> function - | Positive mode -> Positive.via_monotone obj morph mode - | Negative mode -> Positive.via_antitone obj morph mode - - let apply_into_negative : - type a b l r. - b obj -> - (a, b, l * r) morph -> - (a, l * r) mode -> - (b, r * l) Negative.mode = - fun obj morph -> function - | Positive mode -> Negative.via_antitone obj morph mode - | Negative mode -> Negative.via_monotone obj morph mode - end -end -[@@inline always] diff --git a/typing/solver_intf.mli b/typing/solver_intf.mli index 8eed401fab4..3ccc0d56014 100644 --- a/typing/solver_intf.mli +++ b/typing/solver_intf.mli @@ -61,7 +61,7 @@ module type Lattices_mono = sig - [disallowed], meaning the morphism cannot be on the left because it does not have right adjoint. Similar for ['r]. *) - type ('a, 'b, 'd) morph + type ('a, 'b, 'd) morph constraint 'd = 'l * 'r (* Due to the implementation in [solver.ml], a mode doesn't have sufficient information to infer the object it lives in, whether at compile-time or @@ -160,41 +160,40 @@ module type Lattices_mono = sig val print_morph : 'b obj -> Format.formatter -> ('a, 'b, 'd) morph -> unit end -(** Arrange the permissions appropriately for a positive lattice, by - doing nothing. *) -type 'a pos = 'b * 'c constraint 'a = 'b * 'c - -(** Arrange the permissions appropriately for a negative lattice, by - swapping left and right. *) -type 'a neg = 'c * 'b constraint 'a = 'b * 'c - -module type Solver_polarized = sig +module type Solver_mono = sig (* These first few types will be replaced with types from the Lattices_mono *) (** The morphism type from the [Lattices_mono] we're working with. See comments on [Lattices_mono.morph]. *) - type ('a, 'b, 'd) morph + type ('a, 'b, 'd) morph constraint 'd = 'l * 'r (** The object type from the [Lattices_mono] we're working with *) type 'a obj type 'a error - (** For a negative lattice, we reverse the direction of adjoints. We thus use - [neg] for [polarized] for negative lattices, which reverses ['l * 'r] to - ['r * 'l]. (Use [pos] for positive lattices.) *) - type 'd polarized constraint 'd = 'l * 'r + (* Backtracking facilities used by [types.ml] *) + (** Represents a sequence of state mutations caused by mode operations. All + mutating operations in this module take a [log:changes ref option] and + append to it all changes made, regardless of success or failure. It is + [option] only for performance reasons; the caller should never provide + [log:None]. The caller is responsible for taking care of the appended log: + they can either revert the changes using [undo_changes], or commit the + changes to the global log in [types.ml]. *) type changes + (** An empty sequence of changes. *) + val empty_changes : changes + + (** Undo the sequence of changes recorded. *) + val undo_changes : changes -> unit + (** A mode with carrier type ['a] and allowance ['d]. See Note [Allowance] in allowance.mli.*) type ('a, 'd) mode constraint 'd = 'l * 'r - (** The mode type for the opposite polarity. *) - type ('a, 'd) mode_op constraint 'd = 'l * 'r - include Allow_disallow with type ('a, _, 'd) sided = ('a, 'd) mode (** Returns the mode representing the given constant. *) @@ -269,23 +268,12 @@ module type Solver_polarized = sig val print : ?verbose:bool -> 'a obj -> Format.formatter -> ('a, 'l * 'r) mode -> unit - (** Apply a monotone morphism whose source and target modes are of the - polarity of this enclosing module. That is, [Positive.apply_monotone] - takes a positive mode to a positive mode. *) - val via_monotone : + (** Apply a monotone morphism. *) + val apply : 'b obj -> - ('a, 'b, ('l * 'r) polarized) morph -> + ('a, 'b, 'l * 'r) morph -> ('a, 'l * 'r) mode -> ('b, 'l * 'r) mode - - (** Apply an antitone morphism whose target mode is the mode defined in - this module and whose source mode is the dual mode. That is, - [Positive.apply_antitone] takes a negative mode to a positive one. *) - val via_antitone : - 'b obj -> - ('a, 'b, ('l * 'r) polarized) morph -> - ('a, 'r * 'l) mode_op -> - ('b, 'l * 'r) mode end module type S = sig @@ -306,68 +294,9 @@ module type S = sig (** Solver that supports polarized lattices; needed because some morphisms are antitone *) - module Solvers_polarized (C : Lattices_mono) : sig - (* Backtracking facilities used by [types.ml] *) - - (** Represents a sequence of state mutations caused by mode operations. All - mutating operations in this module take a [log:changes ref option] and - append to it all changes made, regardless of success or failure. It is - [option] only for performance reasons; the caller should never provide - [log:None]. The caller is responsible for taking care of the appended log: - they can either revert the changes using [undo_changes], or commit the - changes to the global log in [types.ml]. *) - type changes - - (** An empty sequence of changes. *) - val empty_changes : changes - - (** Undo the sequence of changes recorded. *) - val undo_changes : changes -> unit - - (* Construct a new category based on the original category [C]. Objects are - two copies of the objects in [C] of opposite polarity. The positive copy - is identical to the original lattice. The negative copy has its lattice - structure reversed. Morphism are four copies of the morphisms in [C], from - two copies of objects to two copies of objects. *) - - module type Solver_polarized = - Solver_polarized - with type ('a, 'b, 'd) morph := ('a, 'b, 'd) C.morph - and type 'a obj := 'a C.obj - and type 'a error := 'a error - and type changes := changes - - module rec Positive : - (Solver_polarized - with type 'd polarized = 'd pos - and type ('a, 'd) mode_op = ('a, 'd) Negative.mode) - - and Negative : - (Solver_polarized - with type 'd polarized = 'd neg - and type ('a, 'd) mode_op = ('a, 'd) Positive.mode) - - (* The following definitions show how this solver works over a category by - defining objects and morphisms. These definitions are not used in - practice. They are put into a module to make it easy to spot if we end up - using these in the future. *) - module Category : sig - type 'a obj = 'a C.obj - - type ('a, 'b, 'd) morph = ('a, 'b, 'd) C.morph - - type ('a, 'd) mode = - | Positive of ('a, 'd pos) Positive.mode - | Negative of ('a, 'd neg) Negative.mode - - val apply_into_positive : - 'b obj -> ('a, 'b, 'd) morph -> ('a, 'd) mode -> ('b, 'd) Positive.mode - - val apply_into_negative : - 'b obj -> - ('a, 'b, 'l * 'r) morph -> - ('a, 'l * 'r) mode -> - ('b, 'r * 'l) Negative.mode - end - end + module Solver_mono (C : Lattices_mono) : + Solver_mono + with type ('a, 'b, 'd) morph := ('a, 'b, 'd) C.morph + and type 'a obj := 'a C.obj + and type 'a error = 'a error end diff --git a/typing/typemode.ml b/typing/typemode.ml index fab67c13c35..4010e2ef95a 100644 --- a/typing/typemode.ml +++ b/typing/typemode.ml @@ -25,7 +25,7 @@ exception Error of Location.t * error module Axis_pair = struct type 'm t = - | Modal_axis_pair : ('m, 'a, 'd) Mode.Alloc.axis * 'a -> modal t + | Modal_axis_pair : ('a, _, _) Mode.Alloc.axis * 'a -> modal t | Any_axis_pair : 'a Axis.t * 'a -> maybe_nonmodal t | Everything_but_nullability : maybe_nonmodal t @@ -264,8 +264,8 @@ let default_mode_annots (annots : Alloc.Const.Option.t) = let transl_mode_annots annots : Alloc.Const.Option.t = let step modifiers_so_far annot = let { txt = - Modal_axis_pair (type m a d) - ((axis, mode) : (m, a, d) Mode.Alloc.axis * a); + Modal_axis_pair (type a d0 d1) + ((axis, mode) : (a, d0, d1) Mode.Alloc.axis * a); loc } = transl_annot ~annot_type:Mode ~required_mode_maturity:(Some Stable) From 506965c75360cc86b886ecdde3042c048a2c4ee1 Mon Sep 17 00:00:00 2001 From: Zesen Qian Date: Wed, 30 Apr 2025 09:59:30 +0100 Subject: [PATCH 2/3] move axis into module --- typing/jkind_axis.ml | 4 +- typing/jkind_axis.mli | 2 +- typing/mode.ml | 85 +++++++++++++++++++++++-------------------- typing/mode_intf.mli | 48 ++++++++++++------------ typing/typemode.ml | 4 +- 5 files changed, 76 insertions(+), 67 deletions(-) diff --git a/typing/jkind_axis.ml b/typing/jkind_axis.ml index ae2a43f8b3a..87225df4099 100644 --- a/typing/jkind_axis.ml +++ b/typing/jkind_axis.ml @@ -122,7 +122,7 @@ module Axis = struct end type 'a t = - | Modal : ('a, _, _) Mode.Alloc.axis -> 'a t + | Modal : ('a, _, _) Mode.Alloc.Axis.t -> 'a t | Nonmodal : 'a Nonmodal.t -> 'a t type packed = Pack : 'a t -> packed [@@unboxed] @@ -159,7 +159,7 @@ module Axis = struct Pack (Nonmodal Nullability) ] let name (type a) : a t -> string = function - | Modal axis -> Format.asprintf "%a" Mode.Alloc.print_axis axis + | Modal axis -> Format.asprintf "%a" Mode.Alloc.Axis.print axis | Nonmodal Externality -> "externality" | Nonmodal Nullability -> "nullability" diff --git a/typing/jkind_axis.mli b/typing/jkind_axis.mli index 6d1ba92b8fa..ac302f02cf2 100644 --- a/typing/jkind_axis.mli +++ b/typing/jkind_axis.mli @@ -49,7 +49,7 @@ module Axis : sig (** Represents an axis of a jkind *) type 'a t = - | Modal : ('a, _, _) Mode.Alloc.axis -> 'a t + | Modal : ('a, _, _) Mode.Alloc.Axis.t -> 'a t | Nonmodal : 'a Nonmodal.t -> 'a t type packed = Pack : 'a t -> packed [@@unboxed] diff --git a/typing/mode.ml b/typing/mode.ml index 4b51864b9fd..676f78c15ea 100644 --- a/typing/mode.ml +++ b/typing/mode.ml @@ -2066,7 +2066,9 @@ module Comonadic_with (Areality : Areality) = struct include Comonadic_gen (Obj) - type error = Error : (Obj.const, 'a) C.Axis.t * 'a Solver.error -> error + type 'a axis = (Obj.const, 'a) C.Axis.t + + type error = Error : 'a axis * 'a Solver.error -> error type equate_error = equate_step * error @@ -2208,7 +2210,9 @@ module Monadic = struct include Monadic_gen (Obj) - type error = Error : (Obj.const, 'a) C.Axis.t * 'a Solver.error -> error + type 'a axis = (Obj.const, 'a) C.Axis.t + + type error = Error : 'a axis * 'a Solver.error -> error type equate_error = equate_step * error @@ -2338,33 +2342,35 @@ module Value_with (Areality : Areality) = struct type lr = (allowed * allowed) t - type ('a, 'd0, 'd1) axis = - | Monadic : (Monadic.Const.t, 'a) Axis.t -> ('a, 'd, 'd neg) axis - | Comonadic : (Comonadic.Const.t, 'a) Axis.t -> ('a, 'd, 'd pos) axis + module Axis = struct + type ('a, 'd0, 'd1) t = + | Monadic : (Monadic.Const.t, 'a) Axis.t -> ('a, 'd, 'd neg) t + | Comonadic : (Comonadic.Const.t, 'a) Axis.t -> ('a, 'd, 'd pos) t - type axis_packed = P : (_, _, _) axis -> axis_packed + type packed = P : (_, _, _) t -> packed - let print_axis (type a d0 d1) ppf (axis : (a, d0, d1) axis) = - match axis with - | Monadic ax -> Axis.print ppf ax - | Comonadic ax -> Axis.print ppf ax + let print (type a d0 d1) ppf (t : (a, d0, d1) t) = + match t with + | Monadic ax -> Axis.print ppf ax + | Comonadic ax -> Axis.print ppf ax + + let all = + [ P (Comonadic Areality); + P (Monadic Uniqueness); + P (Comonadic Linearity); + P (Monadic Contention); + P (Comonadic Portability); + P (Comonadic Statefulness); + P (Monadic Visibility) ] + end - let lattice_of_axis (type a d0 d1) (axis : (a, d0, d1) axis) : + let lattice_of_axis (type a d0 d1) (axis : (a, d0, d1) Axis.t) : (module Lattice with type t = a) = match axis with | Comonadic ax -> Comonadic.Const.lattice_of_axis ax | Monadic ax -> Monadic.Const.lattice_of_axis ax - let all_axes = - [ P (Comonadic Areality); - P (Monadic Uniqueness); - P (Comonadic Linearity); - P (Monadic Contention); - P (Comonadic Portability); - P (Comonadic Statefulness); - P (Monadic Visibility) ] - - let proj_obj : type a d0 d1. (a, d0, d1) axis -> a C.obj = function + let proj_obj : type a d0 d1. (a, d0, d1) Axis.t -> a C.obj = function | Monadic ax -> Monadic.proj_obj ax | Comonadic ax -> Comonadic.proj_obj ax @@ -2604,29 +2610,29 @@ module Value_with (Areality : Areality) = struct let monadic = Monadic.min in merge { comonadic; monadic } - let print_axis : type a. (a, _, _) axis -> _ -> a -> unit = + let print_axis : type a. (a, _, _) Axis.t -> _ -> a -> unit = fun ax ppf a -> let obj = proj_obj ax in C.print obj ppf a - let le_axis : type a d0 d1. (a, d0, d1) axis -> a -> a -> bool = + let le_axis : type a d0 d1. (a, d0, d1) Axis.t -> a -> a -> bool = fun ax m0 m1 -> match ax with | Comonadic ax -> Comonadic.le_axis ax m0 m1 | Monadic ax -> Monadic.le_axis ax m0 m1 - let min_axis : type a d0 d1. (a, d0, d1) axis -> a = function + let min_axis : type a d0 d1. (a, d0, d1) Axis.t -> a = function | Comonadic ax -> Comonadic.min_axis ax | Monadic ax -> Monadic.min_axis ax - let max_axis : type a d0 d1. (a, d0, d1) axis -> a = function + let max_axis : type a d0 d1. (a, d0, d1) Axis.t -> a = function | Comonadic ax -> Comonadic.max_axis ax | Monadic ax -> Monadic.max_axis ax - let is_max : type a d0 d1. (a, d0, d1) axis -> a -> bool = + let is_max : type a d0 d1. (a, d0, d1) Axis.t -> a -> bool = fun ax m -> le_axis ax (max_axis ax) m - let is_min : type a d0 d1. (a, d0, d1) axis -> a -> bool = + let is_min : type a d0 d1. (a, d0, d1) Axis.t -> a -> bool = fun ax m -> le_axis ax m (min_axis ax) let split = split @@ -2677,7 +2683,7 @@ module Value_with (Areality : Areality) = struct let monadic, b1 = Monadic.newvar_below monadic in { monadic; comonadic }, b0 || b1 - type error = Error : ('a, _, _) axis * 'a Solver.error -> error + type error = Error : ('a, _, _) Axis.t * 'a Solver.error -> error type equate_error = equate_step * error @@ -2715,7 +2721,7 @@ module Value_with (Areality : Areality) = struct let proj : type a l0 l1 r0 r1. - (a, l0 * r0, l1 * r1) axis -> (l0 * r0) t -> (a, l1 * r1) mode = + (a, l0 * r0, l1 * r1) Axis.t -> (l0 * r0) t -> (a, l1 * r1) mode = fun ax m -> match ax with | Monadic ax -> proj_monadic ax m @@ -2735,7 +2741,7 @@ module Value_with (Areality : Areality) = struct let max_with : type a l0 l1 r0 r1. - (a, l0 * r0, l1 * r1) axis -> (a, l1 * r1) mode -> (disallowed * r0) t = + (a, l0 * r0, l1 * r1) Axis.t -> (a, l1 * r1) mode -> (disallowed * r0) t = fun ax m -> match ax with | Monadic ax -> max_with_monadic ax m @@ -2755,7 +2761,7 @@ module Value_with (Areality : Areality) = struct let min_with : type a l0 l1 r0 r1. - (a, l0 * r0, l1 * r1) axis -> (a, l1 * r1) mode -> (l0 * disallowed) t = + (a, l0 * r0, l1 * r1) Axis.t -> (a, l1 * r1) mode -> (l0 * disallowed) t = fun ax m -> match ax with | Monadic ax -> min_with_monadic ax m @@ -2771,7 +2777,7 @@ module Value_with (Areality : Areality) = struct let join_with : type a l0 l1 r0 r1. - (a, l0 * r0, l1 * r1) axis -> a -> (l0 * r0) t -> (l0 * r0) t = + (a, l0 * r0, l1 * r1) Axis.t -> a -> (l0 * r0) t -> (l0 * r0) t = fun ax c m -> match ax with | Monadic ax -> join_with_monadic ax c m @@ -2787,7 +2793,7 @@ module Value_with (Areality : Areality) = struct let meet_with : type a l0 r0 l1 r1. - (a, l0 * r0, l1 * r1) axis -> a -> (l0 * r0) t -> (l0 * r0) t = + (a, l0 * r0, l1 * r1) Axis.t -> a -> (l0 * r0) t -> (l0 * r0) t = fun ax c m -> match ax with | Monadic ax -> meet_with_monadic ax c m @@ -2922,7 +2928,7 @@ module Const = struct } module Axis = struct - let alloc_as_value : Alloc.axis_packed -> Value.axis_packed = function + let alloc_as_value : Alloc.Axis.packed -> Value.Axis.packed = function | P (Comonadic Areality) -> P (Comonadic Areality) | P (Comonadic Linearity) -> P (Comonadic Linearity) | P (Comonadic Portability) -> P (Comonadic Portability) @@ -2973,7 +2979,7 @@ module Modality = struct | Meet_with : 'a -> 'a raw | Join_with : 'a -> 'a raw - type t = Atom : ('a, _, _) Value.axis * 'a raw -> t + type t = Atom : ('a, _, _) Value.Axis.t * 'a raw -> t let is_id (Atom (ax, a)) = match a with @@ -2994,7 +3000,7 @@ module Modality = struct module Monadic = struct module Mode = Value.Monadic - type 'a axis = (Mode.Const.t, 'a) Axis.t + type 'a axis = 'a Mode.axis type error = Error : 'a axis * 'a raw Solver.error -> error @@ -3149,7 +3155,7 @@ module Modality = struct module Comonadic = struct module Mode = Value.Comonadic - type 'a axis = (Mode.Const.t, 'a) Axis.t + type 'a axis = 'a Mode.axis type error = Error : 'a axis * 'a raw Solver.error -> error @@ -3307,7 +3313,8 @@ module Modality = struct end module Value = struct - type error = Error : ('a, _, _) Value.axis * 'a raw Solver.error -> error + type error = + | Error : ('a, _, _) Value.Axis.t * 'a raw Solver.error -> error type equate_error = equate_step * error @@ -3358,7 +3365,7 @@ module Modality = struct let to_list { monadic; comonadic } = Comonadic.to_list comonadic @ Monadic.to_list monadic - let proj (type a d0 d1) (ax : (a, d0, d1) Value.axis) + let proj (type a d0 d1) (ax : (a, d0, d1) Value.Axis.t) { monadic; comonadic } = match ax with | Monadic ax -> Monadic.proj ax monadic diff --git a/typing/mode_intf.mli b/typing/mode_intf.mli index a425fb3e4e0..32de65b9cd7 100644 --- a/typing/mode_intf.mli +++ b/typing/mode_intf.mli @@ -414,21 +414,23 @@ module type S = sig val meet_const : Const.t -> ('l * 'r) t -> ('l * 'r) t end - (** Represents a mode axis in this product whose constant is ['a], and whose - allowance is ['d1] given the product's allowance ['d0]. *) - type ('a, 'd0, 'd1) axis = - | Monadic : (Monadic.Const.t, 'a) Axis.t -> ('a, 'd, 'd neg) axis - | Comonadic : (Comonadic.Const.t, 'a) Axis.t -> ('a, 'd, 'd pos) axis + module Axis : sig + (** Represents a mode axis in this product whose constant is ['a], and whose + allowance is ['d1] given the product's allowance ['d0]. *) + type ('a, 'd0, 'd1) t = + | Monadic : (Monadic.Const.t, 'a) Axis.t -> ('a, 'd, 'd neg) t + | Comonadic : (Comonadic.Const.t, 'a) Axis.t -> ('a, 'd, 'd pos) t + + type packed = P : (_, _, _) t -> packed - type axis_packed = P : (_, _, _) axis -> axis_packed + val print : Format.formatter -> ('a, _, _) t -> unit - val print_axis : Format.formatter -> ('a, _, _) axis -> unit + val all : packed list + end (** Gets the normal lattice for comonadic axes and the "op"ped lattice for monadic ones. *) - val lattice_of_axis : ('a, _, _) axis -> (module Lattice with type t = 'a) - - val all_axes : axis_packed list + val lattice_of_axis : ('a, _, _) Axis.t -> (module Lattice with type t = 'a) type ('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h) modes = { areality : 'a; @@ -476,9 +478,9 @@ module type S = sig val print : Format.formatter -> t -> unit end - val is_max : ('a, _, _) axis -> 'a -> bool + val is_max : ('a, _, _) Axis.t -> 'a -> bool - val is_min : ('a, _, _) axis -> 'a -> bool + val is_min : ('a, _, _) Axis.t -> 'a -> bool val split : t -> (Monadic.Const.t, Comonadic.Const.t) monadic_comonadic @@ -495,10 +497,10 @@ module type S = sig val partial_apply : t -> t (** Prints a constant on any axis. *) - val print_axis : ('a, _, _) axis -> Format.formatter -> 'a -> unit + val print_axis : ('a, _, _) Axis.t -> Format.formatter -> 'a -> unit end - type error = Error : ('a, _, _) axis * 'a Solver.error -> error + type error = Error : ('a, _, _) Axis.t * 'a Solver.error -> error type 'd t = ('d Monadic.t, 'd Comonadic.t) monadic_comonadic @@ -514,23 +516,23 @@ module type S = sig end val proj : - ('a, 'l0 * 'r0, 'l1 * 'r1) axis -> ('l0 * 'r0) t -> ('a, 'l1 * 'r1) mode + ('a, 'l0 * 'r0, 'l1 * 'r1) Axis.t -> ('l0 * 'r0) t -> ('a, 'l1 * 'r1) mode val max_with : - ('a, 'l0 * 'r0, 'l1 * 'r1) axis -> + ('a, 'l0 * 'r0, 'l1 * 'r1) Axis.t -> ('a, 'l1 * 'r1) mode -> (disallowed * 'r0) t val min_with : - ('a, 'l0 * 'r0, 'l1 * 'r1) axis -> + ('a, 'l0 * 'r0, 'l1 * 'r1) Axis.t -> ('a, 'l1 * 'r1) mode -> ('l0 * disallowed) t val meet_with : - ('a, 'l0 * 'r0, 'l1 * 'r1) axis -> 'a -> ('l0 * 'r0) t -> ('l0 * 'r0) t + ('a, 'l0 * 'r0, 'l1 * 'r1) Axis.t -> 'a -> ('l0 * 'r0) t -> ('l0 * 'r0) t val join_with : - ('a, 'l0 * 'r0, 'l1 * 'r1) axis -> 'a -> ('l0 * 'r0) t -> ('l0 * 'r0) t + ('a, 'l0 * 'r0, 'l1 * 'r1) Axis.t -> 'a -> ('l0 * 'r0) t -> ('l0 * 'r0) t val zap_to_legacy : lr -> Const.t @@ -572,7 +574,7 @@ module type S = sig val alloc_as_value : Alloc.Const.t -> Value.Const.t module Axis : sig - val alloc_as_value : Alloc.axis_packed -> Value.axis_packed + val alloc_as_value : Alloc.Axis.packed -> Value.Axis.packed end val locality_as_regionality : Locality.Const.t -> Regionality.Const.t @@ -609,7 +611,7 @@ module type S = sig in which case it's the identity modality. *) (** An atom modality is a [raw] accompanied by the axis it acts on. *) - type t = Atom : ('a, _, _) Value.axis * 'a raw -> t + type t = Atom : ('a, _, _) Value.Axis.t * 'a raw -> t (** Test if the given modality is the identity modality. *) val is_id : t -> bool @@ -624,7 +626,7 @@ module type S = sig type atom := t type error = - | Error : ('a, _, _) Value.axis * 'a raw Solver.error -> error + | Error : ('a, _, _) Value.Axis.t * 'a raw Solver.error -> error type nonrec equate_error = equate_step * error @@ -674,7 +676,7 @@ module type S = sig val of_list : atom list -> t (** Project out the [atom] for the given axis in the given modality. *) - val proj : ('a, _, _) Value.axis -> t -> atom + val proj : ('a, _, _) Value.Axis.t -> t -> atom (** [equate t0 t1] checks that [t0 = t1]. Definition: [t0 = t1] iff [t0 <= t1] and [t1 <= t0]. *) diff --git a/typing/typemode.ml b/typing/typemode.ml index 4010e2ef95a..7f77aa58258 100644 --- a/typing/typemode.ml +++ b/typing/typemode.ml @@ -25,7 +25,7 @@ exception Error of Location.t * error module Axis_pair = struct type 'm t = - | Modal_axis_pair : ('a, _, _) Mode.Alloc.axis * 'a -> modal t + | Modal_axis_pair : ('a, _, _) Mode.Alloc.Axis.t * 'a -> modal t | Any_axis_pair : 'a Axis.t * 'a -> maybe_nonmodal t | Everything_but_nullability : maybe_nonmodal t @@ -265,7 +265,7 @@ let transl_mode_annots annots : Alloc.Const.Option.t = let step modifiers_so_far annot = let { txt = Modal_axis_pair (type a d0 d1) - ((axis, mode) : (a, d0, d1) Mode.Alloc.axis * a); + ((axis, mode) : (a, d0, d1) Mode.Alloc.Axis.t * a); loc } = transl_annot ~annot_type:Mode ~required_mode_maturity:(Some Stable) From 9f251f0462a96fa3d05f0c95ed760722cd51f94d Mon Sep 17 00:00:00 2001 From: Zesen Qian Date: Tue, 29 Apr 2025 14:58:51 +0100 Subject: [PATCH 3/3] more clean up --- typing/includecore.ml | 4 +- typing/jkind_axis.ml | 2 +- typing/mode.ml | 106 ++++++++++++++-------------- typing/mode_intf.mli | 157 ++++++++++++++++++----------------------- typing/printtyp.ml | 3 +- typing/solver_intf.mli | 3 +- typing/untypeast.ml | 6 +- 7 files changed, 131 insertions(+), 150 deletions(-) diff --git a/typing/includecore.ml b/typing/includecore.ml index 31a9a417034..3901a6123aa 100644 --- a/typing/includecore.ml +++ b/typing/includecore.ml @@ -692,9 +692,9 @@ module Record_diffing = struct | Immutable, Mutable _ -> Some Second | Mutable m1, Mutable m2 -> let open Mode.Alloc.Comonadic.Const in - (if not (eq m1 legacy) then + (if not (Misc.Le_result.equal ~le m1 legacy) then Misc.fatal_errorf "Unexpected mutable(%a)" print m1); - (if not (eq m2 legacy) then + (if not (Misc.Le_result.equal ~le m2 legacy) then Misc.fatal_errorf "Unexpected mutable(%a)" print m2); None in diff --git a/typing/jkind_axis.ml b/typing/jkind_axis.ml index 87225df4099..cb6d30ac857 100644 --- a/typing/jkind_axis.ml +++ b/typing/jkind_axis.ml @@ -142,7 +142,7 @@ module Axis = struct let get (type a) : a t -> (module Axis_ops with type t = a) = function | Modal axis -> - (module Accent_lattice ((val Mode.Alloc.lattice_of_axis axis))) + (module Accent_lattice ((val Mode.Alloc.Const.lattice_of_axis axis))) | Nonmodal Externality -> (module Externality) | Nonmodal Nullability -> (module Nullability) diff --git a/typing/mode.ml b/typing/mode.ml index 676f78c15ea..a380b81e4e5 100644 --- a/typing/mode.ml +++ b/typing/mode.ml @@ -1753,6 +1753,14 @@ module Comonadic_gen (Obj : Obj) = struct let of_const : type l r. const -> (l * r) t = fun a -> Solver.of_const obj a + let meet_const c m = Solver.apply obj (Meet_with c) m + + let join_const c m = Solver.apply obj (Join_with c) m + + let imply c m = Solver.apply obj (Imply c) (Solver.disallow_left m) + + let subtract c m = Solver.apply obj (Subtract c) (Solver.disallow_right m) + module Guts = struct let get_floor m = Solver.get_floor obj m @@ -1827,6 +1835,14 @@ module Monadic_gen (Obj : Obj) = struct let of_const : type l r. const -> (l * r) t = fun a -> Solver.of_const obj a + let meet_const c m = Solver.apply Obj.obj (Join_with c) m + + let join_const c m = Solver.apply Obj.obj (Meet_with c) m + + let imply c m = Solver.apply obj (Subtract c) (Solver.disallow_right m) + + let subtract c m = Solver.apply obj (Imply c) (Solver.disallow_left m) + module Guts = struct let _get_floor m = Solver.get_ceil obj m @@ -2057,6 +2073,20 @@ module type Areality = sig val zap_to_legacy : (Const.t, allowed * 'r) Solver.mode -> Const.t end +module BiHeyting_Product (L : BiHeyting) = struct + include L + + type 'a axis = (t, 'a) Axis.t + + let min_with ax c = Axis.update ax c min + + let max_with ax c = Axis.update ax c max + + let min_axis ax = Axis.proj ax min + + let max_axis ax = Axis.proj ax max +end + module Comonadic_with (Areality : Areality) = struct module Obj = struct type const = Areality.Const.t C.comonadic_with @@ -2075,28 +2105,16 @@ module Comonadic_with (Areality : Areality) = struct let proj_obj ax = C.proj_obj ax Obj.obj module Const = struct - include C.Comonadic_with (Areality.Const) + include BiHeyting_Product (C.Comonadic_with (Areality.Const)) - let eq a b = le a b && le b a + let print_axis ax ppf a = + let obj = proj_obj ax in + C.print obj ppf a let le_axis ax a b = let obj = proj_obj ax in C.le obj a b - let min_axis ax = - let obj = proj_obj ax in - C.min obj - - let max_axis ax = - let obj = proj_obj ax in - C.max obj - - let max_with ax c = Axis.update ax c (C.max Obj.obj) - - let print_axis ax ppf a = - let obj = proj_obj ax in - C.print obj ppf a - let lattice_of_axis (type a) (axis : (t, a) Axis.t) : (module Lattice with type t = a) = match axis with @@ -2109,19 +2127,15 @@ module Comonadic_with (Areality : Areality) = struct let proj ax m = Solver.apply (proj_obj ax) (Proj (Obj.obj, ax)) m - let meet_const c m = Solver.apply Obj.obj (Meet_with c) m - - let join_const c m = Solver.apply Obj.obj (Join_with c) m - let min_with ax m = Solver.apply Obj.obj (Min_with ax) (Solver.disallow_right m) let max_with ax m = Solver.apply Obj.obj (Max_with ax) (Solver.disallow_left m) - let join_with ax c m = join_const (C.min_with Obj.obj ax c) m + let join_with ax c m = join_const (Const.min_with ax c) m - let meet_with ax c m = meet_const (C.max_with Obj.obj ax c) m + let meet_with ax c m = meet_const (Const.max_with ax c) m let zap_to_legacy m : Const.t = let areality = proj Areality m |> Areality.zap_to_legacy in @@ -2134,10 +2148,6 @@ module Comonadic_with (Areality : Areality) = struct let yielding = proj Yielding m |> Yielding.zap_to_legacy ~global in { areality; linearity; portability; yielding; statefulness } - let imply c m = Solver.apply Obj.obj (Imply c) (Solver.disallow_left m) - - let subtract c m = Solver.apply Obj.obj (Subtract c) (Solver.disallow_right m) - let legacy = of_const Const.legacy let axis_of_error (err : Obj.const Solver.error) : error = @@ -2219,20 +2229,11 @@ module Monadic = struct let proj_obj ax = C.proj_obj ax Obj.obj module Const = struct - include C.Monadic + include BiHeyting_Product (C.Monadic) - (* CR zqian: The flipping logic leaking to here is bad. Refactoring needed. *) - - (* Monadic fragment is flipped, so are the following definitions. *) - let min_with ax c = Axis.update ax c (C.max Obj.obj) - - let min_axis ax = - let obj = proj_obj ax in - C.max obj - - let max_axis ax = + let print_axis ax ppf a = let obj = proj_obj ax in - C.min obj + C.print obj ppf a let le_axis ax a b = let obj = proj_obj ax in @@ -2252,23 +2253,15 @@ module Monadic = struct (* The monadic fragment is inverted. *) - let meet_const c m = Solver.apply Obj.obj (Join_with c) m - - let join_const c m = Solver.apply Obj.obj (Meet_with c) m - let max_with ax m = Solver.apply Obj.obj (Min_with ax) (Solver.disallow_right m) let min_with ax m = Solver.apply Obj.obj (Max_with ax) (Solver.disallow_left m) - let join_with ax c m = join_const (C.max_with Obj.obj ax c) m - - let meet_with ax c m = meet_const (C.min_with Obj.obj ax c) m - - let imply c m = Solver.apply Obj.obj (Subtract c) (Solver.disallow_right m) + let join_with ax c m = join_const (Const.min_with ax c) m - let subtract c m = Solver.apply Obj.obj (Imply c) (Solver.disallow_left m) + let meet_with ax c m = meet_const (Const.max_with ax c) m let zap_to_legacy m : Const.t = let uniqueness = proj Uniqueness m |> Uniqueness.zap_to_legacy in @@ -2364,12 +2357,6 @@ module Value_with (Areality : Areality) = struct P (Monadic Visibility) ] end - let lattice_of_axis (type a d0 d1) (axis : (a, d0, d1) Axis.t) : - (module Lattice with type t = a) = - match axis with - | Comonadic ax -> Comonadic.Const.lattice_of_axis ax - | Monadic ax -> Monadic.Const.lattice_of_axis ax - let proj_obj : type a d0 d1. (a, d0, d1) Axis.t -> a C.obj = function | Monadic ax -> Monadic.proj_obj ax | Comonadic ax -> Comonadic.proj_obj ax @@ -2481,6 +2468,12 @@ module Value_with (Areality : Areality) = struct let comonadic = Comonadic.join m0.comonadic m1.comonadic in merge { monadic; comonadic } + let lattice_of_axis (type a d0 d1) (axis : (a, d0, d1) Axis.t) : + (module Lattice with type t = a) = + match axis with + | Comonadic ax -> Comonadic.lattice_of_axis ax + | Monadic ax -> Monadic.lattice_of_axis ax + module Option = struct type some = t @@ -2856,6 +2849,11 @@ module Value_with (Areality : Areality) = struct let comonadic = Comonadic.zap_to_ceil comonadic in merge { monadic; comonadic } + let zap_to_floor { comonadic; monadic } = + let monadic = Monadic.zap_to_floor monadic in + let comonadic = Comonadic.zap_to_floor comonadic in + merge { monadic; comonadic } + let zap_to_legacy { comonadic; monadic } = let monadic = Monadic.zap_to_legacy monadic in let comonadic = Comonadic.zap_to_legacy comonadic in diff --git a/typing/mode_intf.mli b/typing/mode_intf.mli index 32de65b9cd7..5913263664f 100644 --- a/typing/mode_intf.mli +++ b/typing/mode_intf.mli @@ -39,6 +39,24 @@ module type Lattice = sig val print : Format.formatter -> t -> unit end +module type Lattice_product = sig + include Lattice + + type 'a axis + + val min_axis : 'a axis -> 'a + + val max_axis : 'a axis -> 'a + + val min_with : 'a axis -> 'a -> t + + val max_with : 'a axis -> 'a -> t + + val le_axis : 'a axis -> 'a -> 'a -> bool + + val print_axis : 'a axis -> Format.formatter -> 'a -> unit +end + type equate_step = | Left_le_right | Right_le_left @@ -93,6 +111,35 @@ module type Common = sig val print : ?verbose:bool -> unit -> Format.formatter -> ('l * 'r) t -> unit val of_const : Const.t -> ('l * 'r) t + + val imply : Const.t -> ('l * 'r) t -> (disallowed * 'r) t + + val subtract : Const.t -> ('l * 'r) t -> ('l * disallowed) t + + val join_const : Const.t -> ('l * 'r) t -> ('l * 'r) t + + val meet_const : Const.t -> ('l * 'r) t -> ('l * 'r) t + + val zap_to_ceil : ('l * allowed) t -> Const.t + + val zap_to_floor : (allowed * 'r) t -> Const.t +end + +module type Common_axis = sig + module Const : Lattice + + include + Common with module Const := Const and type error = Const.t Solver.error +end + +module type Common_product = sig + type 'a axis + + module Const : Lattice_product with type 'a axis = 'a axis + + type error = Error : 'a axis * 'a Solver.error -> error + + include Common with type error := error and module Const := Const end module type S = sig @@ -124,22 +171,15 @@ module type S = sig include Lattice with type t := t end - type error = Const.t Solver.error - include - Common + Common_axis with module Const := Const - and type error := error and type 'd t = (Const.t, 'd pos) mode val global : lr val local : lr - val zap_to_floor : (allowed * 'r) t -> Const.t - - val zap_to_ceil : ('l * allowed) t -> Const.t - module Guts : sig (** This module exposes some functions that allow callers to inspect modes directly, which could be useful for error printing and dev tools (such as @@ -167,12 +207,9 @@ module type S = sig include Lattice with type t := t end - type error = Const.t Solver.error - include - Common + Common_axis with module Const := Const - and type error := error and type 'd t = (Const.t, 'd pos) mode val global : lr @@ -191,12 +228,9 @@ module type S = sig include Lattice with type t := t end - type error = Const.t Solver.error - include - Common + Common_axis with module Const := Const - and type error := error and type 'd t = (Const.t, 'd pos) mode val many : lr @@ -213,12 +247,9 @@ module type S = sig include Lattice with type t := t end - type error = Const.t Solver.error - include - Common + Common_axis with module Const := Const - and type error := error and type 'd t = (Const.t, 'd pos) mode end @@ -233,19 +264,14 @@ module type S = sig module Const_op : Lattice with type t = Const.t - type error = Const.t Solver.error - include - Common + Common_axis with module Const := Const - and type error := error and type 'd t = (Const.t, 'd neg) mode val aliased : lr val unique : lr - - val zap_to_ceil : ('l * allowed) t -> Const.t end module Contention : sig @@ -260,12 +286,9 @@ module type S = sig module Const_op : Lattice with type t = Const.t - type error = Const.t Solver.error - include - Common + Common_axis with module Const := Const - and type error := error and type 'd t = (Const.t, 'd neg) mode end @@ -278,12 +301,9 @@ module type S = sig include Lattice with type t := t end - type error = Const.t Solver.error - include - Common + Common_axis with module Const := Const - and type error := error and type 'd t = (Const.t, 'd pos) mode val yielding : lr @@ -301,12 +321,9 @@ module type S = sig include Lattice with type t := t end - type error = Const.t Solver.error - include - Common + Common_axis with module Const := Const - and type error := error and type 'd t = (Const.t, 'd pos) mode val stateless : lr @@ -328,12 +345,9 @@ module type S = sig module Const_op : Lattice with type t = Const.t - type error = Const.t Solver.error - include - Common + Common_axis with module Const := Const - and type error := error and type 'd t = (Const.t, 'd neg) mode val immutable : lr @@ -379,40 +393,18 @@ module type S = sig module Areality : Common module Monadic : sig - module Const : sig - include Lattice with type t = monadic - - val max_axis : (t, 'a) Axis.t -> 'a - - val min_axis : (t, 'a) Axis.t -> 'a - end - - module Const_op : Lattice with type t = monadic - - include Common with module Const := Const + include + Common_product + with type Const.t = monadic + and type 'a axis := (monadic, 'a) Axis.t - val join_const : Const.t -> ('l * 'r) t -> ('l * 'r) t + module Const_op : Lattice with type t = Const.t end - module Comonadic : sig - module Const : sig - include Lattice with type t = Areality.Const.t comonadic_with - - val eq : t -> t -> bool - - val print_axis : (t, 'a) Axis.t -> Format.formatter -> 'a -> unit - - val max_axis : (t, 'a) Axis.t -> 'a - - val min_axis : (t, 'a) Axis.t -> 'a - end - - type error = Error : (Const.t, 'a) Axis.t * 'a Solver.error -> error - - include Common with type error := error and module Const := Const - - val meet_const : Const.t -> ('l * 'r) t -> ('l * 'r) t - end + module Comonadic : + Common_product + with type Const.t = Areality.Const.t comonadic_with + and type 'a axis := (Areality.Const.t comonadic_with, 'a) Axis.t module Axis : sig (** Represents a mode axis in this product whose constant is ['a], and whose @@ -428,10 +420,6 @@ module type S = sig val all : packed list end - (** Gets the normal lattice for comonadic axes and the "op"ped lattice for - monadic ones. *) - val lattice_of_axis : ('a, _, _) Axis.t -> (module Lattice with type t = 'a) - type ('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h) modes = { areality : 'a; linearity : 'b; @@ -457,6 +445,11 @@ module type S = sig Visibility.Const.t ) modes + (** Gets the normal lattice for comonadic axes and the "op"ped lattice for + monadic ones. *) + val lattice_of_axis : + ('a, _, _) Axis.t -> (module Lattice with type t = 'a) + module Option : sig type some = t @@ -536,18 +529,8 @@ module type S = sig val zap_to_legacy : lr -> Const.t - val zap_to_ceil : ('l * allowed) t -> Const.t - val comonadic_to_monadic : ('l * 'r) Comonadic.t -> ('r * 'l) Monadic.t - val meet_const : Const.t -> ('l * 'r) t -> ('l * 'r) t - - val join_const : Const.t -> ('l * 'r) t -> ('l * 'r) t - - val imply : Const.t -> ('l * 'r) t -> (disallowed * 'r) t - - val subtract : Const.t -> ('l * 'r) t -> ('l * disallowed) t - (* The following two are about the scenario where we partially apply a function [A -> B -> C] to [A] and get back [B -> C]. The mode of the three are constrained. *) diff --git a/typing/printtyp.ml b/typing/printtyp.ml index 2e2bb6de636..8394addb4cb 100644 --- a/typing/printtyp.ml +++ b/typing/printtyp.ml @@ -1833,7 +1833,8 @@ let tree_of_label l = match l.ld_mutable with | Mutable m -> let mut = - if Alloc.Comonadic.Const.eq m Alloc.Comonadic.Const.legacy then + let open Alloc.Comonadic.Const in + if Misc.Le_result.equal ~le m Alloc.Comonadic.Const.legacy then Om_mutable None else Om_mutable (Some "") diff --git a/typing/solver_intf.mli b/typing/solver_intf.mli index 3ccc0d56014..d8c970222cf 100644 --- a/typing/solver_intf.mli +++ b/typing/solver_intf.mli @@ -292,8 +292,7 @@ module type S = sig module Magic_equal (X : Equal) : Equal with type ('a, 'b, 'c) t = ('a, 'b, 'c) X.t - (** Solver that supports polarized lattices; needed because some morphisms - are antitone *) + (** Solver that supports lattices with monotone morphisms between them. *) module Solver_mono (C : Lattices_mono) : Solver_mono with type ('a, 'b, 'd) morph := ('a, 'b, 'd) C.morph diff --git a/typing/untypeast.ml b/typing/untypeast.ml index cf5c974118d..b5ccf59f5a5 100644 --- a/typing/untypeast.ml +++ b/typing/untypeast.ml @@ -274,11 +274,11 @@ let mutable_ (mut : Types.mutability) : mutable_flag = match mut with | Immutable -> Immutable | Mutable m -> - if Mode.Alloc.Comonadic.Const.eq m Mode.Alloc.Comonadic.Const.legacy then + let open Mode.Alloc.Comonadic.Const in + if Misc.Le_result.equal ~le:le m legacy then Mutable else - Misc.fatal_errorf "unexpected mutable(%a)" - Mode.Alloc.Comonadic.Const.print m + Misc.fatal_errorf "unexpected mutable(%a)" print m let label_declaration sub ld = let loc = sub.location sub ld.ld_loc in