Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add small bit of type hinting #1745

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions src/rez/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright Contributors to the Rez Project

import os
import re
import copy

from contextlib import contextmanager
from functools import lru_cache
from inspect import ismodule
import typing

from rez import __version__
from rez.utils.data_utils import AttrDictWrapper, RO_AttrDictWrapper, \
Expand All @@ -16,12 +24,6 @@
from rez.vendor import yaml
from rez.vendor.yaml.error import YAMLError
import rez.deprecations
from contextlib import contextmanager
from functools import lru_cache
from inspect import ismodule
import os
import re
import copy


class _Deprecation(object):
Expand Down Expand Up @@ -802,7 +804,7 @@ def _get_new_session_popen_args(self):

class _PluginConfigs(object):
"""Lazy config loading for plugins."""
def __init__(self, plugin_data):
def __init__(self, plugin_data: typing.Dict[str, typing.Any]):
self.__dict__['_data'] = plugin_data
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious why we do this self.__dict__['_data'] accessing here. Perhaps there is a cleaner way to do this?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's to bypass __getattr__.


def __setattr__(self, attr, value):
Expand Down
62 changes: 33 additions & 29 deletions src/rez/plugin_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,26 @@
"""
Manages loading of all types of Rez plugins.
"""
import pkgutil
import os.path
import sys
from types import ModuleType
from typing import Dict, ItemsView, List, KeysView, Optional, Type
from zipimport import zipimporter

from rez.config import config, expand_system_vars, _load_config_from_filepaths
from rez.utils.formatting import columnise
from rez.utils.schema import dict_to_schema
from rez.utils.data_utils import LazySingleton, cached_property, deep_update
from rez.utils.logging_ import print_debug, print_warning
from rez.vendor.schema.schema import Schema
from rez.exceptions import RezPluginError
from zipimport import zipimporter
import pkgutil
import os.path
import sys


# modified from pkgutil standard library:
# this function is called from the __init__.py files of each plugin type inside
# the 'rezplugins' package.
def extend_path(path, name):
def extend_path(path: List[str], name: str):
"""Extend a package's path.

Intended use is to place the following code in a package's __init__.py:
Expand Down Expand Up @@ -85,23 +89,23 @@ class RezPluginType(object):
'type_name' must correspond with one of the source directories found under
the 'plugins' directory.
"""
type_name = None
type_name: Optional[str] = None

def __init__(self):
if self.type_name is None:
raise TypeError("Subclasses of RezPluginType must provide a "
"'type_name' attribute")
self.pretty_type_name = self.type_name.replace('_', ' ')
self.plugin_classes = {}
self.failed_plugins = {}
self.plugin_modules = {}
self.config_data = {}
self.pretty_type_name: str = self.type_name.replace('_', ' ')
self.plugin_classes: Dict[str, Type[object]] = {}
self.failed_plugins: Dict[str, str] = {}
self.plugin_modules: Dict[str, ModuleType] = {}
self.config_data: Dict[str, Dict] = {}
self.load_plugins()

def __repr__(self):
return '%s(%s)' % (self.__class__.__name__, self.plugin_classes.keys())

def register_plugin(self, plugin_name, plugin_class, plugin_module):
def register_plugin(self, plugin_name: str, plugin_class: Type[object], plugin_module: ModuleType):
# TODO: check plugin_class to ensure it is a sub-class of expected base-class?
# TODO: perhaps have a Plugin base class. This introduces multiple
# inheritance in Shell class though :/
Expand Down Expand Up @@ -205,15 +209,15 @@ def load_plugins(self):
data, _ = _load_config_from_filepaths([os.path.join(path, "rezconfig")])
deep_update(self.config_data, data)

def get_plugin_class(self, plugin_name):
def get_plugin_class(self, plugin_name: str) -> Type[object]:
"""Returns the class registered under the given plugin name."""
try:
return self.plugin_classes[plugin_name]
except KeyError:
raise RezPluginError("Unrecognised %s plugin: '%s'"
% (self.pretty_type_name, plugin_name))

def get_plugin_module(self, plugin_name):
def get_plugin_module(self, plugin_name: str) -> ModuleType:
"""Returns the module containing the plugin of the given name."""
try:
return self.plugin_modules[plugin_name]
Expand All @@ -222,7 +226,7 @@ def get_plugin_module(self, plugin_name):
% (self.pretty_type_name, plugin_name))

