Skip to content

CSE of involutions #410

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: flambda2.0-stable
Choose a base branch
from
8 changes: 8 additions & 0 deletions doc/missed_optimisations.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# List of missed optimisations to look at later

This should be compiled to the identity. But the CSE environment seems to lose
the relationship in the join after the first switch.

let f x =
let not_x = if x then false else true in
if not_x then false else true
18 changes: 18 additions & 0 deletions flambdatest/mlexamples/involutions.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@

let f x =
not (not (not x))

let g x =
~- (~- (~- x))

let h x =
Int64.(neg (neg (neg x)))

let i x =
Int32.(neg (neg (neg x)))

let j x =
Nativeint.(neg (neg (neg x)))

let k x =
~-. (~-. (~-. x))
14 changes: 9 additions & 5 deletions middle_end/flambda/simplify/simplify_primitive.ml
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,16 @@ let apply_cse dacc ~original_prim =
match DE.find_cse (DA.denv dacc) with_fixed_value with
| None -> None
| Some simple ->
let canonical =
match
TE.get_canonical_simple_exn (DA.typing_env dacc) simple
~min_name_mode:NM.normal
~name_mode_of_existing_simple:NM.normal
in
match canonical with
| exception Not_found -> None
with
| exception Not_found ->
(* CR pchambart: this exception was not caught for some time, it is
expected never to happen, hence the fatal_error *)
Misc.fatal_errorf "No canonical simple for the CSE candidate: %a"
Simple.print simple
| simple -> Some simple

let try_cse dacc ~original_prim ~simplified_args_with_tys ~min_name_mode
Expand All @@ -47,12 +50,13 @@ let try_cse dacc ~original_prim ~simplified_args_with_tys ~min_name_mode
if not (Name_mode.equal min_name_mode Name_mode.normal) then Not_applied dacc
else
let result_var = VB.var result_var in
let args = List.map fst simplified_args_with_tys in
let original_prim = P.update_args original_prim args in
match apply_cse dacc ~original_prim with
| Some replace_with ->
let named = Named.create_simple replace_with in
let ty = T.alias_type_of (P.result_kind' original_prim) replace_with in
let env_extension = TEE.one_equation (Name.var result_var) ty in
let args = List.map fst simplified_args_with_tys in
let simplified_named =
let cost_metrics =
Cost_metrics.notify_removed
Expand Down
10 changes: 10 additions & 0 deletions middle_end/flambda/simplify/simplify_unary_primitive.ml
Original file line number Diff line number Diff line change
Expand Up @@ -507,4 +507,14 @@ let simplify_unary_primitive dacc (prim : P.unary_primitive)
let reachable, env_extension, dacc =
simplifier dacc ~original_term ~arg ~arg_ty ~result_var
in
let dacc =
if P.is_an_involution prim then
DA.map_denv dacc ~f:(fun denv ->
let prim : P.t =
Unary (prim, Simple.var (Var_in_binding_pos.var result_var))
in
DE.add_cse denv (P.Eligible_for_cse.create_exn prim) ~bound_to:arg)
else
dacc
in
reachable, env_extension, [arg], dacc
25 changes: 25 additions & 0 deletions middle_end/flambda/terms/flambda_primitive.ml
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,18 @@ let unary_classify_for_printing p =
| Select_closure _
| Project_var _ -> Destructive

let is_an_involution (p : unary_primitive) =
match p with
| Boolean_not -> true
| Int_arith (_, Neg) -> true
| Float_arith Neg -> true
(* Checked through using the following SMT formula with Z3
{[ (set-logic QF_FP)
(declare-fun x () (_ FloatingPoint 11 53))
(assert (not (= x (fp.neg (fp.neg x)))))
(check-sat) ]} *)
| _ -> false

type binary_int_arith_op =
| Add | Sub | Mul | Div | Mod | And | Or | Xor

Expand Down Expand Up @@ -1626,6 +1638,19 @@ let args t =
| Ternary (_, x0, x1, x2) -> [x0; x1; x2]
| Variadic (_, xs) -> xs

let update_args t args =
match t, args with
| Nullary _, [] -> t
| Unary (p, _), [x0] -> Unary (p, x0)
| Binary (p, _, _), [x0; x1] -> Binary (p, x0, x1)
| Ternary (p, _, _, _), [x0; x1; x2] -> Ternary (p, x0, x1, x2)
| Variadic (p, l), xs ->
assert(List.length l = List.length xs);
Variadic (p, xs)
| _, _ ->
Misc.fatal_errorf "Wrong arity for updating primitive %a with arguments %a"
print t Simple.List.print args

let result_kind (t : t) =
match t with
| Nullary prim -> result_kind_of_nullary_primitive prim
Expand Down
6 changes: 6 additions & 0 deletions middle_end/flambda/terms/flambda_primitive.mli
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,8 @@ include Contains_ids.S with type t := t

val args : t -> Simple.t list

val update_args : t -> Simple.t list -> t

(** Simpler version (e.g. for [Inlining_cost]), where only the actual
primitive matters, not the arguments. *)
module Without_args : sig
Expand Down Expand Up @@ -404,6 +406,10 @@ val at_most_generative_effects : t -> bool
and no other effects. *)
val only_generative_effects : t -> bool

(** Returns [true] iff the primitive is an involution. i.e.
x = Prim(y) iff y = Prim(x) *)
val is_an_involution : unary_primitive -> bool

module Eligible_for_cse : sig
(** Primitive applications that may be replaced by a variable which is let
bound to a single instance of such application. Primitives that are
Expand Down
45 changes: 45 additions & 0 deletions testsuite/tests/flambda2/involution.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
(* TEST
* flambda
** native
*)

(* This tests that involutions are effectively detected.
It checks that for an involution f, f (f x) == x
This is verified by testing that the result of the physical equality
is a known constant, through an allocation test. *)

external minor_words : unit -> (float [@unboxed])
= "caml_gc_minor_words" "caml_gc_minor_words_unboxed"

let[@inline never] check_no_alloc test_name f x =
let before = minor_words () in
let _ = Sys.opaque_identity ((f[@inlined never]) x) in
let after = minor_words () in
let diff = after -. before in
if diff = 0. then
Format.printf "No allocs for test '%s'@." test_name
else
Format.printf "Some allocs for test '%s'@." test_name

let () =
let[@inline] test_known f =
let tester x =
let x = Sys.opaque_identity x in
let n = if (f[@inlined hint]) ((f[@inlined hint]) x) == x then 1 else 0 in
Int64.of_int n
in
let () = Sys.opaque_identity () in
tester
in

(* Test the test: this should allocate *)
check_no_alloc "not an involution" (test_known (succ)) 1;
(* Actual tests *)
check_no_alloc "not" (test_known (not)) true;
check_no_alloc "~-." (test_known (~-.)) 42.;
check_no_alloc "~-" (test_known (~-)) 42;
check_no_alloc "Int32.neg" (test_known Int32.neg) 42l;
check_no_alloc "Int64.neg" (test_known Int64.neg) 42L;
check_no_alloc "Nativeint.neg" (test_known Nativeint.neg) 42n;
()