From 4226b2a7e83efd7cb07ccf620c1e653dff2e2215 Mon Sep 17 00:00:00 2001
From: Luke Maurer <luke.v.maurer@gmail.com>
Date: Wed, 31 Mar 2021 14:29:15 +0100
Subject: [PATCH] Replace `aliases` in `Aliases_of_canonical_element.t` with 3
 fields

`Aliases_of_canonical_element.t` carries around a map from `Name_mode.t`
to `Simple.Set.t` so that it can compute the earliest element in each
name mode. This is less efficient than simply having a field for each
name mode. Furthermore, we don't need the sets if we just track the
earliest for each name mode (along with its binding time).
---
 middle_end/flambda/types/env/aliases.ml | 388 +++++++++++++++---------
 1 file changed, 252 insertions(+), 136 deletions(-)

diff --git a/middle_end/flambda/types/env/aliases.ml b/middle_end/flambda/types/env/aliases.ml
index c0621a330afc..a11c5c75a6a2 100644
--- a/middle_end/flambda/types/env/aliases.ml
+++ b/middle_end/flambda/types/env/aliases.ml
@@ -19,163 +19,288 @@ module Aliases_of_canonical_element : sig
 
   val print : Format.formatter -> t -> unit
 
-  val invariant : t -> unit
+  val invariant
+     : t
+    -> binding_times_and_modes:(Binding_time.With_name_mode.t Simple.Map.t)
+    -> unit
 
   val empty : t
-  val is_empty : t -> bool
 
-  val add : t -> Simple.t -> Name_mode.t -> t
+  val add : t -> Simple.t -> Binding_time.With_name_mode.t -> t
 
-  val find_earliest_candidates
+  val earliest_alias_exn
      : t
-    -> min_name_mode:Name_mode.t
-    -> Simple.Set.t option
+    -> min_name_mode:Name_mode.t option
+    -> Simple.t
 
   val all : t -> Simple.Set.t
 
   val mem : t -> Simple.t -> bool
 
   val union : t -> t -> t
-  val inter : t -> t -> t
 
-  val rename : (Simple.t -> Simple.t) -> t -> t
+  val disjoint : t -> t -> bool
 
-  val merge : t -> t -> t
+  val rename : (Simple.t -> Simple.t) -> t -> t
 
   val move_variables_to_mode_in_types : t -> t
 end = struct
+  module Earliest_alias : sig
+    type t =
+      | Earliest of {
+          (* Calling this "name" because it's factually always a name, and it's
+             about to become a Name.t anyway *)
+          name : Simple.t;
+          binding_time : Binding_time.t;
+        }
+      | No_alias
+
+    val exists : t -> bool
+    val update : t -> Simple.t -> Binding_time.t -> t
+    val union : t -> t -> t
+    val map_name : t -> f:(Simple.t -> Simple.t) -> t
+    val print : Format.formatter -> t -> unit
+  end = struct
+    type t =
+      | Earliest of {
+          name : Simple.t;
+          binding_time : Binding_time.t;
+        }
+      | No_alias
+
+    let exists = function
+      | Earliest _ -> true
+      | No_alias -> false
+
+    let update t new_name binding_time =
+      match t with
+      | No_alias ->
+        Earliest { name = new_name; binding_time }
+      | Earliest { binding_time = old_binding_time; _ } ->
+        if Binding_time.strictly_earlier binding_time ~than:old_binding_time
+        then Earliest { name = new_name; binding_time }
+        else t
+
+    let union t1 t2 =
+      match t2 with
+      | No_alias -> t1
+      | Earliest { name; binding_time } ->
+        update t1 name binding_time
+
+    let map_name t ~f =
+      match t with
+      | Earliest e ->
+        let name = f e.name in
+        Earliest { e with name }
+      | No_alias -> No_alias
+
+    let print ppf = function
+      | Earliest { name; binding_time } ->
+        Format.fprintf ppf
+          "@[<hov 1>(%a@ \
+           @[<hov 1>(binding_time@ %a)@])@]"
+          Simple.print name
+          Binding_time.print binding_time
+      | No_alias ->
+        Format.pp_print_string ppf "<none>"
+  end
   type t = {
-    aliases : Simple.Set.t Name_mode.Map.t;
     all : Simple.Set.t;
+    earliest : Earliest_alias.t;
+    earliest_normal : Earliest_alias.t;
+    (* Earliest alias whose name mode >= phantom (that is, normal or phantom) *)
+    earliest_ge_phantom : Earliest_alias.t;
+    (* Earliest alias whose name mode >= in-types *)
+    earliest_ge_in_types : Earliest_alias.t;
   }
 
