Skip to content

Commit 3ab0ba3

Browse files
authored
Fix sharding prop cache clear (#257)
* Fix sharding prop cache clear * lint
1 parent 9b17c20 commit 3ab0ba3

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

.github/workflows/test_cuda.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ jobs:
4040
pip uninstall -y torch
4141
pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126
4242
pip install --quiet .
43-
pytest tests --deselect=tests/test_dtensor.py::ImplicitRegistrationTest::test_implicit_registration
43+
pytest tests
4444
python examples/example_autoparallel.py
4545
python examples/example_llama3.py
4646
python examples/example_dcp.py

autoparallel/dtensor_util/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222
generate_redistribute_costs,
2323
is_tensor_shardable,
2424
)
25+
from torch.distributed.tensor.debug import (
26+
_clear_fast_path_sharding_prop_cache,
27+
_clear_python_sharding_prop_cache,
28+
)
2529
from torch.distributed.tensor.placement_types import Placement, Replicate, Shard
2630

2731
try:
@@ -82,7 +86,8 @@ def op_strategy_context(op_overload, strategy_func, schema_info=None):
8286
del propagator.op_to_schema_info[op_overload]
8387
else:
8488
propagator.op_to_schema_info[op_overload] = _origin_op_strategy_schema
85-
propagator.propagate_op_sharding.cache.cache_clear()
89+
_clear_fast_path_sharding_prop_cache()
90+
_clear_python_sharding_prop_cache()
8691

8792

8893
# -------------define universal op strategy-------------

0 commit comments

Comments
 (0)