From 87268acfc47acfc6c3a1ff6c1d57b2d95e7fd527 Mon Sep 17 00:00:00 2001 From: Congcong Cai Date: Fri, 6 Jun 2025 14:44:18 +0800 Subject: [PATCH 1/4] feat: implement tail call optimization --- src/passes/CMakeLists.txt | 1 + src/passes/TailCall.cpp | 82 +++++++++++++++++++++++++++++++++++++++ src/passes/pass.cpp | 2 + src/passes/passes.h | 1 + 4 files changed, 86 insertions(+) create mode 100644 src/passes/TailCall.cpp diff --git a/src/passes/CMakeLists.txt b/src/passes/CMakeLists.txt index 16891caade3..96907e9565d 100644 --- a/src/passes/CMakeLists.txt +++ b/src/passes/CMakeLists.txt @@ -116,6 +116,7 @@ set(passes_SOURCES ReorderGlobals.cpp ReorderLocals.cpp ReReloop.cpp + TailCall.cpp TrapMode.cpp TypeGeneralizing.cpp TypeRefining.cpp diff --git a/src/passes/TailCall.cpp b/src/passes/TailCall.cpp new file mode 100644 index 00000000000..e32e5eab3c0 --- /dev/null +++ b/src/passes/TailCall.cpp @@ -0,0 +1,82 @@ + +#include "pass.h" +#include "wasm-traversal.h" +#include "wasm.h" +#include + +namespace wasm { + +namespace { + +struct Finder : PostWalker { + std::vector tailCalls; + std::vector tailCallIndirects; + void visitFunction(Function* curr) { checkTailCall(curr->body); } + void visitReturn(Return* curr) { checkTailCall(curr->value); } + +private: + void checkTailCall(Expression* expr) { + if (expr == nullptr) { + return; + } + if (auto* call = expr->dynCast()) { + if (!call->isReturn && call->type == getFunction()->getResults()) { + tailCalls.push_back(call); + } + return; + } + if (auto* call = expr->dynCast()) { + if (!call->isReturn && call->type == getFunction()->getResults()) { + tailCallIndirects.push_back(call); + } + return; + } + if (auto* block = expr->dynCast()) { + return checkTailCall(block->list); + } + if (auto* ifElse = expr->dynCast()) { + checkTailCall(ifElse->ifTrue); + checkTailCall(ifElse->ifFalse); + return; + } + } + void checkTailCall(ExpressionList const& exprs) { + if (exprs.empty()) { + return; + } + checkTailCall(exprs.back()); + return; + } +}; + +} // namespace + +struct TailCallOptimizer : public Pass { + bool isFunctionParallel() override { return true; } + std::unique_ptr create() override { + return std::make_unique(); + } + void runOnFunction(Module* module, Function* function) override { + if (!module->features.hasTailCall()) { + return; + } + Finder finder{}; + finder.walkFunctionInModule(function, module); + for (Call* call : finder.tailCalls) { + if (!call->isReturn) { + call->isReturn = true; + call->finalize(); + } + } + for (CallIndirect* call : finder.tailCallIndirects) { + if (!call->isReturn) { + call->isReturn = true; + call->finalize(); + } + } + } +}; + +Pass* createTailCallPass() { return new TailCallOptimizer(); } + +} // namespace wasm diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp index 5e98ef5d086..6b69fafdda1 100644 --- a/src/passes/pass.cpp +++ b/src/passes/pass.cpp @@ -558,6 +558,8 @@ void PassRegistry::registerPasses() { registerPass("strip-target-features", "strip the wasm target features section", createStripTargetFeaturesPass); + registerPass( + "tail-call", "transform call to return call", createTailCallPass); registerPass("translate-to-new-eh", "deprecated; same as translate-to-exnref", createTranslateToExnrefPass); diff --git a/src/passes/passes.h b/src/passes/passes.h index e0c03bad8d7..fb5bfd5de36 100644 --- a/src/passes/passes.h +++ b/src/passes/passes.h @@ -178,6 +178,7 @@ Pass* createStripEHPass(); Pass* createStubUnsupportedJSOpsPass(); Pass* createSSAifyPass(); Pass* createSSAifyNoMergePass(); +Pass* createTailCallPass(); Pass* createTable64LoweringPass(); Pass* createTranslateToExnrefPass(); Pass* createTrapModeClamp(); From 5dd3bcaa3d8a6a71f84713e23eb6bae28ae60125 Mon Sep 17 00:00:00 2001 From: Congcong Cai Date: Mon, 16 Jun 2025 12:14:55 +0800 Subject: [PATCH 2/4] handle try catch --- src/passes/TailCall.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/passes/TailCall.cpp b/src/passes/TailCall.cpp index e32e5eab3c0..d0ac118af07 100644 --- a/src/passes/TailCall.cpp +++ b/src/passes/TailCall.cpp @@ -8,7 +8,7 @@ namespace wasm { namespace { -struct Finder : PostWalker { +struct Finder : TryDepthWalker { std::vector tailCalls; std::vector tailCallIndirects; void visitFunction(Function* curr) { checkTailCall(curr->body); } @@ -19,6 +19,10 @@ struct Finder : PostWalker { if (expr == nullptr) { return; } + if (tryDepth > 0) { + // We are in a try block, so we cannot optimize tail calls. + return; + } if (auto* call = expr->dynCast()) { if (!call->isReturn && call->type == getFunction()->getResults()) { tailCalls.push_back(call); From 981e052da44d146206f94c344bf14583aa2a75b8 Mon Sep 17 00:00:00 2001 From: Congcong Cai Date: Mon, 14 Jul 2025 22:43:40 +0800 Subject: [PATCH 3/4] use getImmediateFallthrough --- src/passes/TailCall.cpp | 74 +++++++++++++++++++++++------------------ 1 file changed, 41 insertions(+), 33 deletions(-) diff --git a/src/passes/TailCall.cpp b/src/passes/TailCall.cpp index d0ac118af07..4a8ac4b4752 100644 --- a/src/passes/TailCall.cpp +++ b/src/passes/TailCall.cpp @@ -1,7 +1,9 @@ +#include "ir/properties.h" #include "pass.h" #include "wasm-traversal.h" #include "wasm.h" +#include #include namespace wasm { @@ -9,47 +11,53 @@ namespace wasm { namespace { struct Finder : TryDepthWalker { + explicit Finder(const PassOptions& passOptions) + : TryDepthWalker(), passOptions(passOptions) {} + const PassOptions& passOptions; std::vector tailCalls; std::vector tailCallIndirects; - void visitFunction(Function* curr) { checkTailCall(curr->body); } - void visitReturn(Return* curr) { checkTailCall(curr->value); } - -private: - void checkTailCall(Expression* expr) { - if (expr == nullptr) { + void visitFunction(Function* curr) { + if (passOptions.shrinkLevel > 0 && passOptions.optimizeLevel == 0) { + // When we more force on the binary size, add return_call will increase + // the code size. return; } + checkTailCall(curr->body); + } + void visitReturn(Return* curr) { if (tryDepth > 0) { - // We are in a try block, so we cannot optimize tail calls. - return; - } - if (auto* call = expr->dynCast()) { - if (!call->isReturn && call->type == getFunction()->getResults()) { - tailCalls.push_back(call); - } - return; - } - if (auto* call = expr->dynCast()) { - if (!call->isReturn && call->type == getFunction()->getResults()) { - tailCallIndirects.push_back(call); - } - return; - } - if (auto* block = expr->dynCast()) { - return checkTailCall(block->list); - } - if (auto* ifElse = expr->dynCast()) { - checkTailCall(ifElse->ifTrue); - checkTailCall(ifElse->ifFalse); + // (return (call ...)) is not equal to (return_call ...) in try block return; } + checkTailCall(curr->value); } - void checkTailCall(ExpressionList const& exprs) { - if (exprs.empty()) { - return; + +private: + void checkTailCall(Expression* const curr) { + std::stack workList{}; + workList.push(curr); + while (!workList.empty()) { + Expression* const target = workList.top(); + workList.pop(); + if (auto* call = target->dynCast()) { + if (!call->isReturn && call->type == getFunction()->getResults()) { + tailCalls.push_back(call); + } + } else if (auto* call = target->dynCast()) { + if (!call->isReturn && call->type == getFunction()->getResults()) { + tailCallIndirects.push_back(call); + } + } else if (auto* ifElse = target->dynCast()) { + workList.push(ifElse->ifTrue); + workList.push(ifElse->ifFalse); + } else { + Expression* const next = Properties::getImmediateFallthrough( + target, passOptions, *getModule()); + if (next != target) { + workList.push(next); + } + } } - checkTailCall(exprs.back()); - return; } }; @@ -64,7 +72,7 @@ struct TailCallOptimizer : public Pass { if (!module->features.hasTailCall()) { return; } - Finder finder{}; + Finder finder{getPassOptions()}; finder.walkFunctionInModule(function, module); for (Call* call : finder.tailCalls) { if (!call->isReturn) { From 945d3484d70515ce497ea5c7cefe48602d025310 Mon Sep 17 00:00:00 2001 From: Congcong Cai Date: Mon, 14 Jul 2025 23:35:37 +0800 Subject: [PATCH 4/4] add test --- src/passes/TailCall.cpp | 12 +- src/passes/pass.cpp | 5 +- test/lit/help/wasm-metadce.test | 2 + test/lit/help/wasm-opt.test | 2 + test/lit/help/wasm2js.test | 2 + test/lit/tail-call-optimization-eh.wast | 94 ++++++++++++ test/lit/tail-call-optimization-shrink.wast | 22 +++ test/lit/tail-call-optimization.wast | 157 ++++++++++++++++++++ 8 files changed, 292 insertions(+), 4 deletions(-) create mode 100644 test/lit/tail-call-optimization-eh.wast create mode 100644 test/lit/tail-call-optimization-shrink.wast create mode 100644 test/lit/tail-call-optimization.wast diff --git a/src/passes/TailCall.cpp b/src/passes/TailCall.cpp index 4a8ac4b4752..d9163252013 100644 --- a/src/passes/TailCall.cpp +++ b/src/passes/TailCall.cpp @@ -1,5 +1,6 @@ #include "ir/properties.h" +#include "ir/utils.h" #include "pass.h" #include "wasm-traversal.h" #include "wasm.h" @@ -50,6 +51,14 @@ struct Finder : TryDepthWalker { } else if (auto* ifElse = target->dynCast()) { workList.push(ifElse->ifTrue); workList.push(ifElse->ifFalse); + } else if (auto* tryy = target->dynCast()) { + for (Expression* catchBody : tryy->catchBodies) { + workList.push(catchBody); + } + } else if (auto* block = target->dynCast()) { + if (!block->list.empty()) { + workList.push(block->list.back()); + } } else { Expression* const next = Properties::getImmediateFallthrough( target, passOptions, *getModule()); @@ -77,15 +86,14 @@ struct TailCallOptimizer : public Pass { for (Call* call : finder.tailCalls) { if (!call->isReturn) { call->isReturn = true; - call->finalize(); } } for (CallIndirect* call : finder.tailCallIndirects) { if (!call->isReturn) { call->isReturn = true; - call->finalize(); } } + ReFinalize{}.walkFunctionInModule(function, module); } }; diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp index 6b69fafdda1..f6f27f383a5 100644 --- a/src/passes/pass.cpp +++ b/src/passes/pass.cpp @@ -558,8 +558,9 @@ void PassRegistry::registerPasses() { registerPass("strip-target-features", "strip the wasm target features section", createStripTargetFeaturesPass); - registerPass( - "tail-call", "transform call to return call", createTailCallPass); + registerPass("tail-call-optimization", + "transform call to return call", + createTailCallPass); registerPass("translate-to-new-eh", "deprecated; same as translate-to-exnref", createTranslateToExnrefPass); diff --git a/test/lit/help/wasm-metadce.test b/test/lit/help/wasm-metadce.test index 4c727bcbd1f..7bb5951920f 100644 --- a/test/lit/help/wasm-metadce.test +++ b/test/lit/help/wasm-metadce.test @@ -530,6 +530,8 @@ ;; CHECK-NEXT: ;; CHECK-NEXT: --table64-lowering alias for memory64-lowering ;; CHECK-NEXT: +;; CHECK-NEXT: --tail-call-optimization transform call to return call +;; CHECK-NEXT: ;; CHECK-NEXT: --trace-calls instrument the build with code ;; CHECK-NEXT: to intercept specific function ;; CHECK-NEXT: calls diff --git a/test/lit/help/wasm-opt.test b/test/lit/help/wasm-opt.test index a0d8d199f94..d546114945c 100644 --- a/test/lit/help/wasm-opt.test +++ b/test/lit/help/wasm-opt.test @@ -554,6 +554,8 @@ ;; CHECK-NEXT: ;; CHECK-NEXT: --table64-lowering alias for memory64-lowering ;; CHECK-NEXT: +;; CHECK-NEXT: --tail-call-optimization transform call to return call +;; CHECK-NEXT: ;; CHECK-NEXT: --trace-calls instrument the build with code ;; CHECK-NEXT: to intercept specific function ;; CHECK-NEXT: calls diff --git a/test/lit/help/wasm2js.test b/test/lit/help/wasm2js.test index 881e18950e7..f62457ead1c 100644 --- a/test/lit/help/wasm2js.test +++ b/test/lit/help/wasm2js.test @@ -494,6 +494,8 @@ ;; CHECK-NEXT: ;; CHECK-NEXT: --table64-lowering alias for memory64-lowering ;; CHECK-NEXT: +;; CHECK-NEXT: --tail-call-optimization transform call to return call +;; CHECK-NEXT: ;; CHECK-NEXT: --trace-calls instrument the build with code ;; CHECK-NEXT: to intercept specific function ;; CHECK-NEXT: calls diff --git a/test/lit/tail-call-optimization-eh.wast b/test/lit/tail-call-optimization-eh.wast new file mode 100644 index 00000000000..13e1f0b96d3 --- /dev/null +++ b/test/lit/tail-call-optimization-eh.wast @@ -0,0 +1,94 @@ +;; NOTE: Assertions have been generated by update_lit_checks.py --all-items and should not be edited. + +;; RUN: foreach %s %t wasm-opt --tail-call-optimization --enable-tail-call --enable-exception-handling --optimize-level 2 --shrink-level 0 -S -o - | filecheck %s + +;; Tests for tail call optimization with exception handling + +(module $exception + ;; CHECK: (type $0 (func (result i32))) + + ;; CHECK: (type $1 (func)) + + ;; CHECK: (tag $empty (type $1)) + (tag $empty) + ;; CHECK: (func $f (result i32) + ;; CHECK-NEXT: (i32.const 0) + ;; CHECK-NEXT: ) + (func $f (result i32) + i32.const 0 + ) + ;; CHECK: (func $in_try (result i32) + ;; CHECK-NEXT: (try (result i32) + ;; CHECK-NEXT: (do + ;; CHECK-NEXT: (call $f) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (catch $empty + ;; CHECK-NEXT: (return_call $f) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $in_try (result i32) + try (result i32) + call $f + catch $empty + call $f + end + ) + ;; CHECK: (func $out_try (result i32) + ;; CHECK-NEXT: (try + ;; CHECK-NEXT: (do + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (catch $empty + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (return_call $f) + ;; CHECK-NEXT: ) + (func $out_try (result i32) + try + catch $empty + end + call $f + ) + ;; CHECK: (func $in_catch (result i32) + ;; CHECK-NEXT: (try + ;; CHECK-NEXT: (do + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (catch $empty + ;; CHECK-NEXT: (return + ;; CHECK-NEXT: (return_call $f) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (i32.const 0) + ;; CHECK-NEXT: ) + (func $in_catch (result i32) + try + catch $empty + call $f + return + end + i32.const 0 + ) + ;; CHECK: (func $implicit_in_catch (result i32) + ;; CHECK-NEXT: (try (result i32) + ;; CHECK-NEXT: (do + ;; CHECK-NEXT: (i32.const 0) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (catch $empty + ;; CHECK-NEXT: (return_call $f) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (catch_all + ;; CHECK-NEXT: (return_call $f) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $implicit_in_catch (result i32) + try (result i32) + i32.const 0 + catch $empty + call $f + catch_all + call $f + end + ) +) diff --git a/test/lit/tail-call-optimization-shrink.wast b/test/lit/tail-call-optimization-shrink.wast new file mode 100644 index 00000000000..5fca8e8b76b --- /dev/null +++ b/test/lit/tail-call-optimization-shrink.wast @@ -0,0 +1,22 @@ +;; NOTE: Assertions have been generated by update_lit_checks.py --all-items and should not be edited. + +;; RUN: foreach %s %t wasm-opt --tail-call-optimization --enable-tail-call --optimize-level 0 --shrink-level 2 -S -o - | filecheck %s + +;; Tests for tail call optimization + +(module + ;; CHECK: (type $0 (func (result i32))) + + ;; CHECK: (func $f (result i32) + ;; CHECK-NEXT: (i32.const 0) + ;; CHECK-NEXT: ) + (func $f (result i32) + i32.const 0 + ) + ;; CHECK: (func $implicit_return (result i32) + ;; CHECK-NEXT: (call $f) + ;; CHECK-NEXT: ) + (func $implicit_return (result i32) + call $f + ) +) diff --git a/test/lit/tail-call-optimization.wast b/test/lit/tail-call-optimization.wast new file mode 100644 index 00000000000..25305dd432b --- /dev/null +++ b/test/lit/tail-call-optimization.wast @@ -0,0 +1,157 @@ +;; NOTE: Assertions have been generated by update_lit_checks.py --all-items and should not be edited. + +;; RUN: foreach %s %t wasm-opt --tail-call-optimization --enable-tail-call --optimize-level 2 --shrink-level 0 -S -o - | filecheck %s + +;; Tests for tail call optimization + +(module + ;; CHECK: (type $0 (func (result i32))) + + ;; CHECK: (type $1 (func (param i32) (result i32))) + + ;; CHECK: (func $f (result i32) + ;; CHECK-NEXT: (i32.const 0) + ;; CHECK-NEXT: ) + (func $f (result i32) + i32.const 0 + ) + ;; CHECK: (func $implicit_return (result i32) + ;; CHECK-NEXT: (return_call $f) + ;; CHECK-NEXT: ) + (func $implicit_return (result i32) + call $f + ) + ;; CHECK: (func $explicit_return (result i32) + ;; CHECK-NEXT: (return_call $f) + ;; CHECK-NEXT: ) + (func $explicit_return (result i32) + call $f + ) + ;; CHECK: (func $return_through_tee (param $0 i32) (result i32) + ;; CHECK-NEXT: (local.tee $0 + ;; CHECK-NEXT: (return_call $f) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $return_through_tee (param $0 i32) (result i32) + call $f + local.tee $0 + ) + ;; CHECK: (func $return_through_block (result i32) + ;; CHECK-NEXT: (return_call $f) + ;; CHECK-NEXT: ) + (func $return_through_block (result i32) + block (result i32) + call $f + end + ) + ;; CHECK: (func $return_through_loop (result i32) + ;; CHECK-NEXT: (loop + ;; CHECK-NEXT: (return_call $f) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $return_through_loop (result i32) + loop (result i32) + call $f + end + ) + ;; CHECK: (func $return_through_if_then (param $0 i32) (result i32) + ;; CHECK-NEXT: (if (result i32) + ;; CHECK-NEXT: (local.get $0) + ;; CHECK-NEXT: (then + ;; CHECK-NEXT: (return_call $f) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (else + ;; CHECK-NEXT: (i32.const 0) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $return_through_if_then (param $0 i32) (result i32) + local.get $0 + if (result i32) + call $f + else + i32.const 0 + end + ) + ;; CHECK: (func $return_through_if_else (param $0 i32) (result i32) + ;; CHECK-NEXT: (if (result i32) + ;; CHECK-NEXT: (local.get $0) + ;; CHECK-NEXT: (then + ;; CHECK-NEXT: (i32.const 0) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (else + ;; CHECK-NEXT: (return_call $f) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $return_through_if_else (param $0 i32) (result i32) + local.get $0 + if (result i32) + i32.const 0 + else + call $f + end + ) + ;; CHECK: (func $return_through_if_both (param $0 i32) (result i32) + ;; CHECK-NEXT: (if + ;; CHECK-NEXT: (local.get $0) + ;; CHECK-NEXT: (then + ;; CHECK-NEXT: (return_call $f) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (else + ;; CHECK-NEXT: (return_call $f) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $return_through_if_both (param $0 i32) (result i32) + local.get $0 + if (result i32) + call $f + else + call $f + end + ) + ;; CHECK: (func $return_through_br_if (param $0 i32) (result i32) + ;; CHECK-NEXT: (block $block + ;; CHECK-NEXT: (block + ;; CHECK-NEXT: (return_call $f) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (i32.const 1) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $return_through_br_if (param $0 i32) (result i32) + block (result i32) + call $f + i32.const 1 + br_if 0 + end + ) +) + +(module $NYI + ;; CHECK: (type $0 (func (result i32))) + + ;; CHECK: (type $1 (func (param i32) (result i32))) + + ;; CHECK: (func $f (result i32) + ;; CHECK-NEXT: (i32.const 0) + ;; CHECK-NEXT: ) + (func $f (result i32) + i32.const 0 + ) + ;; CHECK: (func $return_through_br (param $0 i32) (result i32) + ;; CHECK-NEXT: (block $block (result i32) + ;; CHECK-NEXT: (br $block + ;; CHECK-NEXT: (call $f) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $return_through_br (param $0 i32) (result i32) + block (result i32) + call $f + br 0 + end + ) +)