Skip to content

Commit bee315d

Browse files
committed
Add flambda2-join-algorithm flag
1 parent 0446c7c commit bee315d

10 files changed

+188
-68
lines changed

driver/flambda_backend_args.ml

+27
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,15 @@ let mk_flambda2_advanced_meet f =
262262
Printf.sprintf " Use an advanced meet algorithm (deprecated) (Flambda 2 only)"
263263
;;
264264

265+
let mk_flambda2_join_algorithm f =
266+
"-flambda2-join-algorithm", Arg.Symbol (["binary"; "n-way"; "checked"], f),
267+
Printf.sprintf " Select the join algorithm to use (Flambda 2 only)\n \
268+
\ Valid values are: \n\
269+
\ \"binary\" is the legacy binary join;\n\
270+
\ \"n-way\" is the new n-way join;\n\
271+
\ \"checked\" runs both algorithms and compares them (use for \
272+
debugging)."
273+
;;
265274

266275
let mk_flambda2_join_points f =
267276
"-flambda2-join-points", Arg.Unit f,
@@ -760,6 +769,7 @@ module type Flambda_backend_options = sig
760769
val no_flambda2_result_types : unit -> unit
761770
val flambda2_basic_meet : unit -> unit
762771
val flambda2_advanced_meet : unit -> unit
772+
val flambda2_join_algorithm : string -> unit
763773
val flambda2_unbox_along_intra_function_control_flow : unit -> unit
764774
val no_flambda2_unbox_along_intra_function_control_flow : unit -> unit
765775
val flambda2_backend_cse_at_toplevel : unit -> unit
@@ -894,6 +904,7 @@ struct
894904
F.no_flambda2_result_types;
895905
mk_flambda2_basic_meet F.flambda2_basic_meet;
896906
mk_flambda2_advanced_meet F.flambda2_advanced_meet;
907+
mk_flambda2_join_algorithm F.flambda2_join_algorithm;
897908
mk_flambda2_unbox_along_intra_function_control_flow
898909
F.flambda2_unbox_along_intra_function_control_flow;
899910
mk_no_flambda2_unbox_along_intra_function_control_flow
@@ -1097,6 +1108,15 @@ module Flambda_backend_options_impl = struct
10971108
Flambda2.function_result_types := Flambda_backend_flags.Set Flambda_backend_flags.Never
10981109
let flambda2_basic_meet () = ()
10991110
let flambda2_advanced_meet () = ()
1111+
let flambda2_join_algorithm algorithm =
1112+
match algorithm with
1113+
| "binary" ->
1114+
Flambda2.join_algorithm := Flambda_backend_flags.Set Flambda_backend_flags.Binary
1115+
| "n-way" ->
1116+
Flambda2.join_algorithm := Flambda_backend_flags.Set Flambda_backend_flags.N_way
1117+
| "checked" ->
1118+
Flambda2.join_algorithm := Flambda_backend_flags.Set Flambda_backend_flags.Checked
1119+
| _ -> () (* This should not occur as we use Arg.Symbol *)
11001120
let flambda2_unbox_along_intra_function_control_flow =
11011121
set Flambda2.unbox_along_intra_function_control_flow
11021122
let no_flambda2_unbox_along_intra_function_control_flow =
@@ -1424,6 +1444,13 @@ module Extra_params = struct
14241444
| _ ->
14251445
Misc.fatal_error "Syntax: flambda2-meet_algorithm=basic|advanced");
14261446
true
1447+
| "flambda2-join-algorithm" ->
1448+
(match String.lowercase_ascii v with
1449+
| "binary" | "n-way" | "checked" as v ->
1450+
Flambda_backend_options_impl.flambda2_join_algorithm v
1451+
| _ ->
1452+
Misc.fatal_error "Syntax: flambda2-join-algorithm=binary|n-way|checked");
1453+
true
14271454
| "flambda2-unbox-along-intra-function-control-flow" ->
14281455
set Flambda2.unbox_along_intra_function_control_flow
14291456
| "flambda2-backend-cse-at-toplevel" ->