-  let invariant _t = ()
-
-  let print ppf { aliases; all = _; } =
-    Name_mode.Map.print Simple.Set.print ppf aliases
-
   let empty = {
-    aliases = Name_mode.Map.empty;
     all = Simple.Set.empty;
+    earliest = No_alias;
+    earliest_normal = No_alias;
+    earliest_ge_phantom = No_alias;
+    earliest_ge_in_types = No_alias;
   }
 
-  let is_empty t = Simple.Set.is_empty t.all
-
-  let add t elt name_mode =
-    if Simple.Set.mem elt t.all then begin
+  let print ppf
+        { earliest; earliest_normal; earliest_ge_phantom; earliest_ge_in_types;
+          all } =
+    let pp_earliest field_name ppf (earliest : Earliest_alias.t) =
+      Format.fprintf ppf "@[<hov 1>@<0>%s(%s@ %a)@<0>%s@]"
+        (if Earliest_alias.exists earliest 
+         then Flambda_colours.normal ()
+         else Flambda_colours.elide ())
+        field_name
+        Earliest_alias.print earliest
+        (Flambda_colours.normal ())
+    in  
+    Format.fprintf ppf
+      "@[<hov 1>(\
+         %a@ %a@ %a@ %a@ \
+         @[<hov 1>(all@ %a)@])\
+         @]"
+      (pp_earliest "earliest") earliest
+      (pp_earliest "earliest_normal") earliest_normal
+      (pp_earliest "earliest_ge_phantom") earliest_ge_phantom
+      (pp_earliest "earliest_ge_in_types") earliest_ge_in_types
+      Simple.Set.print all
+      
+  let add t new_name binding_time_and_name_mode =
+    if Simple.Set.mem new_name t.all then begin
       Misc.fatal_errorf "%a already added to [Aliases_of_canonical_element]: \
-          %a"
-        Simple.print elt
+                         %a"
+        Simple.print new_name
         print t
     end;
-    let aliases =
-      Name_mode.Map.update name_mode
-        (function
-          | None -> Some (Simple.Set.singleton elt)
-          | Some elts ->
-            if !Clflags.flambda_invariant_checks then begin
-              assert (not (Simple.Set.mem elt elts))
-            end;
-            Some (Simple.Set.add elt elts))
-        t.aliases
+    let binding_time, name_mode =
+      Binding_time.With_name_mode.(
+        binding_time binding_time_and_name_mode,
+        name_mode binding_time_and_name_mode)
     in
-    let all = Simple.Set.add elt t.all in
-    { aliases;
-      all;
-    }
+    let update earliest =
+      Earliest_alias.update earliest new_name binding_time
+    in
+    let update_if_mode_ge mode earliest =
+      match Name_mode.compare_partial_order name_mode mode with
+      | Some c when c >= 0 -> update earliest
+      | _ -> earliest
+    in
+    let earliest = update t.earliest in
+    let earliest_normal =
+      update_if_mode_ge Name_mode.normal t.earliest_normal
+    in
+    let earliest_ge_phantom =
+      update_if_mode_ge Name_mode.phantom t.earliest_ge_phantom
+    in
+    let earliest_ge_in_types =
+      update_if_mode_ge Name_mode.in_types t.earliest_ge_in_types
+    in
+    let all = Simple.Set.add new_name t.all in
+    { earliest; earliest_normal; earliest_ge_phantom; earliest_ge_in_types;
+      all }
+
+  let find_earliest t ~(min_name_mode : Name_mode.t option) =
+    match min_name_mode with
+    | None -> t.earliest
+    | Some min_name_mode ->
+      begin match Name_mode.descr min_name_mode with
+      | Normal -> t.earliest_normal
+      | Phantom -> t.earliest_ge_phantom
+      | In_types -> t.earliest_ge_in_types
+      end
 
