diff --git a/Veir/Passes/CSE.lean b/Veir/Passes/CSE.lean index 42dfad197..666cbf533 100644 --- a/Veir/Passes/CSE.lean +++ b/Veir/Passes/CSE.lean @@ -127,7 +127,6 @@ def key? (ctx : IRContext OpCode) (op : OperationPtr) : Option Key := do return ordinaryKey ctx op kind | _ => none -set_option warn.sorry false in /-- Perform CSE on a single BB: Walk the operations, building up a hash of available values. For any operation whose value is already available, replace it with the earlier one. -/ @@ -143,14 +142,13 @@ def processBlock if let some key := key? ctx.raw op then match available[key]? with | some earlier => - ctx := WfRewriter.replaceValue ctx (op.getResult 0) (earlier.getResult 0) sorry sorry sorry - ctx := WfRewriter.eraseOp ctx op sorry sorry sorry + ctx := WfRewriter.replaceValue! ctx (op.getResult 0) (earlier.getResult 0) + ctx := WfRewriter.eraseOp! ctx op | none => available := available.insert key op current := next return ctx -set_option warn.sorry false in def processAllBlocks (ctx : WfIRContext OpCode) : WfIRContext OpCode := Id.run do let mut ctx := ctx diff --git a/Veir/Passes/Canonicalize.lean b/Veir/Passes/Canonicalize.lean index b77dec4d5..011b2bcfe 100644 --- a/Veir/Passes/Canonicalize.lean +++ b/Veir/Passes/Canonicalize.lean @@ -16,7 +16,6 @@ def isConstOperand (ctx : IRContext OpCode) (v : ValuePtr) : Bool := | some defOp => (defOp.getOpType! ctx).isConstantLike | none => false -set_option warn.sorry false in def commutativeConstantRHS (rewriter : PatternRewriter OpCode) (op : OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do let opType := op.getOpType! rewriter.ctx.raw @@ -28,9 +27,9 @@ def commutativeConstantRHS (rewriter : PatternRewriter OpCode) (op : OperationPt if reordered == operands then return rewriter let resultTypes := op.getResultTypes! rewriter.ctx.raw let properties := op.getProperties! rewriter.ctx.raw opType - let (rewriter, newOp) ← rewriter.createOp opType resultTypes reordered - #[] #[] properties (some $ .before op) sorry sorry sorry sorry - rewriter.replaceOp op newOp sorry sorry sorry sorry sorry + let (rewriter, newOp) := rewriter.createOp! opType resultTypes reordered + #[] #[] properties (some $ .before op) + return rewriter.replaceOp! op newOp /-! ## Pass implementation -/ diff --git a/Veir/Passes/CastsReconciliation/Reconciliation.lean b/Veir/Passes/CastsReconciliation/Reconciliation.lean index d6ea0d91a..7de64e225 100644 --- a/Veir/Passes/CastsReconciliation/Reconciliation.lean +++ b/Veir/Passes/CastsReconciliation/Reconciliation.lean @@ -29,18 +29,17 @@ def isPreservingIntegerTypeRoundTrip (inputType interType : TypeAttr) : Bool := | _, _ => false /- Reconciles round-trip casts of the form X->Y->X if allowed for these types by `legal X Y` -/ -set_option warn.sorry false in def reconcilePairingCast (legal : TypeAttr → TypeAttr → Bool) (rewriter : PatternRewriter OpCode) - (op : OperationPtr) (opInBounds : op.InBounds rewriter.ctx.raw) : + (op : OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do - let some cast := matchCastOp op rewriter.ctx.raw | return rewriter + let some _ := matchCastOp op rewriter.ctx.raw | return rewriter let input := op.getOperand! rewriter.ctx.raw 0 /- Note that reconciliation matches on the second casting operation, so the input type of this op would be the intermediate type -/ let interType := input.getType! rewriter.ctx.raw let resultType := ((op.getResult 0).get! rewriter.ctx.raw).type /- If the operand's parent is a cast operation -/ let .opResult op' := input | return rewriter - let some cast := matchCastOp op'.op rewriter.ctx.raw | return rewriter + let some _ := matchCastOp op'.op rewriter.ctx.raw | return rewriter let parentInput := (op'.op.getOperand! rewriter.ctx.raw 0) /- And the result's type coincides with the parent operation operand's type -/ let inputType := parentInput.getType! rewriter.ctx.raw @@ -48,27 +47,26 @@ def reconcilePairingCast (legal : TypeAttr → TypeAttr → Bool) (rewriter : Pa /- And the reconciliation is legal -/ if ¬ legal inputType interType then return rewriter /- Replace the initial operation's output with the parent operations input -/ - let rewriter := rewriter.replaceValue (op.getResult 0) parentInput sorry sorry sorry + let rewriter := rewriter.replaceValue! (op.getResult 0) parentInput /- Erase the redundant cast operation -/ - let rewriter ← rewriter.eraseOp op sorry sorry sorry + let rewriter := rewriter.eraseOp! op /- If unused and side-effect-free, erase the parent cast operation as well. These need to be erased in this order, otherwise the parent operation will always be used. -/ if ¬ op'.op.hasUses! rewriter.ctx.raw && ¬ op'.op.hasSideEffects rewriter.ctx.raw then - rewriter.eraseOp op'.op sorry sorry sorry + return rewriter.eraseOp! op'.op else return rewriter -set_option warn.sorry false in def reconcileIdentityCast (rewriter : PatternRewriter OpCode) (op : OperationPtr) - (opInBounds : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do - let some cast := matchCastOp op rewriter.ctx.raw | return rewriter + (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do + let some _ := matchCastOp op rewriter.ctx.raw | return rewriter /- get the input and output types -/ let input := op.getOperand! rewriter.ctx.raw 0 let inputType := input.getType! rewriter.ctx.raw let resultType := ((op.getResult 0).get! rewriter.ctx.raw).type if inputType ≠ resultType then return rewriter - let rewriter := rewriter.replaceValue (op.getResult 0) input sorry sorry sorry - rewriter.eraseOp op sorry sorry sorry + let rewriter := rewriter.replaceValue! (op.getResult 0) input + return rewriter.eraseOp! op def CastReconcilePass.impl (ctx : WfIRContext OpCode) (op : OperationPtr) (_ : op.InBounds ctx.raw) : ExceptT String IO (WfIRContext OpCode) := do diff --git a/Veir/Passes/DCE/dce.lean b/Veir/Passes/DCE/dce.lean index 49b8f3f4b..6748fbee7 100644 --- a/Veir/Passes/DCE/dce.lean +++ b/Veir/Passes/DCE/dce.lean @@ -6,12 +6,11 @@ namespace Veir /-! We implement a dead code elimination pass. -/ -set_option warn.sorry false in def eliminateDeadOp (rewriter: PatternRewriter OpCode) (op: OperationPtr) - (opInBounds : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do + (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do /- delete operations that are not used and have no side effects -/ if ¬ op.hasUses! rewriter.ctx.raw && ¬ op.hasSideEffects rewriter.ctx.raw then - rewriter.eraseOp op sorry sorry sorry + return rewriter.eraseOp! op else return rewriter diff --git a/Veir/Passes/InstCombine.lean b/Veir/Passes/InstCombine.lean index cd9eb4b20..a0868319d 100644 --- a/Veir/Passes/InstCombine.lean +++ b/Veir/Passes/InstCombine.lean @@ -13,7 +13,6 @@ namespace Veir /-! ## Pattern Rewrites -/ -set_option warn.sorry false in /-- Rewrites `x * 2` to `x + x`. -/ def mulITwoToAddi (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do @@ -23,15 +22,14 @@ def mulITwoToAddi (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op. | return rewriter if cst.value ≠ 2 then return rewriter - let (rewriter, newOp) ← rewriter.createOp (.llvm .add) #[lhs.getType! rewriter.ctx.raw] #[lhs, lhs] - #[] #[] properties (some $ .before op) sorry sorry sorry sorry - rewriter.replaceOp op newOp sorry sorry sorry sorry sorry + let (rewriter, newOp) := rewriter.createOp! (.llvm .add) #[lhs.getType! rewriter.ctx.raw] #[lhs, lhs] + #[] #[] properties (some $ .before op) + return rewriter.replaceOp! op newOp -set_option warn.sorry false in /-- Rewrites `x * 0` to `0`. -/ def mulIZeroToCst (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do - let some (lhs, rhs, properties) := matchMuli op rewriter.ctx + let some (lhs, rhs, _) := matchMuli op rewriter.ctx | return rewriter let some cst := matchConstantIntVal rhs rewriter.ctx | return rewriter @@ -40,11 +38,10 @@ def mulIZeroToCst (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op. let .integerType type := (lhs.getType! rewriter.ctx.raw).val | return rewriter let cstProp := LLVMConstantProperties.mk (.integer (IntegerAttr.mk 0 type)) - let (rewriter, newOp) ← rewriter.createOp (.llvm .mlir__constant) #[lhs.getType! rewriter.ctx.raw] #[] - #[] #[] cstProp (some $ .before op) sorry sorry sorry sorry - rewriter.replaceOp op newOp sorry sorry sorry sorry sorry + let (rewriter, newOp) := rewriter.createOp! (.llvm .mlir__constant) #[lhs.getType! rewriter.ctx.raw] #[] + #[] #[] cstProp (some $ .before op) + return rewriter.replaceOp! op newOp -set_option warn.sorry false in /-- Rewrites `x + 0` to `x`. -/ def addiZeroToX (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do @@ -54,10 +51,9 @@ def addiZeroToX (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.In | return rewriter if cst.value ≠ 0 then return rewriter - let rewriter := rewriter.replaceValue (op.getResult 0) lhs sorry sorry sorry - rewriter.eraseOp op sorry sorry sorry + let rewriter := rewriter.replaceValue! (op.getResult 0) lhs + return rewriter.eraseOp! op -set_option warn.sorry false in /-- Rewrites `x * 1` to `x`. -/ def mulIOneToX (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do @@ -67,10 +63,9 @@ def mulIOneToX (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InB | return rewriter if cst.value ≠ 1 then return rewriter - let rewriter := rewriter.replaceValue (op.getResult 0) lhs sorry sorry sorry - rewriter.eraseOp op sorry sorry sorry + let rewriter := rewriter.replaceValue! (op.getResult 0) lhs + return rewriter.eraseOp! op -set_option warn.sorry false in /-- Rewrites `x - 0` to `x`. -/ def subiZeroToX (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do @@ -80,10 +75,9 @@ def subiZeroToX (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.In | return rewriter if cst.value ≠ 0 then return rewriter - let rewriter := rewriter.replaceValue (op.getResult 0) lhs sorry sorry sorry - rewriter.eraseOp op sorry sorry sorry + let rewriter := rewriter.replaceValue! (op.getResult 0) lhs + return rewriter.eraseOp! op -set_option warn.sorry false in /-- Rewrites `x - x` to `0`. -/ def subiSelfToZero (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do @@ -94,11 +88,10 @@ def subiSelfToZero (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op let .integerType type := (lhs.getType! rewriter.ctx.raw).val | return rewriter let cstProp := LLVMConstantProperties.mk (.integer (IntegerAttr.mk 0 type)) - let (rewriter, newOp) ← rewriter.createOp (.llvm .mlir__constant) #[lhs.getType! rewriter.ctx.raw] #[] - #[] #[] cstProp (some $ .before op) sorry sorry sorry sorry - rewriter.replaceOp op newOp sorry sorry sorry sorry sorry + let (rewriter, newOp) := rewriter.createOp! (.llvm .mlir__constant) #[lhs.getType! rewriter.ctx.raw] #[] + #[] #[] cstProp (some $ .before op) + return rewriter.replaceOp! op newOp -set_option warn.sorry false in /-- Rewrites `x & x` to `x`. -/ def andiSelfToX (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do @@ -106,10 +99,9 @@ def andiSelfToX (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.In | return rewriter if lhs ≠ rhs then return rewriter - let rewriter := rewriter.replaceValue (op.getResult 0) lhs sorry sorry sorry - rewriter.eraseOp op sorry sorry sorry + let rewriter := rewriter.replaceValue! (op.getResult 0) lhs + return rewriter.eraseOp! op -set_option warn.sorry false in /-- Rewrites `x & 0` to `0`. -/ def andiZeroToZero (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do @@ -122,11 +114,10 @@ def andiZeroToZero (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op let .integerType type := (lhs.getType! rewriter.ctx.raw).val | return rewriter let cstProp := LLVMConstantProperties.mk (.integer (IntegerAttr.mk 0 type)) - let (rewriter, newOp) ← rewriter.createOp (.llvm .mlir__constant) #[lhs.getType! rewriter.ctx.raw] #[] - #[] #[] cstProp (some $ .before op) sorry sorry sorry sorry - rewriter.replaceOp op newOp sorry sorry sorry sorry sorry + let (rewriter, newOp) := rewriter.createOp! (.llvm .mlir__constant) #[lhs.getType! rewriter.ctx.raw] #[] + #[] #[] cstProp (some $ .before op) + return rewriter.replaceOp! op newOp -set_option warn.sorry false in /-- Rewrites `x | 0` to `x`. -/ def oriZeroToX (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do @@ -136,10 +127,9 @@ def oriZeroToX (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InB | return rewriter if cst.value ≠ 0 then return rewriter - let rewriter := rewriter.replaceValue (op.getResult 0) lhs sorry sorry sorry - rewriter.eraseOp op sorry sorry sorry + let rewriter := rewriter.replaceValue! (op.getResult 0) lhs + return rewriter.eraseOp! op -set_option warn.sorry false in /-- Rewrites `x | x` to `x`. -/ def oriSelfToX (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do @@ -147,10 +137,9 @@ def oriSelfToX (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InB | return rewriter if lhs ≠ rhs then return rewriter - let rewriter := rewriter.replaceValue (op.getResult 0) lhs sorry sorry sorry - rewriter.eraseOp op sorry sorry sorry + let rewriter := rewriter.replaceValue! (op.getResult 0) lhs + return rewriter.eraseOp! op -set_option warn.sorry false in /-- Rewrites `x ^ 0` to `x`. -/ def xoriZeroToX (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do @@ -160,10 +149,9 @@ def xoriZeroToX (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.In | return rewriter if cst.value ≠ 0 then return rewriter - let rewriter := rewriter.replaceValue (op.getResult 0) lhs sorry sorry sorry - rewriter.eraseOp op sorry sorry sorry + let rewriter := rewriter.replaceValue! (op.getResult 0) lhs + return rewriter.eraseOp! op -set_option warn.sorry false in /-- Rewrites `x ^ x` to `0`. -/ def xoriSelfToZero (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do @@ -174,11 +162,10 @@ def xoriSelfToZero (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op let .integerType type := (lhs.getType! rewriter.ctx.raw).val | return rewriter let cstProp := LLVMConstantProperties.mk (.integer (IntegerAttr.mk 0 type)) - let (rewriter, newOp) ← rewriter.createOp (.llvm .mlir__constant) #[lhs.getType! rewriter.ctx.raw] #[] - #[] #[] cstProp (some $ .before op) sorry sorry sorry sorry - rewriter.replaceOp op newOp sorry sorry sorry sorry sorry + let (rewriter, newOp) := rewriter.createOp! (.llvm .mlir__constant) #[lhs.getType! rewriter.ctx.raw] #[] + #[] #[] cstProp (some $ .before op) + return rewriter.replaceOp! op newOp -set_option warn.sorry false in /-- Rewrites `~~x` to `x`. -/ def notNotToX (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do @@ -186,10 +173,9 @@ def notNotToX (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBo | return rewriter let some inner := matchNot outerNotted rewriter.ctx | return rewriter - let rewriter := rewriter.replaceValue (op.getResult 0) inner sorry sorry sorry - rewriter.eraseOp op sorry sorry sorry + let rewriter := rewriter.replaceValue! (op.getResult 0) inner + return rewriter.eraseOp! op -set_option warn.sorry false in /-- Rewrites `~(~a & ~b)` to `a | b` (DeMorgan). -/ /- TODO: the precondition should be strengthened by some hasOneUse() checks -/ def deMorganAndToOr (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : @@ -206,11 +192,10 @@ def deMorganAndToOr (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : o | return rewriter let resultType := a.getType! rewriter.ctx.raw let orProps : DisjointProperties := { disjoint := false } - let (rewriter, newOp) ← rewriter.createOp (.llvm .or) #[resultType] #[a, b] - #[] #[] orProps (some $ .before op) sorry sorry sorry sorry - rewriter.replaceOp op newOp sorry sorry sorry sorry sorry + let (rewriter, newOp) := rewriter.createOp! (.llvm .or) #[resultType] #[a, b] + #[] #[] orProps (some $ .before op) + return rewriter.replaceOp! op newOp -set_option warn.sorry false in /-- Rewrites `~(~a | ~b)` to `a & b` (DeMorgan). -/ /- TODO: the precondition should be strengthened by some hasOneUse() checks -/ def deMorganOrToAnd (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : @@ -226,9 +211,9 @@ def deMorganOrToAnd (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : o let some b := matchNot orR rewriter.ctx | return rewriter let resultType := a.getType! rewriter.ctx.raw - let (rewriter, newOp) ← rewriter.createOp (.llvm .and) #[resultType] #[a, b] - #[] #[] () (some $ .before op) sorry sorry sorry sorry - rewriter.replaceOp op newOp sorry sorry sorry sorry sorry + let (rewriter, newOp) := rewriter.createOp! (.llvm .and) #[resultType] #[a, b] + #[] #[] () (some $ .before op) + return rewriter.replaceOp! op newOp def InstCombinePass.impl (ctx : WfIRContext OpCode) (op : OperationPtr) (_ : op.InBounds ctx.raw) : ExceptT String IO (WfIRContext OpCode) := do diff --git a/Veir/Passes/InstructionSelection/Common.lean b/Veir/Passes/InstructionSelection/Common.lean index ef07dd29d..52d55d648 100644 --- a/Veir/Passes/InstructionSelection/Common.lean +++ b/Veir/Passes/InstructionSelection/Common.lean @@ -7,18 +7,16 @@ namespace Veir Shared helpers for the RISC-V instruction-selection lowering patterns. -/ -set_option warn.sorry false in /-- Insert `unrealized_conversion_cast : (typeof v) -> !riscv.reg` before `op`, returning the updated rewriter and the register-typed result value. -/ def castToReg (rewriter : PatternRewriter OpCode) (op : OperationPtr) (v : ValuePtr) : Option (PatternRewriter OpCode × ValuePtr) := do - let (rewriter, castOp) ← rewriter.createOp (.builtin .unrealized_conversion_cast) - #[RegisterType.mk] #[v] #[] #[] () (some $ .before op) sorry (by simp) (by simp) sorry + let (rewriter, castOp) := rewriter.createOp! (.builtin .unrealized_conversion_cast) + #[RegisterType.mk] #[v] #[] #[] () (some $ .before op) return (rewriter, castOp.getResult 0) -set_option warn.sorry false in /-- Cast the register value `reg` back to `op`'s result type and replace `op` with the cast. The target type is read from `op`, so this is type-agnostic (it also @@ -27,8 +25,8 @@ set_option warn.sorry false in def replaceWithReg (rewriter : PatternRewriter OpCode) (op : OperationPtr) (reg : ValuePtr) : Option (PatternRewriter OpCode) := do let type := ((op.getResult 0).get! rewriter.ctx.raw).type - let (rewriter, castOp) ← rewriter.createOp (.builtin .unrealized_conversion_cast) - #[type] #[reg] #[] #[] () (some $ .before op) sorry (by simp) (by simp) sorry - rewriter.replaceOp op castOp sorry sorry sorry sorry sorry + let (rewriter, castOp) := rewriter.createOp! (.builtin .unrealized_conversion_cast) + #[type] #[reg] #[] #[] () (some $ .before op) + return rewriter.replaceOp! op castOp end Veir diff --git a/Veir/Passes/InstructionSelection/RISCV64.lean b/Veir/Passes/InstructionSelection/RISCV64.lean index caea34f7f..63f8d4407 100644 --- a/Veir/Passes/InstructionSelection/RISCV64.lean +++ b/Veir/Passes/InstructionSelection/RISCV64.lean @@ -19,7 +19,6 @@ namespace Veir def isLegalExtOpWidth (w : Nat) : Bool := w = 8 ∨ w = 16 ∨ w = 32 -set_option warn.sorry false in /-- `llvm.intr.ctlz` -> `riscv.clz`. -/ @@ -38,7 +37,6 @@ def ctlz (rewriter : PatternRewriter OpCode) (op : OperationPtr) #[] #[] () (some $ .before op) replaceWithReg rewriter op (retOp.getResult 0) -set_option warn.sorry false in /-- `llvm.intr.cttz` -> `riscv.ctz`. -/ @@ -57,7 +55,6 @@ def cttz (rewriter : PatternRewriter OpCode) (op : OperationPtr) #[] #[] () (some $ .before op) replaceWithReg rewriter op (retOp.getResult 0) -set_option warn.sorry false in /-- `llvm.intr.ctpop` -> `riscv.cpop`. -/ @@ -76,7 +73,6 @@ def ctpop (rewriter : PatternRewriter OpCode) (op : OperationPtr) #[] #[] () (some $ .before op) replaceWithReg rewriter op (retOp.getResult 0) -set_option warn.sorry false in /-- `llvm.intr.bswap` -> `riscv.rev8`. -/ @@ -98,7 +94,6 @@ def bswap (rewriter : PatternRewriter OpCode) (op : OperationPtr) else replaceWithReg rewriter op (rev8Op.getResult 0) -set_option warn.sorry false in /-- One SWAR bit-reversal stage: `((x & mask) << shamt) | ((x >> shamt) & mask)`. @@ -122,7 +117,6 @@ def bitreverseStage (mask shamt : Int) (rewriter : PatternRewriter OpCode) #[lowShiftOp.getResult 0, highOp.getResult 0] #[] #[] () (some $ .before op) return (rewriter, orOp.getResult 0) -set_option warn.sorry false in /-- `llvm.intr.bitreverse` -> mask/shift/or stages followed by `riscv.rev8`. -/ @@ -152,7 +146,6 @@ def bitreverse (rewriter : PatternRewriter OpCode) (op : OperationPtr) #[] #[] () (some $ .before op) replaceWithReg rewriter op (retOp.getResult 0) -set_option warn.sorry false in /-- llvm.constant -> riscv.li -/ def constant (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do @@ -166,13 +159,12 @@ def constant (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBou #[] #[] {value := const} (some $ .before op) let (rewriter, castOp) := rewriter.createOp! (.builtin .unrealized_conversion_cast) #[type] #[newOp.getResult 0] #[] #[] () (some $ .before op) - rewriter.replaceOp op castOp sorry sorry sorry sorry sorry + return rewriter.replaceOp! op castOp -set_option warn.sorry false in /-- llvm.add -> riscv.add -/ def add (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do - let some (lhs, rhs, properties) := matchAdd op rewriter.ctx | return rewriter + let some (lhs, rhs, _) := matchAdd op rewriter.ctx | return rewriter /- support `i64` and `i32` (experiment) -/ let .integerType ltype := (lhs.getType! rewriter.ctx.raw).val | return rewriter if ltype.bitwidth ≠ 64 ∧ ltype.bitwidth ≠ 32 then return rewriter @@ -197,9 +189,8 @@ def add (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds r /- Cast back result for type consistency-/ let (rewriter, castAddOp) := rewriter.createOp! (.builtin .unrealized_conversion_cast) #[type] #[addOp.getResult 0] #[] #[] () (some $ .before op) - rewriter.replaceOp op castAddOp sorry sorry sorry sorry sorry + return rewriter.replaceOp! op castAddOp -set_option warn.sorry false in /-- llvm.and -> riscv.and -/ def and (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do @@ -223,13 +214,12 @@ def and (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds r /- Cast back result for type consistency-/ let (rewriter, castAddOp) := rewriter.createOp! (.builtin .unrealized_conversion_cast) #[type] #[andOp.getResult 0] #[] #[] () (some $ .before op) - rewriter.replaceOp op castAddOp sorry sorry sorry sorry sorry + return rewriter.replaceOp! op castAddOp -set_option warn.sorry false in /-- llvm.ashr -> riscv.sra -/ def ashr (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do - let some (lhs, rhs, properties) := matchAshr op rewriter.ctx | return rewriter + let some (lhs, rhs, _) := matchAshr op rewriter.ctx | return rewriter /- support `i64` and `i32` -/ let .integerType ltype := (lhs.getType! rewriter.ctx.raw).val | return rewriter if ltype.bitwidth ≠ 64 ∧ ltype.bitwidth ≠ 32 then return rewriter @@ -254,9 +244,8 @@ def ashr (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds /- Cast back result for type consistency-/ let (rewriter, castSraOp) := rewriter.createOp! (.builtin .unrealized_conversion_cast) #[type] #[sraOp.getResult 0] #[] #[] () (some $ .before op) - rewriter.replaceOp op castSraOp sorry sorry sorry sorry sorry + return rewriter.replaceOp! op castSraOp -set_option warn.sorry false in /-- llvm.icmp eq lhs rhs -> riscv.sltiu (riscv.xor lhs rhs) 1 llvm.icmp ne lhs rhs -> riscv.sltu 0 (riscv.xor lhs rhs) @@ -295,7 +284,7 @@ def icmp (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds pure (rewriter, lcastOp.getResult 0, rcastOp.getResult 0) /- Casting back result for type consistency is always necessary. -/ let type := ((op.getResult 0).get! rewriter.ctx.raw).type - let .integerType type' := type.val | rewriter + let .integerType _ := type.val | rewriter /- Match depending on the predicate and build correct lowering. -/ let (rewriter, retOp) ← match property.predicate with | .eq => @@ -370,9 +359,8 @@ def icmp (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds pure (rewriter, retOp) let (rewriter, castEqOp) := rewriter.createOp! (.builtin .unrealized_conversion_cast) #[type] #[retOp.getResult 0] #[] #[] () (some $ .before op) - rewriter.replaceOp op castEqOp sorry sorry sorry sorry sorry + return rewriter.replaceOp! op castEqOp -set_option warn.sorry false in /-- llvm.or -> riscv.or -/ def or (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do @@ -396,9 +384,8 @@ def or (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds re /- Cast back result for type consistency-/ let (rewriter, castOrOp) := rewriter.createOp! (.builtin .unrealized_conversion_cast) #[type] #[orOp.getResult 0] #[] #[] () (some $ .before op) - rewriter.replaceOp op castOrOp sorry sorry sorry sorry sorry + return rewriter.replaceOp! op castOrOp -set_option warn.sorry false in /-- llvm.xor -> riscv.xor -/ def xor (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do @@ -422,9 +409,8 @@ def xor (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds r /- Cast back result for type consistency-/ let (rewriter, castXorOp) := rewriter.createOp! (.builtin .unrealized_conversion_cast) #[type] #[xorOp.getResult 0] #[] #[] () (some $ .before op) - rewriter.replaceOp op castXorOp sorry sorry sorry sorry sorry + return rewriter.replaceOp! op castXorOp -set_option warn.sorry false in /-- llvm.mul -> riscv.mul -/ def mul (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do @@ -453,9 +439,8 @@ def mul (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds r /- Cast back result for type consistency-/ let (rewriter, castOp) := rewriter.createOp! (.builtin .unrealized_conversion_cast) #[type] #[mulOp.getResult 0] #[] #[] () (some $ .before op) - rewriter.replaceOp op castOp sorry sorry sorry sorry sorry + return rewriter.replaceOp! op castOp -set_option warn.sorry false in /-- llvm.sdiv -> riscv.div -/ def sdiv (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do @@ -484,9 +469,8 @@ def sdiv (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds /- Cast back result for type consistency-/ let (rewriter, castOp) := rewriter.createOp! (.builtin .unrealized_conversion_cast) #[type] #[divOp.getResult 0] #[] #[] () (some $ .before op) - rewriter.replaceOp op castOp sorry sorry sorry sorry sorry + return rewriter.replaceOp! op castOp -set_option warn.sorry false in /-- llvm.udiv -> riscv.divu -/ def udiv (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do @@ -515,9 +499,8 @@ def udiv (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds /- Cast back result for type consistency-/ let (rewriter, castOp) := rewriter.createOp! (.builtin .unrealized_conversion_cast) #[type] #[divuOp.getResult 0] #[] #[] () (some $ .before op) - rewriter.replaceOp op castOp sorry sorry sorry sorry sorry + return rewriter.replaceOp! op castOp -set_option warn.sorry false in /-- llvm.srem -> riscv.rem -/ def srem (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do @@ -546,9 +529,8 @@ def srem (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds /- Cast back result for type consistency-/ let (rewriter, castOp) := rewriter.createOp! (.builtin .unrealized_conversion_cast) #[type] #[remOp.getResult 0] #[] #[] () (some $ .before op) - rewriter.replaceOp op castOp sorry sorry sorry sorry sorry + return rewriter.replaceOp! op castOp -set_option warn.sorry false in /-- llvm.urem -> riscv.remu -/ def urem (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do @@ -577,9 +559,8 @@ def urem (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds /- Cast back result for type consistency-/ let (rewriter, castOp) := rewriter.createOp! (.builtin .unrealized_conversion_cast) #[type] #[remuOp.getResult 0] #[] #[] () (some $ .before op) - rewriter.replaceOp op castOp sorry sorry sorry sorry sorry + return rewriter.replaceOp! op castOp -set_option warn.sorry false in /-- llvm.sub -> riscv.sub -/ def sub (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do @@ -608,9 +589,8 @@ def sub (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds r /- Cast back result for type consistency-/ let (rewriter, castOp) := rewriter.createOp! (.builtin .unrealized_conversion_cast) #[type] #[subOp.getResult 0] #[] #[] () (some $ .before op) - rewriter.replaceOp op castOp sorry sorry sorry sorry sorry + return rewriter.replaceOp! op castOp -set_option warn.sorry false in /-- llvm.sext %x `i8` to `i32` -> riscv.sextb %x llvm.sext %x `i8` to `i64` -> riscv.sextb %x @@ -648,9 +628,8 @@ def sext (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds /- Cast back result for type consistency-/ let (rewriter, castOp) := rewriter.createOp! (.builtin .unrealized_conversion_cast) #[type] #[retOp.getResult 0] #[] #[] () (some $ .before op) - rewriter.replaceOp op castOp sorry sorry sorry sorry sorry + return rewriter.replaceOp! op castOp -set_option warn.sorry false in /-- llvm.zext %x `i16` to `i64` -> riscv.zexth %x llvm.zext %x `i16` to `i32` -> riscv.zexth %x @@ -686,9 +665,8 @@ def zext (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds /- Cast back result for type consistency-/ let (rewriter, castOp) := rewriter.createOp! (.builtin .unrealized_conversion_cast) #[type] #[retOp.getResult 0] #[] #[] () (some $ .before op) - rewriter.replaceOp op castOp sorry sorry sorry sorry sorry + return rewriter.replaceOp! op castOp -set_option warn.sorry false in /-- llvm.trunc %x iX to iY -> builtin_unrealized_conversion_cast (!riscv.reg) : iY where `iY`'s width is smaller than `iX`'s. @@ -706,9 +684,8 @@ def trunc (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds /- Then, cast register to expected output width. -/ let (rewriter, castOp) := rewriter.createOp! (.builtin .unrealized_conversion_cast) #[type] #[opCastOp.getResult 0] #[] #[] () (some $ .before op) - rewriter.replaceOp op castOp sorry sorry sorry sorry sorry + return rewriter.replaceOp! op castOp -set_option warn.sorry false in /-- llvm.shl -> riscv.sll -/ def shl (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do @@ -737,9 +714,8 @@ def shl (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds r /- Cast back result for type consistency-/ let (rewriter, castOp) := rewriter.createOp! (.builtin .unrealized_conversion_cast) #[type] #[mulOp.getResult 0] #[] #[] () (some $ .before op) - rewriter.replaceOp op castOp sorry sorry sorry sorry sorry + return rewriter.replaceOp! op castOp -set_option warn.sorry false in /-- llvm.shl -> riscv.srl -/ def lshr (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do @@ -768,9 +744,8 @@ def lshr (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds /- Cast back result for type consistency-/ let (rewriter, castOp) := rewriter.createOp! (.builtin .unrealized_conversion_cast) #[type] #[mulOp.getResult 0] #[] #[] () (some $ .before op) - rewriter.replaceOp op castOp sorry sorry sorry sorry sorry + return rewriter.replaceOp! op castOp -set_option warn.sorry false in /-- llvm.load -> riscv.ld (i64) / riscv.lw (i32) / riscv.lb (i8) -/ def load (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do @@ -796,9 +771,8 @@ def load (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds #[] #[] zero (some $ .before op) let (rewriter, castOp) := rewriter.createOp! (.builtin .unrealized_conversion_cast) #[type] #[ldOp.getResult 0] #[] #[] () (some $ .before op) - rewriter.replaceOp op castOp sorry sorry sorry sorry sorry + return rewriter.replaceOp! op castOp -set_option warn.sorry false in /-- llvm.store -> riscv.sd (i64) / riscv.sw (i32) / riscv.sb (i8) -/ def store (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do @@ -825,9 +799,8 @@ def store (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds else rewriter.createOp! (.riscv .sd) #[] #[pcastOp.getResult 0, valcastOp.getResult 0] #[] #[] zero (some $ .before op) - rewriter.replaceOp op sdOp sorry sorry sorry sorry sorry + return rewriter.replaceOp! op sdOp -set_option warn.sorry false in /-- Lower a single-dynamic-index `llvm.getelementptr` computing `ptr + idx * scale`, where `scale` is the byte size of the element type. @@ -885,7 +858,7 @@ def getelementptr (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op. /- Cast the resulting register back to `!llvm.ptr`. -/ let (rewriter, castOp) := rewriter.createOp! (.builtin .unrealized_conversion_cast) #[type] #[retOp.getResult 0] #[] #[] () (some $ .before op) - rewriter.replaceOp op castOp sorry sorry sorry sorry sorry + return rewriter.replaceOp! op castOp /-! ## Zicond branchless `select` lowering @@ -901,7 +874,6 @@ def getelementptr (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op. `selectCzeronez` are registered before `selectGeneral`. -/ -set_option warn.sorry false in /-- `select c t 0` -> `riscv.czeroeqz t c`. -/ @@ -917,7 +889,6 @@ def selectCzeroeqz (rewriter : PatternRewriter OpCode) (op : OperationPtr) #[] #[] () (some $ .before op) replaceWithReg rewriter op (czOp.getResult 0) -set_option warn.sorry false in /-- `select c 0 f` -> `riscv.czeronez f c`. -/ @@ -933,7 +904,6 @@ def selectCzeronez (rewriter : PatternRewriter OpCode) (op : OperationPtr) #[] #[] () (some $ .before op) replaceWithReg rewriter op (czOp.getResult 0) -set_option warn.sorry false in /-- General branchless select: `select c t f` -> `or (czero.eqz t c) (czero.nez f c)`. @@ -964,7 +934,6 @@ def selectGeneral (rewriter : PatternRewriter OpCode) (op : OperationPtr) multi-instruction expansion and is intentionally left unselected. -/ -set_option warn.sorry false in /-- llvm.intr.smax -> riscv.max -/ def smax (rewriter : PatternRewriter OpCode) (op : OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do @@ -988,7 +957,6 @@ def smax (rewriter : PatternRewriter OpCode) (op : OperationPtr) #[] #[] () (some $ .before op) replaceWithReg rewriter op (maxOp.getResult 0) -set_option warn.sorry false in /-- llvm.intr.smin -> riscv.min -/ def smin (rewriter : PatternRewriter OpCode) (op : OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do @@ -1012,7 +980,6 @@ def smin (rewriter : PatternRewriter OpCode) (op : OperationPtr) #[] #[] () (some $ .before op) replaceWithReg rewriter op (minOp.getResult 0) -set_option warn.sorry false in /-- llvm.intr.umax -> riscv.maxu -/ def umax (rewriter : PatternRewriter OpCode) (op : OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do @@ -1025,7 +992,6 @@ def umax (rewriter : PatternRewriter OpCode) (op : OperationPtr) #[] #[] () (some $ .before op) replaceWithReg rewriter op (maxuOp.getResult 0) -set_option warn.sorry false in /-- llvm.intr.umin -> riscv.minu -/ def umin (rewriter : PatternRewriter OpCode) (op : OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do @@ -1038,7 +1004,6 @@ def umin (rewriter : PatternRewriter OpCode) (op : OperationPtr) #[] #[] () (some $ .before op) replaceWithReg rewriter op (minuOp.getResult 0) -set_option warn.sorry false in /-- llvm.intr.fshl with identical data operands is a rotate-left: -> riscv.rol. The general (distinct-operand) funnel shift is left unselected. -/ def fshl (rewriter : PatternRewriter OpCode) (op : OperationPtr) @@ -1058,7 +1023,6 @@ def fshl (rewriter : PatternRewriter OpCode) (op : OperationPtr) #[] #[] () (some $ .before op) replaceWithReg rewriter op (rolOp.getResult 0) -set_option warn.sorry false in /-- llvm.intr.fshr with identical data operands is a rotate-right: -> riscv.ror. The general (distinct-operand) funnel shift is left unselected. -/ def fshr (rewriter : PatternRewriter OpCode) (op : OperationPtr) @@ -1078,7 +1042,6 @@ def fshr (rewriter : PatternRewriter OpCode) (op : OperationPtr) #[] #[] () (some $ .before op) replaceWithReg rewriter op (rorOp.getResult 0) -set_option warn.sorry false in /-- llvm.intr.fshr with identical data operands and a constant shift amount is a constant rotate-right: -> riscv.rori (mirrors `PatGprImm`). -/ def fshrConst (rewriter : PatternRewriter OpCode) (op : OperationPtr) @@ -1103,7 +1066,6 @@ def fshrConst (rewriter : PatternRewriter OpCode) (op : OperationPtr) #[] #[] imm (some $ .before op) replaceWithReg rewriter op (roriOp.getResult 0) -set_option warn.sorry false in /-- llvm.intr.fshl with identical data operands and a constant shift amount is a constant rotate-left. There is no `roli`, so (like LLVM) it lowers to `riscv.rori` with the negated immediate `(64 - amt) mod 64`. -/ @@ -1133,7 +1095,6 @@ def fshlConst (rewriter : PatternRewriter OpCode) (op : OperationPtr) replaceWithReg rewriter op (roriOp.getResult 0) -set_option warn.sorry false in /-- llvm.mlir.poison -> riscv.li 0 -/ def poisonConst (rewriter : PatternRewriter OpCode) (op : OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do @@ -1143,7 +1104,6 @@ def poisonConst (rewriter : PatternRewriter OpCode) (op : OperationPtr) #[] #[] imm (some $ .before op) replaceWithReg rewriter op (liOp.getResult 0) -set_option warn.sorry false in /-- llvm.freeze arg : Int w -> unrealized_conversion_cast (unrealized_conversion_cast arg : Int w -> Reg) : Reg -> Int w -/ def freeze (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : @@ -1159,7 +1119,7 @@ def freeze (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBound /- Then, cast register to expected output width. -/ let (rewriter, castOp) := rewriter.createOp! (.builtin .unrealized_conversion_cast) #[type] #[opCastOp.getResult 0] #[] #[] () (some $ .before op) - rewriter.replaceOp op castOp sorry sorry sorry sorry sorry + return rewriter.replaceOp! op castOp /-! # Pass implementation -/ diff --git a/Veir/Passes/InstructionSelection/RISCV64Sdag.lean b/Veir/Passes/InstructionSelection/RISCV64Sdag.lean index ba7170092..b41ba7af8 100644 --- a/Veir/Passes/InstructionSelection/RISCV64Sdag.lean +++ b/Veir/Passes/InstructionSelection/RISCV64Sdag.lean @@ -19,7 +19,6 @@ def singleSetBit (x : BitVec 64) : Option Int := /-! # SelectionDAG Lowering Patterns -/ -set_option warn.sorry false in /-- `and x (not y)` -> `riscv.andn x y`. The `not` may appear on either operand. -/ @@ -36,11 +35,10 @@ def andn (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds | none => none) | return rewriter let (rewriter, xReg) ← castToReg rewriter op x let (rewriter, yReg) ← castToReg rewriter op y - let (rewriter, andnOp) ← rewriter.createOp (.riscv .andn) #[RegisterType.mk] #[xReg, yReg] - #[] #[] () (some $ .before op) sorry (by simp) (by simp) sorry + let (rewriter, andnOp) := rewriter.createOp! (.riscv .andn) #[RegisterType.mk] #[xReg, yReg] + #[] #[] () (some $ .before op) replaceWithReg rewriter op (andnOp.getResult 0) -set_option warn.sorry false in /-- `or x (not y)` -> `riscv.orn x y`. The `not` may appear on either operand. -/ @@ -57,11 +55,10 @@ def orn (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds r | none => none) | return rewriter let (rewriter, xReg) ← castToReg rewriter op x let (rewriter, yReg) ← castToReg rewriter op y - let (rewriter, ornOp) ← rewriter.createOp (.riscv .orn) #[RegisterType.mk] #[xReg, yReg] - #[] #[] () (some $ .before op) sorry (by simp) (by simp) sorry + let (rewriter, ornOp) := rewriter.createOp! (.riscv .orn) #[RegisterType.mk] #[xReg, yReg] + #[] #[] () (some $ .before op) replaceWithReg rewriter op (ornOp.getResult 0) -set_option warn.sorry false in /-- `xor x (not y)` -> `riscv.xnor x y`. The `not` may appear on either operand. -/ @@ -78,11 +75,10 @@ def xnor (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds | none => none) | return rewriter let (rewriter, xReg) ← castToReg rewriter op x let (rewriter, yReg) ← castToReg rewriter op y - let (rewriter, xnorOp) ← rewriter.createOp (.riscv .xnor) #[RegisterType.mk] #[xReg, yReg] - #[] #[] () (some $ .before op) sorry (by simp) (by simp) sorry + let (rewriter, xnorOp) := rewriter.createOp! (.riscv .xnor) #[RegisterType.mk] #[xReg, yReg] + #[] #[] () (some $ .before op) replaceWithReg rewriter op (xnorOp.getResult 0) -set_option warn.sorry false in /-- `sub (shl M (8 - Y)) (lshr M Y)` -> `riscv.orcb M`, where `M = and Z (0x0101_0101_0101_0101 <<< Y)` @@ -124,8 +120,8 @@ def orcb (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds if !(isMask mo0 || isMask mo1) then return rewriter let (rewriter, mReg) ← castToReg rewriter op m /- actual `riscv.orcb` -/ - let (rewriter, orcbOp) ← rewriter.createOp (.riscv .orcb) #[RegisterType.mk] #[mReg] - #[] #[] () (some $ .before op) sorry (by simp) (by simp) sorry + let (rewriter, orcbOp) := rewriter.createOp! (.riscv .orcb) #[RegisterType.mk] #[mReg] + #[] #[] () (some $ .before op) replaceWithReg rewriter op (orcbOp.getResult 0) @@ -146,7 +142,6 @@ def orcb (rewriter: PatternRewriter OpCode) (op: OperationPtr) (_ : op.InBounds https://github.com/llvm/llvm-project/blob/2e87cf8c2b8ec6453ccfa7e448d5b33f1d71a2ca/llvm/lib/Target/RISCV/RISCVInstrInfo.td#L1386-L1393 -/ -set_option warn.sorry false in /-- `OP x (const imm)` -> `OPi x imm`, when the op's result has width `width` and the immediate lies in `[lo, hi]`. Mirrors `PatGprImm`. The @@ -163,8 +158,8 @@ def selectBinopImm {α} (matchPair : OperationPtr → IRContext OpCode → Optio if imm.value < lo || imm.value > hi then return rewriter let (rewriter, xReg) ← castToReg rewriter op lhs let immProps := RISCVImmediateProperties.mk (IntegerAttr.mk imm.value (IntegerType.mk 64)) - let (rewriter, newOp) ← rewriter.createOp (.riscv dst) #[RegisterType.mk] #[xReg] - #[] #[] (cast h.symm immProps) (some $ .before op) sorry (by simp) (by simp) sorry + let (rewriter, newOp) := rewriter.createOp! (.riscv dst) #[RegisterType.mk] #[xReg] + #[] #[] (cast h.symm immProps) (some $ .before op) replaceWithReg rewriter op (newOp.getResult 0) /-- imm12 binops on i64: `add/or/and/xor x (const ∈ [-2048, 2047])` -> `addi/ori/andi/xori`. @@ -188,7 +183,6 @@ def slliw := selectBinopImm matchShl .slliw rfl 32 0 31 def srliw := selectBinopImm matchLshr .srliw rfl 32 0 31 def sraiw := selectBinopImm matchAshr .sraiw rfl 32 0 31 -set_option warn.sorry false in /-- `icmp slt x (const imm12)` -> `riscv.slti x imm`; `icmp ult x (const imm12)` -> `riscv.sltiu x imm`. @@ -207,13 +201,13 @@ def slti (rewriter : PatternRewriter OpCode) (op : OperationPtr) match prop.predicate with | .slt => let (rewriter, xReg) ← castToReg rewriter op lhs - let (rewriter, newOp) ← rewriter.createOp (.riscv .slti) #[RegisterType.mk] #[xReg] - #[] #[] immProps (some $ .before op) sorry (by simp) (by simp) sorry + let (rewriter, newOp) := rewriter.createOp! (.riscv .slti) #[RegisterType.mk] #[xReg] + #[] #[] immProps (some $ .before op) replaceWithReg rewriter op (newOp.getResult 0) | .ult => let (rewriter, xReg) ← castToReg rewriter op lhs - let (rewriter, newOp) ← rewriter.createOp (.riscv .sltiu) #[RegisterType.mk] #[xReg] - #[] #[] immProps (some $ .before op) sorry (by simp) (by simp) sorry + let (rewriter, newOp) := rewriter.createOp! (.riscv .sltiu) #[RegisterType.mk] #[xReg] + #[] #[] immProps (some $ .before op) replaceWithReg rewriter op (newOp.getResult 0) | _ => return rewriter @@ -231,7 +225,6 @@ def slti (rewriter : PatternRewriter OpCode) (op : OperationPtr) https://github.com/llvm/llvm-project/blob/2e87cf8c2b8ec6453ccfa7e448d5b33f1d71a2ca/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td#L529-L534 -/ -set_option warn.sorry false in /-- Single-bit immediate selection. `complement = false` for `bseti`/`binvi` (the immediate itself is a single set bit); `complement = true` for `bclri` (the @@ -252,15 +245,14 @@ def selectSingleBit {α} (matchPair : OperationPtr → IRContext OpCode → Opti let some n := singleSetBit (if complement then ~~~ bv else bv) | return rewriter let (rewriter, xReg) ← castToReg rewriter op lhs let immProps := RISCVImmediateProperties.mk (IntegerAttr.mk n (IntegerType.mk 64)) - let (rewriter, newOp) ← rewriter.createOp (.riscv dst) #[RegisterType.mk] #[xReg] - #[] #[] (cast h.symm immProps) (some $ .before op) sorry (by simp) (by simp) sorry + let (rewriter, newOp) := rewriter.createOp! (.riscv dst) #[RegisterType.mk] #[xReg] + #[] #[] (cast h.symm immProps) (some $ .before op) replaceWithReg rewriter op (newOp.getResult 0) def bseti := selectSingleBit matchOr .bseti rfl false def binvi := selectSingleBit matchXor .binvi rfl false def bclri := selectSingleBit matchAnd .bclri rfl true -set_option warn.sorry false in /-- `and (lshr x n) 1` -> `riscv.bexti x n` (`PatGprImm`-style single-bit extract). https://github.com/llvm/llvm-project/blob/2e87cf8c2b8ec6453ccfa7e448d5b33f1d71a2ca/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td#L536-L537 -/ def bexti (rewriter : PatternRewriter OpCode) (op : OperationPtr) @@ -276,11 +268,10 @@ def bexti (rewriter : PatternRewriter OpCode) (op : OperationPtr) if sh.value < 0 || sh.value > 63 then return rewriter let (rewriter, xReg) ← castToReg rewriter op x let immProps := RISCVImmediateProperties.mk (IntegerAttr.mk sh.value (IntegerType.mk 64)) - let (rewriter, newOp) ← rewriter.createOp (.riscv .bexti) #[RegisterType.mk] #[xReg] - #[] #[] immProps (some $ .before op) sorry (by simp) (by simp) sorry + let (rewriter, newOp) := rewriter.createOp! (.riscv .bexti) #[RegisterType.mk] #[xReg] + #[] #[] immProps (some $ .before op) replaceWithReg rewriter op (newOp.getResult 0) -set_option warn.sorry false in /-- `fshr x x (const)` on i32 is a constant word rotate-right -> `riscv.roriw` (i32 analogue of `fshrConst` -> `rori`; mirrors `PatGprImm`). https://github.com/llvm/llvm-project/blob/2e87cf8c2b8ec6453ccfa7e448d5b33f1d71a2ca/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td#L504 -/ @@ -294,11 +285,10 @@ def roriw (rewriter : PatternRewriter OpCode) (op : OperationPtr) let sh : Int := ((amtAttr.value % 32) + 32) % 32 let (rewriter, valReg) ← castToReg rewriter op a let immProps := RISCVImmediateProperties.mk (IntegerAttr.mk sh (IntegerType.mk 64)) - let (rewriter, newOp) ← rewriter.createOp (.riscv .roriw) #[RegisterType.mk] #[valReg] - #[] #[] immProps (some $ .before op) sorry (by simp) (by simp) sorry + let (rewriter, newOp) := rewriter.createOp! (.riscv .roriw) #[RegisterType.mk] #[valReg] + #[] #[] immProps (some $ .before op) replaceWithReg rewriter op (newOp.getResult 0) -set_option warn.sorry false in /-- `fshl x x (const)` on i32 is a constant word rotate-left. There is no `roliw`, so (like `fshlConst` at i64) it lowers to `riscv.roriw` with the negated immediate `(32 - amt) mod 32` (i32 analogue of `fshlConst`). -/ @@ -314,11 +304,10 @@ def roliw (rewriter : PatternRewriter OpCode) (op : OperationPtr) let imm : Int := (32 - sh) % 32 let (rewriter, valReg) ← castToReg rewriter op a let immProps := RISCVImmediateProperties.mk (IntegerAttr.mk imm (IntegerType.mk 64)) - let (rewriter, newOp) ← rewriter.createOp (.riscv .roriw) #[RegisterType.mk] #[valReg] - #[] #[] immProps (some $ .before op) sorry (by simp) (by simp) sorry + let (rewriter, newOp) := rewriter.createOp! (.riscv .roriw) #[RegisterType.mk] #[valReg] + #[] #[] immProps (some $ .before op) replaceWithReg rewriter op (newOp.getResult 0) -set_option warn.sorry false in /-- `shl (zext i32->i64 x) (const ∈ [0,31])` -> `riscv.slliuw x shamt` (Zba: `(i64 (shl (and GPR, 0xFFFFFFFF), uimm5)) -> SLLI_UW`; our `zext` is the mask). https://github.com/llvm/llvm-project/blob/2e87cf8c2b8ec6453ccfa7e448d5b33f1d71a2ca/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td#L821-L822 -/ @@ -335,11 +324,10 @@ def slliuw (rewriter : PatternRewriter OpCode) (op : OperationPtr) if srcT.bitwidth ≠ 32 then return rewriter let (rewriter, xReg) ← castToReg rewriter op x let immProps := RISCVImmediateProperties.mk (IntegerAttr.mk sh.value (IntegerType.mk 64)) - let (rewriter, newOp) ← rewriter.createOp (.riscv .slliuw) #[RegisterType.mk] #[xReg] - #[] #[] immProps (some $ .before op) sorry (by simp) (by simp) sorry + let (rewriter, newOp) := rewriter.createOp! (.riscv .slliuw) #[RegisterType.mk] #[xReg] + #[] #[] immProps (some $ .before op) replaceWithReg rewriter op (newOp.getResult 0) -set_option warn.sorry false in /-- llvm.zext x i1 to i64 -> and x 1 -/ def zext_1 (rewriter : PatternRewriter OpCode) (op : OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do @@ -349,16 +337,15 @@ def zext_1 (rewriter : PatternRewriter OpCode) (op : OperationPtr) let .integerType opType := (operand.getType! rewriter.ctx.raw).val | return rewriter if opType.bitwidth ≠ 1 then return rewriter /- First, cast the operand to registers -/ - let (rewriter, opCastOp) ← rewriter.createOp (.builtin .unrealized_conversion_cast) #[RegisterType.mk] #[operand] - #[] #[] () (some $ .before op) sorry (by simp) (by simp) sorry + let (rewriter, opCastOp) := rewriter.createOp! (.builtin .unrealized_conversion_cast) #[RegisterType.mk] #[operand] + #[] #[] () (some $ .before op) let imm := RISCVImmediateProperties.mk (IntegerAttr.mk 1 (IntegerType.mk 64)) - let (rewriter, andiOp) ← rewriter.createOp (.riscv .andi) #[RegisterType.mk] #[opCastOp.getResult 0] - #[] #[] imm (some $ .before op) sorry (by simp) (by simp) sorry - let (rewriter, castOp) ← rewriter.createOp (.builtin .unrealized_conversion_cast) #[t] #[andiOp.getResult 0] - #[] #[] () (some $ .before op) (by sorry) (by simp) (by simp) sorry - rewriter.replaceOp op castOp sorry sorry sorry sorry sorry + let (rewriter, andiOp) := rewriter.createOp! (.riscv .andi) #[RegisterType.mk] #[opCastOp.getResult 0] + #[] #[] imm (some $ .before op) + let (rewriter, castOp) := rewriter.createOp! (.builtin .unrealized_conversion_cast) #[t] #[andiOp.getResult 0] + #[] #[] () (some $ .before op) + return rewriter.replaceOp! op castOp -set_option warn.sorry false in /-- llvm.sext x i1 to i64 -> srai (slli x 63) 1 -/ def sext_1 (rewriter : PatternRewriter OpCode) (op : OperationPtr) (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do @@ -368,16 +355,16 @@ def sext_1 (rewriter : PatternRewriter OpCode) (op : OperationPtr) let .integerType opType := (operand.getType! rewriter.ctx.raw).val | return rewriter if opType.bitwidth ≠ 1 then return rewriter /- First, cast the operand to registers -/ - let (rewriter, opCastOp) ← rewriter.createOp (.builtin .unrealized_conversion_cast) #[RegisterType.mk] #[operand] - #[] #[] () (some $ .before op) sorry (by simp) (by simp) sorry + let (rewriter, opCastOp) := rewriter.createOp! (.builtin .unrealized_conversion_cast) #[RegisterType.mk] #[operand] + #[] #[] () (some $ .before op) let imm := RISCVImmediateProperties.mk (IntegerAttr.mk 63 (IntegerType.mk 64)) - let (rewriter, slliOp) ← rewriter.createOp (.riscv .slli) #[RegisterType.mk] #[opCastOp.getResult 0] - #[] #[] imm (some $ .before op) sorry (by simp) (by simp) sorry - let (rewriter, sraiOp) ← rewriter.createOp (.riscv .srai) #[RegisterType.mk] #[slliOp.getResult 0] - #[] #[] imm (some $ .before op) sorry (by simp) (by simp) sorry - let (rewriter, castOp) ← rewriter.createOp (.builtin .unrealized_conversion_cast) #[t] #[sraiOp.getResult 0] - #[] #[] () (some $ .before op) (by sorry) (by simp) (by simp) sorry - rewriter.replaceOp op castOp sorry sorry sorry sorry sorry + let (rewriter, slliOp) := rewriter.createOp! (.riscv .slli) #[RegisterType.mk] #[opCastOp.getResult 0] + #[] #[] imm (some $ .before op) + let (rewriter, sraiOp) := rewriter.createOp! (.riscv .srai) #[RegisterType.mk] #[slliOp.getResult 0] + #[] #[] imm (some $ .before op) + let (rewriter, castOp) := rewriter.createOp! (.builtin .unrealized_conversion_cast) #[t] #[sraiOp.getResult 0] + #[] #[] () (some $ .before op) + return rewriter.replaceOp! op castOp /-! # Pass implementation -/ diff --git a/Veir/Passes/ModArithToArith.lean b/Veir/Passes/ModArithToArith.lean index 013a030ff..eb9deea45 100644 --- a/Veir/Passes/ModArithToArith.lean +++ b/Veir/Passes/ModArithToArith.lean @@ -17,28 +17,25 @@ namespace Veir /-! ## Unrealized Conversion Casts -/ -set_option warn.sorry false in /-- Emit `unrealized_conversion_cast v : !mod_arith.int → iN`. -/ def castToStorage (rewriter : PatternRewriter OpCode) (v : ValuePtr) (ip : InsertPoint) : Option (PatternRewriter OpCode × ValuePtr) := do let .modArithType mt := (v.getType! rewriter.ctx.raw).val | none let storageType : TypeAttr := mt.modulus.type - let (rewriter, castOp) ← rewriter.createOp (.builtin .unrealized_conversion_cast) - #[storageType] #[v] #[] #[] () (some ip) sorry (by simp) (by simp) sorry + let (rewriter, castOp) := rewriter.createOp! (.builtin .unrealized_conversion_cast) + #[storageType] #[v] #[] #[] () (some ip) return (rewriter, (castOp.getResult 0 : ValuePtr)) -set_option warn.sorry false in /-- Emit `unrealized_conversion_cast x : iN → ty`, where `ty` is a `mod_arith` type. -/ def castToModArith (rewriter : PatternRewriter OpCode) (x : ValuePtr) (ty : ModArithType) (ip : InsertPoint) : Option (PatternRewriter OpCode × ValuePtr) := do - let (rewriter, castOp) ← rewriter.createOp (.builtin .unrealized_conversion_cast) - #[ty] #[x] #[] #[] () (some ip) sorry (by simp) (by simp) sorry + let (rewriter, castOp) := rewriter.createOp! (.builtin .unrealized_conversion_cast) + #[ty] #[x] #[] #[] () (some ip) return (rewriter, (castOp.getResult 0 : ValuePtr)) /-! ## Unpack / Pack ModArithType -/ -set_option warn.sorry false in /-- Unpack a `!mod_arith.int` value `v` into the IntegerType `intermediateType` -/ @@ -48,13 +45,12 @@ def unpackValue (rewriter : PatternRewriter OpCode) (v : ValuePtr) (intermediate let .integerType storageType := (stored.getType! rewriter.ctx.raw).val | none if intermediateType.bitwidth > storageType.bitwidth then - let (rewriter, ext) ← rewriter.createOp (.arith .extui) - #[intermediateType] #[stored] #[] #[] { nneg := false } (some ip) sorry (by simp) (by simp) sorry + let (rewriter, ext) := rewriter.createOp! (.arith .extui) + #[intermediateType] #[stored] #[] #[] { nneg := false } (some ip) return (rewriter, (ext.getResult 0 : ValuePtr)) else return (rewriter, stored) -set_option warn.sorry false in /-- Pack an IntegerType value `v` of IntegerType `intermediateType` into a value of `!mod_arith.int` type `ty`. -/ @@ -64,9 +60,9 @@ def packValue (rewriter : PatternRewriter OpCode) (v : ValuePtr) (ty : ModArithT | none let storageType := ty.modulus.type if intermediateType.bitwidth > storageType.bitwidth then - let (rewriter, narrowed) ← rewriter.createOp (.arith .trunci) + let (rewriter, narrowed) := rewriter.createOp! (.arith .trunci) #[storageType] #[v] #[] #[] { attr := { nsw := false, nuw := true } } - (some ip) sorry (by simp) (by simp) sorry + (some ip) castToModArith rewriter (narrowed.getResult 0 : ValuePtr) ty ip else castToModArith rewriter (v : ValuePtr) ty ip @@ -74,24 +70,22 @@ def packValue (rewriter : PatternRewriter OpCode) (v : ValuePtr) (ty : ModArithT /-! ## Arith Helpers -/ -set_option warn.sorry false in /-- Emit `arith.constant c : i`. Requires `c` to fit into width (unsigned) -/ def emitArithConstant (rewriter : PatternRewriter OpCode) (c : Int) (width : Nat) (ip : InsertPoint) : Option (PatternRewriter OpCode × ValuePtr) := do let ty : TypeAttr := IntegerType.mk width let props : ArithConstantProperties := { value := IntegerAttr.mk c (IntegerType.mk width) } - let (rewriter, c) ← rewriter.createOp (.arith .constant) - #[ty] #[] #[] #[] props (some ip) sorry (by simp) (by simp) sorry + let (rewriter, c) := rewriter.createOp! (.arith .constant) + #[ty] #[] #[] #[] props (some ip) return (rewriter, (c.getResult 0 : ValuePtr)) -set_option warn.sorry false in /-- Emit a binary Arith op `arithOp` on `a` and `b` -/ def emitArithBinOp (rewriter : PatternRewriter OpCode) (arithOp : Arith) (props : propertiesOf (.arith arithOp)) (a b : ValuePtr) (ip : InsertPoint) : Option (PatternRewriter OpCode × ValuePtr) := do let ty := a.getType! rewriter.ctx.raw - let (rewriter, r) ← rewriter.createOp (.arith arithOp) - #[ty] #[a, b] #[] #[] props (some ip) sorry (by simp) (by simp) sorry + let (rewriter, r) := rewriter.createOp! (.arith arithOp) + #[ty] #[a, b] #[] #[] props (some ip) return (rewriter, (r.getResult 0 : ValuePtr)) @@ -103,13 +97,12 @@ abbrev Builder := (ip : InsertPoint) → Option (PatternRewriter OpCode × ValuePtr) -set_option warn.sorry false in /-- Lower a binary `mod_arith` op `modOp`, using intermediate Type iM given storage type iM, with M = `widen` N, and using Builder `build` to determine the exact `arith` operations to emit -/ def lowerModArithBinOp (modOp : Mod_Arith) (widen : Nat → Nat) (build : Builder) (rewriter : PatternRewriter OpCode) (op : OperationPtr) - (opInBounds : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do + (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do -- match op and extract operands: let some (operands, _) := matchOp op rewriter.ctx (.mod_arith modOp) 2 | return rewriter @@ -128,8 +121,8 @@ def lowerModArithBinOp (modOp : Mod_Arith) (widen : Nat → Nat) (build : Builde let (rewriter, r) ← build rewriter a b q ip let (rewriter, r) ← emitArithBinOp rewriter .remui () r q ip let (rewriter, r) ← packValue rewriter r modArithType ip - let rewriter := rewriter.replaceValue (op.getResult 0) r sorry sorry sorry - rewriter.eraseOp op sorry sorry sorry + let rewriter := rewriter.replaceValue! (op.getResult 0) r + return rewriter.eraseOp! op /-! ## Binary op lowering Patterns -/ @@ -156,10 +149,9 @@ def lowerModArithSubOp := lowerModArithBinOp .sub (· + 1) buildSub /-! ## Constant lowering Pattern -/ -set_option warn.sorry false in /-- Lower `mod_arith.constant` to an `arith.constant` (assumes value is in `[0, q)` already). -/ def lowerModArithConstant (rewriter : PatternRewriter OpCode) (op : OperationPtr) - (opInBounds : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do + (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do -- match op and extract attribute: let some (_, props) := matchOp op rewriter.ctx (.mod_arith .constant) 0 | return rewriter @@ -172,8 +164,8 @@ def lowerModArithConstant (rewriter : PatternRewriter OpCode) (op : OperationPtr let ip := InsertPoint.before op let (rewriter, r) ← emitArithConstant rewriter c storageType.bitwidth ip let (rewriter, out) ← castToModArith rewriter (r : ValuePtr) modArithType ip - let rewriter := rewriter.replaceValue (op.getResult 0) out sorry sorry sorry - rewriter.eraseOp op sorry sorry sorry + let rewriter := rewriter.replaceValue! (op.getResult 0) out + return rewriter.eraseOp! op /-! ## Pass implementation -/ diff --git a/Veir/Passes/RISCVCombines/Combine.lean b/Veir/Passes/RISCVCombines/Combine.lean index 09027e384..aec1fe831 100644 --- a/Veir/Passes/RISCVCombines/Combine.lean +++ b/Veir/Passes/RISCVCombines/Combine.lean @@ -15,17 +15,16 @@ namespace Veir.RISCV added to the pattern list in `Combine.impl`. -/ -set_option warn.sorry false in /-- riscv.add x 0 -> x -/ def right_identity_zero_add (rewriter: PatternRewriter OpCode) (op: OperationPtr) - (opInBounds : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do + (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do let some (operands, _) := matchOp op rewriter.ctx (.riscv .add) 2 | return rewriter let lhs := operands[0]! let some liOp := getDefiningOp operands[1]! rewriter.ctx | return rewriter let some (_, cst) := matchOp liOp rewriter.ctx (.riscv .li) 0 | return rewriter if cst.value.value ≠ 0 then return rewriter - let rewriter := rewriter.replaceValue (op.getResult 0) lhs sorry sorry sorry - rewriter.eraseOp op sorry sorry sorry + let rewriter := rewriter.replaceValue! (op.getResult 0) lhs + return rewriter.eraseOp! op def Combine.impl (ctx : WfIRContext OpCode) (op : OperationPtr) (_ : op.InBounds ctx.raw) : ExceptT String IO (WfIRContext OpCode) := do diff --git a/Veir/Passes/RISCVCombines/MIRCombinesVeir.lean b/Veir/Passes/RISCVCombines/MIRCombinesVeir.lean index e4b490583..1a9d52c65 100644 --- a/Veir/Passes/RISCVCombines/MIRCombinesVeir.lean +++ b/Veir/Passes/RISCVCombines/MIRCombinesVeir.lean @@ -6,156 +6,140 @@ import Veir.Passes.Matching namespace Veir.RISCV -set_option warn.sorry false in def sub_minus_one (rewriter: PatternRewriter OpCode) (op: OperationPtr) - (opInBounds : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do + (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do let some (lhs, op1, _props) := matchSub op rewriter.ctx | return rewriter let some cst := matchConstantIntVal lhs rewriter.ctx | return rewriter if cst.value ≠ -1 then return rewriter let .integerType ctype := (op1.getType! rewriter.ctx.raw).val | return rewriter let cstOpProp := LLVMConstantProperties.mk (.integer (IntegerAttr.mk (-1) ctype)) - let (rewriter, cstOp) ← rewriter.createOp (.llvm .mlir__constant) #[op1.getType! rewriter.ctx.raw] #[] - #[] #[] cstOpProp (some $ .before op) sorry sorry sorry sorry - let (rewriter, newOp) ← rewriter.createOp (.llvm .xor) #[op1.getType! rewriter.ctx.raw] #[op1, (cstOp.getResult 0)] - #[] #[] () (some $ .before op) sorry sorry sorry sorry - rewriter.replaceOp op newOp sorry sorry sorry sorry sorry + let (rewriter, cstOp) := rewriter.createOp! (.llvm .mlir__constant) #[op1.getType! rewriter.ctx.raw] #[] + #[] #[] cstOpProp (some $ .before op) + let (rewriter, newOp) := rewriter.createOp! (.llvm .xor) #[op1.getType! rewriter.ctx.raw] #[op1, (cstOp.getResult 0)] + #[] #[] () (some $ .before op) + return rewriter.replaceOp! op newOp -set_option warn.sorry false in def right_identity_zero_0 (rewriter: PatternRewriter OpCode) (op: OperationPtr) - (opInBounds : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do + (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do let some (lhs, rhs, _props) := matchSub op rewriter.ctx | return rewriter let some cst := matchConstantIntVal rhs rewriter.ctx | return rewriter if cst.value ≠ 0 then return rewriter - let rewriter := rewriter.replaceValue (op.getResult 0) lhs sorry sorry sorry - rewriter.eraseOp op sorry sorry sorry + let rewriter := rewriter.replaceValue! (op.getResult 0) lhs + return rewriter.eraseOp! op -set_option warn.sorry false in def right_identity_zero_1 (rewriter: PatternRewriter OpCode) (op: OperationPtr) - (opInBounds : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do + (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do let some (lhs, rhs, _props) := matchAdd op rewriter.ctx | return rewriter let some cst := matchConstantIntVal rhs rewriter.ctx | return rewriter if cst.value ≠ 0 then return rewriter - let rewriter := rewriter.replaceValue (op.getResult 0) lhs sorry sorry sorry - rewriter.eraseOp op sorry sorry sorry + let rewriter := rewriter.replaceValue! (op.getResult 0) lhs + return rewriter.eraseOp! op -set_option warn.sorry false in def right_identity_zero_2 (rewriter: PatternRewriter OpCode) (op: OperationPtr) - (opInBounds : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do + (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do let some (lhs, rhs, _props) := matchOr op rewriter.ctx | return rewriter let some cst := matchConstantIntVal rhs rewriter.ctx | return rewriter if cst.value ≠ 0 then return rewriter - let rewriter := rewriter.replaceValue (op.getResult 0) lhs sorry sorry sorry - rewriter.eraseOp op sorry sorry sorry + let rewriter := rewriter.replaceValue! (op.getResult 0) lhs + return rewriter.eraseOp! op -set_option warn.sorry false in def right_identity_zero_3 (rewriter: PatternRewriter OpCode) (op: OperationPtr) - (opInBounds : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do + (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do let some (lhs, rhs, _props) := matchXor op rewriter.ctx | return rewriter let some cst := matchConstantIntVal rhs rewriter.ctx | return rewriter if cst.value ≠ 0 then return rewriter - let rewriter := rewriter.replaceValue (op.getResult 0) lhs sorry sorry sorry - rewriter.eraseOp op sorry sorry sorry + let rewriter := rewriter.replaceValue! (op.getResult 0) lhs + return rewriter.eraseOp! op -set_option warn.sorry false in def right_identity_zero_4 (rewriter: PatternRewriter OpCode) (op: OperationPtr) - (opInBounds : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do + (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do let some (lhs, rhs, _props) := matchShl op rewriter.ctx | return rewriter let some cst := matchConstantIntVal rhs rewriter.ctx | return rewriter if cst.value ≠ 0 then return rewriter - let rewriter := rewriter.replaceValue (op.getResult 0) lhs sorry sorry sorry - rewriter.eraseOp op sorry sorry sorry + let rewriter := rewriter.replaceValue! (op.getResult 0) lhs + return rewriter.eraseOp! op -set_option warn.sorry false in def right_identity_zero_5 (rewriter: PatternRewriter OpCode) (op: OperationPtr) - (opInBounds : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do + (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do let some (lhs, rhs, _props) := matchAshr op rewriter.ctx | return rewriter let some cst := matchConstantIntVal rhs rewriter.ctx | return rewriter if cst.value ≠ 0 then return rewriter - let rewriter := rewriter.replaceValue (op.getResult 0) lhs sorry sorry sorry - rewriter.eraseOp op sorry sorry sorry + let rewriter := rewriter.replaceValue! (op.getResult 0) lhs + return rewriter.eraseOp! op -set_option warn.sorry false in def right_identity_zero_6 (rewriter: PatternRewriter OpCode) (op: OperationPtr) - (opInBounds : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do + (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do let some (lhs, rhs, _props) := matchLshr op rewriter.ctx | return rewriter let some cst := matchConstantIntVal rhs rewriter.ctx | return rewriter if cst.value ≠ 0 then return rewriter - let rewriter := rewriter.replaceValue (op.getResult 0) lhs sorry sorry sorry - rewriter.eraseOp op sorry sorry sorry + let rewriter := rewriter.replaceValue! (op.getResult 0) lhs + return rewriter.eraseOp! op -set_option warn.sorry false in def right_identity_one_int (rewriter: PatternRewriter OpCode) (op: OperationPtr) - (opInBounds : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do + (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do let some (x, rhs, _props) := matchMul op rewriter.ctx | return rewriter let some cst := matchConstantIntVal rhs rewriter.ctx | return rewriter if cst.value ≠ 1 then return rewriter - let rewriter := rewriter.replaceValue (op.getResult 0) x sorry sorry sorry - rewriter.eraseOp op sorry sorry sorry + let rewriter := rewriter.replaceValue! (op.getResult 0) x + return rewriter.eraseOp! op -set_option warn.sorry false in def binop_same_val_0 (rewriter: PatternRewriter OpCode) (op: OperationPtr) - (opInBounds : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do + (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do let some (src, src1, _) := matchAnd op rewriter.ctx | return rewriter if src != src1 then return rewriter - let rewriter := rewriter.replaceValue (op.getResult 0) src sorry sorry sorry - rewriter.eraseOp op sorry sorry sorry + let rewriter := rewriter.replaceValue! (op.getResult 0) src + return rewriter.eraseOp! op -set_option warn.sorry false in def binop_same_val_1 (rewriter: PatternRewriter OpCode) (op: OperationPtr) - (opInBounds : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do + (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do let some (src, src1, _props) := matchOr op rewriter.ctx | return rewriter if src != src1 then return rewriter - let rewriter := rewriter.replaceValue (op.getResult 0) src sorry sorry sorry - rewriter.eraseOp op sorry sorry sorry + let rewriter := rewriter.replaceValue! (op.getResult 0) src + return rewriter.eraseOp! op -set_option warn.sorry false in def same_val_zero_0 (rewriter: PatternRewriter OpCode) (op: OperationPtr) - (opInBounds : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do + (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do let some (x, x1, _props) := matchSub op rewriter.ctx | return rewriter if x != x1 then return rewriter let .integerType type := (x.getType! rewriter.ctx.raw).val | return rewriter let cstProp := LLVMConstantProperties.mk (.integer (IntegerAttr.mk (0) type)) - let (rewriter, newOp) ← rewriter.createOp (.llvm .mlir__constant) #[x.getType! rewriter.ctx.raw] #[] - #[] #[] cstProp (some $ .before op) sorry sorry sorry sorry - rewriter.replaceOp op newOp sorry sorry sorry sorry sorry + let (rewriter, newOp) := rewriter.createOp! (.llvm .mlir__constant) #[x.getType! rewriter.ctx.raw] #[] + #[] #[] cstProp (some $ .before op) + return rewriter.replaceOp! op newOp -set_option warn.sorry false in def same_val_zero_1 (rewriter: PatternRewriter OpCode) (op: OperationPtr) - (opInBounds : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do + (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do let some (x, x1, _props) := matchXor op rewriter.ctx | return rewriter if x != x1 then return rewriter let .integerType type := (x.getType! rewriter.ctx.raw).val | return rewriter let cstProp := LLVMConstantProperties.mk (.integer (IntegerAttr.mk (0) type)) - let (rewriter, newOp) ← rewriter.createOp (.llvm .mlir__constant) #[x.getType! rewriter.ctx.raw] #[] - #[] #[] cstProp (some $ .before op) sorry sorry sorry sorry - rewriter.replaceOp op newOp sorry sorry sorry sorry sorry + let (rewriter, newOp) := rewriter.createOp! (.llvm .mlir__constant) #[x.getType! rewriter.ctx.raw] #[] + #[] #[] cstProp (some $ .before op) + return rewriter.replaceOp! op newOp -set_option warn.sorry false in def binop_right_to_zero (rewriter: PatternRewriter OpCode) (op: OperationPtr) - (opInBounds : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do - let some (lhs, zero, _props) := matchMul op rewriter.ctx | return rewriter + (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do + let some (_, zero, _props) := matchMul op rewriter.ctx | return rewriter let some cst := matchConstantIntVal zero rewriter.ctx | return rewriter if cst.value ≠ 0 then return rewriter - let rewriter := rewriter.replaceValue (op.getResult 0) zero sorry sorry sorry - rewriter.eraseOp op sorry sorry sorry + let rewriter := rewriter.replaceValue! (op.getResult 0) zero + return rewriter.eraseOp! op -set_option warn.sorry false in def mul_by_neg_one (rewriter: PatternRewriter OpCode) (op: OperationPtr) - (opInBounds : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do + (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do let some (x, rhs, _props) := matchMul op rewriter.ctx | return rewriter let some cst := matchConstantIntVal rhs rewriter.ctx | return rewriter if cst.value ≠ -1 then return rewriter let .integerType ctype := (x.getType! rewriter.ctx.raw).val | return rewriter let cstOpProp := LLVMConstantProperties.mk (.integer (IntegerAttr.mk (0) ctype)) - let (rewriter, cstOp) ← rewriter.createOp (.llvm .mlir__constant) #[x.getType! rewriter.ctx.raw] #[] - #[] #[] cstOpProp (some $ .before op) sorry sorry sorry sorry - let (rewriter, newOp) ← rewriter.createOp (.llvm .sub) #[x.getType! rewriter.ctx.raw] #[(cstOp.getResult 0), x] - #[] #[] _props (some $ .before op) sorry sorry sorry sorry - rewriter.replaceOp op newOp sorry sorry sorry sorry sorry + let (rewriter, cstOp) := rewriter.createOp! (.llvm .mlir__constant) #[x.getType! rewriter.ctx.raw] #[] + #[] #[] cstOpProp (some $ .before op) + let (rewriter, newOp) := rewriter.createOp! (.llvm .sub) #[x.getType! rewriter.ctx.raw] #[(cstOp.getResult 0), x] + #[] #[] _props (some $ .before op) + return rewriter.replaceOp! op newOp -set_option warn.sorry false in def or_and_xor_to_or (rewriter: PatternRewriter OpCode) (op: OperationPtr) - (opInBounds : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do + (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do let some (and, y, _props) := matchOr op rewriter.ctx | return rewriter let some defOp := getDefiningOp and rewriter.ctx | return rewriter let some (x, not, _) := matchAnd defOp rewriter.ctx | return rewriter @@ -164,13 +148,12 @@ def or_and_xor_to_or (rewriter: PatternRewriter OpCode) (op: OperationPtr) if y != y1 then return rewriter let some cst := matchConstantIntVal rhs rewriter.ctx | return rewriter if cst.value ≠ -1 then return rewriter - let (rewriter, newOp) ← rewriter.createOp (.llvm .or) #[and.getType! rewriter.ctx.raw] #[x, y] - #[] #[] _props (some $ .before op) sorry sorry sorry sorry - rewriter.replaceOp op newOp sorry sorry sorry sorry sorry + let (rewriter, newOp) := rewriter.createOp! (.llvm .or) #[and.getType! rewriter.ctx.raw] #[x, y] + #[] #[] _props (some $ .before op) + return rewriter.replaceOp! op newOp -set_option warn.sorry false in def and_xor_or_to_and (rewriter: PatternRewriter OpCode) (op: OperationPtr) - (opInBounds : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do + (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do let some (or, y, _) := matchAnd op rewriter.ctx | return rewriter let some defOp := getDefiningOp or rewriter.ctx | return rewriter let some (x, not, _props) := matchOr defOp rewriter.ctx | return rewriter @@ -179,173 +162,159 @@ def and_xor_or_to_and (rewriter: PatternRewriter OpCode) (op: OperationPtr) if y != y1 then return rewriter let some cst := matchConstantIntVal rhs rewriter.ctx | return rewriter if cst.value ≠ -1 then return rewriter - let (rewriter, newOp) ← rewriter.createOp (.llvm .and) #[or.getType! rewriter.ctx.raw] #[x, y] - #[] #[] () (some $ .before op) sorry sorry sorry sorry - rewriter.replaceOp op newOp sorry sorry sorry sorry sorry + let (rewriter, newOp) := rewriter.createOp! (.llvm .and) #[or.getType! rewriter.ctx.raw] #[x, y] + #[] #[] () (some $ .before op) + return rewriter.replaceOp! op newOp -set_option warn.sorry false in def add_sub_reg_0 (rewriter: PatternRewriter OpCode) (op: OperationPtr) - (opInBounds : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do + (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do let some (x, tmp, _props) := matchAdd op rewriter.ctx | return rewriter let some defOp := getDefiningOp tmp rewriter.ctx | return rewriter let some (src, x1, _props1) := matchSub defOp rewriter.ctx | return rewriter if x != x1 then return rewriter - let rewriter := rewriter.replaceValue (op.getResult 0) src sorry sorry sorry - rewriter.eraseOp op sorry sorry sorry + let rewriter := rewriter.replaceValue! (op.getResult 0) src + return rewriter.eraseOp! op -set_option warn.sorry false in def add_sub_reg_1 (rewriter: PatternRewriter OpCode) (op: OperationPtr) - (opInBounds : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do + (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do let some (tmp, x, _props) := matchAdd op rewriter.ctx | return rewriter let some defOp := getDefiningOp tmp rewriter.ctx | return rewriter let some (src, x1, _props1) := matchSub defOp rewriter.ctx | return rewriter if x != x1 then return rewriter - let rewriter := rewriter.replaceValue (op.getResult 0) src sorry sorry sorry - rewriter.eraseOp op sorry sorry sorry + let rewriter := rewriter.replaceValue! (op.getResult 0) src + return rewriter.eraseOp! op -set_option warn.sorry false in def APlusBMinusCMinusB (rewriter: PatternRewriter OpCode) (op: OperationPtr) - (opInBounds : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do + (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do let some (add1, B, _props) := matchSub op rewriter.ctx | return rewriter let some defOp := getDefiningOp add1 rewriter.ctx | return rewriter let some (A, sub1, _props1) := matchAdd defOp rewriter.ctx | return rewriter let some defOp1 := getDefiningOp sub1 rewriter.ctx | return rewriter let some (B1, C, _props2) := matchSub defOp1 rewriter.ctx | return rewriter if B != B1 then return rewriter - let (rewriter, newOp) ← rewriter.createOp (.llvm .sub) #[add1.getType! rewriter.ctx.raw] #[A, C] - #[] #[] _props (some $ .before op) sorry sorry sorry sorry - rewriter.replaceOp op newOp sorry sorry sorry sorry sorry + let (rewriter, newOp) := rewriter.createOp! (.llvm .sub) #[add1.getType! rewriter.ctx.raw] #[A, C] + #[] #[] _props (some $ .before op) + return rewriter.replaceOp! op newOp -set_option warn.sorry false in def AMinusBMinusCMinusC (rewriter: PatternRewriter OpCode) (op: OperationPtr) - (opInBounds : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do + (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do let some (sub2, C, _props) := matchSub op rewriter.ctx | return rewriter let some defOp := getDefiningOp sub2 rewriter.ctx | return rewriter let some (A, sub1, _props1) := matchSub defOp rewriter.ctx | return rewriter let some defOp1 := getDefiningOp sub1 rewriter.ctx | return rewriter let some (B, C1, _props2) := matchSub defOp1 rewriter.ctx | return rewriter if C != C1 then return rewriter - let (rewriter, newOp) ← rewriter.createOp (.llvm .sub) #[sub2.getType! rewriter.ctx.raw] #[A, B] - #[] #[] _props (some $ .before op) sorry sorry sorry sorry - rewriter.replaceOp op newOp sorry sorry sorry sorry sorry + let (rewriter, newOp) := rewriter.createOp! (.llvm .sub) #[sub2.getType! rewriter.ctx.raw] #[A, B] + #[] #[] _props (some $ .before op) + return rewriter.replaceOp! op newOp -set_option warn.sorry false in def ZeroMinusAPlusB (rewriter: PatternRewriter OpCode) (op: OperationPtr) - (opInBounds : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do + (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do let some (sub, B, _props) := matchAdd op rewriter.ctx | return rewriter let some defOp := getDefiningOp sub rewriter.ctx | return rewriter let some (lhs, A, _props1) := matchSub defOp rewriter.ctx | return rewriter let some cst := matchConstantIntVal lhs rewriter.ctx | return rewriter if cst.value ≠ 0 then return rewriter - let (rewriter, newOp) ← rewriter.createOp (.llvm .sub) #[sub.getType! rewriter.ctx.raw] #[B, A] - #[] #[] _props (some $ .before op) sorry sorry sorry sorry - rewriter.replaceOp op newOp sorry sorry sorry sorry sorry + let (rewriter, newOp) := rewriter.createOp! (.llvm .sub) #[sub.getType! rewriter.ctx.raw] #[B, A] + #[] #[] _props (some $ .before op) + return rewriter.replaceOp! op newOp -set_option warn.sorry false in def APlusZeroMinusB (rewriter: PatternRewriter OpCode) (op: OperationPtr) - (opInBounds : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do + (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do let some (A, sub, _props) := matchAdd op rewriter.ctx | return rewriter let some defOp := getDefiningOp sub rewriter.ctx | return rewriter let some (lhs, B, _props1) := matchSub defOp rewriter.ctx | return rewriter let some cst := matchConstantIntVal lhs rewriter.ctx | return rewriter if cst.value ≠ 0 then return rewriter - let (rewriter, newOp) ← rewriter.createOp (.llvm .sub) #[A.getType! rewriter.ctx.raw] #[A, B] - #[] #[] _props (some $ .before op) sorry sorry sorry sorry - rewriter.replaceOp op newOp sorry sorry sorry sorry sorry + let (rewriter, newOp) := rewriter.createOp! (.llvm .sub) #[A.getType! rewriter.ctx.raw] #[A, B] + #[] #[] _props (some $ .before op) + return rewriter.replaceOp! op newOp -set_option warn.sorry false in def APlusBMinusB (rewriter: PatternRewriter OpCode) (op: OperationPtr) - (opInBounds : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do + (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do let some (A, sub, _props) := matchAdd op rewriter.ctx | return rewriter let some defOp := getDefiningOp sub rewriter.ctx | return rewriter let some (B, A1, _props1) := matchSub defOp rewriter.ctx | return rewriter if A != A1 then return rewriter - let rewriter := rewriter.replaceValue (op.getResult 0) B sorry sorry sorry - rewriter.eraseOp op sorry sorry sorry + let rewriter := rewriter.replaceValue! (op.getResult 0) B + return rewriter.eraseOp! op -set_option warn.sorry false in def BMinusAPlusA (rewriter: PatternRewriter OpCode) (op: OperationPtr) - (opInBounds : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do + (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do let some (sub, A, _props) := matchAdd op rewriter.ctx | return rewriter let some defOp := getDefiningOp sub rewriter.ctx | return rewriter let some (B, A1, _props1) := matchSub defOp rewriter.ctx | return rewriter if A != A1 then return rewriter - let rewriter := rewriter.replaceValue (op.getResult 0) B sorry sorry sorry - rewriter.eraseOp op sorry sorry sorry + let rewriter := rewriter.replaceValue! (op.getResult 0) B + return rewriter.eraseOp! op -set_option warn.sorry false in def AMinusBPlusCMinusA (rewriter: PatternRewriter OpCode) (op: OperationPtr) - (opInBounds : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do + (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do let some (sub1, sub2, _props) := matchAdd op rewriter.ctx | return rewriter let some defOp := getDefiningOp sub1 rewriter.ctx | return rewriter let some (A, B, _props1) := matchSub defOp rewriter.ctx | return rewriter let some defOp1 := getDefiningOp sub2 rewriter.ctx | return rewriter let some (C, A1, _props2) := matchSub defOp1 rewriter.ctx | return rewriter if A != A1 then return rewriter - let (rewriter, newOp) ← rewriter.createOp (.llvm .sub) #[sub1.getType! rewriter.ctx.raw] #[C, B] - #[] #[] _props (some $ .before op) sorry sorry sorry sorry - rewriter.replaceOp op newOp sorry sorry sorry sorry sorry + let (rewriter, newOp) := rewriter.createOp! (.llvm .sub) #[sub1.getType! rewriter.ctx.raw] #[C, B] + #[] #[] _props (some $ .before op) + return rewriter.replaceOp! op newOp -set_option warn.sorry false in def AMinusBPlusBMinusC (rewriter: PatternRewriter OpCode) (op: OperationPtr) - (opInBounds : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do + (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do let some (sub1, sub2, _props) := matchAdd op rewriter.ctx | return rewriter let some defOp := getDefiningOp sub1 rewriter.ctx | return rewriter let some (A, B, _props1) := matchSub defOp rewriter.ctx | return rewriter let some defOp1 := getDefiningOp sub2 rewriter.ctx | return rewriter let some (B1, C, _props2) := matchSub defOp1 rewriter.ctx | return rewriter if B != B1 then return rewriter - let (rewriter, newOp) ← rewriter.createOp (.llvm .sub) #[sub1.getType! rewriter.ctx.raw] #[A, C] - #[] #[] _props (some $ .before op) sorry sorry sorry sorry - rewriter.replaceOp op newOp sorry sorry sorry sorry sorry + let (rewriter, newOp) := rewriter.createOp! (.llvm .sub) #[sub1.getType! rewriter.ctx.raw] #[A, C] + #[] #[] _props (some $ .before op) + return rewriter.replaceOp! op newOp -set_option warn.sorry false in def APlusBMinusAplusC (rewriter: PatternRewriter OpCode) (op: OperationPtr) - (opInBounds : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do + (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do let some (A, sub1, _props) := matchAdd op rewriter.ctx | return rewriter let some defOp := getDefiningOp sub1 rewriter.ctx | return rewriter let some (B, add1, _props1) := matchSub defOp rewriter.ctx | return rewriter let some defOp1 := getDefiningOp add1 rewriter.ctx | return rewriter let some (A1, C, _props2) := matchAdd defOp1 rewriter.ctx | return rewriter if A != A1 then return rewriter - let (rewriter, newOp) ← rewriter.createOp (.llvm .sub) #[A.getType! rewriter.ctx.raw] #[B, C] - #[] #[] _props (some $ .before op) sorry sorry sorry sorry - rewriter.replaceOp op newOp sorry sorry sorry sorry sorry + let (rewriter, newOp) := rewriter.createOp! (.llvm .sub) #[A.getType! rewriter.ctx.raw] #[B, C] + #[] #[] _props (some $ .before op) + return rewriter.replaceOp! op newOp -set_option warn.sorry false in def APlusBMinusCPlusA (rewriter: PatternRewriter OpCode) (op: OperationPtr) - (opInBounds : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do + (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do let some (A, sub1, _props) := matchAdd op rewriter.ctx | return rewriter let some defOp := getDefiningOp sub1 rewriter.ctx | return rewriter let some (B, add1, _props1) := matchSub defOp rewriter.ctx | return rewriter let some defOp1 := getDefiningOp add1 rewriter.ctx | return rewriter let some (C, A1, _props2) := matchAdd defOp1 rewriter.ctx | return rewriter if A != A1 then return rewriter - let (rewriter, newOp) ← rewriter.createOp (.llvm .sub) #[A.getType! rewriter.ctx.raw] #[B, C] - #[] #[] _props (some $ .before op) sorry sorry sorry sorry - rewriter.replaceOp op newOp sorry sorry sorry sorry sorry + let (rewriter, newOp) := rewriter.createOp! (.llvm .sub) #[A.getType! rewriter.ctx.raw] #[B, C] + #[] #[] _props (some $ .before op) + return rewriter.replaceOp! op newOp -set_option warn.sorry false in def AMinusZeroMinusB (rewriter: PatternRewriter OpCode) (op: OperationPtr) - (opInBounds : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do + (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do let some (A, sub1, _props) := matchSub op rewriter.ctx | return rewriter let some defOp := getDefiningOp sub1 rewriter.ctx | return rewriter let some (lhs, B, _props1) := matchSub defOp rewriter.ctx | return rewriter let some cst := matchConstantIntVal lhs rewriter.ctx | return rewriter if cst.value ≠ 0 then return rewriter - let (rewriter, newOp) ← rewriter.createOp (.llvm .add) #[A.getType! rewriter.ctx.raw] #[A, B] - #[] #[] _props (some $ .before op) sorry sorry sorry sorry - rewriter.replaceOp op newOp sorry sorry sorry sorry sorry + let (rewriter, newOp) := rewriter.createOp! (.llvm .add) #[A.getType! rewriter.ctx.raw] #[A, B] + #[] #[] _props (some $ .before op) + return rewriter.replaceOp! op newOp -set_option warn.sorry false in def AMinusBMinusA (rewriter: PatternRewriter OpCode) (op: OperationPtr) - (opInBounds : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do + (_ : op.InBounds rewriter.ctx.raw) : Option (PatternRewriter OpCode) := do let some (A, add, _props) := matchSub op rewriter.ctx | return rewriter let some defOp := getDefiningOp add rewriter.ctx | return rewriter let some (A1, B, _props1) := matchSub defOp rewriter.ctx | return rewriter if A != A1 then return rewriter - let rewriter := rewriter.replaceValue (op.getResult 0) B sorry sorry sorry - rewriter.eraseOp op sorry sorry sorry + let rewriter := rewriter.replaceValue! (op.getResult 0) B + return rewriter.eraseOp! op def mir_pattern_combines := [sub_minus_one, diff --git a/Veir/PatternRewriter/Basic.lean b/Veir/PatternRewriter/Basic.lean index 16eae0c34..ba6396c3e 100644 --- a/Veir/PatternRewriter/Basic.lean +++ b/Veir/PatternRewriter/Basic.lean @@ -43,6 +43,8 @@ def Worklist.empty : Worklist where indexInStack := HashMap.emptyWithCapacity wf_index := by grind +instance : Inhabited Worklist := ⟨Worklist.empty⟩ + def Worklist.isEmpty (worklist: Worklist) : Bool := worklist.indexInStack.size = 0 @@ -108,6 +110,9 @@ structure PatternRewriter (OpInfo : Type) [HasOpInfo OpInfo] where hasDoneAction: Bool worklist: PatternRewriter.Worklist +instance : Inhabited (PatternRewriter OpInfo) := + ⟨{ ctx := default, hasDoneAction := false, worklist := default }⟩ + variable {rewriter : PatternRewriter OpInfo} namespace PatternRewriter @@ -203,6 +208,18 @@ def eraseOp (rewriter: PatternRewriter OpInfo) (op: OperationPtr) worklist := rewriter.worklist.remove op, } +/-- +Erase an operation, panicking if the operation is out of bounds, has regions, or has uses. +-/ +def eraseOp! (rewriter: PatternRewriter OpInfo) (op: OperationPtr) + : PatternRewriter OpInfo := + let newCtx := WfRewriter.eraseOp! rewriter.ctx op + { rewriter with + ctx := newCtx, + hasDoneAction := true, + worklist := rewriter.worklist.remove op, + } + def replaceOp (rewriter: PatternRewriter OpInfo) (oldOp newOp: OperationPtr) (opNe : oldOp ≠ newOp := by grind) (hpar : (oldOp.get! rewriter.ctx.raw).parent.isSome = true := by grind) @@ -221,6 +238,27 @@ def replaceOp (rewriter: PatternRewriter OpInfo) (oldOp newOp: OperationPtr) worklist := rewriter.worklist.remove oldOp |>.push newOp, } +/-- +Replace all results of an operation with the results of another, erasing the replaced operation. +Panics if the two operations are equal, if the old operation has no parent or has regions, if +either operation is out of bounds, or if the operations have different numbers of results. +-/ +def replaceOp! (rewriter: PatternRewriter OpInfo) (oldOp newOp: OperationPtr) + : PatternRewriter OpInfo := + if oldIn : oldOp.InBounds rewriter.ctx.raw then Id.run do + let mut rw : {r : PatternRewriter OpInfo // r.ctx = rewriter.ctx } := ⟨rewriter, by grind⟩ + for h : i in 0...(oldOp.getNumResults rewriter.ctx.raw oldIn) do + rw := ⟨rw.val.addUsersInWorklist (oldOp.getResult i) (by grind), by grind⟩ + let rewriter := rw.val + let newCtx := WfRewriter.replaceOp! rewriter.ctx oldOp newOp + return { rewriter with + ctx := newCtx, + hasDoneAction := true, + worklist := rewriter.worklist.remove oldOp |>.push newOp, + } + else + panic! "PatternRewriter.replaceOp! failed: old operation is out of bounds" + def replaceValue (rewriter: PatternRewriter OpInfo) (oldVal newVal: ValuePtr) (neValues : oldVal ≠ newVal := by grind) (oldIn: oldVal.InBounds rewriter.ctx.raw := by grind) @@ -230,6 +268,19 @@ def replaceValue (rewriter: PatternRewriter OpInfo) (oldVal newVal: ValuePtr) let ctx := WfRewriter.replaceValue rewriter.ctx oldVal newVal { rewriter with ctx, hasDoneAction := true} +/-- +Replace a value with another value, panicking if the two values are equal, or if either value +is out of bounds. +-/ +def replaceValue! (rewriter: PatternRewriter OpInfo) (oldVal newVal: ValuePtr) + : PatternRewriter OpInfo := + if oldIn : oldVal.InBounds rewriter.ctx.raw then + let rewriter := rewriter.addUsersInWorklist oldVal oldIn + let ctx := WfRewriter.replaceValue! rewriter.ctx oldVal newVal + { rewriter with ctx, hasDoneAction := true } + else + panic! "PatternRewriter.replaceValue! failed: old value is out of bounds" + def createBlock (rewriter: PatternRewriter OpInfo) (argTypes: Array TypeAttr) (insertPoint : Option BlockInsertPoint)