-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[muti_backend] #294
base: master
Are you sure you want to change the base?
[muti_backend] #294
Conversation
|
||
vendors = vendors | ||
AUTOGRAD = AUTOGRAD | ||
Autograd = Autograd |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what do these mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for future use.
src/flag_gems/__init__.py
Outdated
|
||
__version__ = "2.1" | ||
|
||
device = runtime.device.device_instance.device_name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is device used here?
global register_instance | ||
if not register_instance: | ||
register_instance = Register(*args, **kargs) | ||
return register_instance |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is to_register used for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for global instance, will be used in next pr.
def get_vendor_unused_op(self): | ||
if self.device.vendor != backend.vendors.NVIDIA: | ||
return backend.get_curent_device_unused_op(self.device.vendor_name) | ||
return {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
user might specify unused ops when device is NVIDIA.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The unused_list consists of two parts, one specified by the vendor and the other specified by the user in enable(unused=[xxx]) or use_gems(unused=[xxx])
self.lib.impl(key, fn, device_key) | ||
|
||
def close(self): | ||
self.lib |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will remove close.
return get_vendor_module(vendor_name, query).device.get_vendor_info() | ||
global vendor_module | ||
get_vendor_module(vendor_name) | ||
return vendor_module.device.get_vendor_info() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto. I didn't get why set query as an argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
from .. import backend, error | ||
|
||
|
||
class device_ctx: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's it used for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
load the device information.
single_config["META"], | ||
num_warps=current_config["num_warps"], | ||
num_stages=current_config["num_stages"], | ||
num_ctas=current_config["num_ctas"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's not necessary to fill all fields in triton.Config. triton.Autotuner will automatically complement the default value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All are passed in because it is not certain which default values were changed in the above calculation
def to_gen_config(self, gen_config): | ||
param_config = gen_config["param_map"] | ||
meta_config = param_config["META"] | ||
iteration_keys = list(meta_config) + list(param_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there overlap between meta_config and param_config?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No.
current_config["META"][param_key] = single_value | ||
else: | ||
current_config[param_key] = single_value | ||
config_item = self._gen_impl( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest using loop instead of recursion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Get.
|
||
|
||
class device_ctx: | ||
def __init__(self, vendor_name=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add necessary documentation to help understanding the code.
self.device_info = self.get_vendor(vendor_name) | ||
self.vendor_name, self.device_name, self.cmd, self.vendor = self.device_info | ||
|
||
def get_vendor(self, vendor_name=None) -> tuple: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest use typing to add type hints and use type hints that works for older versions of python, for example typing.Tuple
. Also, if it is possible to be clear, Tuple
of what?
def get_vendor(self, vendor_name=None) -> tuple: | ||
if vendor_name is not None: | ||
return backend.get_vendor_info(vendor_name) | ||
vendor_from_evn = self._get_vendor_from_evn() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you mean environment, please use env, which is a more accepted abbreviation for it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Get it.
def get_module(vendor_name): | ||
current_file_path = os.path.abspath(__file__) | ||
current_dir_path = os.path.dirname(current_file_path) | ||
sys.path.append(current_dir_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding this path to sys.path
may increase the risk to cause name conflicts with other packages.
Is there alternative ways to do this?
class device: | ||
@staticmethod | ||
def get_vendor_info(): | ||
return ("nvidia", "cuda", "nvidia-smi") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this tuple mean?
If the fields have specific meaning, consider using a namedtuple or a dataclass to represent it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
annotated
@@ -0,0 +1,3 @@ | |||
from .error import * # noqa: F403 | |||
|
|||
__all__ = ["*"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the module is simple, do not wrap it into a package. A simple module is enough.
from .ops import * # noqa: F403 | ||
|
||
|
||
class device: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the relationship of this device
and device_ctx
.
If the only purpose of this class is to provide a static method, why bother?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
device is stored in each vendor's directory and maintained by each vendor. device_ctx is used to obtain device information of each vendor
def _get_vendor_from_lib(self): | ||
try: | ||
return triton.get_vendor_info() | ||
except Exception: | ||
return torch.get_vendor_info() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this valid code?
I checked triton and torch, neither has get_vendor_info
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reserve the associated interface for triton or torch, although they are not implemented yet.
device_from_evn = os.environ.get("GEMS_VENDOR") | ||
return None if device_from_evn not in self.vendor_list else device_from_evn | ||
|
||
def _get_vendor_from_sys(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function simply runs a detection to detect if device is available by try a command and expect return code to be zero. Is this a reasonable way to detect all devices?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no other way at the gems level. will push more chip vendors to provide apis at the torch or triton level in the future
src/flag_gems/runtime/__init__.py
Outdated
|
||
global configer, device | ||
configer = Config() | ||
device = device_ctx() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is intended to be a singleton, make it a real singleton.
See also the device in device.py
src/flag_gems/runtime/error.py
Outdated
@@ -0,0 +1,15 @@ | |||
def backend_not_support(device_name, backend_list): | |||
raise RuntimeError(f"The {device_name} device is not support currently. ") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
supported
src/flag_gems/runtime/error.py
Outdated
|
||
def device_not_found(): | ||
raise RuntimeError( | ||
"No devices were detected on your machine ! \n " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No
is followed by singular form.
|
||
def register_error(e): | ||
raise RuntimeError( | ||
e, "An error was encountered while registering the triton operator." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a conventional way to initialize an RuntimeError?
|
||
|
||
@dataclass | ||
class vendor_info_base: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please follow our code style and use CamaelCase for class name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this class designed to be a base class?
I don't see it bing subclassed?
Also, add documentation to clarify the intention of this class.
|
||
from .ops import * # noqa: F403 | ||
|
||
sys.path.append(os.path.dirname(os.path.abspath(__file__))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dont's append other path to sys.path
since it may cause name conflicts with other packages.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, why do you need this?
It may be a sign that the dependencies between modules or packages are not properly organized.
@@ -0,0 +1,3 @@ | |||
from .add import add | |||
|
|||
__all__ = ["add"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the purpose of this package?
For backend-specific operators?
try: | ||
exec(compiled_code, globals()) | ||
except Exception as e: | ||
RuntimeError(e) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just raise e?
Is Constructing a RuntimeError the desired behavior?
src/flag_gems/runtime/config.py
Outdated
|
||
class Config: | ||
def __init__(self): | ||
self.config = self.get_vendor_tune_config() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it is just some raw config that need to be preprocessed, do not expose it or do not use such a PLAIN name, which is confusing.
src/flag_gems/runtime/config.py
Outdated
from .device import device | ||
|
||
|
||
class Config: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see the point of this class.
- It is actually not a config. It is a config loader.
@@ -0,0 +1,52 @@ | |||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the meaning of _
in _nvidia
dir?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nvidia is an existing library
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have some same questions with @StrongSpoon and @iclementine, so it would be better if you could answer their questions in the conversation to help me understand~
PR Category
Other, support muti_backend
Type of Change
New Feature
Description
After this code is executed, the global device information is obtained for future use
import flag_gems
The vendor maintains its own auto_tune config and calls the get_op_tune_config interface directly when passing parameters
Use
runtime.device_guard
uniformly fordevice_guard
such astorch.cuda.device
with
runtime.device_guard(x.device)
The specified operator can be disabled at initialization
flag_gems.enable(unused=["add", "mm"])
Issue
Progress
Performance