-  let find_earliest_candidates t ~min_name_mode =
-    Name_mode.Map.fold (fun order aliases res_opt ->
-        match res_opt with
-        | Some _ -> res_opt
-        | None ->
-          begin match
-            Name_mode.compare_partial_order
-              order min_name_mode
-          with
-          | None -> None
-          | Some result ->
-            if result >= 0 then Some aliases else None
-          end)
-      t.aliases
-      None
+  let invariant t ~binding_times_and_modes =
+    let describe_field name_mode =
+      match name_mode with
+      | None -> "overall"
+      | Some name_mode ->
+        begin match Name_mode.descr name_mode with
+          | Normal -> "normal"
+          | Phantom -> "phantom (or normal)"
+          | In_types -> "in-types (or normal)"
+        end
+    in
+    let check name binding_time name_mode (earliest : Earliest_alias.t) =
+      match earliest with
+      | No_alias -> ()
+      | Earliest e as earliest_as_recorded ->
+        if Binding_time.compare binding_time e.binding_time < 0 then
+          Misc.fatal_errorf
+            "@[<hov 1>Earliest %s alias %a@ has binding time %a,@ \
+             earlier than %a@ in %a\
+             @]"
+           (describe_field name_mode)
+           Simple.print name
+           Binding_time.print binding_time
+           Earliest_alias.print earliest_as_recorded
+           print t
+    in
+    Simple.Set.iter (fun name ->
+      let binding_time_and_mode =
+        Simple.Map.find name binding_times_and_modes
+      in
+      let binding_time, name_mode =
+        Binding_time.With_name_mode.(
+          binding_time binding_time_and_mode,
+          name_mode binding_time_and_mode)
+      in  
+      check name binding_time None t.earliest;
+      let earliest_in_mode = find_earliest t ~min_name_mode:(Some name_mode) in
+      check name binding_time (Some name_mode) earliest_in_mode
+    ) t.all;
+    let check_earliest min_name_mode =
+      match find_earliest t ~min_name_mode with
+      | No_alias -> ()
+      | Earliest e as earliest ->
+        if not (Simple.Set.mem e.name t.all) then begin
+          Misc.fatal_errorf
+            "@[<v>Aliases_of_canonical_element: Earliest %s not in map@ \
+             @[<hov 1>Alias: %a@ Map: %a@]@]"
+            (describe_field min_name_mode)
+            Earliest_alias.print earliest
+            Simple.Set.print t.all
+        end
+    in
+    List.iter check_earliest
+      [ None; 
+        Some Name_mode.normal;
+        Some Name_mode.phantom;
+        Some Name_mode.in_types ]
 
   let mem t elt =
     Simple.Set.mem elt t.all
 
   let all t = t.all
 
+  let earliest_alias_exn t ~min_name_mode =
+    match find_earliest t ~min_name_mode with
+    | Earliest { name; binding_time = _ } -> name
+    | No_alias -> raise Not_found
+
   let union t1 t2 =
-    let aliases =
-      Name_mode.Map.union (fun _order elts1 elts2 ->
-          Some (Simple.Set.union elts1 elts2))
-        t1.aliases t2.aliases
-    in
-    let t =
-      { aliases;
-        all = Simple.Set.union t1.all t2.all;
-      }
+    let all = Simple.Set.union t1.all t2.all in
+    let earliest =
+      Earliest_alias.union t1.earliest t2.earliest
     in
-    invariant t;
-    t
-
-  let inter t1 t2 =
-    let aliases =
-      Name_mode.Map.merge (fun _order elts1 elts2 ->
-          match elts1, elts2 with
-          | None, None | Some _, None | None, Some _ -> None
-          | Some elts1, Some elts2 -> Some (Simple.Set.inter elts1 elts2))
-        t1.aliases t2.aliases
+    let earliest_normal =
+      Earliest_alias.union t1.earliest_normal t2.earliest_normal
     in
-    let t =
-      { aliases;
-        all = Simple.Set.inter t1.all t2.all;
-      }
+    let earliest_ge_phantom =
+      Earliest_alias.union t1.earliest_ge_phantom t2.earliest_ge_phantom
     in
-    invariant t;
-    t
-
-  let rename rename_simple { aliases; all } =
-    let aliases =
-      Name_mode.Map.map (fun elts -> Simple.Set.map rename_simple elts)
-        aliases
+    let earliest_ge_in_types =
+      Earliest_alias.union t1.earliest_ge_in_types t2.earliest_ge_in_types
     in
+    { all; earliest; earliest_normal; earliest_ge_phantom;
+      earliest_ge_in_types; }
+
+  let disjoint t1 t2 =
+    Simple.Set.intersection_is_empty t1.all t2.all
+
+  let update_all_earliest
+        { all; earliest; earliest_normal;
+          earliest_ge_phantom; earliest_ge_in_types; } ~f =
+    let earliest = f earliest in
+    let earliest_normal = f earliest_normal in
+    let earliest_ge_phantom = f earliest_ge_phantom in
+    let earliest_ge_in_types = f earliest_ge_in_types in
+    { all; earliest; earliest_normal; earliest_ge_phantom;
+      earliest_ge_in_types; }
+
+  let rename rename_simple
+        ({ all; earliest = _; earliest_normal = _;
+           earliest_ge_phantom = _; earliest_ge_in_types = _; } as t) =
     let all = Simple.Set.map rename_simple all in
