Skip to content

Commit b7c8ad3

Browse files
authored
Vectorizer refactor heuristic for select_and_join (#3449)
* Refactor [Block.find_last_instruction], cache [Computation.last_pos] * Improve heuristics in [Computation.select_and_join] using [last_pos]
1 parent 22f81d8 commit b7c8ad3

File tree

1 file changed

+81
-49
lines changed

1 file changed

+81
-49
lines changed

backend/cfg/vectorize.ml

Lines changed: 81 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -355,11 +355,11 @@ module Block : sig
355355

356356
val find : t -> Instruction.Id.t -> Instruction.t
357357

358-
(** [find_last_instruction t instrs] returns instruction [i]
359-
from [instrs] such that [i] appears after
360-
all other instructions from [instrs] according to the order of instructions
361-
in this basic block. Raises if [instrs] is empty. *)
362-
val find_last_instruction : t -> Instruction.Id.t list -> Instruction.t
358+
(** [find_last_instruction_id_and_pos group block] returns scalar instruction [i] from
359+
[group] and its position [pos] such that [i] appears after all other instructions
360+
from [group] according to the order of instructions in this basic [block]. *)
361+
val find_last_instruction_id_and_pos :
362+
t -> Instruction.t list -> Instruction.Id.t * int
363363

364364
val get_live_regs_before_terminator : t -> State.live_regs
365365

@@ -417,28 +417,29 @@ end = struct
417417
let get_live_regs_before_terminator t =
418418
State.liveness t.state t.block.terminator.id
419419

420-
let find_last_instruction t instructions =
421-
let instruction_set = Instruction.Id.Set.of_list instructions in
422-
let terminator = terminator t in
423-
if Instruction.Id.Set.mem (Instruction.id terminator) instruction_set
424-
then terminator
425-
else
426-
let body = t.block.body in
427-
let rec find_last cell_option =
428-
match cell_option with
429-
| None ->
430-
Misc.fatal_errorf "Vectorizer.find_last_instruction in block %a"
431-
Label.print t.block.start ()
432-
| Some cell ->
433-
let current_instruction = Instruction.basic (DLL.value cell) in
434-
let current_instruction_id = Instruction.id current_instruction in
435-
if Instruction.Id.Set.exists
436-
(Instruction.Id.equal current_instruction_id)
437-
instruction_set
438-
then current_instruction
439-
else find_last (DLL.prev cell)
440-
in
441-
find_last (DLL.last_cell body)
420+
let find_last_instruction_id_and_pos t instructions =
421+
let get instr =
422+
let id = Instruction.id instr in
423+
let pos = pos t id in
424+
id, pos
425+
in
426+
let rec loop instructions last_id last_pos =
427+
match instructions with
428+
| [] -> last_id, last_pos
429+
| hd :: tl ->
430+
let hd_id, hd_pos = get hd in
431+
if Int.compare hd_pos last_pos > 0
432+
then loop tl hd_id hd_pos
433+
else loop tl last_id last_pos
434+
in
435+
let loop_non_empty instructions =
436+
match instructions with
437+
| [] -> assert false
438+
| hd :: tl ->
439+
let last_id, last_pos = get hd in
440+
loop tl last_id last_pos
441+
in
442+
loop_non_empty instructions
442443
end
443444

444445
(* CR-someday gyorsh: Dependencies computed below can be used for other
@@ -2345,12 +2346,16 @@ end = struct
23452346

23462347
type t =
23472348
{ groups : Group.t Instruction.Id.Map.t;
2348-
(* [all_instructions] is all the scalar instructions in the computations.
2349-
It is an optimization to cache this value here. It is used for ruling
2350-
out computations that are invalid or not implementable, and to estimate
2351-
cost/benefit of vectorized computations. *)
23522349
all_scalar_instructions : Instruction.Id.Set.t;
2353-
new_positions : int Instruction.Id.Map.t
2350+
(** [all_scalar_instructions] is all the scalar instructions in the
2351+
computations. It is an optimization to cache this value here. It is used
2352+
for ruling out computations that are invalid or not implementable, and to
2353+
estimate cost/benefit of vectorized computations. *)
2354+
new_positions : int Instruction.Id.Map.t;
2355+
(** [new_positions] is used for validation. *)
2356+
last_pos : int option
2357+
(** [last_pos] the position in the block body of the last scalar instruction, used
2358+
for heuristics. [None] for empty computations. *)
23542359
}
23552360

23562361
let num_groups t = Instruction.Id.Map.cardinal t.groups
@@ -2592,19 +2597,31 @@ end = struct
25922597
&& respects_register_order_constraints t deps
25932598
&& not (is_dependency_of_outside_body t block deps)
25942599

2595-
(** The key is the last instruction id, for now. This is the place
2596-
where the vectorized intructions will be inserted. *)
2597-
let get_key block instruction_ids =
2598-
let last_instruction = Block.find_last_instruction block instruction_ids in
2599-
Instruction.id last_instruction
2600+
(** The key is the last instruction id, for now. This is the place in the body of the
2601+
block where the vectorized instructions will be inserted. *)
2602+
let get_key group block =
2603+
let id, _pos =
2604+
Block.find_last_instruction_id_and_pos block
2605+
(Group.scalar_instructions group)
2606+
in
2607+
id
2608+
2609+
let get_last_pos group block =
2610+
let _id, pos =
2611+
Block.find_last_instruction_id_and_pos block
2612+
(Group.scalar_instructions group)
2613+
in
2614+
pos
26002615

26012616
(** Returns the dependencies of arguments at position [arg_i]
26022617
of each instruction in [instruction_ids]. Returns None if
26032618
one of the instruction's dependencies is None for [arg_i]. *)
2604-
let get_deps deps ~arg_i instruction_ids =
2619+
let get_deps deps ~arg_i group =
26052620
Misc.Stdlib.List.map_option
2606-
(Dependencies.get_direct_dependency_of_arg deps ~arg_i)
2607-
instruction_ids
2621+
(fun instruction ->
2622+
let id = Instruction.id instruction in
2623+
Dependencies.get_direct_dependency_of_arg deps ~arg_i id)
2624+
(Group.scalar_instructions group)
26082625

26092626
let all_instructions map =
26102627
Instruction.Id.Map.fold
@@ -2634,7 +2651,8 @@ end = struct
26342651
let empty =
26352652
{ groups = Instruction.Id.Map.empty;
26362653
all_scalar_instructions = Instruction.Id.Set.empty;
2637-
new_positions = Instruction.Id.Map.empty
2654+
new_positions = Instruction.Id.Map.empty;
2655+
last_pos = None
26382656
}
26392657

26402658
(* CR gyorsh: if same instruction belongs to two groups, is it handled
@@ -2649,10 +2667,7 @@ end = struct
26492667
match group with
26502668
| None -> None
26512669
| Some (group : Group.t) -> (
2652-
let instruction_ids =
2653-
Group.scalar_instructions group |> List.map Instruction.id
2654-
in
2655-
let key = get_key block instruction_ids in
2670+
let key = get_key group block in
26562671
(* Is there another group with the same key already in the tree? If the
26572672
key instruction of the group is already in another group, and the other
26582673
group is different from this group, we won't vectorize this for
@@ -2674,7 +2689,7 @@ end = struct
26742689
(* CR-someday gyorsh: refer directly to [Reg.t] instead of
26752690
positional [arg_i]. Currently, the code assumes that address
26762691
args are always at the end. *)
2677-
match get_deps deps ~arg_i instruction_ids with
2692+
match get_deps deps ~arg_i group with
26782693
| None ->
26792694
(* At least one of the arguments has a dependency outside the
26802695
block. Currently, not supported. *)
@@ -2706,14 +2721,21 @@ end = struct
27062721
let t =
27072722
{ groups = map;
27082723
all_scalar_instructions = all_instructions map;
2709-
new_positions = new_positions map block
2724+
new_positions = new_positions map block;
2725+
last_pos = Some (get_last_pos root block)
27102726
}
27112727
in
27122728
State.dump_debug (Block.state block)
27132729
"Computation.from_seed build finished\n%a\n" (dump ~block) t;
27142730
assert (seed_address_does_not_depend_on_tree t block deps seed);
27152731
if is_valid t block deps then Some t else None
27162732

2733+
let max_pos o1 o2 =
2734+
match o1, o2 with
2735+
| Some p1, Some p2 -> Some (Int.max p1 p2)
2736+
| None, None -> None
2737+
| (Some _ as res), None | None, (Some _ as res) -> res
2738+
27172739
let join t1 t2 =
27182740
{ groups =
27192741
Instruction.Id.Map.union
@@ -2739,7 +2761,8 @@ end = struct
27392761
pos2=%d"
27402762
Instruction.Id.print key pos1 pos2;
27412763
Some pos1)
2742-
t1.new_positions t2.new_positions
2764+
t1.new_positions t2.new_positions;
2765+
last_pos = max_pos t1.last_pos t2.last_pos
27432766
}
27442767

27452768
(** address registers and vectorizable registers of [t] and [t'] are compatible, i.e.,
@@ -2801,7 +2824,16 @@ end = struct
28012824
| trees ->
28022825
(* sort by cost, ascending *)
28032826
let compare_cost t1 t2 = Int.compare (cost t1) (cost t2) in
2804-
let trees = List.sort compare_cost trees in
2827+
let compare_cost_and_last_pos t1 t2 =
2828+
let c = compare_cost t1 t2 in
2829+
if not (c = 0)
2830+
then c
2831+
else
2832+
(* heuristic to prioritize groups that appear later, it reduces the
2833+
chance they are a dependency of the rest of the body. *)
2834+
Int.neg (Option.compare Int.compare t1.last_pos t2.last_pos)
2835+
in
2836+
let trees = List.sort compare_cost_and_last_pos trees in
28052837
let rec loop trees acc =
28062838
match trees with
28072839
| [] -> acc

0 commit comments

Comments
 (0)