Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

config torch to avoid graph breaks caused by logger #6999

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 10 additions & 31 deletions deepspeed/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import logging
import sys
import os
from deepspeed.runtime.compiler import is_compile_supported, is_compiling
import torch
from deepspeed.utils.torch import required_torch_version

log_levels = {
"debug": logging.DEBUG,
Expand All @@ -20,31 +21,6 @@

class LoggerFactory:

def create_warning_filter(logger):
warn = False

def warn_once(record):
nonlocal warn
if is_compile_supported() and is_compiling() and not warn:
warn = True
logger.warning("To avoid graph breaks caused by logger in compile-mode, it is recommended to"
" disable logging by setting env var DISABLE_LOGS_WHILE_COMPILING=1")
return True

return warn_once

@staticmethod
def logging_decorator(func):

@functools.wraps(func)
def wrapper(*args, **kwargs):
if is_compiling():
return
else:
return func(*args, **kwargs)

return wrapper

@staticmethod
def create_logger(name=None, level=logging.INFO):
"""create a logger
Expand All @@ -70,12 +46,15 @@ def create_logger(name=None, level=logging.INFO):
ch.setLevel(level)
ch.setFormatter(formatter)
logger_.addHandler(ch)
if os.getenv("DISABLE_LOGS_WHILE_COMPILING", "0") == "1":
for method in ['info', 'debug', 'error', 'warning', 'critical', 'exception']:
if required_torch_version(min_version=2.6) and os.getenv("DISABLE_LOGS_WHILE_COMPILING", "0") == "1":
excluded_set = {
item.strip()
for item in os.getenv("LOGGER_METHODS_TO_EXCLUDE_FROM_DISABLE", "").split(",")
}
ignore_set = {'info', 'debug', 'error', 'warning', 'critical', 'exception', 'isEnabledFor'} - excluded_set
for method in ignore_set:
original_logger = getattr(logger_, method)
setattr(logger_, method, LoggerFactory.logging_decorator(original_logger))
else:
logger_.addFilter(LoggerFactory.create_warning_filter(logger_))
torch._dynamo.config.ignore_logger_methods.add(original_logger)
return logger_


Expand Down
Loading