Skip to content

Commit 0b58fe6

Browse files
committed
arch extensions
1 parent 04d8e31 commit 0b58fe6

File tree

5 files changed

+86
-26
lines changed

5 files changed

+86
-26
lines changed

backend/amd64/arch.ml

+23-3
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ module Extension = struct
3030
| LZCNT
3131
| BMI
3232
| BMI2
33+
| AVX
34+
| AVX2
35+
| AVX512F
3336

3437
let rank = function
3538
| POPCNT -> 0
@@ -43,6 +46,9 @@ module Extension = struct
4346
| LZCNT -> 8
4447
| BMI -> 9
4548
| BMI2 -> 10
49+
| AVX -> 11
50+
| AVX2 -> 12
51+
| AVX512F -> 13
4652

4753
let compare left right = Int.compare (rank left) (rank right)
4854
end
@@ -62,6 +68,9 @@ module Extension = struct
6268
| LZCNT -> "LZCNT"
6369
| BMI -> "BMI"
6470
| BMI2 -> "BMI2"
71+
| AVX -> "AVX"
72+
| AVX2 -> "AVX2"
73+
| AVX512F -> "AVX512F"
6574

6675
let generation = function
6776
| POPCNT -> "Nehalem+"
@@ -75,18 +84,29 @@ module Extension = struct
7584
| LZCNT -> "Haswell+"
7685
| BMI -> "Haswell+"
7786
| BMI2 -> "Haswell+"
87+
| AVX -> "Sandybridge+"
88+
| AVX2 -> "Haswell+"
89+
| AVX512F -> "SkylakeXeon+"
7890

7991
let enabled_by_default = function
8092
| SSE3 | SSSE3 | SSE4_1 | SSE4_2
81-
| POPCNT | CLMUL | LZCNT | BMI | BMI2 -> true
82-
| PREFETCHW | PREFETCHWT1 -> false
93+
| POPCNT | CLMUL | LZCNT | BMI | BMI2 | AVX | AVX2 -> true
94+
| PREFETCHW | PREFETCHWT1 | AVX512F -> false
8395

84-
let all = Set.of_list [ POPCNT; PREFETCHW; PREFETCHWT1; SSE3; SSSE3; SSE4_1; SSE4_2; CLMUL; LZCNT; BMI; BMI2 ]
96+
let all = Set.of_list [ POPCNT; PREFETCHW; PREFETCHWT1; SSE3; SSSE3; SSE4_1; SSE4_2; CLMUL; LZCNT; BMI; BMI2; AVX; AVX2; AVX512F ]
8597
let config = ref (Set.filter enabled_by_default all)
8698

8799
let enabled t = Set.mem t !config
88100
let disabled t = not (enabled t)
89101

102+
let allow_vec256 () = List.exists (fun t -> enabled t) [AVX; AVX2; AVX512F]
103+
let allow_vec512 () = List.exists (fun t -> enabled t) [AVX512F]
104+
105+
let require_vec256 () =
106+
if not (allow_vec256 ()) then Misc.fatal_error "AVX or AVX512 is required for 256-bit vectors"
107+
let require_vec512 () =
108+
if not (allow_vec512 ()) then Misc.fatal_error "AVX512 is required for 512-bit vectors"
109+
90110
let args =
91111
let y t = "-f" ^ (name t |> String.lowercase_ascii) in
92112
let n t = "-fno-" ^ (name t |> String.lowercase_ascii) in

backend/amd64/arch.mli

+8
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,19 @@ module Extension : sig
3131
to Haswell, i.e. they do not cause an illegal instruction fault.
3232
That means code using LZCNT/TZCNT will silently produce wrong results. *)
3333
| BMI2
34+
| AVX
35+
| AVX2
36+
| AVX512F
3437

3538
val name : t -> string
3639

3740
val enabled : t -> bool
3841
val available : unit -> t list
42+
43+
val allow_vec256 : unit -> bool
44+
val allow_vec512 : unit -> bool
45+
val require_vec256 : unit -> unit
46+
val require_vec512 : unit -> unit
3947
end
4048

4149
val trap_notes : bool ref

