@@ -183,6 +183,86 @@ let get_literal_cases (sw_names : Ast_untagged_variants.switch_names option) =
183183 | {name; tag_type = None } -> res := String name :: ! res));
184184 ! res
185185
186+ let has_explicit_tag_name (sw_names : Ast_untagged_variants.switch_names option )
187+ : bool =
188+ match sw_names with
189+ | None -> false
190+ | Some {blocks} ->
191+ Array. exists
192+ (fun {Ast_untagged_variants. tag_name} -> tag_name <> None )
193+ blocks
194+
195+ let discriminant_expr ~untagged ~sw_names ~tag_name (e : E.t ) : E.t =
196+ if untagged && has_explicit_tag_name sw_names then E. tag ~name: tag_name e
197+ else e
198+
199+ let split_sw_blocks_by_catchall sw_blocks get_block_tag =
200+ let is_literal_block (i , _ ) =
201+ match get_block_tag i with
202+ | Some {Ast_untagged_variants. tag_type = Some (Untagged _ )} -> false
203+ | Some {Ast_untagged_variants. tag_type = Some _ } -> true
204+ | _ -> false
205+ in
206+ let literals = List. filter is_literal_block sw_blocks in
207+ let untagged_only =
208+ List. filter
209+ (fun (i , _ ) ->
210+ match get_block_tag i with
211+ | Some {Ast_untagged_variants. tag_type = Some (Untagged _ )} -> true
212+ | _ -> false )
213+ sw_blocks
214+ in
215+ (literals, untagged_only)
216+
217+ let block_literal_cases_for_guard sw_blocks get_block_tag =
218+ List. filter_map
219+ (fun (i , _ ) ->
220+ match get_block_tag i with
221+ | Some {Ast_untagged_variants. tag_type = Some t } -> Some t
222+ | _ -> None )
223+ sw_blocks
224+
225+ let all_literal_cases_with_block_tags
226+ (sw_names : Ast_untagged_variants.switch_names option ) :
227+ Ast_untagged_variants. tag_type list =
228+ match sw_names with
229+ | None -> []
230+ | Some {blocks; _} as names -> (
231+ match
232+ Array. find_opt
233+ (fun {Ast_untagged_variants. tag_name} -> tag_name <> None )
234+ blocks
235+ with
236+ | None -> get_literal_cases names
237+ | Some _ ->
238+ let acc = ref (get_literal_cases names) in
239+ Ext_array. iter blocks (function
240+ | {Ast_untagged_variants. block_type = None ; tag} -> (
241+ match tag.tag_type with
242+ | Some t -> acc := t :: ! acc
243+ | None -> acc := String tag.name :: ! acc)
244+ | _ -> () );
245+ ! acc)
246+
247+ (* Compile the split path for tagged unions with literal block tags and a
248+ primitive catch-all on the discriminant: first try literal tags on the
249+ discriminant value, otherwise fall back to the primitive catch-all cases. *)
250+ let compile_literal_then_catchall ~cxt ~discr ~block_cases ~default
251+ ~get_block_tag sw_blocks_literal_only sw_blocks_untagged_only :
252+ initialization =
253+ [
254+ S. if_
255+ (E. is_a_literal_case
256+ ~literal_cases:
257+ (block_literal_cases_for_guard sw_blocks_literal_only get_block_tag)
258+ ~block_cases discr)
259+ (compile_cases ~cxt ~switch_exp: discr ~block_cases ~default
260+ ~get_tag: get_block_tag sw_blocks_literal_only)
261+ ~else_:
262+ (compile_cases ~untagged: true ~cxt ~switch_exp: discr ~block_cases
263+ ~default ~get_tag: get_block_tag sw_blocks_untagged_only);
264+ ]
265+
186266let has_null_undefined_other
187267 (sw_names : Ast_untagged_variants.switch_names option ) =
188268 let null, undefined, other = (ref false , ref false , ref false ) in
@@ -700,7 +780,13 @@ let compile output_prefix =
700780 Some tag
701781 in
702782 let tag_name = get_tag_name sw_names in
703- let untagged = block_cases <> [] in
783+ (* Whether this switch includes block (non-const) cases. Used to decide
784+ whether to compile via the untagged/block path in case lowering. *)
785+ let has_block_cases = block_cases <> [] in
786+ (* For tagged unions with a primitive catch-all on the discriminant:
787+ - Guard first on literal cases against the discriminant value.
788+ - If none match, fall back to the primitive typeof checks (catch-alls).
789+ This mirrors unboxed variant handling but targets the tag field. *)
704790 let compile_whole (cxt : Lam_compile_context.t ) =
705791 match
706792 compile_lambda {cxt with continuation = NeedValue Not_tail } switch_arg
@@ -710,20 +796,37 @@ let compile output_prefix =
710796 block
711797 @
712798 if sw_consts_full && sw_consts = [] then
713- compile_cases ~block_cases ~untagged ~cxt
714- ~switch_exp: (if untagged then e else E. tag ~name: tag_name e)
715- ~default: sw_blocks_default ~get_tag: get_block_tag sw_blocks
799+ let has_explicit = has_explicit_tag_name sw_names in
800+ let sw_blocks_literal_only, sw_blocks_untagged_only =
801+ split_sw_blocks_by_catchall sw_blocks get_block_tag
802+ in
803+ let has_literal_block_tags = sw_blocks_literal_only <> [] in
804+ if has_block_cases && has_explicit && has_literal_block_tags then
805+ let discr =
806+ discriminant_expr ~untagged: has_block_cases ~sw_names ~tag_name e
807+ in
808+ compile_literal_then_catchall ~cxt ~discr ~block_cases
809+ ~default: sw_blocks_default ~get_block_tag sw_blocks_literal_only
810+ sw_blocks_untagged_only
811+ else
812+ compile_cases ~block_cases ~untagged: has_block_cases ~cxt
813+ ~switch_exp:
814+ (if has_block_cases then e else E. tag ~name: tag_name e)
815+ ~default: sw_blocks_default ~get_tag: get_block_tag sw_blocks
716816 else if sw_blocks_full && sw_blocks = [] then
717817 compile_cases ~cxt ~switch_exp: e ~block_cases ~default: sw_num_default
718818 ~get_tag: get_const_tag sw_consts
719819 else
720820 (* [e] will be used twice *)
721821 let dispatch e =
722822 let is_a_literal_case () =
723- if untagged then
724- E. is_a_literal_case
725- ~literal_cases: (get_literal_cases sw_names)
726- ~block_cases e
823+ if has_block_cases then
824+ let lit_e =
825+ discriminant_expr ~untagged: has_block_cases ~sw_names
826+ ~tag_name e
827+ in
828+ let lit_cases = all_literal_cases_with_block_tags sw_names in
829+ E. is_a_literal_case ~literal_cases: lit_cases ~block_cases lit_e
727830 else
728831 E. is_int_tag
729832 ~has_null_undefined_other: (has_null_undefined_other sw_names)
@@ -737,27 +840,32 @@ let compile output_prefix =
737840 | _ -> false
738841 in
739842 if
740- untagged
843+ has_block_cases
741844 && List. length sw_consts = 0
742845 && eq_default sw_num_default sw_blocks_default
743846 then
744847 let literal_cases = get_literal_cases sw_names in
745848 let has_null_case =
746849 List. mem Ast_untagged_variants. Null literal_cases
747850 in
748- compile_cases ~untagged ~cxt
749- ~switch_exp: (if untagged then e else E. tag ~name: tag_name e)
851+ compile_cases ~untagged: has_block_cases ~cxt
852+ ~switch_exp:
853+ (if has_block_cases then e else E. tag ~name: tag_name e)
750854 ~block_cases ~has_null_case ~default: sw_blocks_default
751855 ~get_tag: get_block_tag sw_blocks
752856 else
753857 [
754858 S. if_ (is_a_literal_case () )
755- (compile_cases ~cxt ~switch_exp: e ~block_cases
756- ~default: sw_num_default ~get_tag: get_const_tag sw_consts)
859+ (compile_cases ~cxt
860+ ~switch_exp:
861+ (discriminant_expr ~untagged: has_block_cases ~sw_names
862+ ~tag_name e)
863+ ~block_cases ~default: sw_num_default ~get_tag: get_const_tag
864+ sw_consts)
757865 ~else_:
758- (compile_cases ~untagged ~cxt
866+ (compile_cases ~untagged: has_block_cases ~cxt
759867 ~switch_exp:
760- (if untagged then e else E. tag ~name: tag_name e)
868+ (if has_block_cases then e else E. tag ~name: tag_name e)
761869 ~block_cases ~default: sw_blocks_default
762870 ~get_tag: get_block_tag sw_blocks);
763871 ]
0 commit comments