Skip to content

Commit b3f80b5

Browse files
committed
poc tagged union primitive catch all
1 parent fd0404d commit b3f80b5

19 files changed

+799
-30
lines changed

compiler/core/js_dump.ml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -922,13 +922,17 @@ and expression_desc cxt ~(level : int) f x : cxt =
922922
| None -> L.tag
923923
| Some s -> s
924924
in
925+
let is_primitive_catch_all =
926+
Ast_untagged_variants.has_primitive_catchall p.attrs
927+
in
925928
let tails =
926929
Ext_list.filter_map tails (fun ((f, optional), x) ->
927930
match x.expression_desc with
928931
| Undefined _ when optional -> None
929932
| _ -> Some (f, x))
930933
in
931934
if untagged then tails
935+
else if is_primitive_catch_all then tails
932936
else
933937
( Js_op.Lit tag_name,
934938
(* TAG:xx for inline records *)

compiler/core/lam_compile.ml

Lines changed: 123 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,86 @@ let get_literal_cases (sw_names : Ast_untagged_variants.switch_names option) =
183183
| {name; tag_type = None} -> res := String name :: !res));
184184
!res
185185

186+
let has_explicit_tag_name (sw_names : Ast_untagged_variants.switch_names option)
187+
: bool =
188+
match sw_names with
189+
| None -> false
190+
| Some {blocks} ->
191+
Array.exists
192+
(fun {Ast_untagged_variants.tag_name} -> tag_name <> None)
193+
blocks
194+
195+
let discriminant_expr ~untagged ~sw_names ~tag_name (e : E.t) : E.t =
196+
if untagged && has_explicit_tag_name sw_names then E.tag ~name:tag_name e
197+
else e
198+
199+
let split_sw_blocks_by_catchall sw_blocks get_block_tag =
200+
let is_literal_block (i, _) =
201+
match get_block_tag i with
202+
| Some {Ast_untagged_variants.tag_type = Some (Untagged _)} -> false
203+
| Some {Ast_untagged_variants.tag_type = Some _} -> true
204+
| _ -> false
205+
in
206+
let literals = List.filter is_literal_block sw_blocks in
207+
let untagged_only =
208+
List.filter
209+
(fun (i, _) ->
210+
match get_block_tag i with
211+
| Some {Ast_untagged_variants.tag_type = Some (Untagged _)} -> true
212+
| _ -> false)
213+
sw_blocks
214+
in
215+
(literals, untagged_only)
216+
217+
let block_literal_cases_for_guard sw_blocks get_block_tag =
218+
List.filter_map
219+
(fun (i, _) ->
220+
match get_block_tag i with
221+
| Some {Ast_untagged_variants.tag_type = Some t} -> Some t
222+
| _ -> None)
223+
sw_blocks
224+
225+
let all_literal_cases_with_block_tags
226+
(sw_names : Ast_untagged_variants.switch_names option) :
227+
Ast_untagged_variants.tag_type list =
228+
match sw_names with
229+
| None -> []
230+
| Some {blocks; _} as names -> (
231+
match
232+
Array.find_opt
233+
(fun {Ast_untagged_variants.tag_name} -> tag_name <> None)
234+
blocks
235+
with
236+
| None -> get_literal_cases names
237+
| Some _ ->
238+
let acc = ref (get_literal_cases names) in
239+
Ext_array.iter blocks (function
240+
| {Ast_untagged_variants.block_type = None; tag} -> (
241+
match tag.tag_type with
242+
| Some t -> acc := t :: !acc
243+
| None -> acc := String tag.name :: !acc)
244+
| _ -> ());
245+
!acc)
246+
247+
(* Compile the split path for tagged unions with literal block tags and a
248+
primitive catch-all on the discriminant: first try literal tags on the
249+
discriminant value, otherwise fall back to the primitive catch-all cases. *)
250+
let compile_literal_then_catchall ~cxt ~discr ~block_cases ~default
251+
~get_block_tag sw_blocks_literal_only sw_blocks_untagged_only :
252+
initialization =
253+
[
254+
S.if_
255+
(E.is_a_literal_case
256+
~literal_cases:
257+
(block_literal_cases_for_guard sw_blocks_literal_only get_block_tag)
258+
~block_cases discr)
259+
(compile_cases ~cxt ~switch_exp:discr ~block_cases ~default
260+
~get_tag:get_block_tag sw_blocks_literal_only)
261+
~else_:
262+
(compile_cases ~untagged:true ~cxt ~switch_exp:discr ~block_cases
263+
~default ~get_tag:get_block_tag sw_blocks_untagged_only);
264+
]
265+
186266
let has_null_undefined_other
187267
(sw_names : Ast_untagged_variants.switch_names option) =
188268
let null, undefined, other = (ref false, ref false, ref false) in
@@ -700,7 +780,13 @@ let compile output_prefix =
700780
Some tag
701781
in
702782
let tag_name = get_tag_name sw_names in
703-
let untagged = block_cases <> [] in
783+
(* Whether this switch includes block (non-const) cases. Used to decide
784+
whether to compile via the untagged/block path in case lowering. *)
785+
let has_block_cases = block_cases <> [] in
786+
(* For tagged unions with a primitive catch-all on the discriminant:
787+
- Guard first on literal cases against the discriminant value.
788+
- If none match, fall back to the primitive typeof checks (catch-alls).
789+
This mirrors unboxed variant handling but targets the tag field. *)
704790
let compile_whole (cxt : Lam_compile_context.t) =
705791
match
706792
compile_lambda {cxt with continuation = NeedValue Not_tail} switch_arg
@@ -710,20 +796,37 @@ let compile output_prefix =
710796
block
711797
@
712798
if sw_consts_full && sw_consts = [] then
713-
compile_cases ~block_cases ~untagged ~cxt
714-
~switch_exp:(if untagged then e else E.tag ~name:tag_name e)
715-
~default:sw_blocks_default ~get_tag:get_block_tag sw_blocks
799+
let has_explicit = has_explicit_tag_name sw_names in
800+
let sw_blocks_literal_only, sw_blocks_untagged_only =
801+
split_sw_blocks_by_catchall sw_blocks get_block_tag
802+
in
803+
let has_literal_block_tags = sw_blocks_literal_only <> [] in
804+
if has_block_cases && has_explicit && has_literal_block_tags then
805+
let discr =
806+
discriminant_expr ~untagged:has_block_cases ~sw_names ~tag_name e
807+
in
808+
compile_literal_then_catchall ~cxt ~discr ~block_cases
809+
~default:sw_blocks_default ~get_block_tag sw_blocks_literal_only
810+
sw_blocks_untagged_only
811+
else
812+
compile_cases ~block_cases ~untagged:has_block_cases ~cxt
813+
~switch_exp:
814+
(if has_block_cases then e else E.tag ~name:tag_name e)
815+
~default:sw_blocks_default ~get_tag:get_block_tag sw_blocks
716816
else if sw_blocks_full && sw_blocks = [] then
717817
compile_cases ~cxt ~switch_exp:e ~block_cases ~default:sw_num_default
718818
~get_tag:get_const_tag sw_consts
719819
else
720820
(* [e] will be used twice *)
721821
let dispatch e =
722822
let is_a_literal_case () =
723-
if untagged then
724-
E.is_a_literal_case
725-
~literal_cases:(get_literal_cases sw_names)
726-
~block_cases e
823+
if has_block_cases then
824+
let lit_e =
825+
discriminant_expr ~untagged:has_block_cases ~sw_names
826+
~tag_name e
827+
in
828+
let lit_cases = all_literal_cases_with_block_tags sw_names in
829+
E.is_a_literal_case ~literal_cases:lit_cases ~block_cases lit_e
727830
else
728831
E.is_int_tag
729832
~has_null_undefined_other:(has_null_undefined_other sw_names)
@@ -737,27 +840,32 @@ let compile output_prefix =
737840
| _ -> false
738841
in
739842
if
740-
untagged
843+
has_block_cases
741844
&& List.length sw_consts = 0
742845
&& eq_default sw_num_default sw_blocks_default
743846
then
744847
let literal_cases = get_literal_cases sw_names in
745848
let has_null_case =
746849
List.mem Ast_untagged_variants.Null literal_cases
747850
in
748-
compile_cases ~untagged ~cxt
749-
~switch_exp:(if untagged then e else E.tag ~name:tag_name e)
851+
compile_cases ~untagged:has_block_cases ~cxt
852+
~switch_exp:
853+
(if has_block_cases then e else E.tag ~name:tag_name e)
750854
~block_cases ~has_null_case ~default:sw_blocks_default
751855
~get_tag:get_block_tag sw_blocks
752856
else
753857
[
754858
S.if_ (is_a_literal_case ())
755-
(compile_cases ~cxt ~switch_exp:e ~block_cases
756-
~default:sw_num_default ~get_tag:get_const_tag sw_consts)
859+
(compile_cases ~cxt
860+
~switch_exp:
861+
(discriminant_expr ~untagged:has_block_cases ~sw_names
862+
~tag_name e)
863+
~block_cases ~default:sw_num_default ~get_tag:get_const_tag
864+
sw_consts)
757865
~else_:
758-
(compile_cases ~untagged ~cxt
866+
(compile_cases ~untagged:has_block_cases ~cxt
759867
~switch_exp:
760-
(if untagged then e else E.tag ~name:tag_name e)
868+
(if has_block_cases then e else E.tag ~name:tag_name e)
761869
~block_cases ~default:sw_blocks_default
762870
~get_tag:get_block_tag sw_blocks);
763871
]

0 commit comments

Comments
 (0)