Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[muti_backend] #294

Open
wants to merge 32 commits into
base: master
Choose a base branch
from
Open

[muti_backend] #294

wants to merge 32 commits into from

Conversation

Galaxy1458
Copy link
Collaborator

@Galaxy1458 Galaxy1458 commented Nov 16, 2024

PR Category

Other, support muti_backend

Type of Change

New Feature

Description

  1. After this code is executed, the global device information is obtained for future use import flag_gems

  2. The vendor maintains its own auto_tune config and calls the get_op_tune_config interface directly when passing parameters

@triton.autotune(
    configs=runtime.get_op_tune_config("bmm"),
    key=["M", "N", "K"],
)
  1. Use runtime.device_guard uniformly for device_guard such as torch.cuda.device
    with runtime.device_guard(x.device)

  2. The specified operator can be disabled at initialization flag_gems.enable(unused=["add", "mm"])

Issue

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance


vendors = vendors
AUTOGRAD = AUTOGRAD
Autograd = Autograd
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do these mean?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for future use.


__version__ = "2.1"

device = runtime.device.device_instance.device_name
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is device used here?

global register_instance
if not register_instance:
register_instance = Register(*args, **kargs)
return register_instance
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is to_register used for?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for global instance, will be used in next pr.

def get_vendor_unused_op(self):
if self.device.vendor != backend.vendors.NVIDIA:
return backend.get_curent_device_unused_op(self.device.vendor_name)
return {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

user might specify unused ops when device is NVIDIA.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The unused_list consists of two parts, one specified by the vendor and the other specified by the user in enable(unused=[xxx]) or use_gems(unused=[xxx])

self.lib.impl(key, fn, device_key)

def close(self):
self.lib
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will remove close.

return get_vendor_module(vendor_name, query).device.get_vendor_info()
global vendor_module
get_vendor_module(vendor_name)
return vendor_module.device.get_vendor_info()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto. I didn't get why set query as an argument.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

from .. import backend, error


class device_ctx:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's it used for?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load the device information.

single_config["META"],
num_warps=current_config["num_warps"],
num_stages=current_config["num_stages"],
num_ctas=current_config["num_ctas"],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not necessary to fill all fields in triton.Config. triton.Autotuner will automatically complement the default value.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All are passed in because it is not certain which default values were changed in the above calculation

def to_gen_config(self, gen_config):
param_config = gen_config["param_map"]
meta_config = param_config["META"]
iteration_keys = list(meta_config) + list(param_config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there overlap between meta_config and param_config?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No.

current_config["META"][param_key] = single_value
else:
current_config[param_key] = single_value
config_item = self._gen_impl(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest using loop instead of recursion.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Get.



class device_ctx:
def __init__(self, vendor_name=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add necessary documentation to help understanding the code.

self.device_info = self.get_vendor(vendor_name)
self.vendor_name, self.device_name, self.cmd, self.vendor = self.device_info

def get_vendor(self, vendor_name=None) -> tuple:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest use typing to add type hints and use type hints that works for older versions of python, for example typing.Tuple. Also, if it is possible to be clear, Tuple of what?

def get_vendor(self, vendor_name=None) -> tuple:
if vendor_name is not None:
return backend.get_vendor_info(vendor_name)
vendor_from_evn = self._get_vendor_from_evn()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you mean environment, please use env, which is a more accepted abbreviation for it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Get it.

def get_module(vendor_name):
current_file_path = os.path.abspath(__file__)
current_dir_path = os.path.dirname(current_file_path)
sys.path.append(current_dir_path)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding this path to sys.path may increase the risk to cause name conflicts with other packages.
Is there alternative ways to do this?

class device:
@staticmethod
def get_vendor_info():
return ("nvidia", "cuda", "nvidia-smi")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this tuple mean?
If the fields have specific meaning, consider using a namedtuple or a dataclass to represent it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

annotated

@@ -0,0 +1,3 @@
from .error import * # noqa: F403

__all__ = ["*"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the module is simple, do not wrap it into a package. A simple module is enough.

from .ops import * # noqa: F403


class device:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the relationship of this device and device_ctx.

If the only purpose of this class is to provide a static method, why bother?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

device is stored in each vendor's directory and maintained by each vendor. device_ctx is used to obtain device information of each vendor

Comment on lines 47 to 51
def _get_vendor_from_lib(self):
try:
return triton.get_vendor_info()
except Exception:
return torch.get_vendor_info()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this valid code?

I checked triton and torch, neither has get_vendor_info.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reserve the associated interface for triton or torch, although they are not implemented yet.

device_from_evn = os.environ.get("GEMS_VENDOR")
return None if device_from_evn not in self.vendor_list else device_from_evn

def _get_vendor_from_sys(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function simply runs a detection to detect if device is available by try a command and expect return code to be zero. Is this a reasonable way to detect all devices?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no other way at the gems level. will push more chip vendors to provide apis at the torch or triton level in the future


global configer, device
configer = Config()
device = device_ctx()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is intended to be a singleton, make it a real singleton.

See also the device in device.py

@@ -0,0 +1,15 @@
def backend_not_support(device_name, backend_list):
raise RuntimeError(f"The {device_name} device is not support currently. ")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

supported


def device_not_found():
raise RuntimeError(
"No devices were detected on your machine ! \n "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No is followed by singular form.


def register_error(e):
raise RuntimeError(
e, "An error was encountered while registering the triton operator."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a conventional way to initialize an RuntimeError?



@dataclass
class vendor_info_base:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please follow our code style and use CamaelCase for class name.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this class designed to be a base class?
I don't see it bing subclassed?

Also, add documentation to clarify the intention of this class.


from .ops import * # noqa: F403

sys.path.append(os.path.dirname(os.path.abspath(__file__)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dont's append other path to sys.path since it may cause name conflicts with other packages.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, why do you need this?
It may be a sign that the dependencies between modules or packages are not properly organized.

@@ -0,0 +1,3 @@
from .add import add

__all__ = ["add"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of this package?

For backend-specific operators?

try:
exec(compiled_code, globals())
except Exception as e:
RuntimeError(e)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just raise e?

Is Constructing a RuntimeError the desired behavior?


class Config:
def __init__(self):
self.config = self.get_vendor_tune_config()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it is just some raw config that need to be preprocessed, do not expose it or do not use such a PLAIN name, which is confusing.

from .device import device


class Config:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see the point of this class.

  1. It is actually not a config. It is a config loader.

@@ -0,0 +1,52 @@
import torch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the meaning of _ in _nvidia dir?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvidia is an existing library

Copy link
Collaborator

@Bowen12992 Bowen12992 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some same questions with @StrongSpoon and @iclementine, so it would be better if you could answer their questions in the conversation to help me understand~

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants