|
| 1 | +//! Expression and value lowering helpers shared across the Yul emitter. |
| 2 | +
|
| 3 | +use hir::hir_def::{ |
| 4 | + Expr, ExprId, LitKind, Stmt, StmtId, |
| 5 | + expr::{ArithBinOp, BinOp, CompBinOp, LogicalBinOp, UnOp}, |
| 6 | +}; |
| 7 | +use mir::{CallOrigin, ValueId, ValueOrigin, ir::SyntheticValue}; |
| 8 | + |
| 9 | +use crate::yul::{doc::YulDoc, state::BlockState}; |
| 10 | + |
| 11 | +use super::{ |
| 12 | + YulError, |
| 13 | + function::FunctionEmitter, |
| 14 | + util::{function_name, try_collapse_cast_shim}, |
| 15 | +}; |
| 16 | + |
| 17 | +impl<'db> FunctionEmitter<'db> { |
| 18 | + /// Lowers a MIR `ValueId` into a Yul expression string. |
| 19 | + /// |
| 20 | + /// * `value_id` - Identifier selecting the MIR value. |
| 21 | + /// * `state` - Current bindings for previously-evaluated expressions. |
| 22 | + /// |
| 23 | + /// Returns the Yul expression referencing the value or an error if unsupported. |
| 24 | + pub(super) fn lower_value( |
| 25 | + &self, |
| 26 | + value_id: ValueId, |
| 27 | + state: &BlockState, |
| 28 | + ) -> Result<String, YulError> { |
| 29 | + let value = self.mir_func.body.value(value_id); |
| 30 | + match &value.origin { |
| 31 | + ValueOrigin::Expr(expr_id) => { |
| 32 | + if let Some(temp) = self.match_values.get(expr_id) { |
| 33 | + Ok(temp.clone()) |
| 34 | + } else { |
| 35 | + self.lower_expr(*expr_id, state) |
| 36 | + } |
| 37 | + } |
| 38 | + ValueOrigin::Call(call) => self.lower_call_value(call, state), |
| 39 | + ValueOrigin::Intrinsic(intr) => self.lower_intrinsic_value(intr, state), |
| 40 | + ValueOrigin::Synthetic(synth) => self.lower_synthetic_value(synth), |
| 41 | + _ => Err(YulError::Unsupported( |
| 42 | + "only expression-derived values are supported".into(), |
| 43 | + )), |
| 44 | + } |
| 45 | + } |
| 46 | + |
| 47 | + /// Lowers a HIR expression into a Yul expression string. |
| 48 | + /// |
| 49 | + /// * `expr_id` - Expression to render. |
| 50 | + /// * `state` - Binding state used for nested expressions. |
| 51 | + /// |
| 52 | + /// Returns the fully-lowered Yul expression. |
| 53 | + pub(super) fn lower_expr( |
| 54 | + &self, |
| 55 | + expr_id: ExprId, |
| 56 | + state: &BlockState, |
| 57 | + ) -> Result<String, YulError> { |
| 58 | + if let Some(temp) = self.expr_temps.get(&expr_id) { |
| 59 | + return Ok(temp.clone()); |
| 60 | + } |
| 61 | + if let Some(temp) = self.match_values.get(&expr_id) { |
| 62 | + return Ok(temp.clone()); |
| 63 | + } |
| 64 | + if let Some(value_id) = self.mir_func.body.expr_values.get(&expr_id) { |
| 65 | + let value = self.mir_func.body.value(*value_id); |
| 66 | + match &value.origin { |
| 67 | + ValueOrigin::Call(call) => return self.lower_call_value(call, state), |
| 68 | + ValueOrigin::Synthetic(synth) => { |
| 69 | + return self.lower_synthetic_value(synth); |
| 70 | + } |
| 71 | + _ => {} |
| 72 | + } |
| 73 | + } |
| 74 | + |
| 75 | + let expr = self.expect_expr(expr_id)?; |
| 76 | + match expr { |
| 77 | + Expr::Lit(LitKind::Int(int_id)) => Ok(int_id.data(self.db).to_string()), |
| 78 | + Expr::Lit(LitKind::Bool(value)) => Ok(if *value { "1" } else { "0" }.into()), |
| 79 | + Expr::Lit(LitKind::String(str_id)) => Ok(format!( |
| 80 | + "0x{}", |
| 81 | + hex::encode(str_id.data(self.db).as_bytes()) |
| 82 | + )), |
| 83 | + Expr::Un(inner, op) => { |
| 84 | + let value = self.lower_expr(*inner, state)?; |
| 85 | + match op { |
| 86 | + UnOp::Minus => Ok(format!("sub(0, {value})")), |
| 87 | + UnOp::Not => Ok(format!("iszero({value})")), |
| 88 | + UnOp::Plus => Ok(value), |
| 89 | + UnOp::BitNot => Ok(format!("not({value})")), |
| 90 | + } |
| 91 | + } |
| 92 | + Expr::Tuple(values) => { |
| 93 | + let parts = values |
| 94 | + .iter() |
| 95 | + .map(|expr| self.lower_expr(*expr, state)) |
| 96 | + .collect::<Result<Vec<_>, _>>()?; |
| 97 | + Ok(format!("tuple({})", parts.join(", "))) |
| 98 | + } |
| 99 | + Expr::Call(callee, call_args) => { |
| 100 | + let callee_expr = self.lower_expr(*callee, state)?; |
| 101 | + let mut lowered_args = Vec::with_capacity(call_args.len()); |
| 102 | + for arg in call_args { |
| 103 | + lowered_args.push(self.lower_expr(arg.expr, state)?); |
| 104 | + } |
| 105 | + if let Some(arg) = try_collapse_cast_shim(&callee_expr, &lowered_args)? { |
| 106 | + return Ok(arg); |
| 107 | + } |
| 108 | + if lowered_args.is_empty() { |
| 109 | + Ok(format!("{callee_expr}()")) |
| 110 | + } else { |
| 111 | + Ok(format!("{callee_expr}({})", lowered_args.join(", "))) |
| 112 | + } |
| 113 | + } |
| 114 | + Expr::Bin(lhs, rhs, bin_op) => match bin_op { |
| 115 | + BinOp::Arith(op) => { |
| 116 | + let left = self.lower_expr(*lhs, state)?; |
| 117 | + let right = self.lower_expr(*rhs, state)?; |
| 118 | + match op { |
| 119 | + ArithBinOp::Add => Ok(format!("add({left}, {right})")), |
| 120 | + ArithBinOp::Sub => Ok(format!("sub({left}, {right})")), |
| 121 | + ArithBinOp::Mul => Ok(format!("mul({left}, {right})")), |
| 122 | + ArithBinOp::Div => Ok(format!("div({left}, {right})")), |
| 123 | + ArithBinOp::Rem => Ok(format!("mod({left}, {right})")), |
| 124 | + ArithBinOp::Pow => Ok(format!("exp({left}, {right})")), |
| 125 | + ArithBinOp::LShift => Ok(format!("shl({right}, {left})")), |
| 126 | + ArithBinOp::RShift => Ok(format!("shr({right}, {left})")), |
| 127 | + ArithBinOp::BitAnd => Ok(format!("and({left}, {right})")), |
| 128 | + ArithBinOp::BitOr => Ok(format!("or({left}, {right})")), |
| 129 | + ArithBinOp::BitXor => Ok(format!("xor({left}, {right})")), |
| 130 | + } |
| 131 | + } |
| 132 | + BinOp::Comp(op) => { |
| 133 | + let left = self.lower_expr(*lhs, state)?; |
| 134 | + let right = self.lower_expr(*rhs, state)?; |
| 135 | + let expr = match op { |
| 136 | + CompBinOp::Eq => format!("eq({left}, {right})"), |
| 137 | + CompBinOp::NotEq => format!("iszero(eq({left}, {right}))"), |
| 138 | + CompBinOp::Lt => format!("lt({left}, {right})"), |
| 139 | + CompBinOp::LtEq => format!("iszero(gt({left}, {right}))"), |
| 140 | + CompBinOp::Gt => format!("gt({left}, {right})"), |
| 141 | + CompBinOp::GtEq => format!("iszero(lt({left}, {right}))"), |
| 142 | + }; |
| 143 | + Ok(expr) |
| 144 | + } |
| 145 | + BinOp::Logical(op) => { |
| 146 | + let left = self.lower_expr(*lhs, state)?; |
| 147 | + let right = self.lower_expr(*rhs, state)?; |
| 148 | + let func = match op { |
| 149 | + LogicalBinOp::And => "and", |
| 150 | + LogicalBinOp::Or => "or", |
| 151 | + }; |
| 152 | + Ok(format!("{func}({left}, {right})")) |
| 153 | + } |
| 154 | + _ => Err(YulError::Unsupported( |
| 155 | + "only arithmetic/logical binary expressions are supported right now".into(), |
| 156 | + )), |
| 157 | + }, |
| 158 | + Expr::Block(stmts) => { |
| 159 | + if let Some(expr) = self.last_expr(stmts) { |
| 160 | + self.lower_expr(expr, state) |
| 161 | + } else { |
| 162 | + Ok("0".into()) |
| 163 | + } |
| 164 | + } |
| 165 | + Expr::Path(path) => { |
| 166 | + let original = self |
| 167 | + .path_ident(*path) |
| 168 | + .ok_or_else(|| YulError::Unsupported("unsupported path expression".into()))?; |
| 169 | + Ok(state.resolve_name(&original)) |
| 170 | + } |
| 171 | + Expr::Field(..) => { |
| 172 | + if let Some(value_id) = self.mir_func.body.expr_values.get(&expr_id) { |
| 173 | + self.lower_value(*value_id, state) |
| 174 | + } else { |
| 175 | + let ty = self.mir_func.typed_body.expr_ty(self.db, expr_id); |
| 176 | + Err(YulError::Unsupported(format!( |
| 177 | + "field expressions should be rewritten before codegen (expr type {})", |
| 178 | + ty.pretty_print(self.db) |
| 179 | + ))) |
| 180 | + } |
| 181 | + } |
| 182 | + Expr::RecordInit(..) => { |
| 183 | + if let Some(temp) = self.expr_temps.get(&expr_id) { |
| 184 | + Ok(temp.clone()) |
| 185 | + } else { |
| 186 | + Err(YulError::Unsupported( |
| 187 | + "record initializers should be lowered before codegen".into(), |
| 188 | + )) |
| 189 | + } |
| 190 | + } |
| 191 | + other => Err(YulError::Unsupported(format!( |
| 192 | + "only simple expressions are supported: {other:?}" |
| 193 | + ))), |
| 194 | + } |
| 195 | + } |
| 196 | + |
| 197 | + /// Returns the last expression statement in a block, if any. |
| 198 | + /// |
| 199 | + /// * `stmts` - Slice of statement IDs to inspect. |
| 200 | + /// |
| 201 | + /// Returns the expression ID for the trailing expression statement when present. |
| 202 | + fn last_expr(&self, stmts: &[StmtId]) -> Option<ExprId> { |
| 203 | + stmts.iter().rev().find_map(|stmt_id| { |
| 204 | + let Ok(stmt) = self.expect_stmt(*stmt_id) else { |
| 205 | + return None; |
| 206 | + }; |
| 207 | + if let Stmt::Expr(expr) = stmt { |
| 208 | + Some(*expr) |
| 209 | + } else { |
| 210 | + None |
| 211 | + } |
| 212 | + }) |
| 213 | + } |
| 214 | + |
| 215 | + /// Lowers a MIR call into a Yul function invocation. |
| 216 | + /// |
| 217 | + /// * `call` - Call origin describing the callee and arguments. |
| 218 | + /// * `state` - Binding state used to lower argument expressions. |
| 219 | + /// |
| 220 | + /// Returns the Yul invocation string for the call. |
| 221 | + pub(super) fn lower_call_value( |
| 222 | + &self, |
| 223 | + call: &CallOrigin<'_>, |
| 224 | + state: &BlockState, |
| 225 | + ) -> Result<String, YulError> { |
| 226 | + let callee = if let Some(name) = &call.resolved_name { |
| 227 | + name.clone() |
| 228 | + } else { |
| 229 | + let Some(func) = call.callable.func_def.hir_func_def(self.db) else { |
| 230 | + return Err(YulError::Unsupported( |
| 231 | + "callable without hir function definition is not supported yet".into(), |
| 232 | + )); |
| 233 | + }; |
| 234 | + function_name(self.db, func) |
| 235 | + }; |
| 236 | + let mut lowered_args = Vec::with_capacity(call.args.len()); |
| 237 | + for &arg in &call.args { |
| 238 | + lowered_args.push(self.lower_value(arg, state)?); |
| 239 | + } |
| 240 | + if let Some(arg) = try_collapse_cast_shim(&callee, &lowered_args)? { |
| 241 | + return Ok(arg); |
| 242 | + } |
| 243 | + if lowered_args.is_empty() { |
| 244 | + Ok(format!("{callee}()")) |
| 245 | + } else { |
| 246 | + Ok(format!("{callee}({})", lowered_args.join(", "))) |
| 247 | + } |
| 248 | + } |
| 249 | + |
| 250 | + /// Lowers special MIR synthetic values such as constants into Yul expressions. |
| 251 | + /// |
| 252 | + /// * `value` - Synthetic value emitted during MIR construction. |
| 253 | + /// |
| 254 | + /// Returns the literal Yul expression for the synthetic value. |
| 255 | + fn lower_synthetic_value(&self, value: &SyntheticValue) -> Result<String, YulError> { |
| 256 | + match value { |
| 257 | + SyntheticValue::Int(int) => Ok(int.to_string()), |
| 258 | + SyntheticValue::Bool(flag) => Ok(if *flag { "1" } else { "0" }.into()), |
| 259 | + } |
| 260 | + } |
| 261 | + |
| 262 | + /// Lowers expressions that may require extra statements (e.g. `if`). |
| 263 | + /// |
| 264 | + /// * `expr_id` - Expression to lower. |
| 265 | + /// * `docs` - Doc list to append emitted statements into. |
| 266 | + /// * `state` - Binding state for allocating temporaries. |
| 267 | + /// |
| 268 | + /// Returns either the inline expression or the name of a temporary containing the result. |
| 269 | + pub(super) fn lower_expr_with_statements( |
| 270 | + &mut self, |
| 271 | + expr_id: ExprId, |
| 272 | + docs: &mut Vec<YulDoc>, |
| 273 | + state: &mut BlockState, |
| 274 | + ) -> Result<String, YulError> { |
| 275 | + if let Some(temp) = self.expr_temps.get(&expr_id) { |
| 276 | + return Ok(temp.clone()); |
| 277 | + } |
| 278 | + if let Some(temp) = self.match_values.get(&expr_id) { |
| 279 | + return Ok(temp.clone()); |
| 280 | + } |
| 281 | + |
| 282 | + let expr = self.expect_expr(expr_id)?; |
| 283 | + if let Expr::If(cond, then_expr, else_expr) = expr { |
| 284 | + let temp = state.alloc_local(); |
| 285 | + docs.push(YulDoc::line(format!("let {temp} := 0"))); |
| 286 | + let cond_expr = self.lower_expr(*cond, state)?; |
| 287 | + let then_expr_str = self.lower_expr(*then_expr, state)?; |
| 288 | + docs.push(YulDoc::block( |
| 289 | + format!("if {cond_expr} "), |
| 290 | + vec![YulDoc::line(format!("{temp} := {then_expr_str}"))], |
| 291 | + )); |
| 292 | + if let Some(else_expr) = else_expr { |
| 293 | + let else_expr_str = self.lower_expr(*else_expr, state)?; |
| 294 | + docs.push(YulDoc::block( |
| 295 | + format!("if iszero({cond_expr}) "), |
| 296 | + vec![YulDoc::line(format!("{temp} := {else_expr_str}"))], |
| 297 | + )); |
| 298 | + } |
| 299 | + Ok(temp) |
| 300 | + } else { |
| 301 | + self.lower_expr(expr_id, state) |
| 302 | + } |
| 303 | + } |
| 304 | + |
| 305 | + /// Returns `true` when the given expression's type is the unit tuple. |
| 306 | + /// |
| 307 | + /// * `expr_id` - Expression identifier whose type should be tested. |
| 308 | + /// |
| 309 | + /// Returns `true` if the expression's type is the unit tuple. |
| 310 | + pub(super) fn expr_is_unit(&self, expr_id: ExprId) -> bool { |
| 311 | + let ty = self.mir_func.typed_body.expr_ty(self.db, expr_id); |
| 312 | + ty.is_tuple(self.db) && ty.field_count(self.db) == 0 |
| 313 | + } |
| 314 | +} |
0 commit comments