Skip to content

Commit 35ad2d8

Browse files
committed
Revert "[jit] fix tuple alias analysis (pytorch#41992)"
This reverts commit 8aa878f.
1 parent 994b37b commit 35ad2d8

File tree

3 files changed

+11
-19
lines changed

3 files changed

+11
-19
lines changed

test/jit/test_freezing.py

+4
Original file line numberDiff line numberDiff line change
@@ -794,6 +794,9 @@ def forward(self, x):
794794
expected = m_s.forward(inp)
795795
self.assertEqual(out, expected)
796796

797+
# Check attribute a is preserved. Alias analysis detects that 'a' has output writers.
798+
# In this example, 'a' is not mutated. However, we do not track which sub
799+
# values of a composite ivalue is mutated.
797800
def test_freeze_module_with_aliased_attr2(self):
798801
class FreezeMe(nn.Module):
799802
def __init__(self):
@@ -812,6 +815,7 @@ def forward(self, x):
812815
m_s = torch.jit.script(m)
813816
m_s.eval()
814817
m_f = torch._C._freeze_module(m_s._c)
818+
self.assertTrue(m_f.hasattr('a'))
815819
inp = torch.tensor([5])
816820
out = m_f.forward(inp)
817821
expected = m.forward(inp)

torch/csrc/jit/ir/alias_analysis.cpp

+7-18
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,6 @@ void AliasDb::analyzeImpl(Node* node) {
498498
case prim::tolist:
499499
return analyzeCreator(node);
500500
case prim::TupleConstruct:
501-
return analyzeTupleConstruct(node);
502501
case prim::DictConstruct:
503502
case prim::ListConstruct:
504503
return analyzeContainerConstruct(node);
@@ -865,30 +864,20 @@ void AliasDb::analyzeConservative(Node* node) {
865864
}
866865
}
867866

868-
void AliasDb::analyzeTupleConstruct(Node* node) {
869-
TORCH_INTERNAL_ASSERT(node->kind() == prim::TupleConstruct);
870-
// tuples which contain immutable types are immutable
871-
if (!isMutableTypeInternal(node->output())) {
872-
return;
873-
}
874-
875-
giveFreshAlias(node->output());
876-
877-
for (const auto& input : node->inputs()) {
878-
if (isMutableTypeInternal(input)) {
879-
addToContainedElements(input, node->output());
880-
}
881-
}
882-
}
883-
884867
// List or dict or tuple: construct: create an aliasing element for the actual
885868
// container, then mark all inputs as wildcards, since they've gone inside the
886869
// container. Then, add the wildcard sets of appropriate type to the contained
887870
// elements of the container.
888871
void AliasDb::analyzeContainerConstruct(Node* node) {
889872
TORCH_INTERNAL_ASSERT(
890873
node->kind() == prim::ListConstruct ||
891-
node->kind() == prim::DictConstruct);
874+
node->kind() == prim::DictConstruct ||
875+
node->kind() == prim::TupleConstruct);
876+
877+
// tuples which contain immutable types are immutable
878+
if (!isMutableTypeInternal(node->output())) {
879+
return;
880+
}
892881

893882
TORCH_INTERNAL_ASSERT(node->outputs().size() == 1);
894883
auto container = node->output();

torch/csrc/jit/ir/alias_analysis.h

-1
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,6 @@ class AliasDb {
194194
void analyzeSetAttr(Node* node);
195195
void analyzeConservative(Node* node);
196196
void analyzeContainerConstruct(Node* node);
197-
void analyzeTupleConstruct(Node* node);
198197
bool tryRegisteredAnalysis(Node* node);
199198

200199
/**

0 commit comments

Comments
 (0)