-    { aliases; all }
-
-  let merge t1 t2 =
-    let aliases =
-      Name_mode.Map.union (fun _mode set1 set2 ->
-          Some (Simple.Set.union set1 set2))
-        t1.aliases
-        t2.aliases
-    in
-    let all = Simple.Set.union t1.all t2.all in
-    { aliases; all; }
-
-  let move_variables_to_mode_in_types { aliases; all; } =
-    let (no_vars_aliases, all_variables) =
-      Name_mode.Map.fold (fun mode aliases (no_vars_aliases, all_variables) ->
-          let (vars, non_vars) = Simple.Set.partition Simple.is_var aliases in
-          let no_vars_aliases =
-            if Simple.Set.is_empty non_vars then no_vars_aliases
-            else Name_mode.Map.add mode non_vars no_vars_aliases
-          in
-          no_vars_aliases, Simple.Set.union vars all_variables)
-        aliases
-        (Name_mode.Map.empty, Simple.Set.empty)
-    in
-    let aliases =
-      if Name_mode.Map.mem Name_mode.in_types no_vars_aliases
-      then Misc.fatal_errorf "move_variables_to_mode_in_types: \
-             The following non-vars have mode In_types:@ %a"
-             Simple.Set.print
-             (Name_mode.Map.find Name_mode.in_types no_vars_aliases)
-      else
-        if Simple.Set.is_empty all_variables then no_vars_aliases
-        else Name_mode.Map.add Name_mode.in_types all_variables no_vars_aliases
+    let t =
+      update_all_earliest t ~f:(Earliest_alias.map_name ~f:rename_simple)
     in
-    { aliases; all; }
+    { t with all; }
+
+  let move_variables_to_mode_in_types t =
+    update_all_earliest t ~f:(fun earliest ->
+      match earliest with
+      | Earliest { name; _ } when Simple.is_var name -> No_alias
+      | _ -> earliest)
 end
 
 type t = {
@@ -218,11 +343,16 @@ let name_mode t elt =
   Binding_time.With_name_mode.name_mode
     (Simple.Map.find elt t.binding_times_and_modes)
 
+let binding_time_and_name_mode t elt =
+  Simple.Map.find elt t.binding_times_and_modes
+
+
 let invariant t =
   if !Clflags.flambda_invariant_checks then begin
     let _all_aliases : Simple.Set.t =
       Simple.Map.fold (fun canonical_element aliases all_aliases ->
-          Aliases_of_canonical_element.invariant aliases;
+          Aliases_of_canonical_element.invariant aliases
+            ~binding_times_and_modes:t.binding_times_and_modes;
           let aliases = Aliases_of_canonical_element.all aliases in
           if not (Simple.Set.for_all (fun elt ->
             defined_earlier t canonical_element ~than:elt) aliases)
@@ -298,15 +428,14 @@ let add_alias_between_canonical_elements t ~canonical_element ~to_be_demoted =
     if !Clflags.flambda_invariant_checks then begin
       assert (not (Aliases_of_canonical_element.mem
         aliases_of_canonical_element to_be_demoted));
-      assert (Aliases_of_canonical_element.is_empty (
-        Aliases_of_canonical_element.inter
-          aliases_of_canonical_element aliases_of_to_be_demoted))
+      assert (Aliases_of_canonical_element.disjoint
+        aliases_of_canonical_element aliases_of_to_be_demoted)
     end;
     let aliases =
       Aliases_of_canonical_element.add
         (Aliases_of_canonical_element.union aliases_of_to_be_demoted
           aliases_of_canonical_element)
-        to_be_demoted (name_mode t to_be_demoted)
+        to_be_demoted (binding_time_and_name_mode t to_be_demoted)
     in
     let aliases_of_canonical_elements =
       t.aliases_of_canonical_elements
@@ -481,21 +610,8 @@ Format.eprintf "looking for canonical for %a, candidate canonical %a, min order
 *)
     let find_earliest () =
       let aliases = get_aliases_of_canonical_element t ~canonical_element in
-      match
-        Aliases_of_canonical_element.find_earliest_candidates aliases
-          ~min_name_mode
-      with
-      | Some at_earliest_mode ->
-        (* Aliases_of_canonical_element.find_earliest_candidates only returns
-           non-empty sets *)
-        assert (not (Simple.Set.is_empty at_earliest_mode));
-        Simple.Set.fold (fun elt min_elt ->
-            if defined_earlier t elt ~than:min_elt
-            then elt
-            else min_elt)
-          at_earliest_mode
-          (Simple.Set.min_elt at_earliest_mode)
-      | None -> raise Not_found
+        Aliases_of_canonical_element.earliest_alias_exn aliases
+          ~min_name_mode:(Some min_name_mode)
     in
     match
       Name_mode.compare_partial_order
@@ -577,7 +693,7 @@ let merge t1 t2 =
     (* Warning: here the keys of the map can come from other
        compilation units, so we cannot assume the keys are disjoint *)
     Simple.Map.union (fun _simple aliases1 aliases2 ->
-        Some (Aliases_of_canonical_element.merge aliases1 aliases2))
+        Some (Aliases_of_canonical_element.union aliases1 aliases2))
       t1.aliases_of_canonical_elements
       t2.aliases_of_canonical_elements
   in