diff --git a/medcat-v2/README.md b/medcat-v2/README.md index b22972c52..aed7c75ce 100644 --- a/medcat-v2/README.md +++ b/medcat-v2/README.md @@ -90,6 +90,40 @@ pip install "medcat[deid]" # for DeID models pip install "medcat[spacy,meta-cat,deid,rel-cat,dict-ner]" # for all of the above ``` +### Installing plugins + +MedCAT v2 supports **external plugins** that can provide new components (e.g. alternative NER models, addons, tokenizers) via Python entry points. + +- **Curated plugins**: The `medcat.plugins.catalog` module ships with a curated plugin catalog that can be updated from a remote JSON file. +- **Installer**: The `medcat.plugins.installer.PluginInstallationManager` wraps a `pip`-based installer and knows how to resolve a compatible plugin version for your current MedCAT version. +- **CLI**: You can install curated plugins directly from the command line: + +```bash +python -m medcat plugins install medcat-gliner +``` + +This will: + +- look up `medcat-gliner` in the curated catalog, +- resolve a version compatible with your installed MedCAT, +- and install it using `pip`. + +You can also: + +- pass `--dry-run` to show what would be installed without making changes: + + ```bash + python -m medcat plugins install --dry-run medcat-gliner + ``` + +- override the version/ref explicitly (e.g. when testing a branch or tag): + + ```bash + python -m medcat plugins install medcat-gliner --force-version main + ``` + +If a plugin requires authentication (for example, private Git repositories), MedCAT will log a warning and the installer will surface pip’s error messages if credentials are missing or incorrect. + ### Version / update checking MedCAT now has the ability to check for newer versions of itself on PyPI (or a local mirror of it). diff --git a/medcat-v2/docs/architecture.md b/medcat-v2/docs/architecture.md index b268de2e2..f4c13ebf7 100644 --- a/medcat-v2/docs/architecture.md +++ b/medcat-v2/docs/architecture.md @@ -28,6 +28,16 @@ All components are registered in a central registry. This means you can: ### Plugins **Plugins** are external Python packages that provide new component implementations or other functionality. They integrate with MedCAT through Python entry points, allowing automatic discovery and registration without modifying MedCAT's core code. +MedCAT v2 also includes a **curated plugin catalog** and an **installer**: + +- `medcat.plugins.catalog.PluginCatalog` maintains a list of known plugins, their metadata, and MedCAT compatibility rules (e.g. “this plugin supports `>=2.5.0,<3.0.0`”). +- `medcat.plugins.installer.PluginInstallationManager` uses that catalog to select a compatible version and install it (currently via `pip`), with support for: + - PyPI packages + - Git repositories (including subdirectories such as monorepo layouts) + - Direct URLs (e.g. wheels or tarballs) + +The curated catalog can be updated from a remote JSON file, and plugins can be installed either programmatically or via the `python -m medcat plugins install ...` CLI. + --- ## Working with Core Components diff --git a/medcat-v2/medcat/__main__.py b/medcat-v2/medcat/__main__.py index c0fc985b1..c76c8e40e 100644 --- a/medcat-v2/medcat/__main__.py +++ b/medcat-v2/medcat/__main__.py @@ -1,9 +1,11 @@ import sys from medcat.utils.download_scripts import main as __download_scripts +from medcat.plugins.cli import install_plugins_command as __install_plugins _COMMANDS = { - "download-scripts": __download_scripts + "download-scripts": __download_scripts, + "install-plugins": __install_plugins, } diff --git a/medcat-v2/medcat/plugins/catalog.py b/medcat-v2/medcat/plugins/catalog.py new file mode 100644 index 000000000..9b9f61098 --- /dev/null +++ b/medcat-v2/medcat/plugins/catalog.py @@ -0,0 +1,242 @@ +"""Management of the curated plugin catalog.""" + +import json +import logging +from typing import Optional +import importlib.resources +import requests + +from packaging.specifiers import SpecifierSet +from packaging.version import Version +from pydantic import BaseModel, Field + +from .downloadable import PluginSourceSpec + +logger = logging.getLogger(__name__) + + +LOCAL_CATALOG_PATH = ( + importlib.resources.files('medcat.plugins.data') / + 'plugin_catalog.json' +) + +class PluginCompatibility(BaseModel): + medcat_version: str + plugin_version: str + + +class PluginInfo(BaseModel): + name: str + display_name: str + description: str + source_spec: PluginSourceSpec + homepage: str + compatibility: list[PluginCompatibility] + requires_auth: bool = False + + def can_merge(self, other: 'PluginInfo') -> bool: + """Checks if 2 plugin infos can be merged. + + This checks to make sure the name and the source spec is the same. + In that case the two objects likely refer to the same plugin. But + one might have updated information. + + Args: + other (PluginInfo): The other plugin info. + + Returns: + bool: Whether they can be merged. + """ + return ( + self.name == other.name and + self.source_spec == other.source_spec) + + def merge(self, other: 'PluginInfo', prefer_other: bool = True) -> None: + """Merge other plugin info into this one. + + Normally it is likely the "other" plugin info is newer so we want to + prefer its data if/when possible. + + Args: + other (PluginInfo): The other plugin info. + prefer_other (bool): Whether to prefer other. Defaults to True. + + Raises: + UnmergablePluginInfo: If the infos cannot be merged. + """ + if not self.can_merge(other): + raise UnmergablePluginInfo(self, other) + if prefer_other: + self.display_name = other.display_name + self.description = other.description + self.homepage = other.homepage + self.requires_auth = other.requires_auth + existing_plugin_versions = {cur.plugin_version for cur in self.compatibility} + for other_comp in other.compatibility: + if other_comp.plugin_version not in existing_plugin_versions: + self.compatibility.append(other_comp) + elif prefer_other: + prev_index = [idx for idx, cur in enumerate(self.compatibility) + if cur.plugin_version == other_comp.plugin_version][0] + self.compatibility[prev_index] = other_comp + + +class CatalogModel(BaseModel): + """Pydantic model for the top-level catalog JSON.""" + plugins: dict[str, PluginInfo] = Field(default_factory=dict) + version: str + last_updated: str + + def merge(self, other: 'CatalogModel', prefer_other: bool = True) -> None: + """Merge another catalog into this one. + + Args: + other (CatalogModel): The other catalog to merge. + prefer_other (bool): Whether to prefer other. Defaults to True. + """ + if prefer_other: + self.version = other.version + self.last_updated = other.last_updated + for plugin_name, info in other.plugins.items(): + if plugin_name not in self.plugins: + self.plugins[plugin_name] = info + elif prefer_other: + self.plugins[plugin_name].merge(info, prefer_other=prefer_other) + + +class PluginCatalog: + """Manages the catalog of curated plugins.""" + + REMOTE_CATALOG_URL = ( + "https://raw.githubusercontent.com/CogStack/cogstack-nlp/main/" + "medcat-v2/medcat/plugins/data/plugin_catalog.json" + ) + + def __init__(self, use_remote: bool = True): + """ + Initialize the plugin catalog. + + Args: + use_remote: Whether to attempt fetching the remote catalog + """ + self._catalog: CatalogModel = CatalogModel( + version="N/A", last_updated='N/A', plugins={}) + self._load_local_catalog() + if use_remote: + try: + self._update_from_remote() + except Exception as e: + logger.debug(f"Could not fetch remote catalog: {e}") + + def _load_local_catalog(self): + """Load the catalog from the packaged JSON file.""" + try: + catalog_data = LOCAL_CATALOG_PATH.read_text() + self._parse_catalog(json.loads(catalog_data)) + logger.debug("Loaded local plugin catalog") + except Exception as e: + logger.warning(f"Could not load local catalog: {e}") + + def _update_from_remote(self, timeout: int = 5): + """Fetch and update from the remote catalog.""" + response = requests.get(self.REMOTE_CATALOG_URL, timeout=timeout) + response.raise_for_status() + + self._parse_catalog(response.json()) + logger.info("Updated plugin catalog from remote source") + + def _parse_catalog(self, data: dict): + """Parse catalog JSON data into PluginInfo objects. + + This uses Pydantic models for schema validation and forward compatibility, + so that adding fields to the JSON does not require rewriting this method. + """ + payload = CatalogModel.model_validate(data) + self._catalog.merge(payload) + + + def get_plugin(self, name: str) -> Optional[PluginInfo]: + """Get plugin info by name.""" + plugin = self._catalog.plugins.get(name) + if plugin: + return plugin + # try lower case and with "-" instead of "_" + return self._catalog.plugins.get(name.lower().replace("_", "-")) + + + def list_plugins(self) -> list[PluginInfo]: + """List all available plugins.""" + return list(self._catalog.plugins.values()) + + def is_curated(self, name: str) -> bool: + """Check if a plugin is in the curated catalog.""" + return name in self._catalog.plugins + + def get_compatible_version( + self, + plugin_name: str, + medcat_version: str + ) -> str: + """ + Get compatible plugin version for given MedCAT version. + + Args: + plugin_name: Name of the plugin + medcat_version: MedCAT version string + + Raises: + NoSuchPluginException: If the plugin wasn't found / known. + NoCompatibleSpecException: If compatibility spec was unable to be met. + + Returns: + Compatible version specifier + """ + plugin = self.get_plugin(plugin_name) + if not plugin: + raise NoSuchPluginException(plugin_name) + + medcat_ver = Version(medcat_version) + + for compat in plugin.compatibility: + spec = SpecifierSet(compat.medcat_version) + if medcat_ver in spec: + return compat.plugin_version + + raise NoCompatibleSpecException(plugin, medcat_ver) + + +# Global catalog instance +_catalog: Optional[PluginCatalog] = None + + +def get_catalog() -> PluginCatalog: + """Get the global plugin catalog instance.""" + global _catalog + if _catalog is None: + _catalog = PluginCatalog() + return _catalog + + +class NoSuchPluginException(ValueError): + + def __init__(self, plugin_name: str) -> None: + super().__init__( + f"No plugin by the name '{plugin_name}' is known to MedCAT") + + +class NoCompatibleSpecException(ValueError): + + def __init__(self, plugin: PluginInfo, medcat_ver: Version) -> None: + super().__init__( + f"Was unable to find a version of the plugin {plugin.name} " + f"that was compatible with MedCAT version {medcat_ver}. " + f"Plugin details: {plugin}") + + +class UnmergablePluginInfo(ValueError): + + def __init__(self, info1: PluginInfo, info2: PluginInfo) -> None: + super().__init__( + "The two plugin infos cannot be merged:\n" + f"One:\n{info1}\nand two:\n{info2}" + ) diff --git a/medcat-v2/medcat/plugins/cli.py b/medcat-v2/medcat/plugins/cli.py new file mode 100644 index 000000000..c3ff82a8b --- /dev/null +++ b/medcat-v2/medcat/plugins/cli.py @@ -0,0 +1,31 @@ +"""CLI entrypoint for MedCAT commands.""" + +import sys + +from medcat.plugins.installer import PluginInstallationManager + + +# TODO: plugin listing and stuff like that + + +def install_plugins_command(*args: str): + opts = [arg for arg in args if arg.startswith("--")] + plugins = [arg for arg in args if arg not in opts] + dry_run = "--dry-run" in opts + + manager = PluginInstallationManager() + + if not plugins: + print("Error: No plugins specified", file=sys.stderr) + return 1 + + results = manager.install_multiple(plugins, dry_run=dry_run) + + failed = [name for name, success in results.items() if not success] + + if failed: + print(f"Failed to install: {', '.join(failed)}", file=sys.stderr) + return 1 + + print(f"Successfully installed: {', '.join(results.keys())}") + return 0 diff --git a/medcat-v2/medcat/plugins/data/__init__.py b/medcat-v2/medcat/plugins/data/__init__.py new file mode 100644 index 000000000..ca62f3e85 --- /dev/null +++ b/medcat-v2/medcat/plugins/data/__init__.py @@ -0,0 +1,2 @@ +"""Packaged resources for MedCAT plugins (e.g. curated plugin catalog).""" + diff --git a/medcat-v2/medcat/plugins/data/plugin_catalog.json b/medcat-v2/medcat/plugins/data/plugin_catalog.json new file mode 100644 index 000000000..031f9172c --- /dev/null +++ b/medcat-v2/medcat/plugins/data/plugin_catalog.json @@ -0,0 +1,24 @@ +{ + "version": "1.0", + "last_updated": "2026-01-28", + "plugins": { + "medcat-gliner": { + "name": "medcat-gliner", + "display_name": "MedCAT-gliner", + "description": "Gliner based NER for MedCAT", + "source_spec": { + "source": "git@github.com:CogStack/cogstack-ops.git", + "source_type": "github_ssh_subdir", + "subdirectory": "medcat-gliner" + }, + "homepage": "https://github.com/CogStack/cogstack-ops/tree/main/medcat-gliner", + "compatibility": [ + { + "medcat_version": ">=2.5.0,<3.0.0", + "plugin_version": "main" + } + ], + "requires_auth": true + } + } +} diff --git a/medcat-v2/medcat/plugins/downloadable.py b/medcat-v2/medcat/plugins/downloadable.py new file mode 100644 index 000000000..89000405d --- /dev/null +++ b/medcat-v2/medcat/plugins/downloadable.py @@ -0,0 +1,129 @@ +"""Protocol definitions for plugin installation backends.""" + +from typing import Protocol, Optional +from pydantic import BaseModel +import re + + +class PluginSourceSpec(BaseModel): + """Where and how to obtain a plugin.""" + source: str # PyPI package name, GitHub URL, SSH URL, etc. + source_type: str + # this includes: "pypi", "github", "github_subdir", + # "github_ssh", "github_ssh_subdir", "url" + subdirectory: Optional[str] = None # Path within repo, e.g., "plugins/negation" + + +class PluginInstallSpec(BaseModel): + """Specification for installing a plugin.""" + name: str + version_spec: str # e.g., ">=1.0.0,<2.0.0" or git ref like "main", "v1.2.3" + source_spec: PluginSourceSpec + + def to_pip_spec(self) -> str: + """Convert to pip-installable spec.""" + src = self.source_spec + if src.source_type == "pypi": + return f"{src.source}{self.version_spec}" + elif src.source_type == "github": + # Standard GitHub install + return f"git+{src.source}@{self.version_spec}" + elif src.source_type == "github_subdir": + # GitHub with subdirectory + # Format: git+https://github.com/user/repo.git@ref#subdirectory=path/to/plugin + base_url = src.source.rstrip('/') + if not base_url.endswith('.git'): + base_url += '.git' + + spec = f"git+{base_url}@{self.version_spec}" + if src.subdirectory: + spec += f"#subdirectory={src.subdirectory}" + return spec + elif src.source_type == "github_ssh": + # GitHub SSH install + # Format: git+ssh://git@github.com/user/repo.git@ref + ssh_url = self._normalize_ssh_url(src.source) + return f"git+{ssh_url}@{self.version_spec}" + elif src.source_type == "github_ssh_subdir": + # GitHub SSH with subdirectory + # Format: git+ssh://git@github.com/user/repo.git@ref#subdirectory=path + ssh_url = self._normalize_ssh_url(src.source) + spec = f"git+{ssh_url}@{self.version_spec}" + if src.subdirectory: + spec += f"#subdirectory={src.subdirectory}" + return spec + elif src.source_type == "url": + # Direct URL (could be a tarball, wheel, etc.) + return src.source + else: + raise ValueError(f"Unknown source_type: {src.source_type}") + + @staticmethod + def _normalize_ssh_url(url: str) -> str: + """ + Normalize SSH URL to the format pip expects. + + Handles various SSH URL formats: + - git@github.com:user/repo.git + - ssh://git@github.com/user/repo.git + - git@github.com:user/repo + + Returns: ssh://git@github.com/user/repo.git + """ + # Already in ssh:// format + if url.startswith("ssh://"): + if not url.endswith('.git'): + url += '.git' + return url + + # Convert git@github.com:user/repo.git to ssh://git@github.com/user/repo.git + if '@' in url and ':' in url: + # Pattern: git@github.com:user/repo.git + match = re.match(r'^git@([^:]+):(.+?)(?:\.git)?$', url) + if match: + host, path = match.groups() + return f"ssh://git@{host}/{path}.git" + + # If we can't parse it, return as-is and let pip handle it + return url + + +class PluginInstaller(Protocol): + """Protocol for plugin installation backends.""" + + def install(self, spec: PluginInstallSpec, dry_run: bool = False) -> bool: + """ + Install a plugin. + + Args: + spec: Plugin installation specification + dry_run: If True, only check what would be installed + + Returns: + True if successful, False otherwise + """ + pass + + def is_available(self) -> bool: + """Check if this installer is available in the environment.""" + pass + + def get_name(self) -> str: + """Get the name of this installer (e.g., 'pip', 'uv').""" + pass + + +class CredentialProvider(Protocol): + """Protocol for providing credentials for private repositories.""" + + def get_credentials(self, source: str) -> Optional[dict]: + """ + Get credentials for a given source. + + Args: + source: The source URL or identifier + + Returns: + Dictionary with credentials (e.g., {'token': '...'}) or None + """ + pass diff --git a/medcat-v2/medcat/plugins/installer.py b/medcat-v2/medcat/plugins/installer.py new file mode 100644 index 000000000..443c52a3d --- /dev/null +++ b/medcat-v2/medcat/plugins/installer.py @@ -0,0 +1,187 @@ +"""Plugin installation functionality.""" + +import sys +import subprocess +import logging +from typing import Optional + +from .downloadable import PluginInstaller, PluginInstallSpec +from .catalog import get_catalog +import medcat + +logger = logging.getLogger(__name__) + + +class PipInstaller: + """Plugin installer using pip.""" + + def install( + self, + spec: PluginInstallSpec, + dry_run: bool = False + ) -> bool: + """Install a plugin using pip.""" + cmd = [ + sys.executable, "-m", "pip", "install", + spec.to_pip_spec() + ] + + if dry_run: + cmd.insert(3, "--dry-run") + + logger.info(f"Installing {spec.name}: {' '.join(cmd)}") + + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=True + ) + logger.debug(f"Install output: {result.stdout}") + return True + + except subprocess.CalledProcessError as e: + logger.error(f"Installation failed: {e.stderr}") + return False + + def is_available(self) -> bool: + """Check if pip is available.""" + try: + subprocess.run( + [sys.executable, "-m", "pip", "--version"], + capture_output=True, + check=True + ) + return True + except (subprocess.CalledProcessError, FileNotFoundError): + return False + + def get_name(self) -> str: + return "pip" + + +class PluginInstallationManager: + """Manages plugin installation.""" + + def __init__(self, installer: Optional[PluginInstaller] = None): + """ + Initialize the installation manager. + + Args: + installer: Plugin installer to use (defaults to PipInstaller) + """ + self.installer = installer or PipInstaller() + self.catalog = get_catalog() + + def install_plugin( + self, + plugin_name: str, + dry_run: bool = False, + force_version: Optional[str] = None + ) -> bool: + """ + Install a curated plugin. + + Args: + plugin_name: Name of the plugin to install + dry_run: If True, only check what would be installed + force_version: Specific version/ref to install (overrides compatibility) + + Returns: + True if installation succeeded + + Raises: + ValueError: If plugin is not in curated catalog + RuntimeError: If no compatible version found + """ + plugin_info = self.catalog.get_plugin(plugin_name) + + if not plugin_info: + plugins = ', '.join(p.name for p in self.catalog.list_plugins()) + raise ValueError( + f"Plugin '{plugin_name}' is not in the curated catalog.\n" + f"Available plugins: {plugins}" + ) + + # Warn about authentication if needed + if plugin_info.requires_auth: + logger.warning( + f"Plugin '{plugin_name}' requires authentication.\n" + "Ensure you have configured Git credentials for " + f"{plugin_info.source_spec.source}" + ) + + # Determine version/ref to install + if force_version: + version_spec = force_version + else: + version_spec = self.catalog.get_compatible_version( + plugin_name, + medcat.__version__ + ) + + if not version_spec: + raise RuntimeError( + f"No compatible version of '{plugin_name}' found for " + f"MedCAT {medcat.__version__}.\n" + f"Visit {plugin_info.homepage} for more information." + ) + + spec = PluginInstallSpec( + name=plugin_name, + version_spec=version_spec, + source_spec=plugin_info.source_spec, + ) + + logger.info( + f"Installing {plugin_info.display_name} " + f"({plugin_name}{version_spec})" + ) + + if plugin_info.source_spec.subdirectory: + logger.info(f" From subdirectory: {plugin_info.source_spec.subdirectory}") + + try: + return self.installer.install(spec, dry_run=dry_run) + except subprocess.CalledProcessError as e: + # Provide helpful error messages + if "subdirectory" in spec.to_pip_spec(): + logger.error( + "Installation failed. This plugin is in a subdirectory.\n" + "Common issues:\n" + " - The subdirectory path might be incorrect\n" + f" - The git ref '{version_spec}' might not exist\n" + " - setup.py/pyproject.toml might be missing in the subdirectory" + ) + + if plugin_info.requires_auth: + logger.error( + "Authentication might be required.\n" + "Configure git credentials with:\n" + " git config --global credential.helper store" + ) + + raise + + + def install_multiple( + self, + plugin_names: list[str], + dry_run: bool = False + ) -> dict: + """ + Install multiple plugins. + + Returns: + Dictionary mapping plugin names to success status + """ + results = {} + for name in plugin_names: + try: + results[name] = self.install_plugin(name, dry_run=dry_run) + except Exception as e: + logger.error(f"Failed to install {name}: {e}") + results[name] = False + + return results diff --git a/medcat-v2/medcat/utils/exceptions.py b/medcat-v2/medcat/utils/exceptions.py index 655503ec1..c55694a2f 100644 --- a/medcat-v2/medcat/utils/exceptions.py +++ b/medcat-v2/medcat/utils/exceptions.py @@ -1,6 +1,9 @@ from typing import TypedDict +from medcat.plugins.catalog import get_catalog + + class MissingPluginInfo(TypedDict): name: str provides: list[tuple[str, str]] @@ -19,6 +22,7 @@ def __init__(self, missing_plugins: list[MissingPluginInfo], super().__init__(message) def _generate_message(self) -> str: + catalog = get_catalog() msg = "The following required plugins are missing:\n" for plugin in self.missing_plugins: msg += f" - Plugin: {plugin['name']}\n" @@ -29,6 +33,9 @@ def _generate_message(self) -> str: msg += f" Author: {plugin['author']}\n" if plugin['url']: msg += f" URL: {plugin['url']}\n" + if catalog.get_plugin(plugin['name']) is not None: + msg += "\n NB: You should be able to install this plugin using:\n" + msg += f" python -m medcat install-plugins {plugin['name']}\n" msg += "\n" msg += "Please install the missing plugins to load this model pack." return msg diff --git a/medcat-v2/pyproject.toml b/medcat-v2/pyproject.toml index bd18db56b..a46e0076c 100644 --- a/medcat-v2/pyproject.toml +++ b/medcat-v2/pyproject.toml @@ -140,6 +140,9 @@ package-dir = { "medcat" = "medcat" } [tool.setuptools.packages.find] include = ["medcat*"] +[tool.setuptools.package-data] +"medcat.plugins.data" = ["plugin_catalog.json"] + [tool.setuptools_scm] # look for .git folder in root of repo root = ".." diff --git a/medcat-v2/tests/plugins/test_catalog.py b/medcat-v2/tests/plugins/test_catalog.py new file mode 100644 index 000000000..1546483d5 --- /dev/null +++ b/medcat-v2/tests/plugins/test_catalog.py @@ -0,0 +1,237 @@ +import unittest +from types import SimpleNamespace +from unittest.mock import patch +from copy import deepcopy + +import medcat.plugins.catalog as catalog_module +from medcat.plugins.catalog import ( + NoCompatibleSpecException, + NoSuchPluginException, + PluginCatalog, + PluginCompatibility, +) + + +class IncludedCatalogSchemaTests(unittest.TestCase): + + def test_is_correct_format(self): + with open(catalog_module.LOCAL_CATALOG_PATH) as f: + text = f.read() + catalog_module.CatalogModel.model_validate_json(text) + + +class CatalogMergeTests(unittest.TestCase): + v1 = "v1" + update1 = "at noon" + v2 = "v2" + update2 = "at night" + v3 = "v3" + update3 = "next morning" + pl1_comp1 = catalog_module.PluginCompatibility( + plugin_version="v0.1.0", medcat_version=">=2.5") + pl1_comp2 = catalog_module.PluginCompatibility( + plugin_version="v0.1.1", medcat_version=">=2.6") + pl1 = catalog_module.PluginInfo( + name='pl1', display_name='Plugin 1', + description="Plugin 1 ...", + source_spec=catalog_module.PluginSourceSpec( + source="github", source_type="github"), + homepage="www.google.com", requires_auth=False, + compatibility=[pl1_comp1] + ) + pl1_alt = catalog_module.PluginInfo( + name='pl1', display_name='Plugin 1', + description="Plugin 1 - UPDATED ...", + source_spec=catalog_module.PluginSourceSpec( + source="github", source_type="github"), + homepage="www.google.com", requires_auth=False, + compatibility=[pl1_comp1, pl1_comp2] + ) + pl2_comp1 = catalog_module.PluginCompatibility( + plugin_version="v0.2.0", medcat_version=">=2.3") + pl2 = catalog_module.PluginInfo( + name='pl2', display_name='Plugin 2', + description="Plugin two ...", + source_spec=catalog_module.PluginSourceSpec( + source="github", source_type="github"), + homepage="www.amazon.com", requires_auth=False, + compatibility=[pl2_comp1] + ) + plugins1 = { + "pl1": pl1, + } + plugins2 = { + "pl2": pl2, + } + plugins3 = { + "pl1": pl1_alt, + "pl2": pl2, + } + + def _copy_catalog(self, catalog: catalog_module.CatalogModel) -> catalog_module.CatalogModel: + model_dump = deepcopy(catalog.model_dump()) + return catalog_module.CatalogModel.model_validate(model_dump) + + def setUp(self) -> None: + self.catalog1 = self._copy_catalog(catalog_module.CatalogModel( + version=self.v1, last_updated=self.update1, plugins=self.plugins1 + )) + self.catalog2 = self._copy_catalog(catalog_module.CatalogModel( + version=self.v2, last_updated=self.update2, plugins=self.plugins2 + )) + self.catalog3_update = self._copy_catalog(catalog_module.CatalogModel( + version=self.v3, last_updated=self.update3, plugins=self.plugins3 + )) + self.merged_2_into_1_po = self._copy_catalog(self.catalog1) + self.merged_2_into_1_po.merge(self.catalog2, prefer_other=True) + self.merged_2_into_1_ps = self._copy_catalog(self.catalog1) + self.merged_2_into_1_ps.merge(self.catalog2, prefer_other=False) + self.merged_3_into_1_po = self._copy_catalog(self.catalog1) + self.merged_3_into_1_po.merge(self.catalog3_update, prefer_other=True) + self.merged_3_into_1_ps = self._copy_catalog(self.catalog1) + self.merged_3_into_1_ps.merge(self.catalog3_update, prefer_other=False) + + def test_catalog_merge_updates_version(self): + assert self.merged_2_into_1_po.version == self.v2 + + def test_catalog_merge_updates_update_date(self): + assert self.merged_2_into_1_po.last_updated == self.update2 + + def test_catalog_merge_can_leave_version(self): + assert self.merged_2_into_1_ps.version == self.v1 + + def test_catalog_merge_can_leave_date(self): + assert self.merged_2_into_1_ps.last_updated == self.update1 + + def assert_has_merged( + self, part1: catalog_module.CatalogModel, part2: catalog_module, + merged: catalog_module.CatalogModel): + assert len(merged.plugins) >= len(part1.plugins) + assert len(merged.plugins) >= len(part2.plugins) + merged_plugins = set(merged.plugins.keys()) + downstream_plugins = set(part1.plugins.keys()) | set(part2.plugins.keys()) + assert merged_plugins == downstream_plugins + + def test_merge_adds_plugins_2_to_1_other(self): + self.assert_has_merged(self.catalog1, self.catalog2, self.merged_2_into_1_po) + + def test_merge_adds_plugins_2_to_1_self(self): + self.assert_has_merged(self.catalog1, self.catalog2, self.merged_2_into_1_ps) + + def test_merge_adds_plugins_3_to_1_other(self): + self.assert_has_merged(self.catalog1, self.catalog3_update, self.merged_3_into_1_po) + + def test_merge_adds_plugins_3_to_1_self(self): + self.assert_has_merged(self.catalog1, self.catalog3_update, self.merged_3_into_1_ps) + + def test_keeps_updated_info(self): + merged_pl = self.merged_3_into_1_po.plugins[self.pl1.name] + cat1_pl = self.catalog1.plugins[self.pl1.name] + cat3_pl = self.catalog3_update.plugins[self.pl1.name] + assert merged_pl.compatibility == cat3_pl.compatibility + assert merged_pl.compatibility != cat1_pl.compatibility + assert merged_pl.description == cat3_pl.description + assert merged_pl.description != cat1_pl.description + + def test_can_keep_prior_info(self): + merged_pl = self.merged_3_into_1_ps.plugins[self.pl1.name] + cat1_pl = self.catalog1.plugins[self.pl1.name] + cat3_pl = self.catalog3_update.plugins[self.pl1.name] + assert merged_pl.compatibility != cat3_pl.compatibility + assert merged_pl.compatibility == cat1_pl.compatibility + assert merged_pl.description != cat3_pl.description + assert merged_pl.description == cat1_pl.description + + +class TestPluginCatalogParsingAndQueries(unittest.TestCase): + EXAMPLE_PLUGIN_NAME = 'example-plugin' + + def setUp(self): + # Avoid network access in tests by not touching the remote catalog. + self.catalog = PluginCatalog(use_remote=False) + # Reset any data that might have been loaded in __init__ + self.catalog._catalog.plugins.clear() + + # Populate the catalog with a simple in-memory definition + self.catalog._parse_catalog( + { + "version": "0.0test", + "last_updated": "test-time", + "plugins": { + self.EXAMPLE_PLUGIN_NAME: { + "name": "example-plugin", + "display_name": "Example Plugin", + "description": "Test plugin", + "source_spec": { + "source": self.EXAMPLE_PLUGIN_NAME, + "source_type": "pypi", + "subdirectory": "plugins/example" + }, + "homepage": "https://example.com/example-plugin", + "requires_auth": True, + "compatibility": [ + { + "medcat_version": ">=1.0.0,<2.0.0", + "plugin_version": "==1.2.3", + }, + { + "medcat_version": ">=2.0.0", + "plugin_version": "==2.0.0", + }, + ], + } + } + } + ) + + def test_get_plugin_and_is_curated(self): + plugin = self.catalog.get_plugin(self.EXAMPLE_PLUGIN_NAME) + self.assertIsNotNone(plugin) + self.assertTrue(self.catalog.is_curated(self.EXAMPLE_PLUGIN_NAME)) + self.assertEqual(plugin.display_name, "Example Plugin") + self.assertEqual(plugin.source_spec.subdirectory, "plugins/example") + self.assertTrue(plugin.requires_auth) + + def test_list_plugins_returns_all(self): + plugins = self.catalog.list_plugins() + self.assertEqual(len(plugins), 1) + self.assertEqual(plugins[0].name, self.EXAMPLE_PLUGIN_NAME) + + def test_get_compatible_version_success_first_spec(self): + version = self.catalog.get_compatible_version(self.EXAMPLE_PLUGIN_NAME, "1.5.0") + self.assertEqual(version, "==1.2.3") + + def test_get_compatible_version_success_second_spec(self): + version = self.catalog.get_compatible_version(self.EXAMPLE_PLUGIN_NAME, "2.1.0") + self.assertEqual(version, "==2.0.0") + + def test_get_compatible_version_no_such_plugin_raises(self): + with self.assertRaises(NoSuchPluginException): + self.catalog.get_compatible_version("missing-plugin", "1.0.0") + + def test_get_compatible_version_no_compatible_spec_raises(self): + with self.assertRaises(NoCompatibleSpecException): + self.catalog.get_compatible_version("example-plugin", "0.5.0") + + +class TestGetCatalogSingleton(unittest.TestCase): + + def tearDown(self): + # Reset the module-level singleton between tests + catalog_module._catalog = None + + @patch.object(catalog_module, "PluginCatalog") + def test_get_catalog_returns_singleton(self, mock_catalog_cls): + fake_instance = SimpleNamespace() + mock_catalog_cls.return_value = fake_instance + + first = catalog_module.get_catalog() + second = catalog_module.get_catalog() + + self.assertIs(first, second) + mock_catalog_cls.assert_called_once() + + +if __name__ == "__main__": + unittest.main() + diff --git a/medcat-v2/tests/plugins/test_cli.py b/medcat-v2/tests/plugins/test_cli.py new file mode 100644 index 000000000..4485027e7 --- /dev/null +++ b/medcat-v2/tests/plugins/test_cli.py @@ -0,0 +1,94 @@ +import io +import unittest +from contextlib import redirect_stderr, redirect_stdout +from unittest.mock import MagicMock, patch + +from medcat.plugins.cli import install_plugins_command + + +class TestInstallPluginsCommand(unittest.TestCase): + + @patch("medcat.plugins.cli.PluginInstallationManager") + def test_no_plugins_returns_error_and_message(self, mock_manager_cls): + # Even when no plugins are provided, the CLI constructs the manager + mock_manager_cls.return_value = MagicMock() + + out = io.StringIO() + err = io.StringIO() + with redirect_stdout(out), redirect_stderr(err): + code = install_plugins_command() + + self.assertEqual(code, 1) + self.assertIn("No plugins specified", err.getvalue()) + mock_manager_cls.assert_called_once() + + @patch("medcat.plugins.cli.PluginInstallationManager") + def test_successful_install_prints_success_and_returns_zero(self, mock_manager_cls): + manager = MagicMock() + manager.install_multiple.return_value = { + "plugin-a": True, + "plugin-b": True, + } + mock_manager_cls.return_value = manager + + out = io.StringIO() + err = io.StringIO() + with redirect_stdout(out), redirect_stderr(err): + code = install_plugins_command("plugin-a", "plugin-b") + + self.assertEqual(code, 0) + manager.install_multiple.assert_called_once_with( + ["plugin-a", "plugin-b"], dry_run=False + ) + self.assertIn( + "Successfully installed: plugin-a, plugin-b", + out.getvalue(), + ) + self.assertEqual("", err.getvalue()) + + @patch("medcat.plugins.cli.PluginInstallationManager") + def test_failed_install_prints_failures_and_returns_non_zero( + self, mock_manager_cls + ): + manager = MagicMock() + manager.install_multiple.return_value = { + "plugin-a": True, + "plugin-b": False, + } + mock_manager_cls.return_value = manager + + out = io.StringIO() + err = io.StringIO() + with redirect_stdout(out), redirect_stderr(err): + code = install_plugins_command("plugin-a", "plugin-b") + + self.assertEqual(code, 1) + manager.install_multiple.assert_called_once_with( + ["plugin-a", "plugin-b"], dry_run=False + ) + self.assertIn("Failed to install: plugin-b", err.getvalue()) + + @patch("medcat.plugins.cli.PluginInstallationManager") + def test_dry_run_flag_is_passed_to_manager(self, mock_manager_cls): + manager = MagicMock() + manager.install_multiple.return_value = { + "plugin-a": True, + } + mock_manager_cls.return_value = manager + + out = io.StringIO() + err = io.StringIO() + with redirect_stdout(out), redirect_stderr(err): + code = install_plugins_command("--dry-run", "plugin-a") + + self.assertEqual(code, 0) + manager.install_multiple.assert_called_once_with( + ["plugin-a"], dry_run=True + ) + self.assertIn("Successfully installed: plugin-a", out.getvalue()) + self.assertEqual("", err.getvalue()) + + +if __name__ == "__main__": + unittest.main() + diff --git a/medcat-v2/tests/plugins/test_downloadable.py b/medcat-v2/tests/plugins/test_downloadable.py new file mode 100644 index 000000000..cef2cc0e7 --- /dev/null +++ b/medcat-v2/tests/plugins/test_downloadable.py @@ -0,0 +1,98 @@ +import unittest + +from medcat.plugins.downloadable import PluginInstallSpec, PluginSourceSpec + + +class TestPluginInstallSpec(unittest.TestCase): + + def test_to_pip_spec_pypi(self): + spec = PluginInstallSpec( + name="example-plugin", + version_spec="==1.2.3", + source_spec=PluginSourceSpec( + source="example-plugin", + source_type="pypi", + ), + ) + + self.assertEqual(spec.to_pip_spec(), "example-plugin==1.2.3") + + def test_to_pip_spec_github(self): + spec = PluginInstallSpec( + name="example-plugin", + version_spec="v1.0.0", + source_spec=PluginSourceSpec( + source="https://github.com/example/example-plugin", + source_type="github", + ), + ) + + self.assertEqual( + spec.to_pip_spec(), + "git+https://github.com/example/example-plugin@v1.0.0", + ) + + def test_to_pip_spec_github_subdir_adds_git_and_subdir(self): + spec = PluginInstallSpec( + name="example-plugin", + version_spec="main", + source_spec=PluginSourceSpec( + source="https://github.com/example/example-plugin", + source_type="github_subdir", + subdirectory="plugins/example", + ), + ) + + self.assertEqual( + spec.to_pip_spec(), + "git+https://github.com/example/example-plugin.git" + "@main#subdirectory=plugins/example", + ) + + def test_to_pip_spec_github_subdir_without_subdir(self): + spec = PluginInstallSpec( + name="example-plugin", + version_spec="main", + source_spec=PluginSourceSpec( + source="https://github.com/example/example-plugin", + source_type="github_subdir", + ), + ) + + self.assertEqual( + spec.to_pip_spec(), + "git+https://github.com/example/example-plugin.git@main", + ) + + def test_to_pip_spec_url_source_type(self): + spec = PluginInstallSpec( + name="example-plugin", + version_spec="", # ignored for url + source_spec=PluginSourceSpec( + source="https://example.com/example-plugin-1.0.0.whl", + source_type="url", + ), + ) + + self.assertEqual( + spec.to_pip_spec(), + "https://example.com/example-plugin-1.0.0.whl", + ) + + def test_to_pip_spec_unknown_source_type_raises(self): + spec = PluginInstallSpec( + name="example-plugin", + version_spec="==1.0.0", + source_spec=PluginSourceSpec( + source="example-plugin", + source_type="unknown-source", + ), + ) + + with self.assertRaises(ValueError): + spec.to_pip_spec() + + +if __name__ == "__main__": + unittest.main() + diff --git a/medcat-v2/tests/plugins/test_installer.py b/medcat-v2/tests/plugins/test_installer.py new file mode 100644 index 000000000..2ec838883 --- /dev/null +++ b/medcat-v2/tests/plugins/test_installer.py @@ -0,0 +1,278 @@ +import subprocess +import unittest +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +from medcat.plugins.downloadable import PluginInstallSpec, PluginSourceSpec +from medcat.plugins.installer import PipInstaller, PluginInstallationManager + + +class TestPipInstaller(unittest.TestCase): + + @patch("medcat.plugins.installer.subprocess.run") + def test_install_success(self, mock_run): + mock_run.return_value = SimpleNamespace(stdout="ok") + installer = PipInstaller() + spec = PluginInstallSpec( + name="example-plugin", + version_spec="==1.0.0", + source_spec=PluginSourceSpec( + source="example-plugin", + source_type="pypi", + ), + ) + + result = installer.install(spec) + + self.assertTrue(result) + mock_run.assert_called_once() + cmd = mock_run.call_args[0][0] + # basic sanity checks on the constructed pip command + self.assertIn("pip", cmd) + self.assertIn(spec.to_pip_spec(), cmd) + + @patch("medcat.plugins.installer.subprocess.run") + def test_install_dry_run_includes_flag(self, mock_run): + mock_run.return_value = SimpleNamespace(stdout="ok") + installer = PipInstaller() + spec = PluginInstallSpec( + name="example-plugin", + version_spec="==1.0.0", + source_spec=PluginSourceSpec( + source="example-plugin", + source_type="pypi", + ), + ) + + result = installer.install(spec, dry_run=True) + + self.assertTrue(result) + mock_run.assert_called_once() + cmd = mock_run.call_args[0][0] + self.assertIn("--dry-run", cmd) + + @patch( + "medcat.plugins.installer.subprocess.run", + side_effect=subprocess.CalledProcessError( + returncode=1, cmd=["pip"], stderr="boom" + ), + ) + def test_install_failure_returns_false(self, mock_run): + installer = PipInstaller() + spec = PluginInstallSpec( + name="example-plugin", + version_spec="==1.0.0", + source_spec=PluginSourceSpec( + source="example-plugin", + source_type="pypi", + ), + ) + + result = installer.install(spec) + + self.assertFalse(result) + mock_run.assert_called_once() + + @patch("medcat.plugins.installer.subprocess.run") + def test_is_available_true(self, mock_run): + mock_run.return_value = SimpleNamespace(stdout="pip 23.0") + installer = PipInstaller() + + self.assertTrue(installer.is_available()) + mock_run.assert_called_once() + + @patch( + "medcat.plugins.installer.subprocess.run", + side_effect=subprocess.CalledProcessError( + returncode=1, cmd=["pip"], stderr="boom" + ), + ) + def test_is_available_false_on_error(self, mock_run): + installer = PipInstaller() + + self.assertFalse(installer.is_available()) + mock_run.assert_called_once() + + @patch( + "medcat.plugins.installer.subprocess.run", + side_effect=FileNotFoundError("python not found"), + ) + def test_is_available_false_on_missing_python(self, mock_run): + installer = PipInstaller() + + self.assertFalse(installer.is_available()) + mock_run.assert_called_once() + + def test_get_name(self): + installer = PipInstaller() + self.assertEqual(installer.get_name(), "pip") + + +class TestPluginInstallationManager(unittest.TestCase): + + @patch("medcat.plugins.installer.get_catalog") + def test_install_plugin_unknown_raises_value_error(self, mock_get_catalog): + fake_catalog = MagicMock() + fake_catalog.get_plugin.return_value = None + fake_catalog.list_plugins.return_value = [ + SimpleNamespace(name="other-plugin", display_name="Other Plugin") + ] + mock_get_catalog.return_value = fake_catalog + + manager = PluginInstallationManager(installer=MagicMock()) + + with self.assertRaises(ValueError) as ctx: + manager.install_plugin("missing-plugin") + + msg = str(ctx.exception) + self.assertIn("missing-plugin", msg) + self.assertIn("Available plugins", msg) + self.assertIn("other-plugin", msg) + + @patch("medcat.plugins.installer.get_catalog") + @patch("medcat.plugins.installer.medcat") + def test_install_plugin_no_compatible_version_raises_runtime_error( + self, + mock_medcat, + mock_get_catalog, + ): + mock_medcat.__version__ = "2.5.0" + + plugin_info = SimpleNamespace( + name="example-plugin", + display_name="Example Plugin", + description="", + source_spec=PluginSourceSpec( + source="https://github.com/example/example-plugin", + source_type="github_subdir", + subdirectory="plugins/example", + ), + homepage="https://example.com/example-plugin", + compatibility=[], + requires_auth=False, + ) + fake_catalog = MagicMock() + fake_catalog.get_plugin.return_value = plugin_info + fake_catalog.get_compatible_version.return_value = None + mock_get_catalog.return_value = fake_catalog + + manager = PluginInstallationManager(installer=MagicMock()) + + with self.assertRaises(RuntimeError) as ctx: + manager.install_plugin("example-plugin") + + msg = str(ctx.exception) + self.assertIn("No compatible version of 'example-plugin' found", msg) + self.assertIn("2.5.0", msg) + + @patch("medcat.plugins.installer.get_catalog") + def test_install_plugin_uses_force_version_over_compat(self, mock_get_catalog): + plugin_info = SimpleNamespace( + name="example-plugin", + display_name="Example Plugin", + description="", + source_spec=PluginSourceSpec( + source="example-plugin", + source_type="pypi", + subdirectory=None, + ), + homepage="https://example.com/example-plugin", + compatibility=[], + requires_auth=False, + ) + fake_catalog = MagicMock() + fake_catalog.get_plugin.return_value = plugin_info + fake_catalog.get_compatible_version.return_value = "==0.1.0" + mock_get_catalog.return_value = fake_catalog + + fake_installer = MagicMock() + fake_installer.install.return_value = True + manager = PluginInstallationManager(installer=fake_installer) + + result = manager.install_plugin( + "example-plugin", dry_run=True, force_version="==9.9.9" + ) + + self.assertTrue(result) + fake_catalog.get_compatible_version.assert_not_called() + fake_installer.install.assert_called_once() + spec = fake_installer.install.call_args[0][0] + self.assertIsInstance(spec, PluginInstallSpec) + self.assertEqual(spec.version_spec, "==9.9.9") + + @patch("medcat.plugins.installer.get_catalog") + @patch("medcat.plugins.installer.medcat") + def test_install_plugin_uses_compatible_version_when_no_force( + self, + mock_medcat, + mock_get_catalog, + ): + mock_medcat.__version__ = "2.5.0" + + plugin_info = SimpleNamespace( + name="example-plugin", + display_name="Example Plugin", + description="", + source_spec=PluginSourceSpec( + source="example-plugin", + source_type="pypi", + subdirectory=None, + ), + homepage="https://example.com/example-plugin", + compatibility=[], + requires_auth=False, + ) + fake_catalog = MagicMock() + fake_catalog.get_plugin.return_value = plugin_info + fake_catalog.get_compatible_version.return_value = "==1.2.3" + mock_get_catalog.return_value = fake_catalog + + fake_installer = MagicMock() + fake_installer.install.return_value = True + manager = PluginInstallationManager(installer=fake_installer) + + result = manager.install_plugin("example-plugin", dry_run=False) + + self.assertTrue(result) + fake_catalog.get_compatible_version.assert_called_once_with( + "example-plugin", "2.5.0" + ) + fake_installer.install.assert_called_once() + spec = fake_installer.install.call_args[0][0] + self.assertEqual(spec.version_spec, "==1.2.3") + self.assertEqual(spec.source_spec.source, "example-plugin") + self.assertEqual(spec.source_spec.source_type, "pypi") + + @patch("medcat.plugins.installer.get_catalog") + def test_install_multiple_collects_results_and_handles_exceptions( + self, + mock_get_catalog, + ): + # We don't care about catalog behaviour here, only that + # exceptions from install_plugin are converted into False. + mock_get_catalog.return_value = MagicMock() + + manager = PluginInstallationManager(installer=MagicMock()) + + with patch.object( + manager, + "install_plugin", + side_effect=[True, Exception("boom"), False], + ): + results = manager.install_multiple( + ["plugin-a", "plugin-b", "plugin-c"], dry_run=False + ) + + self.assertEqual( + results, + { + "plugin-a": True, + "plugin-b": False, + "plugin-c": False, + }, + ) + + +if __name__ == "__main__": + unittest.main() +