Skip to content

flambda2-types: New n-way join algorithm #3538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 11, 2025
4 changes: 2 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,14 @@ jobs:
config: --enable-middle-end=flambda2 --enable-frame-pointers --enable-runtime5 --enable-poll-insertion --enable-flambda-invariants
os: ubuntu-latest
build_ocamlparam: ''
ocamlparam: '_,O3=1,flambda2-expert-cont-lifting-budget=200'
ocamlparam: '_,O3=1,flambda2-expert-cont-lifting-budget=200,flambda2-join-algorithm=n-way'

- name: flambda2_o3_advanced_meet_frame_pointers_runtime5_debug
config: --enable-middle-end=flambda2 --enable-frame-pointers --enable-runtime5
os: ubuntu-latest
build_ocamlparam: ''
use_runtime: d
ocamlparam: '_,O3=1,flambda2-expert-cont-lifting-budget=200,cfg-invariants=1,cfg-eliminate-dead-trap-handlers=1'
ocamlparam: '_,O3=1,flambda2-expert-cont-lifting-budget=200,cfg-invariants=1,cfg-eliminate-dead-trap-handlers=1,flambda2-join-algorithm=n-way'

- name: flambda2_frame_pointers_oclassic_polling
config: --enable-middle-end=flambda2 --enable-frame-pointers --enable-poll-insertion --enable-flambda-invariants
Expand Down
27 changes: 27 additions & 0 deletions driver/flambda_backend_args.ml
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,15 @@ let mk_flambda2_advanced_meet f =
Printf.sprintf " Use an advanced meet algorithm (deprecated) (Flambda 2 only)"
;;

let mk_flambda2_join_algorithm f =
"-flambda2-join-algorithm", Arg.Symbol (["binary"; "n-way"; "checked"], f),
Printf.sprintf " Select the join algorithm to use (Flambda 2 only)\n \
\ Valid values are: \n\
\ \"binary\" is the legacy binary join;\n\
\ \"n-way\" is the new n-way join;\n\
\ \"checked\" runs both algorithms and compares them (use for \
debugging)."
;;

