Skip to content

Commit

Permalink
Merge pull request #1 from jwoudenberg/float-bool-implementations
Browse files Browse the repository at this point in the history
Float bool implementations
  • Loading branch information
jwoudenberg authored Jul 13, 2024
2 parents fbceda6 + f8a7d1b commit 24c0274
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 36 deletions.
133 changes: 122 additions & 11 deletions crates/compiler/gen_llvm/src/llvm/sort.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use super::build::BuilderExt;
use crate::llvm::build::Env;
use crate::llvm::bitcode::call_bitcode_fn;
use inkwell::values::{BasicValueEnum, IntValue};
use inkwell::IntPredicate;
use roc_builtins::bitcode::IntWidth;
use inkwell::{IntPredicate, FloatPredicate};
use roc_builtins::bitcode::{IntWidth, FloatWidth, NUM_LESS_THAN, NUM_GREATER_THAN};
use roc_mono::layout::{
Builtin, InLayout, LayoutIds, LayoutInterner, LayoutRepr, STLayoutInterner,
};
Expand All @@ -21,9 +22,15 @@ pub fn generic_compare<'a, 'ctx>(
LayoutRepr::Builtin(Builtin::Int(int_width)) => {
int_compare(env, lhs_val, rhs_val, int_width)
}
LayoutRepr::Builtin(Builtin::Float(_)) => todo!(),
LayoutRepr::Builtin(Builtin::Bool) => todo!(),
LayoutRepr::Builtin(Builtin::Decimal) => todo!(),
LayoutRepr::Builtin(Builtin::Float(float_width)) => {
float_cmp(env, lhs_val, rhs_val, float_width)
}
LayoutRepr::Builtin(Builtin::Bool) => {
bool_compare(env, lhs_val, rhs_val)
}
LayoutRepr::Builtin(Builtin::Decimal) => {
dec_compare(env, lhs_val, rhs_val)
}
LayoutRepr::Builtin(Builtin::Str) => todo!(),
LayoutRepr::Builtin(Builtin::List(_)) => todo!(),
LayoutRepr::Struct(_) => todo!(),
Expand All @@ -48,22 +55,24 @@ fn int_compare<'ctx>(
// (a > b) + 2 * (a < b);
let lhs_gt_rhs = int_gt(env, lhs_val, rhs_val, builtin);
let lhs_lt_rhs = int_lt(env, lhs_val, rhs_val, builtin);
let two = env.ptr_int().const_int(2, false);
let two = env.context.i8_type().const_int(2, false);
let lhs_lt_rhs_times_two =
env.builder
.new_build_int_mul(lhs_lt_rhs, two, "lhs_lt_rhs_times_two");
env.builder
.new_build_int_sub(lhs_gt_rhs, lhs_lt_rhs_times_two, "int_compare")
.new_build_int_add(lhs_gt_rhs, lhs_lt_rhs_times_two, "int_compare")
}



fn int_lt<'ctx>(
env: &Env<'_, 'ctx, '_>,
lhs_val: BasicValueEnum<'ctx>,
rhs_val: BasicValueEnum<'ctx>,
builtin: IntWidth,
) -> IntValue<'ctx> {
use IntWidth::*;
match builtin {
let lhs_lt_rhs = match builtin {
I128 => env.builder.new_build_int_compare(
IntPredicate::SLT,
lhs_val.into_int_value(),
Expand Down Expand Up @@ -124,7 +133,8 @@ fn int_lt<'ctx>(
rhs_val.into_int_value(),
"lhs_lt_rhs_u8",
),
}
};
env.builder.new_build_int_cast_sign_flag(lhs_lt_rhs, env.context.i8_type(), false, "lhs_lt_rhs_cast")
}

fn int_gt<'ctx>(
Expand All @@ -134,7 +144,7 @@ fn int_gt<'ctx>(
builtin: IntWidth,
) -> IntValue<'ctx> {
use IntWidth::*;
match builtin {
let lhs_gt_rhs = match builtin {
I128 => env.builder.new_build_int_compare(
IntPredicate::SGT,
lhs_val.into_int_value(),
Expand Down Expand Up @@ -195,5 +205,106 @@ fn int_gt<'ctx>(
rhs_val.into_int_value(),
"lhs_gt_rhs_u8",
),
}
};
env.builder.new_build_int_cast_sign_flag(lhs_gt_rhs, env.context.i8_type(), false, "lhs_gt_rhs_cast")
}