@cached_property
def config_schema(self):
def config_schema(self) -> Schema:
"""Returns the merged configuration data schema for this plugin
type."""
from rez.config import _plugin_config_dict
Expand All @@ -235,9 +239,9 @@ def config_schema(self):
deep_update(d, d_)
return dict_to_schema(d, required=True, modifier=expand_system_vars)

def create_instance(self, plugin, **instance_kwargs):
def create_instance(self, plugin_name: str, **instance_kwargs) -> object:
BryceGattis marked this conversation as resolved.
Show resolved Hide resolved
"""Create and return an instance of the given plugin."""
return self.get_plugin_class(plugin)(**instance_kwargs)
return self.get_plugin_class(plugin_name)(**instance_kwargs)


class RezPluginManager(object):
Expand Down Expand Up @@ -294,7 +298,7 @@ def register_plugin():
'rezplugins' is always found first.
"""
def __init__(self):
self._plugin_types = {}
self._plugin_types: Dict[str, LazySingleton] = {}

@cached_property
def rezplugins_module_paths(self):
Expand Down Expand Up @@ -329,53 +333,53 @@ def rezplugins_module_paths(self):

# -- plugin types

def _get_plugin_type(self, plugin_type):
def _get_plugin_type(self, plugin_type: str) -> RezPluginType:
try:
return self._plugin_types[plugin_type]()
except KeyError:
raise RezPluginError("Unrecognised plugin type: '%s'"
% plugin_type)

def register_plugin_type(self, type_class):
def register_plugin_type(self, type_class: Type[RezPluginType]):
if not issubclass(type_class, RezPluginType):
raise TypeError("'type_class' must be a RezPluginType sub class")
if type_class.type_name is None:
raise TypeError("Subclasses of RezPluginType must provide a "
"'type_name' attribute")
self._plugin_types[type_class.type_name] = LazySingleton(type_class)

def get_plugin_types(self):
def get_plugin_types(self) -> KeysView[str]:
"""Return a list of the registered plugin types."""
return self._plugin_types.keys()

# -- plugins

def get_plugins(self, plugin_type):
def get_plugins(self, plugin_type: str) -> KeysView[str]:
"""Return a list of the registered names available for the given plugin
type."""
return self._get_plugin_type(plugin_type).plugin_classes.keys()

def get_plugin_class(self, plugin_type, plugin_name):
def get_plugin_class(self, plugin_type: str, plugin_name: str) -> Type[object]:
"""Return the class registered under the given plugin name."""
plugin = self._get_plugin_type(plugin_type)
return plugin.get_plugin_class(plugin_name)

def get_plugin_module(self, plugin_type, plugin_name):
def get_plugin_module(self, plugin_type: str, plugin_name: str) -> ModuleType:
"""Return the module defining the class registered under the given
plugin name."""
plugin = self._get_plugin_type(plugin_type)
return plugin.get_plugin_module(plugin_name)

def get_plugin_config_data(self, plugin_type):
def get_plugin_config_data(self, plugin_type: str) -> Dict[str, Dict]:
"""Return the merged configuration data for the plugin type."""
plugin = self._get_plugin_type(plugin_type)
return plugin.config_data

def get_plugin_config_schema(self, plugin_type):
def get_plugin_config_schema(self, plugin_type: str) -> Schema:
plugin = self._get_plugin_type(plugin_type)
return plugin.config_schema

def get_failed_plugins(self, plugin_type):
def get_failed_plugins(self, plugin_type: str) -> ItemsView[str, str]:
"""Return a list of plugins for the given type that failed to load.

Returns:
Expand All @@ -385,12 +389,12 @@ def get_failed_plugins(self, plugin_type):
"""
return self._get_plugin_type(plugin_type).failed_plugins.items()

def create_instance(self, plugin_type, plugin_name, **instance_kwargs):
def create_instance(self, plugin_type: str, plugin_name: str, **instance_kwargs) -> object:
BryceGattis marked this conversation as resolved.
Show resolved Hide resolved
"""Create and return an instance of the given plugin."""
plugin_type = self._get_plugin_type(plugin_type)
return plugin_type.create_instance(plugin_name, **instance_kwargs)

def get_summary_string(self):
def get_summary_string(self) -> str:
"""Get a formatted string summarising the plugins that were loaded."""
rows = [["PLUGIN TYPE", "NAME", "DESCRIPTION", "STATUS"],
["-----------", "----", "-----------", "------"]]
Expand Down
Loading