let mk_flambda2_join_points f =
"-flambda2-join-points", Arg.Unit f,
Expand Down Expand Up @@ -777,6 +786,7 @@ module type Flambda_backend_options = sig
val no_flambda2_result_types : unit -> unit
val flambda2_basic_meet : unit -> unit
val flambda2_advanced_meet : unit -> unit
val flambda2_join_algorithm : string -> unit
val flambda2_unbox_along_intra_function_control_flow : unit -> unit
val no_flambda2_unbox_along_intra_function_control_flow : unit -> unit
val flambda2_backend_cse_at_toplevel : unit -> unit
Expand Down Expand Up @@ -916,6 +926,7 @@ struct
F.no_flambda2_result_types;
mk_flambda2_basic_meet F.flambda2_basic_meet;
mk_flambda2_advanced_meet F.flambda2_advanced_meet;
mk_flambda2_join_algorithm F.flambda2_join_algorithm;
mk_flambda2_unbox_along_intra_function_control_flow
F.flambda2_unbox_along_intra_function_control_flow;
mk_no_flambda2_unbox_along_intra_function_control_flow
Expand Down Expand Up @@ -1126,6 +1137,15 @@ module Flambda_backend_options_impl = struct
Flambda2.function_result_types := Flambda_backend_flags.Set Flambda_backend_flags.Never
let flambda2_basic_meet () = ()
let flambda2_advanced_meet () = ()
let flambda2_join_algorithm algorithm =
match algorithm with
| "binary" ->
Flambda2.join_algorithm := Flambda_backend_flags.Set Flambda_backend_flags.Binary
| "n-way" ->
Flambda2.join_algorithm := Flambda_backend_flags.Set Flambda_backend_flags.N_way
| "checked" ->
Flambda2.join_algorithm := Flambda_backend_flags.Set Flambda_backend_flags.Checked
| _ -> () (* This should not occur as we use Arg.Symbol *)
let flambda2_unbox_along_intra_function_control_flow =
set Flambda2.unbox_along_intra_function_control_flow
let no_flambda2_unbox_along_intra_function_control_flow =
Expand Down Expand Up @@ -1456,6 +1476,13 @@ module Extra_params = struct
| _ ->
Misc.fatal_error "Syntax: flambda2-meet_algorithm=basic|advanced");
true
| "flambda2-join-algorithm" ->
(match String.lowercase_ascii v with
| "binary" | "n-way" | "checked" as v ->
Flambda_backend_options_impl.flambda2_join_algorithm v
| _ ->
Misc.fatal_error "Syntax: flambda2-join-algorithm=binary|n-way|checked");
true
| "flambda2-unbox-along-intra-function-control-flow" ->
set Flambda2.unbox_along_intra_function_control_flow
| "flambda2-backend-cse-at-toplevel" ->
Expand Down
1 change: 1 addition & 0 deletions driver/flambda_backend_args.mli
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ module type Flambda_backend_options = sig
val no_flambda2_result_types : unit -> unit
val flambda2_basic_meet : unit -> unit
val flambda2_advanced_meet : unit -> unit
val flambda2_join_algorithm : string -> unit
val flambda2_unbox_along_intra_function_control_flow : unit -> unit
val no_flambda2_unbox_along_intra_function_control_flow : unit -> unit
val flambda2_backend_cse_at_toplevel : unit -> unit
Expand Down
5 changes: 5 additions & 0 deletions driver/flambda_backend_flags.ml
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ let long_frames_threshold = ref max_long_frames_threshold (* -debug-long-frames-
let caml_apply_inline_fast_path = ref false (* -caml-apply-inline-fast-path *)

type function_result_types = Never | Functors_only | All_functions
type join_algorithm = Binary | N_way | Checked
type opt_level = Oclassic | O2 | O3
type 'a or_default = Set of 'a | Default

Expand Down Expand Up @@ -128,6 +129,7 @@ module Flambda2 = struct
let backend_cse_at_toplevel = false
let cse_depth = 2
let join_depth = 5
let join_algorithm = Binary
let function_result_types = Never
let enable_reaper = false
let unicode = true
Expand All @@ -141,6 +143,7 @@ module Flambda2 = struct
backend_cse_at_toplevel : bool;
cse_depth : int;
join_depth : int;
join_algorithm : join_algorithm;
function_result_types : function_result_types;
enable_reaper : bool;
unicode : bool;
Expand All @@ -154,6 +157,7 @@ module Flambda2 = struct
backend_cse_at_toplevel = Default.backend_cse_at_toplevel;
cse_depth = Default.cse_depth;
join_depth = Default.join_depth;
join_algorithm = Default.join_algorithm;
function_result_types = Default.function_result_types;
enable_reaper = Default.enable_reaper;
unicode = Default.unicode;
Expand Down Expand Up @@ -187,6 +191,7 @@ module Flambda2 = struct
let backend_cse_at_toplevel = ref Default
let cse_depth = ref Default
let join_depth = ref Default
let join_algorithm = ref Default
let unicode = ref Default
let kind_checks = ref Default
let function_result_types = ref Default
Expand Down
4 changes: 4 additions & 0 deletions driver/flambda_backend_flags.mli
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ val long_frames_threshold : int ref
val caml_apply_inline_fast_path : bool ref

type function_result_types = Never | Functors_only | All_functions
type join_algorithm = Binary | N_way | Checked
type opt_level = Oclassic | O2 | O3
type 'a or_default = Set of 'a | Default

Expand All @@ -109,6 +110,7 @@ module Flambda2 : sig
val backend_cse_at_toplevel : bool
val cse_depth : int
val join_depth : int
val join_algorithm : join_algorithm
val function_result_types : function_result_types
val enable_reaper : bool
val unicode : bool
Expand All @@ -125,6 +127,7 @@ module Flambda2 : sig
backend_cse_at_toplevel : bool;
cse_depth : int;
join_depth : int;
join_algorithm : join_algorithm;
function_result_types : function_result_types;
enable_reaper : bool;
unicode : bool;
Expand All @@ -141,6 +144,7 @@ module Flambda2 : sig
val backend_cse_at_toplevel : bool or_default ref
val cse_depth : int or_default ref
val join_depth : int or_default ref
val join_algorithm : join_algorithm or_default ref
val enable_reaper : bool or_default ref
val unicode : bool or_default ref
val kind_checks : bool or_default ref
Expand Down
153 changes: 152 additions & 1 deletion middle_end/flambda2/tests/meet_test.ml
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,153 @@ let meet_variants_don't_lose_aliases () =
Format.eprintf "@[<hov 2>meet:@ %a@]@.@[<hov 2>env:@ %a@]@." T.print
tag_meet_ty TE.print tag_meet_env)