backend/amd64/emit.ml

+12-4
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,12 @@ let register_name typ r : X86_ast.arg =
144144
match (typ : Cmm.machtype_component) with
145145
| Int | Val | Addr -> Reg64 int_reg_name.(r)
146146
| Float | Float32 | Vec128 | Valx2 -> Regf xmm_reg_name.(r - 100)
147-
| Vec256 -> Regf ymm_reg_name.(r - 100)
148-
| Vec512 -> Regf zmm_reg_name.(r - 100)
147+
| Vec256 ->
148+
Arch.Extension.require_vec256 ();
149+
Regf ymm_reg_name.(r - 100)
150+
| Vec512 ->
151+
Arch.Extension.require_vec512 ();
152+
Regf zmm_reg_name.(r - 100)
149153

150154
let phys_rax = phys_reg Int 0
151155

@@ -375,8 +379,12 @@ let x86_data_type_for_stack_slot : Cmm.machtype_component -> X86_ast.data_type =
375379
function
376380
| Float -> REAL8
377381
| Vec128 -> VEC128
378-
| Vec256 -> VEC256
379-
| Vec512 -> VEC512
382+
| Vec256 ->
383+
Arch.Extension.require_vec256 ();
384+
VEC256
385+
| Vec512 ->
386+
Arch.Extension.require_vec512 ();
387+
VEC512
380388
| Valx2 -> VEC128
381389
| Int | Addr | Val -> QWORD
382390
| Float32 -> REAL4

backend/amd64/proc.ml

+40-16
Original file line numberDiff line numberDiff line change
@@ -107,17 +107,32 @@ let hard_vec256_reg = Array.map (fun r -> {r with Reg.typ = Vec256}) hard_float_
107107
let hard_vec512_reg = Array.map (fun r -> {r with Reg.typ = Vec512}) hard_float_reg
108108
let hard_float32_reg = Array.map (fun r -> {r with Reg.typ = Float32}) hard_float_reg
109109

110+
let add_hard_vec256_regs list ~f =
111+
if Arch.Extension.allow_vec256 ()
112+
then f hard_vec256_reg :: list else list
113+
114+
let add_hard_vec512_regs list ~f =
115+
if Arch.Extension.allow_vec512 ()
116+
then f hard_vec512_reg :: list else list
117+
110118
let all_phys_regs =
111-
Array.concat [hard_int_reg; hard_float_reg; hard_float32_reg; hard_vec128_reg; hard_vec256_reg; hard_vec512_reg]
119+
[hard_int_reg; hard_float_reg; hard_float32_reg; hard_vec128_reg]
120+
|> add_hard_vec256_regs ~f:(fun regs -> regs)
121+
|> add_hard_vec512_regs ~f:(fun regs -> regs)
122+
|> Array.concat
112123

113124
let phys_reg ty n =
114125
match (ty : machtype_component) with
115126
| Int | Addr | Val -> hard_int_reg.(n)
116127
| Float -> hard_float_reg.(n - 100)
117128
| Float32 -> hard_float32_reg.(n - 100)
118129
| Vec128 | Valx2 -> hard_vec128_reg.(n - 100)
119-
| Vec256 -> hard_vec256_reg.(n - 100)
120-
| Vec512 -> hard_vec512_reg.(n - 100)
130+
| Vec256 ->
131+
Arch.Extension.require_vec256 ();
132+
hard_vec256_reg.(n - 100)
133+
| Vec512 ->
134+
Arch.Extension.require_vec512 ();
135+
hard_vec512_reg.(n - 100)
121136

122137
let rax = phys_reg Int 0
123138
let rdi = phys_reg Int 2
@@ -128,9 +143,14 @@ let r11 = phys_reg Int 11
128143
let rbp = phys_reg Int 12
129144

