From eb5f5e628227fe8cfb1ce2fa3b71534f327703eb Mon Sep 17 00:00:00 2001 From: Jasper Woudenberg Date: Sun, 31 Mar 2024 15:53:47 +0200 Subject: [PATCH 01/12] Rename gen_compare to gen_equality These tests check the equality functionality. In preparation of adding a Sort ability (which defines a compare method), naming the quality tests `gen_compare` creates ambiguity, so this renames them. --- .../compiler/test_gen/src/{gen_compare.rs => gen_equality.rs} | 0 crates/compiler/test_gen/src/tests.rs | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename crates/compiler/test_gen/src/{gen_compare.rs => gen_equality.rs} (100%) diff --git a/crates/compiler/test_gen/src/gen_compare.rs b/crates/compiler/test_gen/src/gen_equality.rs similarity index 100% rename from crates/compiler/test_gen/src/gen_compare.rs rename to crates/compiler/test_gen/src/gen_equality.rs diff --git a/crates/compiler/test_gen/src/tests.rs b/crates/compiler/test_gen/src/tests.rs index befc7de4457..3338b02447d 100644 --- a/crates/compiler/test_gen/src/tests.rs +++ b/crates/compiler/test_gen/src/tests.rs @@ -5,9 +5,9 @@ #![allow(clippy::float_cmp)] pub mod gen_abilities; -pub mod gen_compare; pub mod gen_definitions; pub mod gen_dict; +pub mod gen_equality; pub mod gen_erased; pub mod gen_list; pub mod gen_num; From 458e6147aaf2e7095db74f197583635bb7a30a04 Mon Sep 17 00:00:00 2001 From: Jasper Woudenberg Date: Sun, 31 Mar 2024 15:55:31 +0200 Subject: [PATCH 02/12] Add Sort ability This adds a Sort ability, currently lacking any implementations of the ability. --- crates/compiler/builtins/roc/List.roc | 5 +++++ crates/compiler/builtins/roc/main.roc | 1 + crates/compiler/module/src/symbol.rs | 3 +++ 3 files changed, 9 insertions(+) diff --git a/crates/compiler/builtins/roc/List.roc b/crates/compiler/builtins/roc/List.roc index 3d6fe798a32..781a1c183b9 100644 --- a/crates/compiler/builtins/roc/List.roc +++ b/crates/compiler/builtins/roc/List.roc @@ -69,6 +69,8 @@ module [ walkBackwardsUntil, countIf, chunksOf, + Sort, + compare, ] import Bool exposing [Bool, Eq] @@ -1324,3 +1326,6 @@ iterBackwardsHelp = \list, state, f, prevIndex -> Break b -> Break b else Continue state + +Sort implements + compare : a, a -> [LessThan, Equal, GreaterThan] where a implements Sort diff --git a/crates/compiler/builtins/roc/main.roc b/crates/compiler/builtins/roc/main.roc index 5fb1daef3bc..cabef1aad26 100644 --- a/crates/compiler/builtins/roc/main.roc +++ b/crates/compiler/builtins/roc/main.roc @@ -6,6 +6,7 @@ package [ List, Dict, Set, + Sort Decode, Encode, Hash, diff --git a/crates/compiler/module/src/symbol.rs b/crates/compiler/module/src/symbol.rs index ac08fd59577..76f07aa124f 100644 --- a/crates/compiler/module/src/symbol.rs +++ b/crates/compiler/module/src/symbol.rs @@ -53,6 +53,7 @@ pub const DERIVABLE_ABILITIES: &[(Symbol, &[Symbol])] = &[ (Symbol::DECODE_DECODING, &[Symbol::DECODE_DECODER]), (Symbol::HASH_HASH_ABILITY, &[Symbol::HASH_HASH]), (Symbol::BOOL_EQ, &[Symbol::BOOL_IS_EQ]), + (Symbol::LIST_SORT, &[Symbol::LIST_COMPARE]), ( Symbol::INSPECT_INSPECT_ABILITY, &[Symbol::INSPECT_TO_INSPECTOR], @@ -1531,6 +1532,8 @@ define_builtins! { 86 LIST_WALK_WITH_INDEX_UNTIL: "walkWithIndexUntil" 87 LIST_CLONE: "clone" 88 LIST_LEN_USIZE: "lenUsize" + 89 LIST_SORT: "Sort" exposed_type=true + 90 LIST_COMPARE: "compare" } 7 RESULT: "Result" => { 0 RESULT_RESULT: "Result" exposed_type=true // the Result.Result type alias From 57997cc2921baac829fad3b8315c99d18267b072 Mon Sep 17 00:00:00 2001 From: Jasper Woudenberg Date: Mon, 1 Apr 2024 15:10:12 +0200 Subject: [PATCH 03/12] Add StructuralCompare skeleton I'm modeling implementation of the Sort ability for arbitrary types after how it's done for Eq/NotEq. --- crates/compiler/builtins/roc/List.roc | 4 + crates/compiler/can/src/builtins.rs | 1 + crates/compiler/derive_key/src/lib.rs | 16 +++ crates/compiler/gen_llvm/src/llvm/lowlevel.rs | 3 + crates/compiler/gen_wasm/src/low_level.rs | 4 + crates/compiler/module/src/low_level.rs | 2 + crates/compiler/module/src/symbol.rs | 3 +- .../compiler/mono/src/drop_specialization.rs | 2 +- crates/compiler/mono/src/inc_dec.rs | 2 +- crates/compiler/solve/src/ability.rs | 134 ++++++++++++++++++ crates/compiler/types/src/subs.rs | 3 + 11 files changed, 171 insertions(+), 3 deletions(-) diff --git a/crates/compiler/builtins/roc/List.roc b/crates/compiler/builtins/roc/List.roc index 781a1c183b9..6e7f52fd420 100644 --- a/crates/compiler/builtins/roc/List.roc +++ b/crates/compiler/builtins/roc/List.roc @@ -1329,3 +1329,7 @@ iterBackwardsHelp = \list, state, f, prevIndex -> Sort implements compare : a, a -> [LessThan, Equal, GreaterThan] where a implements Sort + +# INTERNAL COMPILER USE ONLY: used to lower calls to `compare` to structural +# compare via the `Sort` low-level for derived types. +structuralCompare : a, a -> [LessThan, Equal, GreaterThan] diff --git a/crates/compiler/can/src/builtins.rs b/crates/compiler/can/src/builtins.rs index 8191eea3996..652d1763c65 100644 --- a/crates/compiler/can/src/builtins.rs +++ b/crates/compiler/can/src/builtins.rs @@ -216,6 +216,7 @@ map_symbol_to_lowlevel_and_arity! { And; BOOL_AND; 2, Or; BOOL_OR; 2, Not; BOOL_NOT; 1, + Compare; LIST_STRUCTURAL_COMPARE; 2, BoxExpr; BOX_BOX_FUNCTION; 1, UnboxExpr; BOX_UNBOX; 1, Unreachable; LIST_UNREACHABLE; 1, diff --git a/crates/compiler/derive_key/src/lib.rs b/crates/compiler/derive_key/src/lib.rs index 7454dcc6aa9..5696376ec62 100644 --- a/crates/compiler/derive_key/src/lib.rs +++ b/crates/compiler/derive_key/src/lib.rs @@ -81,6 +81,7 @@ pub enum DeriveBuiltin { Decoder, Hash, IsEq, + Compare, ToInspector, } @@ -93,6 +94,7 @@ impl TryFrom for DeriveBuiltin { Symbol::DECODE_DECODER => Ok(DeriveBuiltin::Decoder), Symbol::HASH_HASH => Ok(DeriveBuiltin::Hash), Symbol::BOOL_IS_EQ => Ok(DeriveBuiltin::IsEq), + Symbol::LIST_COMPARE => Ok(DeriveBuiltin::Compare), Symbol::INSPECT_TO_INSPECTOR => Ok(DeriveBuiltin::ToInspector), _ => Err(value), } @@ -127,6 +129,13 @@ impl Derived { Symbol::BOOL_STRUCTURAL_EQ, )) } + DeriveBuiltin::Compare => { + // If obligation checking passes, we always lower derived implementations of + // `compare` to the `Compare` low-level, to be fulfilled by the backends. + Ok(Derived::SingleLambdaSetImmediate( + Symbol::LIST_STRUCTURAL_COMPARE, + )) + } DeriveBuiltin::ToInspector => match FlatInspectable::from_var(subs, var) { FlatInspectable::Immediate(imm) => Ok(Derived::Immediate(imm)), FlatInspectable::Key(repr) => Ok(Derived::Key(DeriveKey::ToInspector(repr))), @@ -161,6 +170,13 @@ impl Derived { Symbol::BOOL_STRUCTURAL_EQ, )) } + DeriveBuiltin::Compare => { + // If obligation checking passes, we always lower derived implementations of + // `compare` to the `Compare` low-level, to be fulfilled by the backends. + Ok(Derived::SingleLambdaSetImmediate( + Symbol::LIST_STRUCTURAL_COMPARE, + )) + } DeriveBuiltin::ToInspector => { match inspect::FlatInspectable::from_builtin_alias(symbol).unwrap() { FlatInspectable::Immediate(imm) => Ok(Derived::Immediate(imm)), diff --git a/crates/compiler/gen_llvm/src/llvm/lowlevel.rs b/crates/compiler/gen_llvm/src/llvm/lowlevel.rs index 8a4de54ee97..5015206f917 100644 --- a/crates/compiler/gen_llvm/src/llvm/lowlevel.rs +++ b/crates/compiler/gen_llvm/src/llvm/lowlevel.rs @@ -1269,6 +1269,9 @@ pub(crate) fn run_low_level<'a, 'ctx>( let bool_val = env.builder.new_build_not(arg.into_int_value(), "bool_not"); BasicValueEnum::IntValue(bool_val) } + Compare => { + panic!("TODO: implement this") + } Hash => { unimplemented!() } diff --git a/crates/compiler/gen_wasm/src/low_level.rs b/crates/compiler/gen_wasm/src/low_level.rs index 0bab4cdbc95..5836e01f6b9 100644 --- a/crates/compiler/gen_wasm/src/low_level.rs +++ b/crates/compiler/gen_wasm/src/low_level.rs @@ -2087,6 +2087,10 @@ impl<'a> LowLevelCall<'a> { Eq | NotEq => self.eq_or_neq(backend), + Compare => { + panic!("TODO: implement this") + } + BoxExpr | UnboxExpr => { unreachable!("The {:?} operation is turned into mono Expr", self.lowlevel) } diff --git a/crates/compiler/module/src/low_level.rs b/crates/compiler/module/src/low_level.rs index ccc706d6d35..7bd92ffb3dd 100644 --- a/crates/compiler/module/src/low_level.rs +++ b/crates/compiler/module/src/low_level.rs @@ -113,6 +113,7 @@ pub enum LowLevel { And, Or, Not, + Compare, Hash, PtrCast, PtrStore, @@ -352,6 +353,7 @@ map_symbol_to_lowlevel! { And <= BOOL_AND; Or <= BOOL_OR; Not <= BOOL_NOT; + Compare <= LIST_STRUCTURAL_COMPARE; Unreachable <= LIST_UNREACHABLE; DictPseudoSeed <= DICT_PSEUDO_SEED; } diff --git a/crates/compiler/module/src/symbol.rs b/crates/compiler/module/src/symbol.rs index 76f07aa124f..f73633a8dae 100644 --- a/crates/compiler/module/src/symbol.rs +++ b/crates/compiler/module/src/symbol.rs @@ -112,7 +112,7 @@ impl Symbol { self, // The `structuralEq` call used deriving structural equality, which will wrap the `Eq` // low-level implementation. - &Self::BOOL_STRUCTURAL_EQ + &Self::BOOL_STRUCTURAL_EQ | &Self::LIST_STRUCTURAL_COMPARE ) } @@ -1534,6 +1534,7 @@ define_builtins! { 88 LIST_LEN_USIZE: "lenUsize" 89 LIST_SORT: "Sort" exposed_type=true 90 LIST_COMPARE: "compare" + unexposed 91 LIST_STRUCTURAL_COMPARE: "structuralCompare" } 7 RESULT: "Result" => { 0 RESULT_RESULT: "Result" exposed_type=true // the Result.Result type alias diff --git a/crates/compiler/mono/src/drop_specialization.rs b/crates/compiler/mono/src/drop_specialization.rs index a7b9fbfa074..da5c89cf2c6 100644 --- a/crates/compiler/mono/src/drop_specialization.rs +++ b/crates/compiler/mono/src/drop_specialization.rs @@ -1558,7 +1558,7 @@ fn low_level_no_rc(lowlevel: &LowLevel) -> RC { | ListReleaseExcessCapacity | StrReleaseExcessCapacity => RC::Rc, - Eq | NotEq => RC::NoRc, + Eq | NotEq | Compare => RC::NoRc, And | Or | NumAdd | NumAddWrap | NumAddChecked | NumAddSaturated | NumSub | NumSubWrap | NumSubChecked | NumSubSaturated | NumMul | NumMulWrap | NumMulSaturated diff --git a/crates/compiler/mono/src/inc_dec.rs b/crates/compiler/mono/src/inc_dec.rs index fe3a9c9b001..34ec7ec9dba 100644 --- a/crates/compiler/mono/src/inc_dec.rs +++ b/crates/compiler/mono/src/inc_dec.rs @@ -1310,7 +1310,7 @@ pub(crate) fn lowlevel_borrow_signature(op: LowLevel) -> &'static [Ownership] { ListReleaseExcessCapacity => &[OWNED], StrReleaseExcessCapacity => &[OWNED], - Eq | NotEq => &[BORROWED, BORROWED], + Eq | NotEq | Compare => &[BORROWED, BORROWED], And | Or | NumAdd | NumAddWrap | NumAddChecked | NumAddSaturated | NumSub | NumSubWrap | NumSubChecked | NumSubSaturated | NumMul | NumMulWrap | NumMulSaturated diff --git a/crates/compiler/solve/src/ability.rs b/crates/compiler/solve/src/ability.rs index 05ee2dd45a8..4c08366e381 100644 --- a/crates/compiler/solve/src/ability.rs +++ b/crates/compiler/solve/src/ability.rs @@ -303,6 +303,13 @@ impl ObligationCache { Symbol::BOOL_EQ => Some(DeriveEq::is_derivable(self, abilities_store, subs, var)), + Symbol::LIST_COMPARE => Some(DeriveCompare::is_derivable( + self, + abilities_store, + subs, + var, + )), + Symbol::INSPECT_INSPECT_ABILITY => Some(DeriveInspect::is_derivable( self, abilities_store, @@ -1384,6 +1391,133 @@ impl DerivableVisitor for DeriveEq { } } +struct DeriveCompare; +impl DerivableVisitor for DeriveCompare { + const ABILITY: Symbol = Symbol::LIST_COMPARE; + const ABILITY_SLICE: SubsSlice = Subs::AB_COMPARE; + + #[inline(always)] + fn is_derivable_builtin_opaque(symbol: Symbol) -> bool { + is_builtin_fixed_int_alias(symbol) + || is_builtin_dec_alias(symbol) + || is_builtin_bool_alias(symbol) + } + + #[inline(always)] + fn visit_recursion(_var: Variable) -> Result { + Ok(Descend(true)) + } + + #[inline(always)] + fn visit_apply(var: Variable, symbol: Symbol) -> Result { + if matches!( + symbol, + Symbol::LIST_LIST | Symbol::SET_SET | Symbol::DICT_DICT | Symbol::BOX_BOX_TYPE, + ) { + Ok(Descend(true)) + } else { + Err(NotDerivable { + var, + context: NotDerivableContext::NoContext, + }) + } + } + + #[inline(always)] + fn visit_record( + subs: &Subs, + var: Variable, + fields: RecordFields, + ) -> Result { + for (field_name, _, field) in fields.iter_all() { + if subs[field].is_optional() { + return Err(NotDerivable { + var, + context: NotDerivableContext::DecodeOptionalRecordField( + subs[field_name].clone(), + ), + }); + } + } + + Ok(Descend(true)) + } + + #[inline(always)] + fn visit_tuple( + _subs: &Subs, + _var: Variable, + _elems: TupleElems, + ) -> Result { + Ok(Descend(true)) + } + + #[inline(always)] + fn visit_tag_union(_var: Variable) -> Result { + Ok(Descend(true)) + } + + #[inline(always)] + fn visit_recursive_tag_union(_var: Variable) -> Result { + Ok(Descend(true)) + } + + #[inline(always)] + fn visit_function_or_tag_union(_var: Variable) -> Result { + Ok(Descend(true)) + } + + #[inline(always)] + fn visit_empty_record(_var: Variable) -> Result<(), NotDerivable> { + Ok(()) + } + + #[inline(always)] + fn visit_empty_tag_union(_var: Variable) -> Result<(), NotDerivable> { + Ok(()) + } + + #[inline(always)] + fn visit_alias(_var: Variable, _symbol: Symbol) -> Result { + Ok(Descend(true)) + } + + fn visit_floating_point_content( + var: Variable, + subs: &mut Subs, + content_var: Variable, + ) -> Result { + use roc_unify::unify::unify; + + // Of the floating-point types, + // only Dec implements Eq. + // TODO(checkmate): pass checkmate through + let unified = unify( + &mut with_checkmate!({ + on => UEnv::new(subs, None), + off => UEnv::new(subs), + }), + content_var, + Variable::DECIMAL, + UnificationMode::EQ, + Polarity::Pos, + ); + match unified { + roc_unify::unify::Unified::Success { .. } => Ok(Descend(false)), + roc_unify::unify::Unified::Failure(..) => Err(NotDerivable { + var, + context: NotDerivableContext::Eq(NotDerivableEq::FloatingPoint), + }), + } + } + + #[inline(always)] + fn visit_ranged_number(_var: Variable, _range: NumericRange) -> Result<(), NotDerivable> { + // Ranged numbers are allowed, because they are always possibly ints - floats can not have + // `isEq` derived, but if something were to be a float, we'd see it exactly as a float. + Ok(()) + } +} /// Determines what type implements an ability member of a specialized signature, given the /// [MustImplementAbility] constraints of the signature. pub fn type_implementing_specialization( diff --git a/crates/compiler/types/src/subs.rs b/crates/compiler/types/src/subs.rs index a14e61a817a..287e68098a8 100644 --- a/crates/compiler/types/src/subs.rs +++ b/crates/compiler/types/src/subs.rs @@ -1701,6 +1701,8 @@ impl Subs { pub const AB_EQ: SubsSlice = SubsSlice::new(4, 1); #[rustfmt::skip] pub const AB_INSPECT: SubsSlice = SubsSlice::new(5, 1); + #[rustfmt::skip] + pub const AB_COMPARE: SubsSlice = SubsSlice::new(6, 1); // END INIT-SymbolSubsSlice pub fn new() -> Self { @@ -1730,6 +1732,7 @@ impl Subs { symbol_names.push(Symbol::HASH_HASH_ABILITY); symbol_names.push(Symbol::BOOL_EQ); symbol_names.push(Symbol::INSPECT_INSPECT_ABILITY); + symbol_names.push(Symbol::LIST_COMPARE); // END INIT-SymbolSubsSlice // IFTTT INIT-VariableSubsSlice From 678d53e649a41b8fe27858710832aa16adb8b827 Mon Sep 17 00:00:00 2001 From: Jasper Woudenberg Date: Sat, 4 May 2024 14:54:27 +0200 Subject: [PATCH 04/12] Add llvm generic_compare implementation for integers This implements the following calculation posted by Richard on Zulip: https://roc.zulipchat.com/#narrow/stream/304641-ideas/topic/ordering.2Fsorting.20ability/near/403858126 I'm using the compare.rs module that implements equality as a guide, some the code has some similarities with that, but also I'm starting with a blank page to not overwhelm myself, which might result in some differences (at least at first). The compare.rs module we might want to rename to eq.rs in a future commit. That would free up the name compare.rs to implement generic_eq. --- crates/compiler/gen_llvm/src/llvm/lowlevel.rs | 14 +- crates/compiler/gen_llvm/src/llvm/mod.rs | 1 + crates/compiler/gen_llvm/src/llvm/sort.rs | 199 ++++++++++++++++++ 3 files changed, 213 insertions(+), 1 deletion(-) create mode 100644 crates/compiler/gen_llvm/src/llvm/sort.rs diff --git a/crates/compiler/gen_llvm/src/llvm/lowlevel.rs b/crates/compiler/gen_llvm/src/llvm/lowlevel.rs index 5015206f917..531c5cef683 100644 --- a/crates/compiler/gen_llvm/src/llvm/lowlevel.rs +++ b/crates/compiler/gen_llvm/src/llvm/lowlevel.rs @@ -55,6 +55,7 @@ use crate::llvm::{ LLVM_SUB_WITH_OVERFLOW, }, refcounting::PointerToRefcount, + sort::generic_compare, }; use super::{build::Env, convert::zig_dec_type}; @@ -1270,7 +1271,18 @@ pub(crate) fn run_low_level<'a, 'ctx>( BasicValueEnum::IntValue(bool_val) } Compare => { - panic!("TODO: implement this") + // Sort.compare : elem, elem -> [LessThan, Equal, GreaterThan] + arguments_with_layouts!((lhs_arg, lhs_layout), (rhs_arg, rhs_layout)); + + generic_compare( + env, + layout_interner, + layout_ids, + lhs_arg, + rhs_arg, + lhs_layout, + rhs_layout, + ) } Hash => { unimplemented!() diff --git a/crates/compiler/gen_llvm/src/llvm/mod.rs b/crates/compiler/gen_llvm/src/llvm/mod.rs index d606191338c..3d009a2abd5 100644 --- a/crates/compiler/gen_llvm/src/llvm/mod.rs +++ b/crates/compiler/gen_llvm/src/llvm/mod.rs @@ -9,6 +9,7 @@ pub mod externs; mod intrinsics; mod lowlevel; pub mod refcounting; +pub mod sort; mod align; mod erased; diff --git a/crates/compiler/gen_llvm/src/llvm/sort.rs b/crates/compiler/gen_llvm/src/llvm/sort.rs new file mode 100644 index 00000000000..0484a75d43b --- /dev/null +++ b/crates/compiler/gen_llvm/src/llvm/sort.rs @@ -0,0 +1,199 @@ +use super::build::BuilderExt; +use crate::llvm::build::Env; +use inkwell::values::{BasicValueEnum, IntValue}; +use inkwell::IntPredicate; +use roc_builtins::bitcode::IntWidth; +use roc_mono::layout::{ + Builtin, InLayout, LayoutIds, LayoutInterner, LayoutRepr, STLayoutInterner, +}; + +pub fn generic_compare<'a, 'ctx>( + env: &Env<'a, 'ctx, '_>, + layout_interner: &STLayoutInterner<'a>, + _layout_ids: &mut LayoutIds<'a>, + lhs_val: BasicValueEnum<'ctx>, + rhs_val: BasicValueEnum<'ctx>, + lhs_layout: InLayout<'a>, + _rhs_layout: InLayout<'a>, +) -> BasicValueEnum<'ctx> { + let lhs_repr = layout_interner.get_repr(lhs_layout); + let result = match lhs_repr { + LayoutRepr::Builtin(Builtin::Int(int_width)) => { + int_compare(env, lhs_val, rhs_val, int_width) + } + LayoutRepr::Builtin(Builtin::Float(_)) => todo!(), + LayoutRepr::Builtin(Builtin::Bool) => todo!(), + LayoutRepr::Builtin(Builtin::Decimal) => todo!(), + LayoutRepr::Builtin(Builtin::Str) => todo!(), + LayoutRepr::Builtin(Builtin::List(_)) => todo!(), + LayoutRepr::Struct(_) => todo!(), + LayoutRepr::LambdaSet(_) => unreachable!("cannot compare closures"), + LayoutRepr::FunctionPointer(_) => unreachable!("cannot compare function pointers"), + LayoutRepr::Erased(_) => unreachable!("cannot compare erased types"), + LayoutRepr::Union(_) => todo!(), + LayoutRepr::Ptr(_) => todo!(), + LayoutRepr::RecursivePointer(_) => todo!(), + }; + BasicValueEnum::IntValue(result) +} + +fn int_compare<'ctx>( + env: &Env<'_, 'ctx, '_>, + lhs_val: BasicValueEnum<'ctx>, + rhs_val: BasicValueEnum<'ctx>, + builtin: IntWidth, +) -> IntValue<'ctx> { + // The following calculation will return 0 for equals, 1 for greater than, + // and 2 for less than. + // (a > b) + 2 * (a < b); + let lhs_gt_rhs = int_gt(env, lhs_val, rhs_val, builtin); + let lhs_lt_rhs = int_lt(env, lhs_val, rhs_val, builtin); + let two = env.ptr_int().const_int(2, false); + let lhs_lt_rhs_times_two = + env.builder + .new_build_int_mul(lhs_lt_rhs, two, "lhs_lt_rhs_times_two"); + env.builder + .new_build_int_sub(lhs_gt_rhs, lhs_lt_rhs_times_two, "int_compare") +} + +fn int_lt<'ctx>( + env: &Env<'_, 'ctx, '_>, + lhs_val: BasicValueEnum<'ctx>, + rhs_val: BasicValueEnum<'ctx>, + builtin: IntWidth, +) -> IntValue<'ctx> { + use IntWidth::*; + match builtin { + I128 => env.builder.new_build_int_compare( + IntPredicate::SLT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_lt_rhs_i28", + ), + I64 => env.builder.new_build_int_compare( + IntPredicate::SLT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_lt_rhs_i64", + ), + I32 => env.builder.new_build_int_compare( + IntPredicate::SLT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_lt_rhs_i32", + ), + I16 => env.builder.new_build_int_compare( + IntPredicate::SLT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_lt_rhs_i16", + ), + I8 => env.builder.new_build_int_compare( + IntPredicate::SLT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_lt_rhs_i8", + ), + U128 => env.builder.new_build_int_compare( + IntPredicate::ULT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_lt_rhs_u128", + ), + U64 => env.builder.new_build_int_compare( + IntPredicate::ULT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_lt_rhs_u64", + ), + U32 => env.builder.new_build_int_compare( + IntPredicate::ULT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_lt_rhs_u32", + ), + U16 => env.builder.new_build_int_compare( + IntPredicate::ULT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_lt_rhs_u16", + ), + U8 => env.builder.new_build_int_compare( + IntPredicate::ULT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_lt_rhs_u8", + ), + } +} + +fn int_gt<'ctx>( + env: &Env<'_, 'ctx, '_>, + lhs_val: BasicValueEnum<'ctx>, + rhs_val: BasicValueEnum<'ctx>, + builtin: IntWidth, +) -> IntValue<'ctx> { + use IntWidth::*; + match builtin { + I128 => env.builder.new_build_int_compare( + IntPredicate::SGT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_gt_rhs_i28", + ), + I64 => env.builder.new_build_int_compare( + IntPredicate::SGT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_gt_rhs_i64", + ), + I32 => env.builder.new_build_int_compare( + IntPredicate::SGT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_gt_rhs_i32", + ), + I16 => env.builder.new_build_int_compare( + IntPredicate::SGT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_gt_rhs_i16", + ), + I8 => env.builder.new_build_int_compare( + IntPredicate::SGT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_gt_rhs_i8", + ), + U128 => env.builder.new_build_int_compare( + IntPredicate::UGT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_gt_rhs_u128", + ), + U64 => env.builder.new_build_int_compare( + IntPredicate::UGT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_gt_rhs_u64", + ), + U32 => env.builder.new_build_int_compare( + IntPredicate::UGT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_gt_rhs_u32", + ), + U16 => env.builder.new_build_int_compare( + IntPredicate::UGT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_gt_rhs_u16", + ), + U8 => env.builder.new_build_int_compare( + IntPredicate::UGT, + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "lhs_gt_rhs_u8", + ), + } +} From e8039c79c0f767ce6f839eed16a8d6c9fd05538f Mon Sep 17 00:00:00 2001 From: Jasper Woudenberg Date: Sat, 18 May 2024 13:42:50 +0200 Subject: [PATCH 05/12] 'Fix' llvm test gen Not sure why, but if I don't choose 108 as the ident id I get an error that structureCompare is missing an implementation. --- crates/compiler/module/src/symbol.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/compiler/module/src/symbol.rs b/crates/compiler/module/src/symbol.rs index f73633a8dae..5761608bf37 100644 --- a/crates/compiler/module/src/symbol.rs +++ b/crates/compiler/module/src/symbol.rs @@ -1534,7 +1534,7 @@ define_builtins! { 88 LIST_LEN_USIZE: "lenUsize" 89 LIST_SORT: "Sort" exposed_type=true 90 LIST_COMPARE: "compare" - unexposed 91 LIST_STRUCTURAL_COMPARE: "structuralCompare" + unexposed 108 LIST_STRUCTURAL_COMPARE: "structuralCompare" } 7 RESULT: "Result" => { 0 RESULT_RESULT: "Result" exposed_type=true // the Result.Result type alias From fbceda6e2051cb2e14ecc2f66b64ffa6e7842c6e Mon Sep 17 00:00:00 2001 From: Jasper Woudenberg Date: Sun, 16 Jun 2024 11:43:21 +0200 Subject: [PATCH 06/12] Add (failing) test for Sort code generation --- crates/compiler/test_gen/src/gen_compare.rs | 33 +++++++++++++++++++++ crates/compiler/test_gen/src/tests.rs | 1 + 2 files changed, 34 insertions(+) create mode 100644 crates/compiler/test_gen/src/gen_compare.rs diff --git a/crates/compiler/test_gen/src/gen_compare.rs b/crates/compiler/test_gen/src/gen_compare.rs new file mode 100644 index 00000000000..c528b49921a --- /dev/null +++ b/crates/compiler/test_gen/src/gen_compare.rs @@ -0,0 +1,33 @@ +#[cfg(feature = "gen-llvm")] +use crate::helpers::llvm::assert_evals_to; + +#[cfg(feature = "gen-dev")] +use crate::helpers::dev::assert_evals_to; + +#[cfg(feature = "gen-wasm")] +use crate::helpers::wasm::assert_evals_to; + +use indoc::indoc; + +#[test] +#[cfg(any(feature = "gen-llvm", feature = "gen-dev", feature = "gen-wasm"))] +fn compare_i64() { + assert_evals_to!( + indoc!( + r#" + i : I64 + i = 1 + + j : I64 + j = 2 + + when List.compare i j is + Equals -> 0 + GreaterThan -> 1 + LessThan -> 2 + "# + ), + 2, + i64 + ); +} diff --git a/crates/compiler/test_gen/src/tests.rs b/crates/compiler/test_gen/src/tests.rs index 3338b02447d..55e48a5a748 100644 --- a/crates/compiler/test_gen/src/tests.rs +++ b/crates/compiler/test_gen/src/tests.rs @@ -5,6 +5,7 @@ #![allow(clippy::float_cmp)] pub mod gen_abilities; +pub mod gen_compare; pub mod gen_definitions; pub mod gen_dict; pub mod gen_equality; From c8d348dd0fe1ca8e02133ec9ae7076c3aa879304 Mon Sep 17 00:00:00 2001 From: Ben Plotke Date: Thu, 11 Jul 2024 05:16:50 -0700 Subject: [PATCH 07/12] Added LIST_SORT symbol --- crates/compiler/solve/src/ability.rs | 54 +++++++++++++++------------- crates/compiler/types/src/subs.rs | 2 ++ 2 files changed, 31 insertions(+), 25 deletions(-) diff --git a/crates/compiler/solve/src/ability.rs b/crates/compiler/solve/src/ability.rs index 4c08366e381..1fb1316d534 100644 --- a/crates/compiler/solve/src/ability.rs +++ b/crates/compiler/solve/src/ability.rs @@ -317,6 +317,8 @@ impl ObligationCache { var, )), + Symbol::LIST_SORT => Some(DeriveSort::is_derivable(self, abilities_store, subs, var)), + _ => None, }; @@ -399,6 +401,8 @@ impl ObligationCache { DeriveEq::ABILITY => DeriveEq::is_derivable_builtin_opaque(opaque), DeriveHash::ABILITY => DeriveHash::is_derivable_builtin_opaque(opaque), DeriveInspect::ABILITY => DeriveInspect::is_derivable_builtin_opaque(opaque), + DeriveSort::ABILITY => DeriveSort::is_derivable_builtin_opaque(opaque), + DeriveCompare::ABILITY => DeriveCompare::is_derivable_builtin_opaque(opaque), _ => false, }; @@ -858,6 +862,26 @@ trait DerivableVisitor { } } +struct DeriveSort; +impl DerivableVisitor for DeriveSort { + const ABILITY: Symbol = Symbol::LIST_SORT; + const ABILITY_SLICE: SubsSlice = Subs::AB_SORT; + + #[inline(always)] + fn is_derivable_builtin_opaque(_symbol: Symbol) -> bool { + true + } + + #[inline(always)] + fn visit_floating_point_content( + _var: Variable, + _subs: &mut Subs, + _content_var: Variable, + ) -> Result { + Ok(Descend(true)) + } +} + struct DeriveInspect; impl DerivableVisitor for DeriveInspect { const ABILITY: Symbol = Symbol::INSPECT_INSPECT_ABILITY; @@ -1482,33 +1506,13 @@ impl DerivableVisitor for DeriveCompare { Ok(Descend(true)) } + #[inline(always)] fn visit_floating_point_content( - var: Variable, - subs: &mut Subs, - content_var: Variable, + _var: Variable, + _subs: &mut Subs, + _content_var: Variable, ) -> Result { - use roc_unify::unify::unify; - - // Of the floating-point types, - // only Dec implements Eq. - // TODO(checkmate): pass checkmate through - let unified = unify( - &mut with_checkmate!({ - on => UEnv::new(subs, None), - off => UEnv::new(subs), - }), - content_var, - Variable::DECIMAL, - UnificationMode::EQ, - Polarity::Pos, - ); - match unified { - roc_unify::unify::Unified::Success { .. } => Ok(Descend(false)), - roc_unify::unify::Unified::Failure(..) => Err(NotDerivable { - var, - context: NotDerivableContext::Eq(NotDerivableEq::FloatingPoint), - }), - } + Ok(Descend(true)) } #[inline(always)] diff --git a/crates/compiler/types/src/subs.rs b/crates/compiler/types/src/subs.rs index 287e68098a8..d76199a0bae 100644 --- a/crates/compiler/types/src/subs.rs +++ b/crates/compiler/types/src/subs.rs @@ -1703,6 +1703,8 @@ impl Subs { pub const AB_INSPECT: SubsSlice = SubsSlice::new(5, 1); #[rustfmt::skip] pub const AB_COMPARE: SubsSlice = SubsSlice::new(6, 1); + #[rustfmt::skip] + pub const AB_SORT: SubsSlice = SubsSlice::new(7, 1); // END INIT-SymbolSubsSlice pub fn new() -> Self { From be1d2ed5fe23297bdb569e5c40e16ba884ba6798 Mon Sep 17 00:00:00 2001 From: Ben Plotke Date: Thu, 11 Jul 2024 05:17:20 -0700 Subject: [PATCH 08/12] implemented sort for float and bool --- crates/compiler/gen_llvm/src/llvm/sort.rs | 101 +++++++++++++++++++--- 1 file changed, 91 insertions(+), 10 deletions(-) diff --git a/crates/compiler/gen_llvm/src/llvm/sort.rs b/crates/compiler/gen_llvm/src/llvm/sort.rs index 0484a75d43b..71ac35a5cbb 100644 --- a/crates/compiler/gen_llvm/src/llvm/sort.rs +++ b/crates/compiler/gen_llvm/src/llvm/sort.rs @@ -1,8 +1,8 @@ use super::build::BuilderExt; use crate::llvm::build::Env; use inkwell::values::{BasicValueEnum, IntValue}; -use inkwell::IntPredicate; -use roc_builtins::bitcode::IntWidth; +use inkwell::{IntPredicate, FloatPredicate}; +use roc_builtins::bitcode::{IntWidth, FloatWidth}; use roc_mono::layout::{ Builtin, InLayout, LayoutIds, LayoutInterner, LayoutRepr, STLayoutInterner, }; @@ -21,8 +21,12 @@ pub fn generic_compare<'a, 'ctx>( LayoutRepr::Builtin(Builtin::Int(int_width)) => { int_compare(env, lhs_val, rhs_val, int_width) } - LayoutRepr::Builtin(Builtin::Float(_)) => todo!(), - LayoutRepr::Builtin(Builtin::Bool) => todo!(), + LayoutRepr::Builtin(Builtin::Float(float_width)) => { + float_cmp(env, lhs_val, rhs_val, float_width) + } + LayoutRepr::Builtin(Builtin::Bool) => { + bool_compare(env, lhs_val, rhs_val) + } LayoutRepr::Builtin(Builtin::Decimal) => todo!(), LayoutRepr::Builtin(Builtin::Str) => todo!(), LayoutRepr::Builtin(Builtin::List(_)) => todo!(), @@ -48,14 +52,16 @@ fn int_compare<'ctx>( // (a > b) + 2 * (a < b); let lhs_gt_rhs = int_gt(env, lhs_val, rhs_val, builtin); let lhs_lt_rhs = int_lt(env, lhs_val, rhs_val, builtin); - let two = env.ptr_int().const_int(2, false); + let two = env.context.i8_type().const_int(2, false); let lhs_lt_rhs_times_two = env.builder .new_build_int_mul(lhs_lt_rhs, two, "lhs_lt_rhs_times_two"); env.builder - .new_build_int_sub(lhs_gt_rhs, lhs_lt_rhs_times_two, "int_compare") + .new_build_int_add(lhs_gt_rhs, lhs_lt_rhs_times_two, "int_compare") } + + fn int_lt<'ctx>( env: &Env<'_, 'ctx, '_>, lhs_val: BasicValueEnum<'ctx>, @@ -63,7 +69,7 @@ fn int_lt<'ctx>( builtin: IntWidth, ) -> IntValue<'ctx> { use IntWidth::*; - match builtin { + let lhs_lt_rhs = match builtin { I128 => env.builder.new_build_int_compare( IntPredicate::SLT, lhs_val.into_int_value(), @@ -124,7 +130,8 @@ fn int_lt<'ctx>( rhs_val.into_int_value(), "lhs_lt_rhs_u8", ), - } + }; + env.builder.new_build_int_cast_sign_flag(lhs_lt_rhs, env.context.i8_type(), false, "lhs_lt_rhs_cast") } fn int_gt<'ctx>( @@ -134,7 +141,7 @@ fn int_gt<'ctx>( builtin: IntWidth, ) -> IntValue<'ctx> { use IntWidth::*; - match builtin { + let lhs_gt_rhs = match builtin { I128 => env.builder.new_build_int_compare( IntPredicate::SGT, lhs_val.into_int_value(), @@ -195,5 +202,79 @@ fn int_gt<'ctx>( rhs_val.into_int_value(), "lhs_gt_rhs_u8", ), - } + }; + env.builder.new_build_int_cast_sign_flag(lhs_gt_rhs, env.context.i8_type(), false, "lhs_gt_rhs_cast") +} + + +// Return 0 for equals, 1 for greater than, and 2 for less than. +// We consider NaNs to be smaller than non-NaNs +// We use the below expression to calculate this +// (a == a) + 2*(b == b) - (a < b) - 2*(a > b) - 3*(a == b) +fn float_cmp<'ctx>( + env: &Env<'_, 'ctx, '_>, + lhs_val: BasicValueEnum<'ctx>, + rhs_val: BasicValueEnum<'ctx>, + float_width: FloatWidth, +) -> IntValue<'ctx> { + use FloatWidth::*; + let type_str = match float_width { + F64 => "F64", + F32 => "F32", + }; + + let make_cmp = |operation, a: BasicValueEnum<'ctx>, b: BasicValueEnum<'ctx>, op_name: &str| { + let full_op_name = format!("{}_{}", op_name, type_str); + let bool_result = env.builder.new_build_float_compare( + operation, + a.into_float_value(), + b.into_float_value(), + &full_op_name, + ); + env.builder.new_build_int_cast_sign_flag(bool_result, env.context.i8_type(), false, &format!("{}_cast", full_op_name)) + }; + + let two = env.context.i8_type().const_int(2, false); + let three = env.context.i8_type().const_int(3, false); + + let lt_test = make_cmp(FloatPredicate::OLT, lhs_val, rhs_val, "rhs_lt_lhs"); + let gt_test = make_cmp(FloatPredicate::OGT, lhs_val, rhs_val, "lhs_gt_rhs"); + let eq_test = make_cmp(FloatPredicate::OEQ, lhs_val, rhs_val, "lhs_eq_rhs"); + let lhs_not_nan_test = make_cmp(FloatPredicate::OEQ, lhs_val, lhs_val, "lhs_not_NaN"); + let rhs_not_nan_test = make_cmp(FloatPredicate::OEQ, rhs_val, rhs_val, "rhs_not_NaN"); + + let rhs_not_nan_scaled = env.builder.new_build_int_mul(two, rhs_not_nan_test, "2 * rhs_not_nan"); + let gt_scaled = env.builder.new_build_int_mul(two, gt_test, "2 * lhs_gt_rhs"); + let eq_scaled = env.builder.new_build_int_mul(three, eq_test, "3 * lhs_eq_rhs"); + + let non_nans = env.builder.new_build_int_add(lhs_not_nan_test, rhs_not_nan_scaled, "(a == a) + 2*(b == b))"); + let minus_lt = env.builder.new_build_int_sub(non_nans, lt_test, "(a == a) + 2*(b == b) - (a < b"); + let minus_gt = env.builder.new_build_int_sub(minus_lt, gt_scaled, "(a == a) + 2*(b == b) - (a < b) - 2*(a > b)"); + env.builder.new_build_int_sub(minus_gt, eq_scaled, "float_compare") } + +// 1 1 0 +// 0 0 0 +// 0 1 1 +// 1 0 2 +fn bool_compare<'ctx>( + env: &Env<'_, 'ctx, '_>, + lhs_val: BasicValueEnum<'ctx>, + rhs_val: BasicValueEnum<'ctx>, +) -> IntValue<'ctx> { + + // (a < b) + let lhs_lt_rhs = env.builder.new_build_int_compare(IntPredicate::SLT, lhs_val.into_int_value(), rhs_val.into_int_value(), "lhs_lt_rhs_bool"); + let lhs_lt_rhs_byte = env.builder.new_build_int_cast_sign_flag(lhs_lt_rhs, env.context.i8_type(), false, "lhs_lt_rhs_byte"); + + // (a > b) + let lhs_gt_rhs = env.builder.new_build_int_compare(IntPredicate::SGT, lhs_val.into_int_value(), rhs_val.into_int_value(), "lhs_gt_rhs_bool"); + let lhs_gt_rhs_byte = env.builder.new_build_int_cast_sign_flag(lhs_gt_rhs, env.context.i8_type(), false, "lhs_gt_rhs_byte"); + + // (a > b) * 2 + let two = env.context.i8_type().const_int(2, false); + let lhs_gt_rhs_times_two = env.builder.new_build_int_mul(lhs_gt_rhs_byte, two, "lhs_gt_rhs_times_two"); + + // (a < b) + (a > b) * 2 + env.builder.new_build_int_add(lhs_lt_rhs_byte, lhs_gt_rhs_times_two, "bool_compare") +} \ No newline at end of file From 6a10d8ae2ef1c0b69ea22d44c7f4918df3a31d0b Mon Sep 17 00:00:00 2001 From: Ben Plotke Date: Fri, 12 Jul 2024 21:27:19 -0700 Subject: [PATCH 09/12] Implemented Dec compare --- crates/compiler/gen_llvm/src/llvm/sort.rs | 48 ++++++++++++++++++----- 1 file changed, 39 insertions(+), 9 deletions(-) diff --git a/crates/compiler/gen_llvm/src/llvm/sort.rs b/crates/compiler/gen_llvm/src/llvm/sort.rs index 71ac35a5cbb..f0cf4c5d1dc 100644 --- a/crates/compiler/gen_llvm/src/llvm/sort.rs +++ b/crates/compiler/gen_llvm/src/llvm/sort.rs @@ -1,8 +1,9 @@ use super::build::BuilderExt; use crate::llvm::build::Env; +use crate::llvm::bitcode::call_bitcode_fn; use inkwell::values::{BasicValueEnum, IntValue}; use inkwell::{IntPredicate, FloatPredicate}; -use roc_builtins::bitcode::{IntWidth, FloatWidth}; +use roc_builtins::bitcode::{IntWidth, FloatWidth, NUM_LESS_THAN, NUM_GREATER_THAN}; use roc_mono::layout::{ Builtin, InLayout, LayoutIds, LayoutInterner, LayoutRepr, STLayoutInterner, }; @@ -27,7 +28,9 @@ pub fn generic_compare<'a, 'ctx>( LayoutRepr::Builtin(Builtin::Bool) => { bool_compare(env, lhs_val, rhs_val) } - LayoutRepr::Builtin(Builtin::Decimal) => todo!(), + LayoutRepr::Builtin(Builtin::Decimal) => { + dec_compare(env, lhs_val, rhs_val) + } LayoutRepr::Builtin(Builtin::Str) => todo!(), LayoutRepr::Builtin(Builtin::List(_)) => todo!(), LayoutRepr::Struct(_) => todo!(), @@ -237,7 +240,7 @@ fn float_cmp<'ctx>( let two = env.context.i8_type().const_int(2, false); let three = env.context.i8_type().const_int(3, false); - let lt_test = make_cmp(FloatPredicate::OLT, lhs_val, rhs_val, "rhs_lt_lhs"); + let lt_test = make_cmp(FloatPredicate::OLT, lhs_val, rhs_val, "lhs_lt_rhs"); let gt_test = make_cmp(FloatPredicate::OGT, lhs_val, rhs_val, "lhs_gt_rhs"); let eq_test = make_cmp(FloatPredicate::OEQ, lhs_val, rhs_val, "lhs_eq_rhs"); let lhs_not_nan_test = make_cmp(FloatPredicate::OEQ, lhs_val, lhs_val, "lhs_not_NaN"); @@ -263,18 +266,45 @@ fn bool_compare<'ctx>( rhs_val: BasicValueEnum<'ctx>, ) -> IntValue<'ctx> { + // Cast the input bools to ints because int comparison of bools does the opposite of what one would expect. + // I could just swap the arguments, but I do not want to rely on behavior which seems wrong + let lhs_byte = env.builder.new_build_int_cast_sign_flag(lhs_val.into_int_value(), env.context.i8_type(), false, "lhs_byte"); + let rhs_byte = env.builder.new_build_int_cast_sign_flag(rhs_val.into_int_value(), env.context.i8_type(), false, "rhs_byte"); + // (a < b) - let lhs_lt_rhs = env.builder.new_build_int_compare(IntPredicate::SLT, lhs_val.into_int_value(), rhs_val.into_int_value(), "lhs_lt_rhs_bool"); + let lhs_lt_rhs = env.builder.new_build_int_compare(IntPredicate::SLT, lhs_byte, rhs_byte, "lhs_lt_rhs_bool"); let lhs_lt_rhs_byte = env.builder.new_build_int_cast_sign_flag(lhs_lt_rhs, env.context.i8_type(), false, "lhs_lt_rhs_byte"); // (a > b) - let lhs_gt_rhs = env.builder.new_build_int_compare(IntPredicate::SGT, lhs_val.into_int_value(), rhs_val.into_int_value(), "lhs_gt_rhs_bool"); + let lhs_gt_rhs = env.builder.new_build_int_compare(IntPredicate::SGT, lhs_byte, rhs_byte, "lhs_gt_rhs_bool"); let lhs_gt_rhs_byte = env.builder.new_build_int_cast_sign_flag(lhs_gt_rhs, env.context.i8_type(), false, "lhs_gt_rhs_byte"); - // (a > b) * 2 + // (a < b) * 2 + let two = env.context.i8_type().const_int(2, false); + let lhs_lt_rhs_times_two = env.builder.new_build_int_mul(lhs_lt_rhs_byte, two, "lhs_lt_rhs_times_two"); + + // (a > b) + (a < b) * 2 + env.builder.new_build_int_add(lhs_gt_rhs_byte, lhs_lt_rhs_times_two, "bool_compare") +} + +fn dec_compare<'ctx>( + env: &Env<'_, 'ctx, '_>, + lhs_val: BasicValueEnum<'ctx>, + rhs_val: BasicValueEnum<'ctx>, +) -> IntValue<'ctx> { + + // (a > b) + let lhs_gt_rhs = call_bitcode_fn(env, &[lhs_val, rhs_val], &NUM_GREATER_THAN[IntWidth::I128]).into_int_value(); + let lhs_gt_rhs_byte = env.builder.new_build_int_cast_sign_flag(lhs_gt_rhs, env.context.i8_type(), false, "lhs_gt_rhs_byte"); + + // (a < b) + let lhs_lt_rhs = call_bitcode_fn(env, &[lhs_val, rhs_val], &NUM_LESS_THAN[IntWidth::I128]).into_int_value(); + let lhs_lt_rhs_byte = env.builder.new_build_int_cast_sign_flag(lhs_lt_rhs, env.context.i8_type(), false, "lhs_lt_rhs_byte"); + + // (a < b) * 2 let two = env.context.i8_type().const_int(2, false); - let lhs_gt_rhs_times_two = env.builder.new_build_int_mul(lhs_gt_rhs_byte, two, "lhs_gt_rhs_times_two"); + let lhs_lt_rhs_times_two = env.builder.new_build_int_mul(lhs_lt_rhs_byte, two, "lhs_gt_rhs_times_two"); - // (a < b) + (a > b) * 2 - env.builder.new_build_int_add(lhs_lt_rhs_byte, lhs_gt_rhs_times_two, "bool_compare") + // (a > b) + (a < b) * 2 + env.builder.new_build_int_add(lhs_gt_rhs_byte, lhs_lt_rhs_times_two, "bool_compare") } \ No newline at end of file From 75f091d308d79a293729acff9382c44c148792e8 Mon Sep 17 00:00:00 2001 From: Jasper Woudenberg Date: Sat, 13 Jul 2024 11:42:36 +0200 Subject: [PATCH 10/12] Add tests for Sort/compare ability We had a single test before, but it was failing due to a typo. This fixes the test, plus adds some more for the new sortable types we've added since. --- crates/compiler/test_gen/src/gen_compare.rs | 219 +++++++++++++++++++- 1 file changed, 216 insertions(+), 3 deletions(-) diff --git a/crates/compiler/test_gen/src/gen_compare.rs b/crates/compiler/test_gen/src/gen_compare.rs index c528b49921a..e3c22973e57 100644 --- a/crates/compiler/test_gen/src/gen_compare.rs +++ b/crates/compiler/test_gen/src/gen_compare.rs @@ -9,9 +9,86 @@ use crate::helpers::wasm::assert_evals_to; use indoc::indoc; +#[test] +#[cfg(any(feature = "gen-llvm", feature = "gen-dev", feature = "gen-wasm"))] +fn compare_u64() { + assert_evals_to!( + indoc!( + r#" + i : U64 + i = 1 + + j : U64 + j = 2 + + when List.compare i j is + Equal -> 0 + GreaterThan -> 1 + LessThan -> 2 + "# + ), + 2, + u8 + ); + assert_evals_to!( + indoc!( + r#" + i : U64 + i = 2 + + j : U64 + j = 1 + + when List.compare i j is + Equal -> 0 + GreaterThan -> 1 + LessThan -> 2 + "# + ), + 1, + u8 + ); + assert_evals_to!( + indoc!( + r#" + i : U64 + i = 1 + + j : U64 + j = 1 + + when List.compare i j is + Equal -> 0 + GreaterThan -> 1 + LessThan -> 2 + "# + ), + 0, + u8 + ); +} + #[test] #[cfg(any(feature = "gen-llvm", feature = "gen-dev", feature = "gen-wasm"))] fn compare_i64() { + assert_evals_to!( + indoc!( + r#" + i : I64 + i = -1 + + j : I64 + j = 1 + + when List.compare i j is + Equal -> 0 + GreaterThan -> 1 + LessThan -> 2 + "# + ), + 2, + u8 + ); assert_evals_to!( indoc!( r#" @@ -19,15 +96,151 @@ fn compare_i64() { i = 1 j : I64 - j = 2 + j = -1 + + when List.compare i j is + Equal -> 0 + GreaterThan -> 1 + LessThan -> 2 + "# + ), + 1, + u8 + ); + assert_evals_to!( + indoc!( + r#" + i : I64 + i = -1 + + j : I64 + j = -1 + + when List.compare i j is + Equal -> 0 + GreaterThan -> 1 + LessThan -> 2 + "# + ), + 0, + u8 + ); +} + +#[test] +#[cfg(any(feature = "gen-llvm", feature = "gen-dev", feature = "gen-wasm"))] +fn compare_f64() { + assert_evals_to!( + indoc!( + r#" + i : F64 + i = 1.2 + + j : F64 + j = 1.3 + + when List.compare i j is + Equal -> 0 + GreaterThan -> 1 + LessThan -> 2 + "# + ), + 2, + u8 + ); + assert_evals_to!( + indoc!( + r#" + i : F64 + i = 1.3 + + j : F64 + j = 1.2 + + when List.compare i j is + Equal -> 0 + GreaterThan -> 1 + LessThan -> 2 + "# + ), + 1, + u8 + ); + assert_evals_to!( + indoc!( + r#" + i : F64 + i = 1.2 + + j : F64 + j = 1.2 + + when List.compare i j is + Equal -> 0 + GreaterThan -> 1 + LessThan -> 2 + "# + ), + 0, + u8 + ); +} + +#[test] +#[cfg(any(feature = "gen-llvm", feature = "gen-dev", feature = "gen-wasm"))] +fn compare_bool() { + assert_evals_to!( + indoc!( + r#" + i : Bool + i = Bool.false + + j : Bool + j = Bool.true when List.compare i j is - Equals -> 0 + Equal -> 0 GreaterThan -> 1 LessThan -> 2 "# ), 2, - i64 + u8 + ); + assert_evals_to!( + indoc!( + r#" + i : Bool + i = Bool.true + + j : Bool + j = Bool.false + + when List.compare i j is + Equal -> 0 + GreaterThan -> 1 + LessThan -> 2 + "# + ), + 1, + u8 + ); + assert_evals_to!( + indoc!( + r#" + i : Bool + i = Bool.false + + j : Bool + j = Bool.false + + when List.compare i j is + Equal -> 0 + GreaterThan -> 1 + LessThan -> 2 + "# + ), + 0, + u8 ); } From 43104c3c422900d43b65847822265b2eafca7b4e Mon Sep 17 00:00:00 2001 From: Ben Plotke Date: Tue, 16 Jul 2024 08:53:13 -0700 Subject: [PATCH 11/12] Rewrote boolean sort to be simpler --- crates/compiler/gen_llvm/src/llvm/sort.rs | 153 ++++++++++++++-------- 1 file changed, 101 insertions(+), 52 deletions(-) diff --git a/crates/compiler/gen_llvm/src/llvm/sort.rs b/crates/compiler/gen_llvm/src/llvm/sort.rs index f0cf4c5d1dc..d3ab46685f6 100644 --- a/crates/compiler/gen_llvm/src/llvm/sort.rs +++ b/crates/compiler/gen_llvm/src/llvm/sort.rs @@ -1,9 +1,9 @@ use super::build::BuilderExt; -use crate::llvm::build::Env; use crate::llvm::bitcode::call_bitcode_fn; +use crate::llvm::build::Env; use inkwell::values::{BasicValueEnum, IntValue}; -use inkwell::{IntPredicate, FloatPredicate}; -use roc_builtins::bitcode::{IntWidth, FloatWidth, NUM_LESS_THAN, NUM_GREATER_THAN}; +use inkwell::{FloatPredicate, IntPredicate}; +use roc_builtins::bitcode::{FloatWidth, IntWidth, NUM_GREATER_THAN, NUM_LESS_THAN}; use roc_mono::layout::{ Builtin, InLayout, LayoutIds, LayoutInterner, LayoutRepr, STLayoutInterner, }; @@ -25,12 +25,8 @@ pub fn generic_compare<'a, 'ctx>( LayoutRepr::Builtin(Builtin::Float(float_width)) => { float_cmp(env, lhs_val, rhs_val, float_width) } - LayoutRepr::Builtin(Builtin::Bool) => { - bool_compare(env, lhs_val, rhs_val) - } - LayoutRepr::Builtin(Builtin::Decimal) => { - dec_compare(env, lhs_val, rhs_val) - } + LayoutRepr::Builtin(Builtin::Bool) => bool_compare(env, lhs_val, rhs_val), + LayoutRepr::Builtin(Builtin::Decimal) => dec_compare(env, lhs_val, rhs_val), LayoutRepr::Builtin(Builtin::Str) => todo!(), LayoutRepr::Builtin(Builtin::List(_)) => todo!(), LayoutRepr::Struct(_) => todo!(), @@ -63,8 +59,6 @@ fn int_compare<'ctx>( .new_build_int_add(lhs_gt_rhs, lhs_lt_rhs_times_two, "int_compare") } - - fn int_lt<'ctx>( env: &Env<'_, 'ctx, '_>, lhs_val: BasicValueEnum<'ctx>, @@ -134,7 +128,12 @@ fn int_lt<'ctx>( "lhs_lt_rhs_u8", ), }; - env.builder.new_build_int_cast_sign_flag(lhs_lt_rhs, env.context.i8_type(), false, "lhs_lt_rhs_cast") + env.builder.new_build_int_cast_sign_flag( + lhs_lt_rhs, + env.context.i8_type(), + false, + "lhs_lt_rhs_cast", + ) } fn int_gt<'ctx>( @@ -206,10 +205,14 @@ fn int_gt<'ctx>( "lhs_gt_rhs_u8", ), }; - env.builder.new_build_int_cast_sign_flag(lhs_gt_rhs, env.context.i8_type(), false, "lhs_gt_rhs_cast") + env.builder.new_build_int_cast_sign_flag( + lhs_gt_rhs, + env.context.i8_type(), + false, + "lhs_gt_rhs_cast", + ) } - // Return 0 for equals, 1 for greater than, and 2 for less than. // We consider NaNs to be smaller than non-NaNs // We use the below expression to calculate this @@ -234,57 +237,89 @@ fn float_cmp<'ctx>( b.into_float_value(), &full_op_name, ); - env.builder.new_build_int_cast_sign_flag(bool_result, env.context.i8_type(), false, &format!("{}_cast", full_op_name)) + env.builder.new_build_int_cast_sign_flag( + bool_result, + env.context.i8_type(), + false, + &format!("{}_cast", full_op_name), + ) }; let two = env.context.i8_type().const_int(2, false); let three = env.context.i8_type().const_int(3, false); - let lt_test = make_cmp(FloatPredicate::OLT, lhs_val, rhs_val, "lhs_lt_rhs"); + let lt_test = make_cmp(FloatPredicate::OLT, lhs_val, rhs_val, "lhs_lt_rhs"); let gt_test = make_cmp(FloatPredicate::OGT, lhs_val, rhs_val, "lhs_gt_rhs"); let eq_test = make_cmp(FloatPredicate::OEQ, lhs_val, rhs_val, "lhs_eq_rhs"); let lhs_not_nan_test = make_cmp(FloatPredicate::OEQ, lhs_val, lhs_val, "lhs_not_NaN"); let rhs_not_nan_test = make_cmp(FloatPredicate::OEQ, rhs_val, rhs_val, "rhs_not_NaN"); - let rhs_not_nan_scaled = env.builder.new_build_int_mul(two, rhs_not_nan_test, "2 * rhs_not_nan"); - let gt_scaled = env.builder.new_build_int_mul(two, gt_test, "2 * lhs_gt_rhs"); - let eq_scaled = env.builder.new_build_int_mul(three, eq_test, "3 * lhs_eq_rhs"); + let rhs_not_nan_scaled = + env.builder + .new_build_int_mul(two, rhs_not_nan_test, "2 * rhs_not_nan"); + let gt_scaled = env + .builder + .new_build_int_mul(two, gt_test, "2 * lhs_gt_rhs"); + let eq_scaled = env + .builder + .new_build_int_mul(three, eq_test, "3 * lhs_eq_rhs"); - let non_nans = env.builder.new_build_int_add(lhs_not_nan_test, rhs_not_nan_scaled, "(a == a) + 2*(b == b))"); - let minus_lt = env.builder.new_build_int_sub(non_nans, lt_test, "(a == a) + 2*(b == b) - (a < b"); - let minus_gt = env.builder.new_build_int_sub(minus_lt, gt_scaled, "(a == a) + 2*(b == b) - (a < b) - 2*(a > b)"); - env.builder.new_build_int_sub(minus_gt, eq_scaled, "float_compare") + let non_nans = env.builder.new_build_int_add( + lhs_not_nan_test, + rhs_not_nan_scaled, + "(a == a) + 2*(b == b))", + ); + let minus_lt = + env.builder + .new_build_int_sub(non_nans, lt_test, "(a == a) + 2*(b == b) - (a < b"); + let minus_gt = env.builder.new_build_int_sub( + minus_lt, + gt_scaled, + "(a == a) + 2*(b == b) - (a < b) - 2*(a > b)", + ); + env.builder + .new_build_int_sub(minus_gt, eq_scaled, "float_compare") } // 1 1 0 // 0 0 0 -// 0 1 1 -// 1 0 2 +// 1 0 1 +// 0 1 2 fn bool_compare<'ctx>( env: &Env<'_, 'ctx, '_>, lhs_val: BasicValueEnum<'ctx>, rhs_val: BasicValueEnum<'ctx>, ) -> IntValue<'ctx> { + // a xor b + let rhs_not_equal_lhs = env.builder.new_build_xor( + lhs_val.into_int_value(), + rhs_val.into_int_value(), + "rhs_not_equal_lhs", + ); + let rhs_not_equal_lhs_byte = env.builder.new_build_int_cast_sign_flag( + rhs_not_equal_lhs, + env.context.i8_type(), + false, + "rhs_not_equal_lhs_byte", + ); - // Cast the input bools to ints because int comparison of bools does the opposite of what one would expect. - // I could just swap the arguments, but I do not want to rely on behavior which seems wrong - let lhs_byte = env.builder.new_build_int_cast_sign_flag(lhs_val.into_int_value(), env.context.i8_type(), false, "lhs_byte"); - let rhs_byte = env.builder.new_build_int_cast_sign_flag(rhs_val.into_int_value(), env.context.i8_type(), false, "rhs_byte"); - - // (a < b) - let lhs_lt_rhs = env.builder.new_build_int_compare(IntPredicate::SLT, lhs_byte, rhs_byte, "lhs_lt_rhs_bool"); - let lhs_lt_rhs_byte = env.builder.new_build_int_cast_sign_flag(lhs_lt_rhs, env.context.i8_type(), false, "lhs_lt_rhs_byte"); - - // (a > b) - let lhs_gt_rhs = env.builder.new_build_int_compare(IntPredicate::SGT, lhs_byte, rhs_byte, "lhs_gt_rhs_bool"); - let lhs_gt_rhs_byte = env.builder.new_build_int_cast_sign_flag(lhs_gt_rhs, env.context.i8_type(), false, "lhs_gt_rhs_byte"); - - // (a < b) * 2 - let two = env.context.i8_type().const_int(2, false); - let lhs_lt_rhs_times_two = env.builder.new_build_int_mul(lhs_lt_rhs_byte, two, "lhs_lt_rhs_times_two"); + // b & !a + let not_lhs = env + .builder + .new_build_not(lhs_val.into_int_value(), "not_lhs"); + let rhs_only = env + .builder + .new_build_and(rhs_val.into_int_value(), not_lhs, "rhs_only"); + let rhs_only_byte = env.builder.new_build_int_cast_sign_flag( + rhs_only, + env.context.i8_type(), + false, + "rhs_only_byte", + ); - // (a > b) + (a < b) * 2 - env.builder.new_build_int_add(lhs_gt_rhs_byte, lhs_lt_rhs_times_two, "bool_compare") + // (a xor b) + (b & !a) + env.builder + .new_build_int_add(rhs_not_equal_lhs_byte, rhs_only_byte, "bool_compare") } fn dec_compare<'ctx>( @@ -292,19 +327,33 @@ fn dec_compare<'ctx>( lhs_val: BasicValueEnum<'ctx>, rhs_val: BasicValueEnum<'ctx>, ) -> IntValue<'ctx> { - // (a > b) - let lhs_gt_rhs = call_bitcode_fn(env, &[lhs_val, rhs_val], &NUM_GREATER_THAN[IntWidth::I128]).into_int_value(); - let lhs_gt_rhs_byte = env.builder.new_build_int_cast_sign_flag(lhs_gt_rhs, env.context.i8_type(), false, "lhs_gt_rhs_byte"); - + let lhs_gt_rhs = call_bitcode_fn(env, &[lhs_val, rhs_val], &NUM_GREATER_THAN[IntWidth::I128]) + .into_int_value(); + let lhs_gt_rhs_byte = env.builder.new_build_int_cast_sign_flag( + lhs_gt_rhs, + env.context.i8_type(), + false, + "lhs_gt_rhs_byte", + ); + // (a < b) - let lhs_lt_rhs = call_bitcode_fn(env, &[lhs_val, rhs_val], &NUM_LESS_THAN[IntWidth::I128]).into_int_value(); - let lhs_lt_rhs_byte = env.builder.new_build_int_cast_sign_flag(lhs_lt_rhs, env.context.i8_type(), false, "lhs_lt_rhs_byte"); + let lhs_lt_rhs = + call_bitcode_fn(env, &[lhs_val, rhs_val], &NUM_LESS_THAN[IntWidth::I128]).into_int_value(); + let lhs_lt_rhs_byte = env.builder.new_build_int_cast_sign_flag( + lhs_lt_rhs, + env.context.i8_type(), + false, + "lhs_lt_rhs_byte", + ); // (a < b) * 2 let two = env.context.i8_type().const_int(2, false); - let lhs_lt_rhs_times_two = env.builder.new_build_int_mul(lhs_lt_rhs_byte, two, "lhs_gt_rhs_times_two"); + let lhs_lt_rhs_times_two = + env.builder + .new_build_int_mul(lhs_lt_rhs_byte, two, "lhs_gt_rhs_times_two"); // (a > b) + (a < b) * 2 - env.builder.new_build_int_add(lhs_gt_rhs_byte, lhs_lt_rhs_times_two, "bool_compare") -} \ No newline at end of file + env.builder + .new_build_int_add(lhs_gt_rhs_byte, lhs_lt_rhs_times_two, "bool_compare") +} From c64a5b84f65aec2d15bb1c1380a96af650af4eaa Mon Sep 17 00:00:00 2001 From: Ben Plotke Date: Fri, 19 Jul 2024 03:53:19 -0700 Subject: [PATCH 12/12] Implemented sort compare for lists --- crates/compiler/gen_llvm/src/llvm/sort.rs | 258 +++++++++++++++++++- crates/compiler/solve/src/ability.rs | 12 + crates/compiler/test_gen/src/gen_compare.rs | 167 +++++++++++++ 3 files changed, 432 insertions(+), 5 deletions(-) diff --git a/crates/compiler/gen_llvm/src/llvm/sort.rs b/crates/compiler/gen_llvm/src/llvm/sort.rs index d3ab46685f6..ca3e8a9ab08 100644 --- a/crates/compiler/gen_llvm/src/llvm/sort.rs +++ b/crates/compiler/gen_llvm/src/llvm/sort.rs @@ -1,9 +1,13 @@ use super::build::BuilderExt; use crate::llvm::bitcode::call_bitcode_fn; -use crate::llvm::build::Env; -use inkwell::values::{BasicValueEnum, IntValue}; -use inkwell::{FloatPredicate, IntPredicate}; +use crate::llvm::build::{load_roc_value, Env, FAST_CALL_CONV}; +use crate::llvm::build_list::{list_len_usize, load_list_ptr}; +use crate::llvm::convert::{basic_type_from_layout, zig_list_type}; +use inkwell::types::BasicType; +use inkwell::values::{BasicValueEnum, FunctionValue, IntValue, StructValue}; +use inkwell::{AddressSpace, FloatPredicate, IntPredicate}; use roc_builtins::bitcode::{FloatWidth, IntWidth, NUM_GREATER_THAN, NUM_LESS_THAN}; +use roc_module::symbol::Symbol; use roc_mono::layout::{ Builtin, InLayout, LayoutIds, LayoutInterner, LayoutRepr, STLayoutInterner, }; @@ -11,7 +15,7 @@ use roc_mono::layout::{ pub fn generic_compare<'a, 'ctx>( env: &Env<'a, 'ctx, '_>, layout_interner: &STLayoutInterner<'a>, - _layout_ids: &mut LayoutIds<'a>, + layout_ids: &mut LayoutIds<'a>, lhs_val: BasicValueEnum<'ctx>, rhs_val: BasicValueEnum<'ctx>, lhs_layout: InLayout<'a>, @@ -28,7 +32,15 @@ pub fn generic_compare<'a, 'ctx>( LayoutRepr::Builtin(Builtin::Bool) => bool_compare(env, lhs_val, rhs_val), LayoutRepr::Builtin(Builtin::Decimal) => dec_compare(env, lhs_val, rhs_val), LayoutRepr::Builtin(Builtin::Str) => todo!(), - LayoutRepr::Builtin(Builtin::List(_)) => todo!(), + LayoutRepr::Builtin(Builtin::List(elem)) => list_compare( + env, + layout_interner, + layout_ids, + elem, + layout_interner.get_repr(elem), + lhs_val.into_struct_value(), + rhs_val.into_struct_value(), + ), LayoutRepr::Struct(_) => todo!(), LayoutRepr::LambdaSet(_) => unreachable!("cannot compare closures"), LayoutRepr::FunctionPointer(_) => unreachable!("cannot compare function pointers"), @@ -357,3 +369,239 @@ fn dec_compare<'ctx>( env.builder .new_build_int_add(lhs_gt_rhs_byte, lhs_lt_rhs_times_two, "bool_compare") } + +fn list_compare<'a, 'ctx>( + env: &Env<'a, 'ctx, '_>, + layout_interner: &STLayoutInterner<'a>, + layout_ids: &mut LayoutIds<'a>, + elem_in_layout: InLayout<'a>, + element_layout: LayoutRepr<'a>, + list1: StructValue<'ctx>, + list2: StructValue<'ctx>, +) -> IntValue<'ctx> { + let block = env.builder.get_insert_block().expect("to be in a function"); + let di_location = env.builder.get_current_debug_location().unwrap(); + + let symbol = Symbol::LIST_COMPARE; + let element_layout = if let LayoutRepr::RecursivePointer(rec) = element_layout { + layout_interner.get_repr(rec) + } else { + element_layout + }; + let fn_name = layout_ids + .get(symbol, &element_layout) + .to_symbol_string(symbol, &env.interns); + + let function = match env.module.get_function(fn_name.as_str()) { + Some(function_value) => function_value, + None => { + let arg_type = zig_list_type(env).into(); + + let function_value = crate::llvm::refcounting::build_header_help( + env, + &fn_name, + env.context.i8_type().into(), + &[arg_type, arg_type], + ); + + list_compare_help( + env, + layout_interner, + layout_ids, + function_value, + elem_in_layout, + element_layout, + ); + + function_value + } + }; + + env.builder.position_at_end(block); + env.builder.set_current_debug_location(di_location); + let call = env + .builder + .new_build_call(function, &[list1.into(), list2.into()], "list_cmp"); + + call.set_call_convention(FAST_CALL_CONV); + + call.try_as_basic_value().left().unwrap().into_int_value() +} + +fn list_compare_help<'a, 'ctx>( + env: &Env<'a, 'ctx, '_>, + layout_interner: &STLayoutInterner<'a>, + layout_ids: &mut LayoutIds<'a>, + parent: FunctionValue<'ctx>, + elem_in_layout: InLayout<'a>, + element_layout: LayoutRepr<'a>, +) { + let ctx = env.context; + let builder = env.builder; + + { + use inkwell::debug_info::AsDIScope; + + let func_scope = parent.get_subprogram().unwrap(); + let lexical_block = env.dibuilder.create_lexical_block( + /* scope */ func_scope.as_debug_info_scope(), + /* file */ env.compile_unit.get_file(), + /* line_no */ 0, + /* column_no */ 0, + ); + + let loc = env.dibuilder.create_debug_location( + ctx, + /* line */ 0, + /* column */ 0, + /* current_scope */ lexical_block.as_debug_info_scope(), + /* inlined_at */ None, + ); + builder.set_current_debug_location(loc); + } + + // Add args to scope + let mut it = parent.get_param_iter(); + let list1 = it.next().unwrap().into_struct_value(); + let list2 = it.next().unwrap().into_struct_value(); + + list1.set_name(Symbol::ARG_1.as_str(&env.interns)); + list2.set_name(Symbol::ARG_2.as_str(&env.interns)); + + let entry = ctx.append_basic_block(parent, "entry"); + let loop_bb = ctx.append_basic_block(parent, "loop_bb"); + let end_l1_bb = ctx.append_basic_block(parent, "end_l1_bb"); + let in_l1_bb = ctx.append_basic_block(parent, "in_l1_bb"); + let elem_compare_bb = ctx.append_basic_block(parent, "increment_bb"); + let not_eq_elems_bb = ctx.append_basic_block(parent, "not_eq_elems_bb"); + let increment_bb = ctx.append_basic_block(parent, "increment_bb"); + let return_eq = ctx.append_basic_block(parent, "return_eq"); + let return_gt = ctx.append_basic_block(parent, "return_gt"); + let return_lt = ctx.append_basic_block(parent, "return_lt"); + + builder.position_at_end(entry); + let len1 = list_len_usize(builder, list1); + let len2 = list_len_usize(builder, list2); + + // allocate a stack slot for the current index + let index_alloca = builder.new_build_alloca(env.ptr_int(), "index"); + builder.new_build_store(index_alloca, env.ptr_int().const_zero()); + + builder.new_build_unconditional_branch(loop_bb); + + builder.position_at_end(loop_bb); + + // load the current index + let index = builder + .new_build_load(env.ptr_int(), index_alloca, "index") + .into_int_value(); + + // true if there are no more elements in list 1 + let end_l1_cond = builder.new_build_int_compare(IntPredicate::EQ, len1, index, "end_l1_cond"); + + builder.new_build_conditional_branch(end_l1_cond, end_l1_bb, in_l1_bb); + + { + builder.position_at_end(end_l1_bb); + + // true if there are no more elements in list 2 + let eq_cond = builder.new_build_int_compare(IntPredicate::EQ, len2, index, "eq_cond"); + + // if both list have no more elements, eq + // else, list 2 still has more elements, so lt + builder.new_build_conditional_branch(eq_cond, return_eq, return_lt); + } + + { + builder.position_at_end(in_l1_bb); + + // list 2 has no more elements + let gt_cond = builder.new_build_int_compare(IntPredicate::EQ, len2, index, "gt_cond"); + + // if list 2 has no more elements, since list 1 still has more, gt + // else, compare the elements at the current index + builder.new_build_conditional_branch(gt_cond, return_gt, elem_compare_bb); + } + + { + builder.position_at_end(elem_compare_bb); + + let element_type = basic_type_from_layout(env, layout_interner, element_layout); + let ptr_type = element_type.ptr_type(AddressSpace::default()); + let ptr1 = load_list_ptr(builder, list1, ptr_type); + let ptr2 = load_list_ptr(builder, list2, ptr_type); + + let elem1 = { + let elem_ptr = unsafe { + builder.new_build_in_bounds_gep(element_type, ptr1, &[index], "load_index") + }; + load_roc_value(env, layout_interner, element_layout, elem_ptr, "get_elem") + }; + + let elem2 = { + let elem_ptr = unsafe { + builder.new_build_in_bounds_gep(element_type, ptr2, &[index], "load_index") + }; + load_roc_value(env, layout_interner, element_layout, elem_ptr, "get_elem") + }; + + let elem_cmp = generic_compare( + env, + layout_interner, + layout_ids, + elem1, + elem2, + elem_in_layout, + elem_in_layout, + ) + .into_int_value(); + + // true if elements are equal + let increment_cond = builder.new_build_int_compare( + IntPredicate::EQ, + elem_cmp, + ctx.i8_type().const_int(0, false), + "increment_cond", + ); + + // if elements are equal, increment the pointers + // else, return gt or lt + builder.new_build_conditional_branch(increment_cond, increment_bb, not_eq_elems_bb); + + { + builder.position_at_end(not_eq_elems_bb); + + // When elements compare not equal, we return the element comparison + builder.new_build_return(Some(&elem_cmp)); + } + } + + { + builder.position_at_end(increment_bb); + + let one = env.ptr_int().const_int(1, false); + + // increment the index + let next_index = builder.new_build_int_add(index, one, "nextindex"); + + builder.new_build_store(index_alloca, next_index); + + // jump back to the top of the loop + builder.new_build_unconditional_branch(loop_bb); + } + + { + builder.position_at_end(return_eq); + builder.new_build_return(Some(&ctx.i8_type().const_int(0, false))); + } + + { + builder.position_at_end(return_gt); + builder.new_build_return(Some(&ctx.i8_type().const_int(1, false))); + } + + { + builder.position_at_end(return_lt); + builder.new_build_return(Some(&ctx.i8_type().const_int(2, false))); + } +} diff --git a/crates/compiler/solve/src/ability.rs b/crates/compiler/solve/src/ability.rs index 1fb1316d534..ec7c2a9cc85 100644 --- a/crates/compiler/solve/src/ability.rs +++ b/crates/compiler/solve/src/ability.rs @@ -872,6 +872,18 @@ impl DerivableVisitor for DeriveSort { true } + #[inline(always)] + fn visit_apply(var: Variable, symbol: Symbol) -> Result { + if matches!(symbol, Symbol::LIST_LIST,) { + Ok(Descend(true)) + } else { + Err(NotDerivable { + var, + context: NotDerivableContext::NoContext, + }) + } + } + #[inline(always)] fn visit_floating_point_content( _var: Variable, diff --git a/crates/compiler/test_gen/src/gen_compare.rs b/crates/compiler/test_gen/src/gen_compare.rs index e3c22973e57..0e304d744dd 100644 --- a/crates/compiler/test_gen/src/gen_compare.rs +++ b/crates/compiler/test_gen/src/gen_compare.rs @@ -244,3 +244,170 @@ fn compare_bool() { u8 ); } + +#[test] +#[cfg(any(feature = "gen-llvm", feature = "gen-dev", feature = "gen-wasm"))] +fn list() { + assert_evals_to!( + indoc!( + r#" + i : List Bool + i = [] + + j : List Bool + j = [] + + when List.compare i j is + Equal -> 0 + GreaterThan -> 1 + LessThan -> 2 + "# + ), + 0, + u8 + ); + assert_evals_to!( + indoc!( + r#" + i : List Bool + i = [Bool.true] + + j : List Bool + j = [] + + when List.compare i j is + Equal -> 0 + GreaterThan -> 1 + LessThan -> 2 + "# + ), + 1, + u8 + ); + assert_evals_to!( + indoc!( + r#" + i : List Bool + i = [] + + j : List Bool + j = [Bool.true] + + when List.compare i j is + Equal -> 0 + GreaterThan -> 1 + LessThan -> 2 + "# + ), + 2, + u8 + ); + assert_evals_to!( + indoc!( + r#" + i : List Bool + i = [Bool.true] + + j : List Bool + j = [Bool.true] + + when List.compare i j is + Equal -> 0 + GreaterThan -> 1 + LessThan -> 2 + "# + ), + 0, + u8 + ); + assert_evals_to!( + indoc!( + r#" + i : List Bool + i = [Bool.true] + + j : List Bool + j = [Bool.false] + + when List.compare i j is + Equal -> 0 + GreaterThan -> 1 + LessThan -> 2 + "# + ), + 1, + u8 + ); + assert_evals_to!( + indoc!( + r#" + i : List Bool + i = [Bool.false] + + j : List Bool + j = [Bool.true] + + when List.compare i j is + Equal -> 0 + GreaterThan -> 1 + LessThan -> 2 + "# + ), + 2, + u8 + ); + assert_evals_to!( + indoc!( + r#" + i : List I64 + i = [1, 2] + + j : List I64 + j = [1, 1] + + when List.compare i j is + Equal -> 0 + GreaterThan -> 1 + LessThan -> 2 + "# + ), + 1, + u8 + ); + assert_evals_to!( + indoc!( + r#" + i : List I64 + i = [1] + + j : List I64 + j = [1, 1] + + when List.compare i j is + Equal -> 0 + GreaterThan -> 1 + LessThan -> 2 + "# + ), + 2, + u8 + ); + assert_evals_to!( + indoc!( + r#" + i : List (List I64) + i = [[0], [1, 2]] + + j : List (List I64) + j = [[0], [1, 1]] + + when List.compare i j is + Equal -> 0 + GreaterThan -> 1 + LessThan -> 2 + "# + ), + 1, + u8 + ); +}