// Return 0 for equals, 1 for greater than, and 2 for less than.
// We consider NaNs to be smaller than non-NaNs
// We use the below expression to calculate this
// (a == a) + 2*(b == b) - (a < b) - 2*(a > b) - 3*(a == b)
fn float_cmp<'ctx>(
env: &Env<'_, 'ctx, '_>,
lhs_val: BasicValueEnum<'ctx>,
rhs_val: BasicValueEnum<'ctx>,
float_width: FloatWidth,
) -> IntValue<'ctx> {
use FloatWidth::*;
let type_str = match float_width {
F64 => "F64",
F32 => "F32",
};

let make_cmp = |operation, a: BasicValueEnum<'ctx>, b: BasicValueEnum<'ctx>, op_name: &str| {
let full_op_name = format!("{}_{}", op_name, type_str);
let bool_result = env.builder.new_build_float_compare(
operation,
a.into_float_value(),
b.into_float_value(),
&full_op_name,
);
env.builder.new_build_int_cast_sign_flag(bool_result, env.context.i8_type(), false, &format!("{}_cast", full_op_name))
};

let two = env.context.i8_type().const_int(2, false);
let three = env.context.i8_type().const_int(3, false);

let lt_test = make_cmp(FloatPredicate::OLT, lhs_val, rhs_val, "lhs_lt_rhs");
let gt_test = make_cmp(FloatPredicate::OGT, lhs_val, rhs_val, "lhs_gt_rhs");
let eq_test = make_cmp(FloatPredicate::OEQ, lhs_val, rhs_val, "lhs_eq_rhs");
let lhs_not_nan_test = make_cmp(FloatPredicate::OEQ, lhs_val, lhs_val, "lhs_not_NaN");
let rhs_not_nan_test = make_cmp(FloatPredicate::OEQ, rhs_val, rhs_val, "rhs_not_NaN");

let rhs_not_nan_scaled = env.builder.new_build_int_mul(two, rhs_not_nan_test, "2 * rhs_not_nan");
let gt_scaled = env.builder.new_build_int_mul(two, gt_test, "2 * lhs_gt_rhs");
let eq_scaled = env.builder.new_build_int_mul(three, eq_test, "3 * lhs_eq_rhs");

let non_nans = env.builder.new_build_int_add(lhs_not_nan_test, rhs_not_nan_scaled, "(a == a) + 2*(b == b))");
let minus_lt = env.builder.new_build_int_sub(non_nans, lt_test, "(a == a) + 2*(b == b) - (a < b");
let minus_gt = env.builder.new_build_int_sub(minus_lt, gt_scaled, "(a == a) + 2*(b == b) - (a < b) - 2*(a > b)");
env.builder.new_build_int_sub(minus_gt, eq_scaled, "float_compare")
}

// 1 1 0
// 0 0 0
// 0 1 1
// 1 0 2
fn bool_compare<'ctx>(
env: &Env<'_, 'ctx, '_>,
lhs_val: BasicValueEnum<'ctx>,
rhs_val: BasicValueEnum<'ctx>,
) -> IntValue<'ctx> {

// Cast the input bools to ints because int comparison of bools does the opposite of what one would expect.
// I could just swap the arguments, but I do not want to rely on behavior which seems wrong
let lhs_byte = env.builder.new_build_int_cast_sign_flag(lhs_val.into_int_value(), env.context.i8_type(), false, "lhs_byte");
let rhs_byte = env.builder.new_build_int_cast_sign_flag(rhs_val.into_int_value(), env.context.i8_type(), false, "rhs_byte");

// (a < b)
let lhs_lt_rhs = env.builder.new_build_int_compare(IntPredicate::SLT, lhs_byte, rhs_byte, "lhs_lt_rhs_bool");
let lhs_lt_rhs_byte = env.builder.new_build_int_cast_sign_flag(lhs_lt_rhs, env.context.i8_type(), false, "lhs_lt_rhs_byte");

// (a > b)
let lhs_gt_rhs = env.builder.new_build_int_compare(IntPredicate::SGT, lhs_byte, rhs_byte, "lhs_gt_rhs_bool");
let lhs_gt_rhs_byte = env.builder.new_build_int_cast_sign_flag(lhs_gt_rhs, env.context.i8_type(), false, "lhs_gt_rhs_byte");

// (a < b) * 2
let two = env.context.i8_type().const_int(2, false);
let lhs_lt_rhs_times_two = env.builder.new_build_int_mul(lhs_lt_rhs_byte, two, "lhs_lt_rhs_times_two");

// (a > b) + (a < b) * 2
env.builder.new_build_int_add(lhs_gt_rhs_byte, lhs_lt_rhs_times_two, "bool_compare")
}

fn dec_compare<'ctx>(
env: &Env<'_, 'ctx, '_>,
lhs_val: BasicValueEnum<'ctx>,
rhs_val: BasicValueEnum<'ctx>,
) -> IntValue<'ctx> {

