Skip to content

Handle exactness in MinimizeRecGroups #7555

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Apr 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions scripts/test/fuzzing.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@
'type-merging-exact.wast',
'type-refining-exact.wast',
'type-refining-gufa-exact.wast',
'mimimize-rec-groups-exact.wast',
'mimimize-rec-groups-ignore-exact.wast',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two lines have typos: mimimize

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh no 😭

'public-exact.wast',
# TODO: fuzzer support for custom descriptors
'custom-descriptors.wast',
Expand Down
39 changes: 22 additions & 17 deletions src/passes/MinimizeRecGroups.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,14 +208,14 @@ struct GroupClassInfo {
static std::vector<std::vector<Index>> initSubtypeGraph(RecGroupInfo& info);
GroupClassInfo(RecGroupInfo& info);

void advance() {
void advance(FeatureSet features) {
++orders;
if (orders == orders.end()) {
advanceBrand();
advanceBrand(features);
}
}

void advanceBrand() {
void advanceBrand(FeatureSet features) {
if (brand) {
++*brand;
} else {
Expand All @@ -231,8 +231,8 @@ struct GroupClassInfo {
}
}
// Make sure the brand is not the same as the real type.
if (singletonType &&
RecGroupShape({**brand}) == RecGroupShape({*singletonType})) {
if (singletonType && RecGroupShape({**brand}, features) ==
RecGroupShape({*singletonType}, features)) {
++*brand;
}
// Start back at the initial permutation with the new brand.
Expand Down Expand Up @@ -370,9 +370,13 @@ struct MinimizeRecGroups : Pass {
// whose shapes we need to check for uniqueness to avoid deep recursions.
std::vector<Index> shapesToUpdate;

// The comparison of rec group shapes depends on the features.
FeatureSet features;

void run(Module* module) override {
features = module->features;
// There are no recursion groups to minimize if GC is not enabled.
if (!module->features.hasGC()) {
if (!features.hasGC()) {
return;
}

Expand Down Expand Up @@ -402,7 +406,7 @@ struct MinimizeRecGroups : Pass {
for (auto group : publicGroups) {
publicGroupTypes.emplace_back(group.begin(), group.end());
[[maybe_unused]] auto [_, inserted] = groupShapeIndices.insert(
{RecGroupShape(publicGroupTypes.back()), PublicGroupIndex});
{RecGroupShape(publicGroupTypes.back(), features), PublicGroupIndex});
assert(inserted);
}

Expand Down Expand Up @@ -452,8 +456,8 @@ struct MinimizeRecGroups : Pass {
}

void updateShape(Index group) {
auto [it, inserted] =
groupShapeIndices.insert({RecGroupShape(groups[group].group), group});
auto [it, inserted] = groupShapeIndices.insert(
{RecGroupShape(groups[group].group, features), group});
if (inserted) {
// This shape was unique. We're done.
return;
Expand Down Expand Up @@ -509,7 +513,7 @@ struct MinimizeRecGroups : Pass {
// We have everything we need to generate the next permutation of this
// group.
auto& classInfo = *groups[groupRep].classInfo;
classInfo.advance();
classInfo.advance(features);
classInfo.permute(groupInfo);
shapesToUpdate.push_back(group);
return;
Expand Down Expand Up @@ -538,7 +542,7 @@ struct MinimizeRecGroups : Pass {

// Move to the next permutation after advancing the type brand to skip
// further repeated shapes.
classInfo.advanceBrand();
classInfo.advanceBrand(features);
classInfo.permute(groupInfo);

shapesToUpdate.push_back(group);
Expand All @@ -556,7 +560,7 @@ struct MinimizeRecGroups : Pass {
// conflict.
if (groups[groupRep].classInfo && groups[otherRep].classInfo) {
auto& classInfo = *groups[groupRep].classInfo;
classInfo.advance();
classInfo.advance(features);
classInfo.permute(groupInfo);
shapesToUpdate.push_back(group);
return;
Expand All @@ -578,7 +582,7 @@ struct MinimizeRecGroups : Pass {
// same shape. Advance `group` to the next permutation.
otherInfo.classInfo = std::nullopt;
otherInfo.permutation = groupInfo.permutation;
classInfo.advance();
classInfo.advance(features);
classInfo.permute(groupInfo);

shapesToUpdate.push_back(group);
Expand All @@ -600,7 +604,7 @@ struct MinimizeRecGroups : Pass {
// permutation.
groupInfo.classInfo = std::nullopt;
groupInfo.permutation = otherInfo.permutation;
classInfo.advance();
classInfo.advance(features);
classInfo.permute(groupInfo);

shapesToUpdate.push_back(group);
Expand Down Expand Up @@ -754,9 +758,10 @@ struct MinimizeRecGroups : Pass {
// shapes to lists of automorphically equivalent root types.
std::map<ComparableRecGroupShape, std::vector<HeapType>> typeClasses;
for (const auto& order : dfsOrders) {
ComparableRecGroupShape shape(order, [this](HeapType a, HeapType b) {
return this->typeIndices.at(a) < this->typeIndices.at(b);
});
ComparableRecGroupShape shape(
order, features, [this](HeapType a, HeapType b) {
return this->typeIndices.at(a) < this->typeIndices.at(b);
});
typeClasses[shape].push_back(order[0]);
}

Expand Down
6 changes: 3 additions & 3 deletions src/tools/wasm-fuzz-types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ void Fuzzer::checkRecGroupShapes() {
};

for (size_t i = 0; i < groups.size(); ++i) {
ComparableRecGroupShape shape(groups[i], less);
ComparableRecGroupShape shape(groups[i], FeatureSet::All, less);
// A rec group should compare equal to itself.
if (shape != shape) {
Fatal() << "Rec group shape " << i << " not equal to itself";
Expand All @@ -556,7 +556,7 @@ void Fuzzer::checkRecGroupShapes() {

// Check how it compares to other groups.
for (size_t j = i + 1; j < groups.size(); ++j) {
ComparableRecGroupShape other(groups[j], less);
ComparableRecGroupShape other(groups[j], FeatureSet::All, less);
bool isLess = shape < other;
bool isEq = shape == other;
bool isGreater = shape > other;
Expand Down Expand Up @@ -598,7 +598,7 @@ void Fuzzer::checkRecGroupShapes() {

if (j + 1 < groups.size()) {
// Check transitivity.
RecGroupShape third(groups[j + 1]);
RecGroupShape third(groups[j + 1], FeatureSet::All);
if ((isLess && other <= third && shape >= third) ||
(isEq && other == third && shape != third) ||
(isGreater && other >= third && shape <= third)) {
Expand Down
12 changes: 10 additions & 2 deletions src/wasm-type-shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <functional>
#include <vector>

#include "wasm-features.h"
#include "wasm-type.h"

namespace wasm {
Expand All @@ -35,7 +36,13 @@ namespace wasm {
struct RecGroupShape {
const std::vector<HeapType>& types;

RecGroupShape(const std::vector<HeapType>& types) : types(types) {}
// Depending on the feature set, some types may be generalized when they are
// written out. Take the features into account to ensure our comparisons
// account for the rec groups that will ultimately be written.
const FeatureSet features;

RecGroupShape(const std::vector<HeapType>& types, const FeatureSet features)
: types(types), features(features) {}

bool operator==(const RecGroupShape& other) const;
bool operator!=(const RecGroupShape& other) const {
Expand All @@ -51,8 +58,9 @@ struct ComparableRecGroupShape : RecGroupShape {
std::function<bool(HeapType, HeapType)> less;

ComparableRecGroupShape(const std::vector<HeapType>& types,
FeatureSet features,
std::function<bool(HeapType, HeapType)> less)
: RecGroupShape(types), less(less) {}
: RecGroupShape(types, features), less(less) {}

bool operator<(const RecGroupShape& other) const;
bool operator>(const RecGroupShape& other) const;
Expand Down
13 changes: 13 additions & 0 deletions src/wasm/wasm-type-shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,16 @@ namespace {
enum Comparison { EQ, LT, GT };

template<typename CompareTypes> struct RecGroupComparator {
FeatureSet features;
std::unordered_map<HeapType, Index> indicesA;
std::unordered_map<HeapType, Index> indicesB;
CompareTypes compareTypes;

RecGroupComparator(CompareTypes compareTypes) : compareTypes(compareTypes) {}

Comparison compare(const RecGroupShape& a, const RecGroupShape& b) {
assert(a.features == b.features);
features = a.features;
if (a.types.size() != b.types.size()) {
return a.types.size() < b.types.size() ? LT : GT;
}
Expand Down Expand Up @@ -147,6 +150,11 @@ template<typename CompareTypes> struct RecGroupComparator {
return compare(a.getTuple(), b.getTuple());
}
assert(a.isRef() && b.isRef());
// Only consider exactness if custom descriptors are enabled. Otherwise, it
// will be erased when the types are written, so we ignore it here, too.
if (features.hasCustomDescriptors() && a.isExact() != b.isExact()) {
return a.isExact() < b.isExact() ? LT : GT;
}
if (a.isNullable() != b.isNullable()) {
return a.isNullable() < b.isNullable() ? LT : GT;
}
Expand Down Expand Up @@ -201,9 +209,11 @@ template<typename CompareTypes>
RecGroupComparator(CompareTypes) -> RecGroupComparator<CompareTypes>;

struct RecGroupHasher {
FeatureSet features;
std::unordered_map<HeapType, Index> typeIndices;

size_t hash(const RecGroupShape& shape) {
features = shape.features;
for (auto type : shape.types) {
typeIndices.insert({type, typeIndices.size()});
}
Expand Down Expand Up @@ -285,6 +295,9 @@ struct RecGroupHasher {
return digest;
}
assert(type.isRef());
if (features.hasCustomDescriptors()) {
wasm::rehash(digest, type.isExact());
}
wasm::rehash(digest, type.isNullable());
hash_combine(digest, hash(type.getHeapType()));
return digest;
Expand Down
19 changes: 19 additions & 0 deletions test/lit/passes/minimize-rec-groups-exact.wast
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
;; NOTE: Assertions have been generated by update_lit_checks.py --all-items and should not be edited.
;; RUN: wasm-opt %s -all --minimize-rec-groups -S -o - | filecheck %s

(module
;; CHECK: (type $foo (struct))
(type $foo (struct))
;; CHECK: (type $exact (struct (field (ref (exact $foo)))))
(type $exact (struct (field (ref (exact $foo)))))
;; CHECK: (type $inexact (struct (field (ref $foo))))
(type $inexact (struct (field (ref $foo))))

;; If we didn't differentiate between exact and inexact types, there would be
;; an assertion failure on adding these public types to the set of public
;; shapes.
;; CHECK: (import "" "exact" (global $exact (ref null $exact)))
(import "" "exact" (global $exact (ref null $exact)))
;; CHECK: (import "" "inexact" (global $inexact (ref null $inexact)))
(import "" "inexact" (global $inexact (ref null $inexact)))
)
53 changes: 53 additions & 0 deletions test/lit/passes/minimize-rec-groups-ignore-exact.wast
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
;; NOTE: Assertions have been generated by update_lit_checks.py --all-items and should not be edited.

;; Check that we take exactness into account correctly depending on the
;; features. It differentiates shapes only when custom descriptors is enabled.

;; RUN: wasm-opt %s -all --minimize-rec-groups -S -o - | filecheck %s
;; RUN: wasm-opt %s -all --disable-custom-descriptors --minimize-rec-groups -S -o - | filecheck %s --check-prefix=NO_CD

(module
(rec
(type $foo (struct))

;; This SCC has only one distinct permutation.
;; CHECK: (rec
;; CHECK-NEXT: (type $b-inexact (struct (field (ref $a-inexact))))

;; CHECK: (type $a-inexact (struct (field (ref $b-inexact))))
;; NO_CD: (rec
;; NO_CD-NEXT: (type $b-inexact (struct (field (ref $a-inexact))))

;; NO_CD: (type $a-inexact (struct (field (ref $b-inexact))))
(type $a-inexact (struct (field (ref $b-inexact))))
(type $b-inexact (struct (field (ref $a-inexact))))

;; This SCC is only different because of exactness. It needs a brand only if
;; custom descriptors is disabled.
;; CHECK: (rec
;; CHECK-NEXT: (type $b-exact (struct (field (ref (exact $a-exact)))))

;; CHECK: (type $a-exact (struct (field (ref (exact $b-exact)))))
;; NO_CD: (rec
;; NO_CD-NEXT: (type $2 (struct))

;; NO_CD: (type $b-exact (struct (field (ref (exact $a-exact)))))

;; NO_CD: (type $a-exact (struct (field (ref (exact $b-exact)))))
(type $a-exact (struct (field (ref (exact $b-exact)))))
(type $b-exact (struct (field (ref (exact $a-exact)))))
)

;; CHECK: (global $a-inexact (ref null $a-inexact) (ref.null none))
;; NO_CD: (global $a-inexact (ref null $a-inexact) (ref.null none))
(global $a-inexact (ref null $a-inexact) (ref.null none))
;; CHECK: (global $b-inexact (ref null $b-inexact) (ref.null none))
;; NO_CD: (global $b-inexact (ref null $b-inexact) (ref.null none))
(global $b-inexact (ref null $b-inexact) (ref.null none))
;; CHECK: (global $a-exact (ref null $a-exact) (ref.null none))
;; NO_CD: (global $a-exact (ref null $a-exact) (ref.null none))
(global $a-exact (ref null $a-exact) (ref.null none))
;; CHECK: (global $b-exact (ref null $b-exact) (ref.null none))
;; NO_CD: (global $b-exact (ref null $b-exact) (ref.null none))
(global $b-exact (ref null $b-exact) (ref.null none))
)
Loading