Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass backend-related ctx to TorchDynamo Optimize Context #201

Merged
merged 3 commits into from
May 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 36 additions & 11 deletions torchdynamo/eval_frame.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import functools
import logging
import threading
Expand All @@ -24,26 +25,32 @@ def nothing():
pass


null_context = contextlib.nullcontext

unset = object()

compile_lock = threading.Lock()


class _TorchDynamoContext:
def __init__(self, callback, on_enter=nothing):
def __init__(self, callback, on_enter=nothing, backend_ctx_ctor=null_context):
super().__init__()
assert callable(callback) or callback is False or callback is None
self.callback = callback
self.prior = unset
self.on_enter = on_enter
self.extra_ctx_ctor = backend_ctx_ctor

def __enter__(self):
self.on_enter()
self.prior = set_eval_frame(self.callback)
self.backend_ctx = self.extra_ctx_ctor()
self.backend_ctx.__enter__()

def __exit__(self, exc_type, exc_val, exc_tb):
set_eval_frame(self.prior)
self.prior = unset
self.backend_ctx.__exit__(exc_type, exc_val, exc_tb)

def __call__(self, fn):
assert callable(fn)
Expand All @@ -69,8 +76,12 @@ def _fn(*args, **kwargs):


class OptimizeContext(_TorchDynamoContext):
def __init__(self, callback):
super().__init__(callback=callback, on_enter=install_generation_tagging_new)
def __init__(self, callback, backend_ctx_ctor):
super().__init__(
callback=callback,
on_enter=install_generation_tagging_new,
backend_ctx_ctor=backend_ctx_ctor,
)


class RunOnlyContext(_TorchDynamoContext):
Expand Down Expand Up @@ -107,8 +118,10 @@ def catch_errors(frame, cache_size):
return catch_errors


def _optimize_catch_errors(compile_fn):
return OptimizeContext(catch_errors_wrapper(compile_fn))
def _optimize_catch_errors(compile_fn, backend_ctx_ctor=null_context):
return OptimizeContext(
catch_errors_wrapper(compile_fn), backend_ctx_ctor=backend_ctx_ctor
)


def optimize(backend, nopython=False):
Expand All @@ -117,10 +130,13 @@ def optimize(backend, nopython=False):
backend() to optimize extracted graphs.

Args:
backend: One of two things:
- Either, a function taking a torch.fx.GraphModule and
backend: One of the two things:
- Either, a function/callable taking a torch.fx.GraphModule and
example_inputs and returning a python callable that runs the
graph faster.
One can also provide additional context for the backend, like
torch.jit.fuser("fuser2"), by setting the backend_ctx_ctor attribute.
See AOTAutogradMemoryEfficientFusionWithContext for the usage.
- Or, a string backend name in `torchdynamo.list_backends()`
nopython: If True, graph breaks will be errors and there will
be a single whole-program graph.
Expand All @@ -136,16 +152,25 @@ def toy_example(a, b):
with torchdynamo.optimize(my_compiler):
...
"""

backend_ctx_ctor = null_context
if hasattr(backend, "backend_ctx_ctor"):
backend_ctx_ctor = getattr(backend, "backend_ctx_ctor")

if nopython:
return optimize_assert(backend)
return _optimize_catch_errors(convert_frame.convert_frame(backend))
return optimize_assert(backend, backend_ctx_ctor)
return _optimize_catch_errors(
convert_frame.convert_frame(backend), backend_ctx_ctor
)


def optimize_assert(backend):
def optimize_assert(backend, backend_ctx_ctor=null_context):
"""
The same as `torchdynamo.optimize(backend, nopython=True)`
"""
return _optimize_catch_errors(convert_frame.convert_frame_assert(backend))
return _optimize_catch_errors(
convert_frame.convert_frame_assert(backend), backend_ctx_ctor
)


def run(fn=None):
Expand Down
12 changes: 11 additions & 1 deletion torchdynamo/optimizations/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,4 +143,14 @@ def candidate(self):
return BACKENDS["aot_autograd"](self.gm, self.example_inputs)


aot_autograd_speedup_strategy = AOTAutogradMemoryEfficientFusion.compile_fn
class AOTAutogradMemoryEfficientFusionWithContext:
"""Pass nvfuser context to TorchDynamo"""

def __init__(self):
self.backend_ctx_ctor = lambda: torch.jit.fuser("fuser2")

def __call__(self, gm: torch.fx.GraphModule, example_inputs):
return AOTAutogradMemoryEfficientFusion.compile_fn(gm, example_inputs)


aot_autograd_speedup_strategy = AOTAutogradMemoryEfficientFusionWithContext()