let test_join_with_extensions () =
let define ?(kind = K.value) env v =
let v' = Bound_var.create v Name_mode.normal in
TE.add_definition env (Bound_name.create_var v') kind
in
let env = create_env () in
let y = Variable.create "y" in
let x = Variable.create "x" in
let a = Variable.create "a" in
let b = Variable.create "b" in
let env = define env y in
let env = define env x in
let env = define ~kind:K.naked_immediate env a in
let env = define ~kind:K.naked_immediate env b in
let tag_0 = Tag.Scannable.zero in
let tag_1 = Option.get (Tag.Scannable.of_tag (Tag.create_exn 1)) in
let make ty =
T.variant
~const_ctors:(T.bottom K.naked_immediate)
~non_const_ctors:
(Tag.Scannable.Map.of_list
[ tag_0, (K.Block_shape.Scannable Value_only, [ty]);
tag_1, (K.Block_shape.Scannable Value_only, []) ])
Alloc_mode.For_types.heap
in
let env = TE.add_equation env (Name.var y) (make (T.unknown K.value)) in
let scope = TE.current_scope env in
let scoped_env = TE.increment_scope env in
let left_env =
TE.add_equation scoped_env (Name.var x)
(T.tagged_immediate_alias_to ~naked_immediate:a)
in
let right_env =
TE.add_equation scoped_env (Name.var x)
(T.tagged_immediate_alias_to ~naked_immediate:b)
in
let ty_a = make (T.tagged_immediate_alias_to ~naked_immediate:a) in
let ty_b = make (T.tagged_immediate_alias_to ~naked_immediate:b) in
let left_env = TE.add_equation left_env (Name.var y) ty_a in
let right_env =
match T.meet right_env ty_a ty_b with
| Ok (ty, right_env) -> TE.add_equation right_env (Name.var y) ty
| Bottom -> assert false
in
Format.eprintf "Left:@.%a@." TE.print left_env;
Format.eprintf "Right:@.%a@." TE.print right_env;
let joined_env =
T.cut_and_n_way_join scoped_env
[ left_env, Apply_cont_rewrite_id.create (), Inlinable;
right_env, Apply_cont_rewrite_id.create (), Inlinable ]
~params:Bound_parameters.empty ~cut_after:scope
~extra_allowed_names:Name_occurrences.empty
~extra_lifted_consts_in_use_envs:Symbol.Set.empty
in
Format.eprintf "Res:@.%a@." TE.print joined_env

let test_join_with_complex_extensions () =
let define ?(kind = K.value) env v =
let v' = Bound_var.create v Name_mode.normal in
TE.add_definition env (Bound_name.create_var v') kind
in
let env = create_env () in
let y = Variable.create "y" in
let x = Variable.create "x" in
let w = Variable.create "w" in
let z = Variable.create "z" in
let a = Variable.create "a" in
let b = Variable.create "b" in
let c = Variable.create "c" in
let d = Variable.create "d" in
let env = define env z in
let env = define env x in
let env = define env y in
let env = define env w in
let env = define ~kind:K.naked_immediate env a in
let env = define ~kind:K.naked_immediate env b in
let env = define ~kind:K.naked_immediate env c in
let env = define ~kind:K.naked_immediate env d in
let tag_0 = Tag.Scannable.zero in
let tag_1 = Option.get (Tag.Scannable.of_tag (Tag.create_exn 1)) in
let make tys =
T.variant
~const_ctors:(T.bottom K.naked_immediate)
~non_const_ctors:
(Tag.Scannable.Map.of_list
[ tag_0, (K.Block_shape.Scannable Value_only, tys);
tag_1, (K.Block_shape.Scannable Value_only, []) ])
Alloc_mode.For_types.heap
in
let env =
TE.add_equation env (Name.var z)
(make [T.unknown K.value; T.unknown K.value])
in
let scope = TE.current_scope env in
let scoped_env = TE.increment_scope env in
let left_env =
TE.add_equation scoped_env (Name.var x)
(T.tagged_immediate_alias_to ~naked_immediate:a)
in
let left_env =
TE.add_equation left_env (Name.var y)
(T.tagged_immediate_alias_to ~naked_immediate:a)
in
let left_env =
TE.add_equation left_env (Name.var w)
(T.tagged_immediate_alias_to ~naked_immediate:a)
in
let right_env =
TE.add_equation scoped_env (Name.var x)
(T.tagged_immediate_alias_to ~naked_immediate:b)
in
let right_env =
TE.add_equation right_env (Name.var y)
(T.tagged_immediate_alias_to ~naked_immediate:c)
in
let right_env =
TE.add_equation right_env (Name.var w)
(T.tagged_immediate_alias_to ~naked_immediate:d)
in
let ty_a =
make
[ T.tagged_immediate_alias_to ~naked_immediate:b;
T.tagged_immediate_alias_to ~naked_immediate:b ]
in
let ty_b =
make
[ T.tagged_immediate_alias_to ~naked_immediate:c;
T.tagged_immediate_alias_to ~naked_immediate:d ]
in
let left_env = TE.add_equation left_env (Name.var z) ty_a in
let right_env =
match T.meet right_env ty_a ty_b with
| Ok (ty, right_env) -> TE.add_equation right_env (Name.var z) ty
| Bottom -> assert false
in
Format.eprintf "Left:@.%a@." TE.print left_env;
Format.eprintf "Right:@.%a@." TE.print right_env;
let joined_env =
T.cut_and_n_way_join scoped_env
[ left_env, Apply_cont_rewrite_id.create (), Inlinable;
right_env, Apply_cont_rewrite_id.create (), Inlinable ]
~params:Bound_parameters.empty ~cut_after:scope
~extra_allowed_names:Name_occurrences.empty
~extra_lifted_consts_in_use_envs:Symbol.Set.empty
in
Format.eprintf "Res:@.%a@." TE.print joined_env

let test_meet_two_blocks () =
let define env v =
let v' = Bound_var.create v Name_mode.normal in
Expand Down Expand Up @@ -272,4 +419,8 @@ let () =
Format.eprintf "@.MEET ALIAS TO RECOVER @\n@.";
test_meet_recover_alias ();
Format.eprintf "@.MEET BOTTOM AFTER ALIAS@\n@.";
test_meet_bottom_after_alias ()
test_meet_bottom_after_alias ();
Format.eprintf "@.JOIN WITH EXTENSIONS@\n@.";
test_join_with_extensions ();
Format.eprintf "@.JOIN WITH COMPLEX EXTENSIONS@\n@.";
test_join_with_complex_extensions ()
Loading