// (a > b)
let lhs_gt_rhs = call_bitcode_fn(env, &[lhs_val, rhs_val], &NUM_GREATER_THAN[IntWidth::I128]).into_int_value();
let lhs_gt_rhs_byte = env.builder.new_build_int_cast_sign_flag(lhs_gt_rhs, env.context.i8_type(), false, "lhs_gt_rhs_byte");

// (a < b)
let lhs_lt_rhs = call_bitcode_fn(env, &[lhs_val, rhs_val], &NUM_LESS_THAN[IntWidth::I128]).into_int_value();
let lhs_lt_rhs_byte = env.builder.new_build_int_cast_sign_flag(lhs_lt_rhs, env.context.i8_type(), false, "lhs_lt_rhs_byte");

// (a < b) * 2
let two = env.context.i8_type().const_int(2, false);
let lhs_lt_rhs_times_two = env.builder.new_build_int_mul(lhs_lt_rhs_byte, two, "lhs_gt_rhs_times_two");

// (a > b) + (a < b) * 2
env.builder.new_build_int_add(lhs_gt_rhs_byte, lhs_lt_rhs_times_two, "bool_compare")
}
54 changes: 29 additions & 25 deletions crates/compiler/solve/src/ability.rs
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,8 @@ impl ObligationCache {
var,
)),

Symbol::LIST_SORT => Some(DeriveSort::is_derivable(self, abilities_store, subs, var)),

_ => None,
};

Expand Down Expand Up @@ -399,6 +401,8 @@ impl ObligationCache {
DeriveEq::ABILITY => DeriveEq::is_derivable_builtin_opaque(opaque),
DeriveHash::ABILITY => DeriveHash::is_derivable_builtin_opaque(opaque),
DeriveInspect::ABILITY => DeriveInspect::is_derivable_builtin_opaque(opaque),
DeriveSort::ABILITY => DeriveSort::is_derivable_builtin_opaque(opaque),
DeriveCompare::ABILITY => DeriveCompare::is_derivable_builtin_opaque(opaque),
_ => false,
};

Expand Down Expand Up @@ -858,6 +862,26 @@ trait DerivableVisitor {
}
}

struct DeriveSort;
impl DerivableVisitor for DeriveSort {
const ABILITY: Symbol = Symbol::LIST_SORT;
const ABILITY_SLICE: SubsSlice<Symbol> = Subs::AB_SORT;

#[inline(always)]
fn is_derivable_builtin_opaque(_symbol: Symbol) -> bool {
true
}

#[inline(always)]
fn visit_floating_point_content(
_var: Variable,
_subs: &mut Subs,
_content_var: Variable,
) -> Result<Descend, NotDerivable> {
Ok(Descend(true))
}
}

struct DeriveInspect;
impl DerivableVisitor for DeriveInspect {
const ABILITY: Symbol = Symbol::INSPECT_INSPECT_ABILITY;
Expand Down Expand Up @@ -1482,33 +1506,13 @@ impl DerivableVisitor for DeriveCompare {
Ok(Descend(true))
}

#[inline(always)]
fn visit_floating_point_content(
var: Variable,
subs: &mut Subs,
content_var: Variable,
_var: Variable,
_subs: &mut Subs,
_content_var: Variable,
) -> Result<Descend, NotDerivable> {
use roc_unify::unify::unify;

// Of the floating-point types,
// only Dec implements Eq.
// TODO(checkmate): pass checkmate through
let unified = unify(
&mut with_checkmate!({
on => UEnv::new(subs, None),
off => UEnv::new(subs),
}),
content_var,
Variable::DECIMAL,
UnificationMode::EQ,
Polarity::Pos,
);
match unified {
roc_unify::unify::Unified::Success { .. } => Ok(Descend(false)),
roc_unify::unify::Unified::Failure(..) => Err(NotDerivable {
var,
context: NotDerivableContext::Eq(NotDerivableEq::FloatingPoint),
}),
}
Ok(Descend(true))
}

#[inline(always)]
Expand Down
2 changes: 2 additions & 0 deletions crates/compiler/types/src/subs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1703,6 +1703,8 @@ impl Subs {
pub const AB_INSPECT: SubsSlice<Symbol> = SubsSlice::new(5, 1);
#[rustfmt::skip]
pub const AB_COMPARE: SubsSlice<Symbol> = SubsSlice::new(6, 1);
#[rustfmt::skip]
pub const AB_SORT: SubsSlice<Symbol> = SubsSlice::new(7, 1);
// END INIT-SymbolSubsSlice

pub fn new() -> Self {
Expand Down

0 comments on commit 24c0274

Please sign in to comment.