130145
(* CSE needs to know that all versions of xmm15 are destroyed. *)
131-
let destroy_xmm n =
132-
[| phys_reg Float (100 + n); phys_reg Float32 (100 + n);
133-
phys_reg Vec128 (100 + n); phys_reg Vec256 (100 + n); phys_reg Vec512 (100 + n) |]
146+
let destroy_xmm =
147+
let types =
148+
([ Float; Float32; Vec128 ] : machtype_component list)
149+
|> add_hard_vec256_regs ~f:(fun _ -> Vec256)
150+
|> add_hard_vec512_regs ~f:(fun _ -> Vec512)
151+
|> Array.of_list
152+
in
153+
fun n -> Array.map (fun t -> phys_reg t (100 + n)) types
134154

135155
let destroyed_by_plt_stub =
136156
if not X86_proc.use_plt then [| |] else [| r10; r11 |]
@@ -189,6 +209,7 @@ let calling_conventions
189209
ofs := !ofs + size_vec128
190210
end
191211
| Vec256 ->
212+
Arch.Extension.require_vec256 ();
192213
if !float <= last_float then begin
193214
loc.(i) <- phys_reg Vec256 !float;
194215
incr float
@@ -198,6 +219,7 @@ let calling_conventions
198219
ofs := !ofs + size_vec256
199220
end
200221
| Vec512 ->
222+
Arch.Extension.require_vec512 ();
201223
if !float <= last_float then begin
202224
loc.(i) <- phys_reg Vec512 !float;
203225
incr float
@@ -390,21 +412,23 @@ let int_regs_destroyed_at_c_call =
390412

391413
let destroyed_at_c_call_win64 =
392414
(* Win64: rbx, rbp, rsi, rdi, r12-r15, xmm6-xmm15 preserved *)
393-
Array.concat [
394-
Array.map (phys_reg Int) int_regs_destroyed_at_c_call_win64;
415+
[ Array.map (phys_reg Int) int_regs_destroyed_at_c_call_win64;
395416
Array.sub hard_float_reg 0 6;
396417
Array.sub hard_float32_reg 0 6;
397-
Array.sub hard_vec128_reg 0 6
398-
]
418+
Array.sub hard_vec128_reg 0 6 ]
419+
|> add_hard_vec256_regs ~f:(fun regs -> Array.sub regs 0 6)
420+
|> add_hard_vec512_regs ~f:(fun regs -> Array.sub regs 0 6)
421+
|> Array.concat
399422

400423
let destroyed_at_c_call_unix =
401424
(* Unix: rbx, rbp, r12-r15 preserved *)
402-
Array.concat [
403-
Array.map (phys_reg Int) int_regs_destroyed_at_c_call;
404-
hard_float_reg;
405-
hard_float32_reg;
406-
hard_vec128_reg
407-
]
425+
[ Array.map (phys_reg Int) int_regs_destroyed_at_c_call;
426+
hard_float_reg;
427+
hard_float32_reg;
428+
hard_vec128_reg ]
429+
|> add_hard_vec256_regs ~f:(fun regs -> regs)
430+
|> add_hard_vec512_regs ~f:(fun regs -> regs)
431+
|> Array.concat
408432

409433
let destroyed_at_c_call =
410434
(* C calling conventions preserve rbx, but it is clobbered

flambda-backend/tests/simd/dune

+3-3
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
(libraries simd_test_helpers stdlib_stable stdlib_upstream_compatible)
5656
(foreign_archives stubs)
5757
(ocamlopt_flags
58-
(:standard -extension simd_alpha)))
58+
(:standard -extension simd_beta)))
5959

6060
(rule
6161
(enabled_if
@@ -219,7 +219,7 @@
219219
(libraries simd_test_helpers stdlib_stable stdlib_upstream_compatible)
220220
(foreign_archives stubs)
221221
(ocamlopt_flags
222-
(:standard -nodynlink -extension simd_alpha)))
222+
(:standard -nodynlink -extension simd_beta)))
223223

224224
(rule
225225
(enabled_if
@@ -391,7 +391,7 @@
391391
(<> %{system} macosx))
392392
(foreign_archives stubs)
393393
(ocamlopt_flags
394-
(:standard -internal-assembler -extension simd_alpha)))
394+
(:standard -internal-assembler -extension simd_beta)))
395395

396396
(rule
397397
(enabled_if

0 commit comments

Comments
 (0)