driver/flambda_backend_args.mli

+1
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ module type Flambda_backend_options = sig
8888
val no_flambda2_result_types : unit -> unit
8989
val flambda2_basic_meet : unit -> unit
9090
val flambda2_advanced_meet : unit -> unit
91+
val flambda2_join_algorithm : string -> unit
9192
val flambda2_unbox_along_intra_function_control_flow : unit -> unit
9293
val no_flambda2_unbox_along_intra_function_control_flow : unit -> unit
9394
val flambda2_backend_cse_at_toplevel : unit -> unit

driver/flambda_backend_flags.ml

+5
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ let long_frames_threshold = ref max_long_frames_threshold (* -debug-long-frames-
9494
let caml_apply_inline_fast_path = ref false (* -caml-apply-inline-fast-path *)
9595

9696
type function_result_types = Never | Functors_only | All_functions
97+
type join_algorithm = Binary | N_way | Checked
9798
type opt_level = Oclassic | O2 | O3
9899
type 'a or_default = Set of 'a | Default
99100

@@ -126,6 +127,7 @@ module Flambda2 = struct
126127
let backend_cse_at_toplevel = false
127128
let cse_depth = 2
128129
let join_depth = 5
130+
let join_algorithm = Binary
129131
let function_result_types = Never
130132
let enable_reaper = false
131133
let unicode = true
@@ -138,6 +140,7 @@ module Flambda2 = struct
138140
backend_cse_at_toplevel : bool;
139141
cse_depth : int;
140142
join_depth : int;
143+
join_algorithm : join_algorithm;
141144
function_result_types : function_result_types;
142145
enable_reaper : bool;
143146
unicode : bool;
@@ -150,6 +153,7 @@ module Flambda2 = struct
150153
backend_cse_at_toplevel = Default.backend_cse_at_toplevel;
151154
cse_depth = Default.cse_depth;
152155
join_depth = Default.join_depth;
156+
join_algorithm = Default.join_algorithm;
153157
function_result_types = Default.function_result_types;
154158
enable_reaper = Default.enable_reaper;
155159
unicode = Default.unicode;
@@ -182,6 +186,7 @@ module Flambda2 = struct
182186
let backend_cse_at_toplevel = ref Default
183187
let cse_depth = ref Default
184188
let join_depth = ref Default
189+
let join_algorithm = ref Default
185190
let unicode = ref Default
186191
let function_result_types = ref Default
187192
let enable_reaper = ref Default

driver/flambda_backend_flags.mli

+4
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ val long_frames_threshold : int ref
8181
val caml_apply_inline_fast_path : bool ref
8282

8383
type function_result_types = Never | Functors_only | All_functions
84+
type join_algorithm = Binary | N_way | Checked
8485
type opt_level = Oclassic | O2 | O3
8586
type 'a or_default = Set of 'a | Default
8687

@@ -107,6 +108,7 @@ module Flambda2 : sig
107108
val backend_cse_at_toplevel : bool
108109
val cse_depth : int
109110
val join_depth : int
111+
val join_algorithm : join_algorithm
110112
val function_result_types : function_result_types
111113
val enable_reaper : bool
112114

@@ -123,6 +125,7 @@ module Flambda2 : sig
123125
backend_cse_at_toplevel : bool;
124126
cse_depth : int;
125127
join_depth : int;
128+
join_algorithm : join_algorithm;
126129
function_result_types : function_result_types;
127130
enable_reaper : bool;
128131

@@ -139,6 +142,7 @@ module Flambda2 : sig
139142
val backend_cse_at_toplevel : bool or_default ref
140143
val cse_depth : int or_default ref
141144
val join_depth : int or_default ref
145+
val join_algorithm : join_algorithm or_default ref
142146
val enable_reaper : bool or_default ref
143147

144148
val unicode : bool or_default ref

middle_end/flambda2/types/flambda2_types.ml

+8-9
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,18 @@ module Typing_env = struct
1818
include Typing_env
1919

2020
let add_equation t name ty =
21-
add_equation t name ty ~meet_type:Meet_and_join.meet_type
21+
add_equation t name ty ~meet_type:(Meet.meet_type ())
2222

2323
let add_equations_on_params t ~params ~param_types =
2424
add_equations_on_params t ~params ~param_types
25-
~meet_type:Meet_and_join.meet_type
25+
~meet_type:(Meet.meet_type ())
2626

2727
let add_env_extension t extension =
28-
add_env_extension t extension ~meet_type:Meet_and_join.meet_type
28+
add_env_extension t extension ~meet_type:(Meet.meet_type ())
2929

3030
let add_env_extension_with_extra_variables t extension =
3131
add_env_extension_with_extra_variables t extension
32-
~meet_type:Meet_and_join.meet_type
32+
~meet_type:(Meet.meet_type ())
3333

3434
module Alias_set = Aliases.Alias_set
3535
end
@@ -43,7 +43,7 @@ type typing_env_extension = Typing_env_extension.t
4343
include Type_grammar
4444
include More_type_creators
4545
include Expand_head
46-
include Meet_and_join
46+
include Meet
4747
include Provers
4848
include Reify
4949
include Join_levels
@@ -54,10 +54,9 @@ let remove_outermost_alias env ty =
5454

5555
module Equal_types_for_debug = struct
5656
let equal_type env t1 t2 =
57-
Equal_types_for_debug.equal_type ~meet_type:Meet_and_join.meet_type env t1
58-
t2
57+
Equal_types_for_debug.equal_type ~meet_type:(Meet.meet_type ()) env t1 t2
5958

6059
let equal_env_extension env ext1 ext2 =
61-
Equal_types_for_debug.equal_env_extension ~meet_type:Meet_and_join.meet_type
62-
env ext1 ext2
60+
Equal_types_for_debug.equal_env_extension ~meet_type:(Meet.meet_type ()) env
61+
ext1 ext2
6362
end

middle_end/flambda2/types/join_levels.ml

+61-59
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ let check_join_inputs ~env_at_fork _envs_with_levels ~params
4444
extra_lifted_consts_in_use_envs
4545

4646
let cut_and_n_way_join definition_typing_env ts_and_use_ids ~params ~cut_after
47-
~extra_lifted_consts_in_use_envs ~extra_allowed_names:_ =
47+
~extra_lifted_consts_in_use_envs =
4848
let params = Bound_parameters.to_list params in
4949
check_join_inputs ~env_at_fork:definition_typing_env ts_and_use_ids ~params
5050
~extra_lifted_consts_in_use_envs;
@@ -53,61 +53,63 @@ let cut_and_n_way_join definition_typing_env ts_and_use_ids ~params ~cut_after
5353
~n_way_join_type:Meet_and_n_way_join.n_way_join definition_typing_env
5454
~cut_after ts
5555

56-
let ignore_names =
57-
String.split_on_char ','
58-
(Option.value ~default:""
59-
(Sys.getenv_opt "FLAMBDA2_JOIN_DEBUG_IGNORE_NAMES"))
60-
61-
let cut_and_n_way_join_checked definition_typing_env ts_and_use_ids ~params
62-
~cut_after ~extra_lifted_consts_in_use_envs ~extra_allowed_names =
63-
let scope = TE.current_scope definition_typing_env in
64-
let typing_env = TE.increment_scope definition_typing_env in
65-
let old_joined_env =
66-
Join_levels_old.cut_and_n_way_join typing_env ts_and_use_ids ~params
67-
~cut_after ~extra_lifted_consts_in_use_envs ~extra_allowed_names
68-
in
69-
let old_joined_level = TE.cut old_joined_env ~cut_after:scope in
70-
let new_joined_env =
71-
cut_and_n_way_join typing_env ts_and_use_ids ~params ~cut_after
72-
~extra_lifted_consts_in_use_envs ~extra_allowed_names
73-
in
74-
let new_joined_level = TE.cut new_joined_env ~cut_after:scope in
75-
(let distinct_names =
76-
Equal_types_for_debug.names_with_non_equal_types_level_ignoring_name_mode
77-
~meet_type:Meet_and_join.meet_type typing_env old_joined_level
78-
new_joined_level
79-
in
80-
let distinct_names =
81-
Name.Set.filter
82-
(fun name ->
83-
match Name.must_be_var_opt name with
84-
| Some var ->
85-
let raw_name = Variable.raw_name var in
86-
not (List.exists (String.equal raw_name) ignore_names)
87-
| None -> true)
88-
distinct_names
89-
in
90-
if not (Name.Set.is_empty distinct_names)
91-
then (
92-
Format.eprintf "@[<v 1>%s Distinct joins %s@ " (String.make 22 '=')
93-
(String.make 22 '=');
94-
if Flambda_features.debug_flambda2 ()
95-
then
96-
List.iteri
97-
(fun i (t, _, _) ->
98-
let level = TE.cut t ~cut_after in
99-
Format.eprintf "@[<v 1>-- Level %d --@ %a@]@ " i TEL.print level)
100-
ts_and_use_ids;
101-
Format.eprintf "@[<v 1>-- Old join --@ %a@]@ " TEL.print old_joined_level;
102-
Format.eprintf "@[<v 1>-- New join --@ %a@]@ " TEL.print new_joined_level;
103-
Format.eprintf "@[Names with distinct types:@ %a@]" Name.Set.print
104-
distinct_names;
105-
Format.eprintf "@]@\n%s@." (String.make 60 '=')));
106-
TE.add_env_extension_from_level definition_typing_env new_joined_level
107-
~meet_type:Meet_and_join.meet_type
108-
109-
let cut_and_n_way_join =
110-
match Sys.getenv "FLAMBDA2_JOIN_ALGORITHM" with
111-
| "old" -> Join_levels_old.cut_and_n_way_join
112-
| "checked" -> cut_and_n_way_join_checked
113-
| _ | (exception Not_found) -> cut_and_n_way_join
56+
let cut_and_n_way_join definition_typing_env ts_and_use_ids ~params ~cut_after
57+
~extra_lifted_consts_in_use_envs ~extra_allowed_names =
58+
match Flambda_features.join_algorithm () with
59+
| Binary ->
60+
Join_levels_old.cut_and_n_way_join definition_typing_env ts_and_use_ids
61+
~params ~cut_after ~extra_lifted_consts_in_use_envs ~extra_allowed_names
62+
| N_way ->
63+
cut_and_n_way_join definition_typing_env ts_and_use_ids ~params ~cut_after
64+
~extra_lifted_consts_in_use_envs
65+
| Checked ->
66+
let ignore_names =
67+
String.split_on_char ','
68+
(Option.value ~default:""
69+
(Sys.getenv_opt "FLAMBDA2_JOIN_DEBUG_IGNORE_NAMES"))
70+
in
71+
let scope = TE.current_scope definition_typing_env in
72+
let typing_env = TE.increment_scope definition_typing_env in
73+
let old_joined_env =
74+
Join_levels_old.cut_and_n_way_join typing_env ts_and_use_ids ~params
75+
~cut_after ~extra_lifted_consts_in_use_envs ~extra_allowed_names
76+
in
77+
let old_joined_level = TE.cut old_joined_env ~cut_after:scope in
78+
let new_joined_env =
79+
cut_and_n_way_join typing_env ts_and_use_ids ~params ~cut_after
80+
~extra_lifted_consts_in_use_envs
81+
in
82+
let new_joined_level = TE.cut new_joined_env ~cut_after:scope in
83+
(let distinct_names =
84+
Equal_types_for_debug.names_with_non_equal_types_level_ignoring_name_mode
85+
~meet_type:(Meet.meet_type ()) typing_env old_joined_level
86+
new_joined_level
87+
in
88+
let distinct_names =
89+
Name.Set.filter
90+
(fun name ->
91+
match Name.must_be_var_opt name with
92+
| Some var ->
93+
let raw_name = Variable.raw_name var in
94+
not (List.exists (String.equal raw_name) ignore_names)
95+
| None -> true)
96+
distinct_names
97+
in
98+
if not (Name.Set.is_empty distinct_names)
99+
then (
100+
Format.eprintf "@[<v 1>%s Distinct joins %s@ " (String.make 22 '=')
101+
(String.make 22 '=');
102+
if Flambda_features.debug_flambda2 ()
103+
then
104+
List.iteri
105+
(fun i (t, _, _) ->
106+
let level = TE.cut t ~cut_after in
107+
Format.eprintf "@[<v 1>-- Level %d --@ %a@]@ " i TEL.print level)
108+
ts_and_use_ids;
109+
Format.eprintf "@[<v 1>-- Old join --@ %a@]@ " TEL.print old_joined_level;
110+
Format.eprintf "@[<v 1>-- New join --@ %a@]@ " TEL.print new_joined_level;
111+
Format.eprintf "@[Names with distinct types:@ %a@]" Name.Set.print
112+
distinct_names;
113+
Format.eprintf "@]@\n%s@." (String.make 60 '=')));
114+
TE.add_env_extension_from_level definition_typing_env new_joined_level
115+
~meet_type:(Meet.meet_type ())

middle_end/flambda2/types/meet.ml

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
(**************************************************************************)
2+
(* *)
3+
(* OCaml *)
4+
(* *)
5+
(* Vincent Laviron, OCamlPro *)
6+
(* Basile Clément, OCamlPro *)
7+
(* *)
8+
(* Copyright 2024 OCamlPro SAS *)
9+
(* *)
10+
(* All rights reserved. This file is distributed under the terms of *)
11+
(* the GNU Lesser General Public License version 2.1, with the *)
12+
(* special exception on linking described in the file LICENSE. *)
13+
(* *)
14+
(**************************************************************************)
15+
16+
let meet env t1 t2 =
17+
if Flambda_features.use_n_way_join ()
18+
then Meet_and_n_way_join.meet env t1 t2
19+
else Meet_and_join.meet env t1 t2
20+
21+
let[@inline] meet_type () =
22+
if Flambda_features.use_n_way_join ()
23+
then Meet_and_n_way_join.meet_type
24+
else Meet_and_join.meet_type
25+
26+
let meet_shape env t ~shape =
27+
if Flambda_features.use_n_way_join ()
28+
then Meet_and_n_way_join.meet_shape env t ~shape
29+
else Meet_and_join.meet_shape env t ~shape

middle_end/flambda2/types/meet.mli

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
(**************************************************************************)
2+
(* *)
3+
(* OCaml *)
4+
(* *)
5+
(* Vincent Laviron, OCamlPro *)
6+
(* Basile Clément, OCamlPro *)
7+
(* *)
8+
(* Copyright 2024 OCamlPro SAS *)
9+
(* *)
10+
(* All rights reserved. This file is distributed under the terms of *)
11+
(* the GNU Lesser General Public License version 2.1, with the *)
12+
(* special exception on linking described in the file LICENSE. *)
13+
(* *)
14+
(**************************************************************************)
15+
16+
val meet :
17+
Typing_env.t ->
18+
Type_grammar.t ->
19+
Type_grammar.t ->
20+
(Type_grammar.t * Typing_env.t) Or_bottom.t
21+
22+
val meet_type : unit -> Typing_env.meet_type
23+
24+
val meet_shape :
25+
Typing_env.t ->
26+
Type_grammar.t ->
27+
shape:Type_grammar.t ->
28+
Typing_env.t Or_bottom.t

0 commit comments

Comments
 (0)