Skip to content

Commit

Permalink
Simplify tail-call checks
Browse files Browse the repository at this point in the history
  • Loading branch information
smores56 committed Oct 26, 2024
1 parent 33d8681 commit 4989178
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 68 deletions.
22 changes: 11 additions & 11 deletions crates/compiler/can/src/def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ use crate::expr::ClosureData;
use crate::expr::Declarations;
use crate::expr::Expr::{self, *};
use crate::expr::StructAccessorData;
use crate::expr::TailCall;
use crate::expr::{canonicalize_expr, Output, Recursive};
use crate::pattern::{canonicalize_def_header_pattern, BindingsFromPattern, Pattern};
use crate::procedure::QualifiedReference;
Expand Down Expand Up @@ -2606,16 +2605,17 @@ fn canonicalize_pending_body<'a>(
env.tailcallable_symbol = outer_tailcallable;

// The closure is self tail recursive iff it tail calls itself (by defined name).
let is_recursive = match can_output.tail_call {
TailCall::NoneMade => Recursive::NotRecursive,
TailCall::Inconsistent => Recursive::Recursive,
TailCall::CallsTo(tail_symbol) => {
if tail_symbol == *defined_symbol {
Recursive::TailRecursive
} else {
Recursive::Recursive
}
}
let is_recursive = if can_output
.early_tail_calls
.iter()
.chain(std::iter::once(&can_output.final_tail_call))
.all(|tail_call| {
matches!(tail_call, Some(tail_symbol) if tail_symbol == defined_symbol)
})
{
Recursive::TailRecursive
} else {
Recursive::NotRecursive
};

closure_data.recursive = is_recursive;
Expand Down
75 changes: 18 additions & 57 deletions crates/compiler/can/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,64 +39,22 @@ pub type PendingDerives = VecMap<Symbol, (Type, Vec<Loc<Symbol>>)>;
#[derive(Clone, Default, Debug)]
pub struct Output {
pub references: References,
pub tail_call: TailCall,
pub early_tail_calls: Vec<Option<Symbol>>,
pub final_tail_call: Option<Symbol>,
pub introduced_variables: IntroducedVariables,
pub aliases: VecMap<Symbol, Alias>,
pub non_closures: VecSet<Symbol>,
pub pending_derives: PendingDerives,
}

#[derive(Clone, Copy, Default, Debug)]
pub enum TailCall {
#[default]
NoneMade,
CallsTo(Symbol),
Inconsistent,
}

impl TailCall {
pub fn for_expr(expr: &Expr) -> Self {
match expr {
Expr::Call(fn_expr, _, _) => match **fn_expr {
(
_,
Loc {
value: Expr::Var(symbol, _),
..
},
_,
_,
) => Self::CallsTo(symbol),
_ => Self::NoneMade,
},
_ => Self::NoneMade,
}
}

pub fn merge(self, other: Self) -> Self {
match self {
TailCall::NoneMade => other,
TailCall::Inconsistent => TailCall::Inconsistent,
TailCall::CallsTo(our_symbol) => match other {
TailCall::NoneMade => TailCall::CallsTo(our_symbol),
TailCall::Inconsistent => TailCall::Inconsistent,
TailCall::CallsTo(other_symbol) => {
if our_symbol == other_symbol {
TailCall::CallsTo(our_symbol)
} else {
TailCall::Inconsistent
}
}
},
}
}
}

impl Output {
pub fn union(&mut self, other: Self) {
self.references.union_mut(&other.references);

self.tail_call = self.tail_call.merge(other.tail_call);
self.early_tail_calls.extend(other.early_tail_calls);
if let (None, Some(later)) = (self.final_tail_call, other.final_tail_call) {
self.final_tail_call = Some(later);
}

self.introduced_variables
.union_owned(other.introduced_variables);
Expand Down Expand Up @@ -768,7 +726,6 @@ pub fn canonicalize_expr<'a>(

let output = Output {
references,
tail_call: TailCall::NoneMade,
..Default::default()
};

Expand Down Expand Up @@ -836,7 +793,6 @@ pub fn canonicalize_expr<'a>(

let output = Output {
references,
tail_call: TailCall::NoneMade,
..Default::default()
};

Expand Down Expand Up @@ -952,13 +908,18 @@ pub fn canonicalize_expr<'a>(

output.union(fn_expr_output);

output.tail_call = TailCall::NoneMade;
// Default: We're not tail-calling a symbol (by name), we're tail-calling a function value.
output.final_tail_call = None;

let expr = match fn_expr.value {
Var(symbol, _) => {
output.references.insert_call(symbol);

output.tail_call = TailCall::CallsTo(symbol);
// we're tail-calling a symbol by name, check if it's the tail-callable symbol
output.final_tail_call = match &env.tailcallable_symbol {
Some(tc_sym) if *tc_sym == symbol => Some(symbol),
Some(_) | None => None,
};

Call(
Box::new((
Expand Down Expand Up @@ -1084,7 +1045,7 @@ pub fn canonicalize_expr<'a>(
canonicalize_expr(env, var_store, scope, loc_cond.region, &loc_cond.value);

// the condition can never be a tail-call
output.tail_call = TailCall::NoneMade;
output.final_tail_call = None;

let mut can_branches = Vec::with_capacity(branches.len());

Expand All @@ -1109,7 +1070,7 @@ pub fn canonicalize_expr<'a>(
// if code gen mistakenly thinks this is a tail call just because its condition
// happened to be one. (The condition gave us our initial output value.)
if branches.is_empty() {
output.tail_call = TailCall::NoneMade;
output.final_tail_call = None;
}

// Incorporate all three expressions into a combined Output value.
Expand Down Expand Up @@ -1319,17 +1280,17 @@ pub fn canonicalize_expr<'a>(
});
}

let (loc_return_expr, output1) = canonicalize_expr(
let (loc_return_expr, mut output1) = canonicalize_expr(
env,
var_store,
scope,
return_expr.region,
&return_expr.value,
);

output.union(output1);
output1.early_tail_calls.push(output1.final_tail_call);

output.tail_call = TailCall::for_expr(&loc_return_expr.value);
output.union(output1);

let return_var = var_store.fresh();

Expand Down

0 comments on commit 4989178

Please sign in to comment.