Skip to content

Commit a033df6

Browse files
jatayloanijain2305
andauthored
[SWDEV-554558] [dynamo] Use BINARY_SUBSCR for pre-graph bytecode for regular dict ac… (#2669)
…cesses (pytorch#155727) vLLM profiler sets with_stack=True that shows the dict_getitem on the profiler, both inflating the numbers and confusing compile users. This PR keeps BINARY_SUBSCR for regular dicts, while using `dict.__getitem__` only for dict subclasses. Using binary_subscr is little bit faster, but not enough to make any major latency improvements. Pull Request resolved: pytorch#155727 Approved by: https://github.com/zou3519, https://github.com/StrongerXi, https://github.com/jansel (cherry picked from commit a9d5157) Fixes #ISSUE_NUMBER Co-authored-by: Animesh Jain <[email protected]>
1 parent dcd8e22 commit a033df6

File tree

4 files changed

+43
-5
lines changed

4 files changed

+43
-5
lines changed

test/dynamo/test_dicts.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from typing import Any, Optional, Tuple
1515

1616
import torch
17-
import torch._dynamo.config
1817
import torch._dynamo.test_case
1918
import torch._dynamo.testing
2019
import torch._functorch.config

torch/_dynamo/guards.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@
9999
ConstDictKeySource,
100100
DefaultsSource,
101101
DictGetItemSource,
102+
DictSubclassGetItemSource,
102103
FlattenScriptObjectSource,
103104
FloatTensorSource,
104105
FSDPNNModuleSource,
@@ -1079,7 +1080,7 @@ def get_guard_manager_from_source(self, source):
10791080
example_value=example_value,
10801081
guard_manager_enum=guard_manager_enum,
10811082
)
1082-
elif istype(source, DictGetItemSource):
1083+
elif istype(source, (DictGetItemSource, DictSubclassGetItemSource)):
10831084
assert base_guard_manager # to make mypy happy
10841085
assert isinstance(base_example_value, (dict, collections.OrderedDict))
10851086
if isinstance(base_guard_manager, DictGuardManager):

torch/_dynamo/source.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,44 @@ def __post_init__(self):
601601
def guard_source(self):
602602
return self.base.guard_source()
603603

604-
def reconstruct(self, codegen):
604+
def reconstruct(self, codegen: "PyCodegen"):
605+
# Load dict
606+
codegen(self.base)
607+
608+
# Load key
609+
if isinstance(self.index, Source):
610+
codegen(self.index)
611+
else:
612+
codegen.append_output(codegen.create_load_const(self.index))
613+
codegen.append_output(create_instruction("BINARY_SUBSCR"))
614+
615+
def name(self):
616+
if isinstance(self.index, ConstDictKeySource):
617+
return f"{self.base.name()}[{self.index.name()}]"
618+
else:
619+
return f"{self.base.name()}[{self.index!r}]"
620+
621+
622+
# Same as DictGetItemSource but used for dict.__getitem__ calls to ensure that
623+
# torch.compile does not run the overridden __getitem__ method
624+
@dataclasses.dataclass(frozen=True)
625+
class DictSubclassGetItemSource(ChainedSource):
626+
# Key to access in the dictionary. It can be one of the the following types
627+
# 1) ConstDictKeySource
628+
# 2) constant - like string, integer
629+
index: Any
630+
631+
def __post_init__(self):
632+
from .variables import ConstantVariable
633+
634+
assert isinstance(
635+
self.index, ConstDictKeySource
636+
) or ConstantVariable.is_literal(self.index)
637+
638+
def guard_source(self):
639+
return self.base.guard_source()
640+
641+
def reconstruct(self, codegen: "PyCodegen"):
605642
# reconstruct dict.__getitem__(dct, key)
606643

607644
# Load dict.__getitem__

torch/_dynamo/variables/builder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
ConstDictKeySource,
9393
ConvertIntSource,
9494
DictGetItemSource,
95+
DictSubclassGetItemSource,
9596
FloatTensorSource,
9697
GetItemSource,
9798
GradSource,
@@ -1274,8 +1275,8 @@ def build_key_value(i, k, v):
12741275
source_key = ConstDictKeySource(self.get_source(), i)
12751276
key = LazyVariableTracker.create(k, source_key)
12761277

1277-
source_value = DictGetItemSource(self.get_source(), source_key)
1278-
value = LazyVariableTracker.create(v, source_value)
1278+
source_value = DictSubclassGetItemSource(base, source_key)
1279+
res_value = LazyVariableTracker.create(v, source_value)
12791280

12801281
return key, value
12811282

0 commit comments

Comments
 (0)