diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ec278f50..7b58e938 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -51,7 +51,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install ".[dev]" + python -m pip install ".[dev,server]" - name: Run tests if: ${{ matrix.coverage != true }} diff --git a/.github/workflows/publish_to_pypi.yml b/.github/workflows/publish_to_pypi.yml index 6db9140e..8fbd5b66 100644 --- a/.github/workflows/publish_to_pypi.yml +++ b/.github/workflows/publish_to_pypi.yml @@ -2,7 +2,7 @@ name: Publish to PyPI on: release: - types: [created] # Run when you click “Publish release” + types: [created] # Run when you click "Publish release" workflow_dispatch: # ... or run it manually from the Actions tab permissions: @@ -38,7 +38,7 @@ jobs: name: dist path: dist/ -# Publish to PyPI (only if “dist/” succeeded) +# Publish to PyPI (only if "dist/" succeeded) pypi-publish: needs: release-build runs-on: ubuntu-latest diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 949bc013..4aa5f0e1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -122,9 +122,12 @@ repos: pytest-asyncio, pytest-mock, python-dotenv, + 'sentry-sdk[fastapi]', slowapi, starlette>=0.40.0, + strenum; python_version < '3.11', tiktoken>=0.7.0, + typing_extensions>= 4.0.0; python_version < '3.10', uvicorn>=0.11.7, ] @@ -144,9 +147,12 @@ repos: pytest-asyncio, pytest-mock, python-dotenv, + 'sentry-sdk[fastapi]', slowapi, starlette>=0.40.0, + strenum; python_version < '3.11', tiktoken>=0.7.0, + typing_extensions>= 4.0.0; python_version < '3.10', uvicorn>=0.11.7, ] diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 0cc77b6b..4ea7f24a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -32,7 +32,7 @@ If you ever get stuck, reach out on [Discord](https://discord.com/invite/zerRaGK ```bash python -m venv .venv source .venv/bin/activate - pip install -e ".[dev]" + pip install -e ".[dev,server]" pre-commit install ``` diff --git a/README.md b/README.md index dc89ca22..501753e2 100644 --- a/README.md +++ b/README.md @@ -66,6 +66,14 @@ You can install it using `pip`: pip install gitingest ``` +or + +```bash +pip install gitingest[server] +``` + +to include server dependencies for self-hosting. + However, it might be a good idea to use `pipx` to install it. You can install `pipx` using your preferred package manager. diff --git a/pyproject.toml b/pyproject.toml index 3095339b..334140dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,17 +6,14 @@ readme = {file = "README.md", content-type = "text/markdown" } requires-python = ">= 3.8" dependencies = [ "click>=8.0.0", - "fastapi[standard]>=0.109.1", # Minimum safe release (https://osv.dev/vulnerability/PYSEC-2024-38) "httpx", "pathspec>=0.12.1", "pydantic", "python-dotenv", - "slowapi", "starlette>=0.40.0", # Minimum safe release (https://osv.dev/vulnerability/GHSA-f96h-pmfr-66vw) + "strenum; python_version < '3.11'", "tiktoken>=0.7.0", # Support for o200k_base encoding "typing_extensions>= 4.0.0; python_version < '3.10'", - "uvicorn>=0.11.7", # Minimum safe release (https://osv.dev/vulnerability/PYSEC-2020-150) - "prometheus-client", ] license = {file = "LICENSE"} @@ -46,6 +43,14 @@ dev = [ "pytest-mock", ] +server = [ + "fastapi[standard]>=0.109.1", # Minimum safe release (https://osv.dev/vulnerability/PYSEC-2024-38) + "prometheus-client", + "sentry-sdk[fastapi]", + "slowapi", + "uvicorn>=0.11.7", # Minimum safe release (https://osv.dev/vulnerability/PYSEC-2020-150) +] + [project.scripts] gitingest = "gitingest.__main__:main" diff --git a/src/gitingest/__init__.py b/src/gitingest/__init__.py index 0248ad0e..75f3ea41 100644 --- a/src/gitingest/__init__.py +++ b/src/gitingest/__init__.py @@ -1,8 +1,5 @@ """Gitingest: A package for ingesting data from Git repositories.""" -from gitingest.clone import clone_repo from gitingest.entrypoint import ingest, ingest_async -from gitingest.ingestion import ingest_query -from gitingest.query_parser import parse_query -__all__ = ["clone_repo", "ingest", "ingest_async", "ingest_query", "parse_query"] +__all__ = ["ingest", "ingest_async"] diff --git a/src/gitingest/clone.py b/src/gitingest/clone.py index 1f091486..6ccf599b 100644 --- a/src/gitingest/clone.py +++ b/src/gitingest/clone.py @@ -8,13 +8,15 @@ from gitingest.config import DEFAULT_TIMEOUT from gitingest.utils.git_utils import ( check_repo_exists, + checkout_partial_clone, create_git_auth_header, create_git_command, ensure_git_installed, is_github_host, + resolve_commit, run_command, ) -from gitingest.utils.os_utils import ensure_directory +from gitingest.utils.os_utils import ensure_directory_exists_or_create from gitingest.utils.timeout_wrapper import async_timeout if TYPE_CHECKING: @@ -45,71 +47,42 @@ async def clone_repo(config: CloneConfig, *, token: str | None = None) -> None: # Extract and validate query parameters url: str = config.url local_path: str = config.local_path - commit: str | None = config.commit - branch: str | None = config.branch - tag: str | None = config.tag partial_clone: bool = config.subpath != "/" - # Create parent directory if it doesn't exist - await ensure_directory(Path(local_path).parent) + await ensure_git_installed() + await ensure_directory_exists_or_create(Path(local_path).parent) - # Check if the repository exists if not await check_repo_exists(url, token=token): msg = "Repository not found. Make sure it is public or that you have provided a valid token." raise ValueError(msg) + commit = await resolve_commit(config, token=token) + clone_cmd = ["git"] if token and is_github_host(url): clone_cmd += ["-c", create_git_auth_header(token, url=url)] - clone_cmd += ["clone", "--single-branch"] - - if config.include_submodules: - clone_cmd += ["--recurse-submodules"] - + clone_cmd += ["clone", "--single-branch", "--no-checkout", "--depth=1"] if partial_clone: clone_cmd += ["--filter=blob:none", "--sparse"] - # Shallow clone unless a specific commit is requested - if not commit: - clone_cmd += ["--depth=1"] - - # Prefer tag over branch when both are provided - if tag: - clone_cmd += ["--branch", tag] - elif branch and branch.lower() not in ("main", "master"): - clone_cmd += ["--branch", branch] - clone_cmd += [url, local_path] # Clone the repository - await ensure_git_installed() await run_command(*clone_cmd) # Checkout the subpath if it is a partial clone if partial_clone: - await _checkout_partial_clone(config, token) + await checkout_partial_clone(config, token=token) - # Checkout the commit if it is provided - if commit: - checkout_cmd = create_git_command(["git"], local_path, url, token) - await run_command(*checkout_cmd, "checkout", commit) + git = create_git_command(["git"], local_path, url, token) + # Ensure the commit is locally available + await run_command(*git, "fetch", "--depth=1", "origin", commit) -async def _checkout_partial_clone(config: CloneConfig, token: str | None) -> None: - """Configure sparse-checkout for a partially cloned repository. + # Write the work-tree at that commit + await run_command(*git, "checkout", commit) - Parameters - ---------- - config : CloneConfig - The configuration for cloning the repository, including subpath and blob flag. - token : str | None - GitHub personal access token (PAT) for accessing private repositories. - - """ - subpath = config.subpath.lstrip("/") - if config.blob: - # Remove the file name from the subpath when ingesting from a file url (e.g. blob/branch/path/file.txt) - subpath = str(Path(subpath).parent.as_posix()) - checkout_cmd = create_git_command(["git"], config.local_path, config.url, token) - await run_command(*checkout_cmd, "sparse-checkout", "set", subpath) + # Update submodules + if config.include_submodules: + await run_command(*git, "submodule", "update", "--init", "--recursive", "--depth=1") diff --git a/src/gitingest/entrypoint.py b/src/gitingest/entrypoint.py index 026a27ec..321e1b3e 100644 --- a/src/gitingest/entrypoint.py +++ b/src/gitingest/entrypoint.py @@ -3,19 +3,30 @@ from __future__ import annotations import asyncio +import errno import shutil +import stat import sys import warnings from contextlib import asynccontextmanager from pathlib import Path -from typing import AsyncGenerator +from typing import TYPE_CHECKING, AsyncGenerator, Callable +from urllib.parse import urlparse from gitingest.clone import clone_repo from gitingest.config import MAX_FILE_SIZE from gitingest.ingestion import ingest_query -from gitingest.query_parser import IngestionQuery, parse_query +from gitingest.query_parser import parse_local_dir_path, parse_remote_repo from gitingest.utils.auth import resolve_token +from gitingest.utils.compat_func import removesuffix from gitingest.utils.ignore_patterns import load_ignore_patterns +from gitingest.utils.pattern_utils import process_patterns +from gitingest.utils.query_parser_utils import KNOWN_GIT_HOSTS + +if TYPE_CHECKING: + from types import TracebackType + + from gitingest.schemas import IngestionQuery async def ingest_async( @@ -74,13 +85,23 @@ async def ingest_async( """ token = resolve_token(token) - query: IngestionQuery = await parse_query( - source=source, - max_file_size=max_file_size, - from_web=False, + source = removesuffix(source.strip(), ".git") + + # Determine the parsing method based on the source type + if urlparse(source).scheme in ("https", "http") or any(h in source for h in KNOWN_GIT_HOSTS): + # We either have a full URL or a domain-less slug + query = await parse_remote_repo(source, token=token) + query.include_submodules = include_submodules + _override_branch_and_tag(query, branch=branch, tag=tag) + + else: + # Local path scenario + query = parse_local_dir_path(source) + + query.max_file_size = max_file_size + query.ignore_patterns, query.include_patterns = process_patterns( + exclude_patterns=exclude_patterns, include_patterns=include_patterns, - ignore_patterns=exclude_patterns, - token=token, ) if query.url: @@ -235,17 +256,49 @@ async def _clone_repo_if_remote(query: IngestionQuery, *, token: str | None) -> GitHub personal access token (PAT) for accessing private repositories. """ + kwargs = {} + if sys.version_info >= (3, 12): + kwargs["onexc"] = _handle_remove_readonly + else: + kwargs["onerror"] = _handle_remove_readonly + if query.url: clone_config = query.extract_clone_config() await clone_repo(clone_config, token=token) try: yield finally: - shutil.rmtree(query.local_path.parent) + shutil.rmtree(query.local_path.parent, **kwargs) else: yield +def _handle_remove_readonly( + func: Callable, + path: str, + exc_info: BaseException | tuple[type[BaseException], BaseException, TracebackType], +) -> None: + """Handle permission errors raised by ``shutil.rmtree()``. + + * Makes the target writable (removes the read-only attribute). + * Retries the original operation (``func``) once. + + """ + # 'onerror' passes a (type, value, tb) tuple; 'onexc' passes the exception + if isinstance(exc_info, tuple): # 'onerror' (Python <3.12) + exc: BaseException = exc_info[1] + else: # 'onexc' (Python 3.12+) + exc = exc_info + + # Handle only'Permission denied' and 'Operation not permitted' + if not isinstance(exc, OSError) or exc.errno not in {errno.EACCES, errno.EPERM}: + raise exc + + # Make the target writable + Path(path).chmod(stat.S_IWRITE) + func(path) + + async def _write_output(tree: str, content: str, target: str | None) -> None: """Write combined output to ``target`` (``"-"`` ⇒ stdout). diff --git a/src/gitingest/ingestion.py b/src/gitingest/ingestion.py index 2990a875..489a41a4 100644 --- a/src/gitingest/ingestion.py +++ b/src/gitingest/ingestion.py @@ -11,7 +11,7 @@ from gitingest.utils.ingestion_utils import _should_exclude, _should_include if TYPE_CHECKING: - from gitingest.query_parser import IngestionQuery + from gitingest.schemas import IngestionQuery def ingest_query(query: IngestionQuery) -> tuple[str, str, str]: diff --git a/src/gitingest/output_formatter.py b/src/gitingest/output_formatter.py index 94bbee62..8a5b4135 100644 --- a/src/gitingest/output_formatter.py +++ b/src/gitingest/output_formatter.py @@ -10,7 +10,7 @@ from gitingest.utils.compat_func import readlink if TYPE_CHECKING: - from gitingest.query_parser import IngestionQuery + from gitingest.schemas import IngestionQuery _TOKEN_THRESHOLDS: list[tuple[int, str]] = [ (1_000_000, "M"), @@ -82,11 +82,14 @@ def _create_summary_prefix(query: IngestionQuery, *, single_file: bool = False) # Local scenario parts.append(f"Directory: {query.slug}") - if query.commit: - parts.append(f"Commit: {query.commit}") + if query.tag: + parts.append(f"Tag: {query.tag}") elif query.branch and query.branch not in ("main", "master"): parts.append(f"Branch: {query.branch}") + if query.commit: + parts.append(f"Commit: {query.commit}") + if query.subpath != "/" and not single_file: parts.append(f"Subpath: {query.subpath}") diff --git a/src/gitingest/query_parser.py b/src/gitingest/query_parser.py index 5fabb226..65b3f065 100644 --- a/src/gitingest/query_parser.py +++ b/src/gitingest/query_parser.py @@ -2,99 +2,25 @@ from __future__ import annotations -import re import uuid import warnings from pathlib import Path -from urllib.parse import unquote, urlparse +from typing import Literal from gitingest.config import TMP_BASE_PATH from gitingest.schemas import IngestionQuery -from gitingest.utils.exceptions import InvalidPatternError -from gitingest.utils.git_utils import check_repo_exists, fetch_remote_branches_or_tags -from gitingest.utils.ignore_patterns import DEFAULT_IGNORE_PATTERNS +from gitingest.utils.git_utils import fetch_remote_branches_or_tags, resolve_commit from gitingest.utils.query_parser_utils import ( - KNOWN_GIT_HOSTS, + PathKind, + _fallback_to_root, _get_user_and_repo_from_path, _is_valid_git_commit_hash, - _is_valid_pattern, - _validate_host, - _validate_url_scheme, + _normalise_source, ) -async def parse_query( - source: str, - *, - max_file_size: int, - from_web: bool, - include_patterns: set[str] | str | None = None, - ignore_patterns: set[str] | str | None = None, - token: str | None = None, -) -> IngestionQuery: - """Parse the input source to extract details for the query and process the include and ignore patterns. - - Parameters - ---------- - source : str - The source URL or file path to parse. - max_file_size : int - The maximum file size in bytes to include. - from_web : bool - Flag indicating whether the source is a web URL. - include_patterns : set[str] | str | None - Patterns to include. Can be a set of strings or a single string. - ignore_patterns : set[str] | str | None - Patterns to ignore. Can be a set of strings or a single string. - token : str | None - GitHub personal access token (PAT) for accessing private repositories. - - Returns - ------- - IngestionQuery - A dataclass object containing the parsed details of the repository or file path. - - """ - # Determine the parsing method based on the source type - if from_web or urlparse(source).scheme in ("https", "http") or any(h in source for h in KNOWN_GIT_HOSTS): - # We either have a full URL or a domain-less slug - query = await _parse_remote_repo(source, token=token) - else: - # Local path scenario - query = _parse_local_dir_path(source) - - # Combine default ignore patterns + custom patterns - ignore_patterns_set = DEFAULT_IGNORE_PATTERNS.copy() - if ignore_patterns: - ignore_patterns_set.update(_parse_patterns(ignore_patterns)) - - # Process include patterns and override ignore patterns accordingly - if include_patterns: - parsed_include = _parse_patterns(include_patterns) - # Override ignore patterns with include patterns - ignore_patterns_set = set(ignore_patterns_set) - set(parsed_include) - else: - parsed_include = None - - return IngestionQuery( - user_name=query.user_name, - repo_name=query.repo_name, - url=query.url, - subpath=query.subpath, - local_path=query.local_path, - slug=query.slug, - id=query.id, - type=query.type, - branch=query.branch, - commit=query.commit, - max_file_size=max_file_size, - ignore_patterns=ignore_patterns_set, - include_patterns=parsed_include, - ) - - -async def _parse_remote_repo(source: str, token: str | None = None) -> IngestionQuery: - """Parse a repository URL into a structured query dictionary. +async def parse_remote_repo(source: str, token: str | None = None) -> IngestionQuery: + """Parse a repository URL and return an ``IngestionQuery`` object. If source is: - A fully qualified URL ('https://gitlab.com/...'), parse & verify that domain @@ -114,116 +40,117 @@ async def _parse_remote_repo(source: str, token: str | None = None) -> Ingestion A dictionary containing the parsed details of the repository. """ - source = unquote(source) - - # Attempt to parse - parsed_url = urlparse(source) - - if parsed_url.scheme: - _validate_url_scheme(parsed_url.scheme) - _validate_host(parsed_url.netloc.lower()) - - else: # Will be of the form 'host/user/repo' or 'user/repo' - tmp_host = source.split("/")[0].lower() - if "." in tmp_host: - _validate_host(tmp_host) - else: - # No scheme, no domain => user typed "user/repo", so we'll guess the domain. - host = await try_domains_for_user_and_repo(*_get_user_and_repo_from_path(source), token=token) - source = f"{host}/{source}" - - source = "https://" + source - parsed_url = urlparse(source) - - host = parsed_url.netloc.lower() - user_name, repo_name = _get_user_and_repo_from_path(parsed_url.path) + parsed_url = await _normalise_source(source, token=token) + host = parsed_url.netloc + user, repo = _get_user_and_repo_from_path(parsed_url.path) _id = str(uuid.uuid4()) - slug = f"{user_name}-{repo_name}" + slug = f"{user}-{repo}" local_path = TMP_BASE_PATH / _id / slug - url = f"https://{host}/{user_name}/{repo_name}" + url = f"https://{host}/{user}/{repo}" - parsed = IngestionQuery( - user_name=user_name, - repo_name=repo_name, + query = IngestionQuery( + host=host, + user_name=user, + repo_name=repo, url=url, local_path=local_path, slug=slug, id=_id, ) - remaining_parts = parsed_url.path.strip("/").split("/")[2:] + path_parts = parsed_url.path.strip("/").split("/")[2:] - if not remaining_parts: - return parsed + # main branch + if not path_parts: + return await _fallback_to_root(query, token=token) - possible_type = remaining_parts.pop(0) # e.g. 'issues', 'pull', 'tree', 'blob' + kind = PathKind(path_parts.pop(0)) # may raise ValueError + query.type = kind - # If no extra path parts, just return - if not remaining_parts: - return parsed - - # If this is an issues page or pull requests, return early without processing subpath # TODO: Handle issues and pull requests - if remaining_parts and possible_type in {"issues", "pull"}: + if query.type in {PathKind.ISSUES, PathKind.PULL}: msg = f"Warning: Issues and pull requests are not yet supported: {url}. Returning repository root." - warnings.warn(msg, RuntimeWarning, stacklevel=2) - return parsed + return await _fallback_to_root(query, token=token, warn_msg=msg) - if possible_type not in {"tree", "blob"}: - # TODO: Handle other types - msg = f"Warning: Type '{possible_type}' is not yet supported: {url}. Returning repository root." - warnings.warn(msg, RuntimeWarning, stacklevel=2) - return parsed + # If no extra path parts, just return + if not path_parts: + msg = f"Warning: No extra path parts: {url}. Returning repository root." + return await _fallback_to_root(query, token=token, warn_msg=msg) - parsed.type = possible_type # 'tree' or 'blob' + if query.type not in {PathKind.TREE, PathKind.BLOB}: + # TODO: Handle other types + msg = f"Warning: Type '{query.type}' is not yet supported: {url}. Returning repository root." + return await _fallback_to_root(query, token=token, warn_msg=msg) # Commit, branch, or tag - commit_or_branch_or_tag = remaining_parts[0] - if _is_valid_git_commit_hash(commit_or_branch_or_tag): # Commit - parsed.commit = commit_or_branch_or_tag - remaining_parts.pop(0) # Consume the commit hash + ref = path_parts[0] + + if _is_valid_git_commit_hash(ref): # Commit + query.commit = ref + path_parts.pop(0) # Consume the commit hash else: # Branch or tag # Try to resolve a tag - parsed.tag = await _configure_branch_or_tag( - remaining_parts, + query.tag = await _configure_branch_or_tag( + path_parts, url=url, ref_type="tags", token=token, ) # If no tag found, try to resolve a branch - if not parsed.tag: - parsed.branch = await _configure_branch_or_tag( - remaining_parts, + if not query.tag: + query.branch = await _configure_branch_or_tag( + path_parts, url=url, ref_type="branches", token=token, ) # Only configure subpath if we have identified a commit, branch, or tag. - if remaining_parts and (parsed.commit or parsed.branch or parsed.tag): - parsed.subpath += "/".join(remaining_parts) + if path_parts and (query.commit or query.branch or query.tag): + query.subpath += "/".join(path_parts) + + query.commit = await resolve_commit(query.extract_clone_config(), token=token) + + return query - return parsed + +def parse_local_dir_path(path_str: str) -> IngestionQuery: + """Parse the given file path into a structured query dictionary. + + Parameters + ---------- + path_str : str + The file path to parse. + + Returns + ------- + IngestionQuery + A dictionary containing the parsed details of the file path. + + """ + path_obj = Path(path_str).resolve() + slug = path_obj.name if path_str == "." else path_str.strip("/") + return IngestionQuery(local_path=path_obj, slug=slug, id=str(uuid.uuid4())) async def _configure_branch_or_tag( - remaining_parts: list[str], + path_parts: list[str], *, url: str, - ref_type: str, + ref_type: Literal["branches", "tags"], token: str | None = None, ) -> str | None: """Configure the branch or tag based on the remaining parts of the URL. Parameters ---------- - remaining_parts : list[str] - The remaining parts of the URL path. + path_parts : list[str] + The path parts of the URL. url : str The URL of the repository. - ref_type : str + ref_type : Literal["branches", "tags"] The type of reference to configure. Can be "branches" or "tags". token : str | None GitHub personal access token (PAT) for accessing private repositories. @@ -233,16 +160,7 @@ async def _configure_branch_or_tag( str | None The branch or tag name if found, otherwise ``None``. - Raises - ------ - ValueError - If the ``ref_type`` parameter is not "branches" or "tags". - """ - if ref_type not in ("branches", "tags"): - msg = f"Invalid reference type: {ref_type}" - raise ValueError(msg) - _ref_type = "tags" if ref_type == "tags" else "branches" try: @@ -252,113 +170,18 @@ async def _configure_branch_or_tag( # If remote discovery fails, we optimistically treat the first path segment as the branch/tag. msg = f"Warning: Failed to fetch {_ref_type}: {exc}" warnings.warn(msg, RuntimeWarning, stacklevel=2) - return remaining_parts.pop(0) if remaining_parts else None + return path_parts.pop(0) if path_parts else None # Iterate over the path components and try to find a matching branch/tag candidate_parts: list[str] = [] - for part in remaining_parts: + for part in path_parts: candidate_parts.append(part) candidate_name = "/".join(candidate_parts) if candidate_name in branches_or_tags: # We found a match — now consume exactly the parts that form the branch/tag - del remaining_parts[: len(candidate_parts)] + del path_parts[: len(candidate_parts)] return candidate_name - # No match found; leave remaining_parts intact + # No match found; leave path_parts intact return None - - -def _parse_patterns(pattern: set[str] | str) -> set[str]: - """Parse and validate file/directory patterns for inclusion or exclusion. - - Takes either a single pattern string or set of pattern strings and processes them into a normalized list. - Patterns are split on commas and spaces, validated for allowed characters, and normalized. - - Parameters - ---------- - pattern : set[str] | str - Pattern(s) to parse - either a single string or set of strings - - Returns - ------- - set[str] - A set of normalized patterns. - - Raises - ------ - InvalidPatternError - If any pattern contains invalid characters. Only alphanumeric characters, - dash (-), underscore (_), dot (.), forward slash (/), plus (+), and - asterisk (*) are allowed. - - """ - patterns = pattern if isinstance(pattern, set) else {pattern} - - parsed_patterns: set[str] = set() - for p in patterns: - parsed_patterns = parsed_patterns.union(set(re.split(",| ", p))) - - # Remove empty string if present - parsed_patterns = parsed_patterns - {""} - - # Normalize Windows paths to Unix-style paths - parsed_patterns = {p.replace("\\", "/") for p in parsed_patterns} - - # Validate and normalize each pattern - for p in parsed_patterns: - if not _is_valid_pattern(p): - raise InvalidPatternError(p) - - return parsed_patterns - - -def _parse_local_dir_path(path_str: str) -> IngestionQuery: - """Parse the given file path into a structured query dictionary. - - Parameters - ---------- - path_str : str - The file path to parse. - - Returns - ------- - IngestionQuery - A dictionary containing the parsed details of the file path. - - """ - path_obj = Path(path_str).resolve() - slug = path_obj.name if path_str == "." else path_str.strip("/") - return IngestionQuery(local_path=path_obj, slug=slug, id=str(uuid.uuid4())) - - -async def try_domains_for_user_and_repo(user_name: str, repo_name: str, token: str | None = None) -> str: - """Attempt to find a valid repository host for the given ``user_name`` and ``repo_name``. - - Parameters - ---------- - user_name : str - The username or owner of the repository. - repo_name : str - The name of the repository. - token : str | None - GitHub personal access token (PAT) for accessing private repositories. - - Returns - ------- - str - The domain of the valid repository host. - - Raises - ------ - ValueError - If no valid repository host is found for the given ``user_name`` and ``repo_name``. - - """ - for domain in KNOWN_GIT_HOSTS: - candidate = f"https://{domain}/{user_name}/{repo_name}" - if await check_repo_exists(candidate, token=token if domain.startswith("github.") else None): - return domain - - msg = f"Could not find a valid repository host for '{user_name}/{repo_name}'." - raise ValueError(msg) diff --git a/src/gitingest/schemas/__init__.py b/src/gitingest/schemas/__init__.py index efe2dd70..db5cb12f 100644 --- a/src/gitingest/schemas/__init__.py +++ b/src/gitingest/schemas/__init__.py @@ -1,6 +1,7 @@ """Module containing the schemas for the Gitingest package.""" +from gitingest.schemas.cloning import CloneConfig from gitingest.schemas.filesystem import FileSystemNode, FileSystemNodeType, FileSystemStats -from gitingest.schemas.ingestion import CloneConfig, IngestionQuery +from gitingest.schemas.ingestion import IngestionQuery __all__ = ["CloneConfig", "FileSystemNode", "FileSystemNodeType", "FileSystemStats", "IngestionQuery"] diff --git a/src/gitingest/schemas/cloning.py b/src/gitingest/schemas/cloning.py new file mode 100644 index 00000000..085afb3f --- /dev/null +++ b/src/gitingest/schemas/cloning.py @@ -0,0 +1,42 @@ +"""Schema for the cloning process.""" + +from __future__ import annotations + +from pydantic import BaseModel, Field + + +class CloneConfig(BaseModel): # pylint: disable=too-many-instance-attributes + """Configuration for cloning a Git repository. + + This model holds the necessary parameters for cloning a repository to a local path, including + the repository's URL, the target local path, and optional parameters for a specific commit, branch, or tag. + + Attributes + ---------- + url : str + The URL of the Git repository to clone. + local_path : str + The local directory where the repository will be cloned. + commit : str | None + The specific commit hash to check out after cloning. + branch : str | None + The branch to clone. + tag : str | None + The tag to clone. + subpath : str + The subpath to clone from the repository (default: ``"/"``). + blob : bool + Whether the repository is a blob (default: ``False``). + include_submodules : bool + Whether to clone submodules (default: ``False``). + + """ + + url: str + local_path: str + commit: str | None = None + branch: str | None = None + tag: str | None = None + subpath: str = Field(default="/") + blob: bool = Field(default=False) + include_submodules: bool = Field(default=False) diff --git a/src/gitingest/schemas/filesystem.py b/src/gitingest/schemas/filesystem.py index b5669f18..cc66e7b1 100644 --- a/src/gitingest/schemas/filesystem.py +++ b/src/gitingest/schemas/filesystem.py @@ -1,4 +1,4 @@ -"""Define the schema for the filesystem representation.""" +"""Schema for the filesystem representation.""" from __future__ import annotations diff --git a/src/gitingest/schemas/ingestion.py b/src/gitingest/schemas/ingestion.py index c40e11d6..97e98804 100644 --- a/src/gitingest/schemas/ingestion.py +++ b/src/gitingest/schemas/ingestion.py @@ -2,50 +2,12 @@ from __future__ import annotations -from dataclasses import dataclass from pathlib import Path # noqa: TC003 (typing-only-standard-library-import) needed for type checking (pydantic) from pydantic import BaseModel, Field from gitingest.config import MAX_FILE_SIZE - - -@dataclass -class CloneConfig: # pylint: disable=too-many-instance-attributes - """Configuration for cloning a Git repository. - - This class holds the necessary parameters for cloning a repository to a local path, including - the repository's URL, the target local path, and optional parameters for a specific commit or branch. - - Attributes - ---------- - url : str - The URL of the Git repository to clone. - local_path : str - The local directory where the repository will be cloned. - commit : str | None - The specific commit hash to check out after cloning. - branch : str | None - The branch to clone. - tag: str | None - The tag to clone. - subpath : str - The subpath to clone from the repository (default: ``"/"``). - blob: bool - Whether the repository is a blob (default: ``False``). - include_submodules: bool - Whether to clone submodules (default: ``False``). - - """ - - url: str - local_path: str - commit: str | None = None - branch: str | None = None - tag: str | None = None - subpath: str = "/" - blob: bool = False - include_submodules: bool = False +from gitingest.schemas.cloning import CloneConfig class IngestionQuery(BaseModel): # pylint: disable=too-many-instance-attributes @@ -53,6 +15,8 @@ class IngestionQuery(BaseModel): # pylint: disable=too-many-instance-attributes Attributes ---------- + host : str | None + The host of the repository. user_name : str | None The username or owner of the repository. repo_name : str | None @@ -73,7 +37,7 @@ class IngestionQuery(BaseModel): # pylint: disable=too-many-instance-attributes The branch of the repository. commit : str | None The commit of the repository. - tag: str | None + tag : str | None The tag of the repository. max_file_size : int The maximum file size to ingest (default: 10 MB). @@ -86,21 +50,22 @@ class IngestionQuery(BaseModel): # pylint: disable=too-many-instance-attributes """ + host: str | None = None user_name: str | None = None repo_name: str | None = None local_path: Path url: str | None = None slug: str id: str - subpath: str = "/" + subpath: str = Field(default="/") type: str | None = None branch: str | None = None commit: str | None = None tag: str | None = None max_file_size: int = Field(default=MAX_FILE_SIZE) - ignore_patterns: set[str] = set() # TODO: ignore_patterns and include_patterns have the same type + ignore_patterns: set[str] = Field(default_factory=set) # TODO: ssame type for ignore_* and include_* patterns include_patterns: set[str] | None = None - include_submodules: bool = False + include_submodules: bool = Field(default=False) def extract_clone_config(self) -> CloneConfig: """Extract the relevant fields for the CloneConfig object. @@ -130,16 +95,3 @@ def extract_clone_config(self) -> CloneConfig: blob=self.type == "blob", include_submodules=self.include_submodules, ) - - def ensure_url(self) -> None: - """Raise if the parsed query has no URL (invalid user input). - - Raises - ------ - ValueError - If the parsed query has no URL (invalid user input). - - """ - if not self.url: - msg = "The 'url' parameter is required." - raise ValueError(msg) diff --git a/src/gitingest/utils/compat_typing.py b/src/gitingest/utils/compat_typing.py index a21f71ee..059db0a1 100644 --- a/src/gitingest/utils/compat_typing.py +++ b/src/gitingest/utils/compat_typing.py @@ -1,13 +1,18 @@ """Compatibility layer for typing.""" +try: + from enum import StrEnum # type: ignore[attr-defined] # Py ≥ 3.11 +except ImportError: + from strenum import StrEnum # type: ignore[import-untyped] # Py ≤ 3.10 + try: from typing import ParamSpec, TypeAlias # type: ignore[attr-defined] # Py ≥ 3.10 except ImportError: - from typing_extensions import ParamSpec, TypeAlias # type: ignore[attr-defined] # Py 3.8 / 3.9 + from typing_extensions import ParamSpec, TypeAlias # type: ignore[attr-defined] # Py ≤ 3.9 try: from typing import Annotated # type: ignore[attr-defined] # Py ≥ 3.9 except ImportError: - from typing_extensions import Annotated # type: ignore[attr-defined] # Py 3.8 + from typing_extensions import Annotated # type: ignore[attr-defined] # Py ≤ 3.8 -__all__ = ["Annotated", "ParamSpec", "TypeAlias"] +__all__ = ["Annotated", "ParamSpec", "StrEnum", "TypeAlias"] diff --git a/src/gitingest/utils/exceptions.py b/src/gitingest/utils/exceptions.py index c96cfd64..b7d23e35 100644 --- a/src/gitingest/utils/exceptions.py +++ b/src/gitingest/utils/exceptions.py @@ -1,28 +1,6 @@ """Custom exceptions for the Gitingest package.""" -class InvalidPatternError(ValueError): - """Exception raised when a pattern contains invalid characters. - - This exception is used to signal that a pattern provided for some operation - contains characters that are not allowed. The valid characters for the pattern - include alphanumeric characters, dash (-), underscore (_), dot (.), forward slash (/), - plus (+), and asterisk (*). - - Parameters - ---------- - pattern : str - The invalid pattern that caused the error. - - """ - - def __init__(self, pattern: str) -> None: - super().__init__( - f"Pattern '{pattern}' contains invalid characters. Only alphanumeric characters, dash (-), " - "underscore (_), dot (.), forward slash (/), plus (+), and asterisk (*) are allowed.", - ) - - class AsyncTimeoutError(Exception): """Exception raised when an async operation exceeds its timeout limit. diff --git a/src/gitingest/utils/git_utils.py b/src/gitingest/utils/git_utils.py index f4215ca4..a094e944 100644 --- a/src/gitingest/utils/git_utils.py +++ b/src/gitingest/utils/git_utils.py @@ -6,7 +6,8 @@ import base64 import re import sys -from typing import Final +from pathlib import Path +from typing import TYPE_CHECKING, Final, Iterable from urllib.parse import urlparse import httpx @@ -16,6 +17,9 @@ from gitingest.utils.exceptions import InvalidGitHubTokenError from server.server_utils import Colors +if TYPE_CHECKING: + from gitingest.schemas import CloneConfig + # GitHub Personal-Access tokens (classic + fine-grained). # - ghp_ / gho_ / ghu_ / ghs_ / ghr_ → 36 alphanumerics # - github_pat_ → 22 alphanumerics + "_" + 59 alphanumerics @@ -237,7 +241,6 @@ async def fetch_remote_branches_or_tags(url: str, *, ref_type: str, token: str | await ensure_git_installed() stdout, _ = await run_command(*cmd) - # For each line in the output: # - Skip empty lines and lines that don't contain "refs/{to_fetch}/" # - Extract the branch or tag name after "refs/{to_fetch}/" @@ -321,3 +324,126 @@ def validate_github_token(token: str) -> None: """ if not re.fullmatch(_GITHUB_PAT_PATTERN, token): raise InvalidGitHubTokenError + + +async def checkout_partial_clone(config: CloneConfig, token: str | None) -> None: + """Configure sparse-checkout for a partially cloned repository. + + Parameters + ---------- + config : CloneConfig + The configuration for cloning the repository, including subpath and blob flag. + token : str | None + GitHub personal access token (PAT) for accessing private repositories. + + """ + subpath = config.subpath.lstrip("/") + if config.blob: + # Remove the file name from the subpath when ingesting from a file url (e.g. blob/branch/path/file.txt) + subpath = str(Path(subpath).parent.as_posix()) + checkout_cmd = create_git_command(["git"], config.local_path, config.url, token) + await run_command(*checkout_cmd, "sparse-checkout", "set", subpath) + + +async def resolve_commit(config: CloneConfig, token: str | None) -> str: + """Resolve the commit to use for the clone. + + Parameters + ---------- + config : CloneConfig + The configuration for cloning the repository. + token : str | None + GitHub personal access token (PAT) for accessing private repositories. + + Returns + ------- + str + The commit SHA. + + """ + if config.commit: + commit = config.commit + elif config.tag: + commit = await _resolve_ref_to_sha(config.url, pattern=f"refs/tags/{config.tag}*", token=token) + elif config.branch: + commit = await _resolve_ref_to_sha(config.url, pattern=f"refs/heads/{config.branch}", token=token) + else: + commit = await _resolve_ref_to_sha(config.url, pattern="HEAD", token=token) + return commit + + +async def _resolve_ref_to_sha(url: str, pattern: str, token: str | None = None) -> str: + """Return the commit SHA that / points to in . + + * Branch → first line from ``git ls-remote``. + * Tag → if annotated, prefer the peeled ``^{}`` line (commit). + + Parameters + ---------- + url : str + The URL of the remote repository. + pattern : str + The pattern to use to resolve the commit SHA. + token : str | None + GitHub personal access token (PAT) for accessing private repositories. + + Returns + ------- + str + The commit SHA. + + Raises + ------ + ValueError + If the ref does not exist in the remote repository. + + """ + # Build: git [-c http./.extraheader=Auth...] ls-remote + cmd: list[str] = ["git"] + if token and is_github_host(url): + cmd += ["-c", create_git_auth_header(token, url=url)] + + cmd += ["ls-remote", url, pattern] + stdout, _ = await run_command(*cmd) + lines = stdout.decode().splitlines() + sha = _pick_commit_sha(lines) + if not sha: + msg = f"{pattern!r} not found in {url}" + raise ValueError(msg) + + return sha + + +def _pick_commit_sha(lines: Iterable[str]) -> str | None: + """Return a commit SHA from ``git ls-remote`` output. + + • Annotated tag → prefer the peeled line ( refs/tags/x^{}) + • Branch / lightweight tag → first non-peeled line + + + Parameters + ---------- + lines : Iterable[str] + The lines of a ``git ls-remote`` output. + + Returns + ------- + str | None + The commit SHA, or ``None`` if no commit SHA is found. + + """ + first_non_peeled: str | None = None + + for ln in lines: + if not ln.strip(): + continue + + sha, ref = ln.split(maxsplit=1) + + if ref.endswith("^{}"): # peeled commit of annotated tag + return sha # ← best match, done + + if first_non_peeled is None: # remember the first ordinary line + first_non_peeled = sha + + return first_non_peeled # branch or lightweight tag (or None) diff --git a/src/gitingest/utils/ingestion_utils.py b/src/gitingest/utils/ingestion_utils.py index 21a03f22..8795b66c 100644 --- a/src/gitingest/utils/ingestion_utils.py +++ b/src/gitingest/utils/ingestion_utils.py @@ -59,7 +59,7 @@ def _should_exclude(path: Path, base_path: Path, ignore_patterns: set[str]) -> b """ rel_path = _relative_or_none(path, base_path) - if rel_path is None: # outside repo → already “excluded” + if rel_path is None: # outside repo → already "excluded" return True spec = PathSpec.from_lines("gitwildmatch", ignore_patterns) diff --git a/src/gitingest/utils/os_utils.py b/src/gitingest/utils/os_utils.py index d90dddd2..e9c3b3e4 100644 --- a/src/gitingest/utils/os_utils.py +++ b/src/gitingest/utils/os_utils.py @@ -3,7 +3,7 @@ from pathlib import Path -async def ensure_directory(path: Path) -> None: +async def ensure_directory_exists_or_create(path: Path) -> None: """Ensure the directory exists, creating it if necessary. Parameters diff --git a/src/gitingest/utils/path_utils.py b/src/gitingest/utils/path_utils.py deleted file mode 100644 index 55ed3d48..00000000 --- a/src/gitingest/utils/path_utils.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Utility functions for working with file paths.""" - -import platform -from pathlib import Path - - -def _is_safe_symlink(symlink_path: Path, base_path: Path) -> bool: - """Return ``True`` if ``symlink_path`` resolves inside ``base_path``. - - Parameters - ---------- - symlink_path : Path - Symlink whose target should be validated. - base_path : Path - Directory that the symlink target must remain within. - - Returns - ------- - bool - Whether the symlink is “safe” (i.e., does not escape ``base_path``). - - """ - # On Windows a non-symlink is immediately unsafe - if platform.system() == "Windows" and not symlink_path.is_symlink(): - return False - - try: - target_path = symlink_path.resolve() - base_resolved = base_path.resolve() - except (OSError, ValueError): - # Any resolution error → treat as unsafe - return False - - return base_resolved in target_path.parents or target_path == base_resolved diff --git a/src/gitingest/utils/pattern_utils.py b/src/gitingest/utils/pattern_utils.py new file mode 100644 index 00000000..0fdd2679 --- /dev/null +++ b/src/gitingest/utils/pattern_utils.py @@ -0,0 +1,73 @@ +"""Pattern utilities for the Gitingest package.""" + +from __future__ import annotations + +import re +from typing import Iterable + +from gitingest.utils.ignore_patterns import DEFAULT_IGNORE_PATTERNS + +_PATTERN_SPLIT_RE = re.compile(r"[,\s]+") + + +def process_patterns( + exclude_patterns: str | set[str] | None = None, + include_patterns: str | set[str] | None = None, +) -> tuple[set[str], set[str] | None]: + """Process include and exclude patterns. + + Parameters + ---------- + exclude_patterns : str | set[str] | None + Exclude patterns to process. + include_patterns : str | set[str] | None + Include patterns to process. + + Returns + ------- + tuple[set[str], set[str] | None] + A tuple containing the processed ignore patterns and include patterns. + + """ + # Combine default ignore patterns + custom patterns + ignore_patterns_set = DEFAULT_IGNORE_PATTERNS.copy() + if exclude_patterns: + ignore_patterns_set.update(_parse_patterns(exclude_patterns)) + + # Process include patterns and override ignore patterns accordingly + if include_patterns: + parsed_include = _parse_patterns(include_patterns) + # Override ignore patterns with include patterns + ignore_patterns_set = set(ignore_patterns_set) - set(parsed_include) + else: + parsed_include = None + + return ignore_patterns_set, parsed_include + + +def _parse_patterns(patterns: str | Iterable[str]) -> set[str]: + """Normalize a collection of file or directory patterns. + + Parameters + ---------- + patterns : str | Iterable[str] + One pattern string or an iterable of pattern strings. Each pattern may contain multiple comma- or + whitespace-separated sub-patterns, e.g. "src/*, tests *.md". + + Returns + ------- + set[str] + Normalized patterns with Windows back-slashes converted to forward-slashes and duplicates removed. + + """ + # Treat a lone string as the iterable [string] + if isinstance(patterns, str): + patterns = [patterns] + + # Flatten, split on commas/whitespace, strip empties, normalise slashes + return { + part.replace("\\", "/") + for pat in patterns + for part in _PATTERN_SPLIT_RE.split(pat.strip()) + if part # discard empty tokens + } diff --git a/src/gitingest/utils/query_parser_utils.py b/src/gitingest/utils/query_parser_utils.py index 4bde02cc..41dc7ada 100644 --- a/src/gitingest/utils/query_parser_utils.py +++ b/src/gitingest/utils/query_parser_utils.py @@ -3,10 +3,19 @@ from __future__ import annotations import string +import warnings +from typing import TYPE_CHECKING, cast +from urllib.parse import ParseResult, unquote, urlparse -HEX_DIGITS: set[str] = set(string.hexdigits) +from gitingest.utils.compat_typing import StrEnum +from gitingest.utils.git_utils import _resolve_ref_to_sha, check_repo_exists + +if TYPE_CHECKING: + from gitingest.schemas import IngestionQuery +HEX_DIGITS: set[str] = set(string.hexdigits) + KNOWN_GIT_HOSTS: list[str] = [ "github.com", "gitlab.com", @@ -17,46 +26,127 @@ ] -def _is_valid_git_commit_hash(commit: str) -> bool: - """Validate if the provided string is a valid Git commit hash. +class PathKind(StrEnum): + """Path kind enum.""" + + TREE = "tree" + BLOB = "blob" + ISSUES = "issues" + PULL = "pull" - This function checks if the commit hash is a 40-character string consisting only - of hexadecimal digits, which is the standard format for Git commit hashes. + +async def _fallback_to_root(query: IngestionQuery, token: str | None, warn_msg: str | None = None) -> IngestionQuery: + """Fallback to the root of the repository if no extra path parts are provided. Parameters ---------- - commit : str - The string to validate as a Git commit hash. + query : IngestionQuery + The query to fallback to the root of the repository. + token : str | None + The token to use to access the repository. + warn_msg : str | None + The message to warn. Returns ------- - bool - ``True`` if the string is a valid 40-character Git commit hash, otherwise ``False``. + IngestionQuery + The query with the fallback to the root of the repository. """ - sha_hex_length = 40 - return len(commit) == sha_hex_length and all(c in HEX_DIGITS for c in commit) + url = cast("str", query.url) + query.commit = await _resolve_ref_to_sha(url, pattern="HEAD", token=token) + if warn_msg: + warnings.warn(warn_msg, RuntimeWarning, stacklevel=3) + return query + + +async def _normalise_source(raw: str, token: str | None) -> ParseResult: + """Return a fully-qualified ParseResult or raise. + + Parameters + ---------- + raw : str + The raw URL to parse. + token : str | None + The token to use to access the repository. + + Returns + ------- + ParseResult + The parsed URL. + + """ + raw = unquote(raw) + parsed = urlparse(raw) + + if parsed.scheme: + _validate_url_scheme(parsed.scheme) + _validate_host(parsed.netloc) + return parsed + # no scheme ('host/user/repo' or 'user/repo') + host = raw.split("/", 1)[0].lower() + if "." in host: + _validate_host(host) + return urlparse(f"https://{raw}") -def _is_valid_pattern(pattern: str) -> bool: - """Validate if the given pattern contains only valid characters. + # "user/repo" slug + host = await _try_domains_for_user_and_repo(*_get_user_and_repo_from_path(raw), token=token) - This function checks if the pattern contains only alphanumeric characters or one - of the following allowed characters: dash ('-'), underscore ('_'), dot ('.'), - forward slash ('/'), plus ('+'), asterisk ('*'), or the at sign ('@'). + return urlparse(f"https://{host}/{raw}") + + +async def _try_domains_for_user_and_repo(user_name: str, repo_name: str, token: str | None = None) -> str: + """Attempt to find a valid repository host for the given ``user_name`` and ``repo_name``. Parameters ---------- - pattern : str - The pattern to validate. + user_name : str + The username or owner of the repository. + repo_name : str + The name of the repository. + token : str | None + GitHub personal access token (PAT) for accessing private repositories. + + Returns + ------- + str + The domain of the valid repository host. + + Raises + ------ + ValueError + If no valid repository host is found for the given ``user_name`` and ``repo_name``. + + """ + for domain in KNOWN_GIT_HOSTS: + candidate = f"https://{domain}/{user_name}/{repo_name}" + if await check_repo_exists(candidate, token=token if domain.startswith("github.") else None): + return domain + + msg = f"Could not find a valid repository host for '{user_name}/{repo_name}'." + raise ValueError(msg) + + +def _is_valid_git_commit_hash(commit: str) -> bool: + """Validate if the provided string is a valid Git commit hash. + + This function checks if the commit hash is a 40-character string consisting only + of hexadecimal digits, which is the standard format for Git commit hashes. + + Parameters + ---------- + commit : str + The string to validate as a Git commit hash. Returns ------- bool - ``True`` if the pattern is valid, otherwise ``False``. + ``True`` if the string is a valid 40-character Git commit hash, otherwise ``False``. """ - return all(c.isalnum() or c in "-_./+*@" for c in pattern) + sha_hex_length = 40 + return len(commit) == sha_hex_length and all(c in HEX_DIGITS for c in commit) def _validate_host(host: str) -> None: diff --git a/src/server/main.py b/src/server/main.py index 24cc6b7e..2a07773a 100644 --- a/src/server/main.py +++ b/src/server/main.py @@ -31,7 +31,8 @@ # Configure Sentry options from environment variables traces_sample_rate = float(os.getenv("GITINGEST_SENTRY_TRACES_SAMPLE_RATE", "1.0")) profile_session_sample_rate = float(os.getenv("GITINGEST_SENTRY_PROFILE_SESSION_SAMPLE_RATE", "1.0")) - profile_lifecycle = os.getenv("GITINGEST_SENTRY_PROFILE_LIFECYCLE", "trace") + profile_lifecycle_raw = os.getenv("GITINGEST_SENTRY_PROFILE_LIFECYCLE", "trace") + profile_lifecycle = profile_lifecycle_raw if profile_lifecycle_raw in ("manual", "trace") else "trace" send_default_pii = os.getenv("GITINGEST_SENTRY_SEND_DEFAULT_PII", "true").lower() == "true" sentry_environment = os.getenv("GITINGEST_SENTRY_ENVIRONMENT", "") diff --git a/src/server/models.py b/src/server/models.py index a6e71edc..1ed95710 100644 --- a/src/server/models.py +++ b/src/server/models.py @@ -7,6 +7,8 @@ from pydantic import BaseModel, Field, field_validator +from gitingest.utils.compat_func import removesuffix + # needed for type checking (pydantic) from server.form_types import IntForm, OptStrForm, StrForm # noqa: TC001 (typing-only-first-party-import) @@ -45,16 +47,16 @@ class IngestRequest(BaseModel): @field_validator("input_text") @classmethod def validate_input_text(cls, v: str) -> str: - """Validate that input_text is not empty.""" + """Validate that ``input_text`` is not empty.""" if not v.strip(): err = "input_text cannot be empty" raise ValueError(err) - return v.strip() + return removesuffix(v.strip(), ".git") @field_validator("pattern") @classmethod def validate_pattern(cls, v: str) -> str: - """Validate pattern field.""" + """Validate ``pattern`` field.""" return v.strip() diff --git a/src/server/query_processor.py b/src/server/query_processor.py index 8513426b..a7b60f61 100644 --- a/src/server/query_processor.py +++ b/src/server/query_processor.py @@ -7,9 +7,10 @@ from gitingest.clone import clone_repo from gitingest.ingestion import ingest_query -from gitingest.query_parser import IngestionQuery, parse_query +from gitingest.query_parser import parse_remote_repo from gitingest.utils.git_utils import validate_github_token -from server.models import IngestErrorResponse, IngestResponse, IngestSuccessResponse +from gitingest.utils.pattern_utils import process_patterns +from server.models import IngestErrorResponse, IngestResponse, IngestSuccessResponse, PatternType from server.server_config import MAX_DISPLAY_SIZE from server.server_utils import Colors, log_slider_to_size @@ -17,8 +18,8 @@ async def process_query( input_text: str, slider_position: int, - pattern_type: str = "exclude", - pattern: str = "", + pattern_type: PatternType, + pattern: str, token: str | None = None, ) -> IngestResponse: """Process a query by parsing input, cloning a repository, and generating a summary. @@ -32,8 +33,8 @@ async def process_query( Input text provided by the user, typically a Git repository URL or slug. slider_position : int Position of the slider, representing the maximum file size in the query. - pattern_type : str - Type of pattern to use (either "include" or "exclude") (default: ``"exclude"``). + pattern_type : PatternType + Type of pattern to use (either "include" or "exclude") pattern : str Pattern to include or exclude in the query, depending on the pattern type. token : str | None @@ -44,61 +45,42 @@ async def process_query( IngestResponse A union type, corresponding to IngestErrorResponse or IngestSuccessResponse - Raises - ------ - ValueError - If an invalid pattern type is provided. - """ - if pattern_type == "include": - include_patterns = pattern - exclude_patterns = None - elif pattern_type == "exclude": - exclude_patterns = pattern - include_patterns = None - else: - msg = f"Invalid pattern type: {pattern_type}" - raise ValueError(msg) - if token: validate_github_token(token) max_file_size = log_slider_to_size(slider_position) - query: IngestionQuery | None = None - short_repo_url = "" - try: - query = await parse_query( - source=input_text, - max_file_size=max_file_size, - from_web=True, - include_patterns=include_patterns, - ignore_patterns=exclude_patterns, - token=token, - ) - query.ensure_url() + query = await parse_remote_repo(input_text, token=token) + except Exception as exc: + print(f"{Colors.BROWN}WARN{Colors.END}: {Colors.RED}<- {Colors.END}", end="") + print(f"{Colors.RED}{exc}{Colors.END}") + return IngestErrorResponse(error=str(exc)) - # Sets the "/" for the page title - short_repo_url = f"{query.user_name}/{query.repo_name}" + query.url = cast("str", query.url) + query.host = cast("str", query.host) + query.max_file_size = max_file_size + query.ignore_patterns, query.include_patterns = process_patterns( + exclude_patterns=pattern if pattern_type == PatternType.EXCLUDE else None, + include_patterns=pattern if pattern_type == PatternType.INCLUDE else None, + ) + + clone_config = query.extract_clone_config() + await clone_repo(clone_config, token=token) - clone_config = query.extract_clone_config() - await clone_repo(clone_config, token=token) + short_repo_url = f"{query.user_name}/{query.repo_name}" # Sets the "/" for the page title + try: summary, tree, content = ingest_query(query) + # TODO: why are we writing the tree and content to a file here? local_txt_file = Path(clone_config.local_path).with_suffix(".txt") - with local_txt_file.open("w", encoding="utf-8") as f: f.write(tree + "\n" + content) except Exception as exc: - if query and query.url: - _print_error(query.url, exc, max_file_size, pattern_type, pattern) - else: - print(f"{Colors.BROWN}WARN{Colors.END}: {Colors.RED}<- {Colors.END}", end="") - print(f"{Colors.RED}{exc}{Colors.END}") - + _print_error(query.url, exc, max_file_size, pattern_type, pattern) return IngestErrorResponse(error=str(exc)) if len(content) > MAX_DISPLAY_SIZE: @@ -107,9 +89,6 @@ async def process_query( "download full ingest to see more)\n" + content[:MAX_DISPLAY_SIZE] ) - query.ensure_url() - query.url = cast("str", query.url) - _print_success( url=query.url, max_file_size=max_file_size, diff --git a/src/server/routers_utils.py b/src/server/routers_utils.py index 358596fb..83242e26 100644 --- a/src/server/routers_utils.py +++ b/src/server/routers_utils.py @@ -7,7 +7,7 @@ from fastapi import status from fastapi.responses import JSONResponse -from server.models import IngestErrorResponse, IngestSuccessResponse +from server.models import IngestErrorResponse, IngestSuccessResponse, PatternType from server.query_processor import process_query COMMON_INGEST_RESPONSES: dict[int | str, dict[str, Any]] = { @@ -29,6 +29,8 @@ async def _perform_ingestion( Consolidates error handling shared by the ``POST`` and ``GET`` ingest endpoints. """ try: + pattern_type = PatternType(pattern_type) + result = await process_query( input_text=input_text, slider_position=max_file_size, diff --git a/src/static/llms.txt b/src/static/llms.txt index 476f77fc..0109c1bf 100644 --- a/src/static/llms.txt +++ b/src/static/llms.txt @@ -30,8 +30,11 @@ pip install gitingest echo "gitingest" >> requirements.txt pip install -r requirements.txt +# For self-hosting: Install with server dependencies +pip install gitingest[server] + # For development: Install with dev dependencies -pip install gitingest[dev] +pip install gitingest[dev,server] ``` ### 1.3 Installation Verification diff --git a/tests/conftest.py b/tests/conftest.py index 15e1d2ad..0e279726 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,7 @@ from __future__ import annotations import json +import sys from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict from unittest.mock import AsyncMock @@ -22,6 +23,26 @@ DEMO_URL = "https://github.com/user/repo" LOCAL_REPO_PATH = "/tmp/repo" +DEMO_COMMIT = "deadbeefdeadbeefdeadbeefdeadbeefdeadbeef" + + +def get_ensure_git_installed_call_count() -> int: + """Get the number of calls made by ensure_git_installed based on platform. + + On Windows, ensure_git_installed makes 2 calls: + 1. git --version + 2. git config core.longpaths + + On other platforms, it makes 1 call: + 1. git --version + + Returns + ------- + int + The number of calls made by ensure_git_installed + + """ + return 2 if sys.platform == "win32" else 1 @pytest.fixture @@ -136,15 +157,38 @@ def _write_notebook(name: str, content: dict[str, Any]) -> Path: return _write_notebook +@pytest.fixture +def stub_resolve_sha(mocker: MockerFixture) -> dict[str, AsyncMock]: + """Patch *both* async helpers that hit the network. + + Include this fixture *only* in tests that should stay offline. + """ + head_mock = mocker.patch( + "gitingest.utils.query_parser_utils._resolve_ref_to_sha", + new_callable=mocker.AsyncMock, + return_value=DEMO_COMMIT, + ) + ref_mock = mocker.patch( + "gitingest.utils.git_utils._resolve_ref_to_sha", + new_callable=mocker.AsyncMock, + return_value=DEMO_COMMIT, + ) + # return whichever you want to assert on; here we return the dict + return {"head": head_mock, "ref": ref_mock} + + @pytest.fixture def stub_branches(mocker: MockerFixture) -> Callable[[list[str]], None]: """Return a function that stubs git branch discovery to *branches*.""" def _factory(branches: list[str]) -> None: + stdout = ( + "\n".join(f"{DEMO_COMMIT[:12]}{i:02d}\trefs/heads/{b}" for i, b in enumerate(branches)).encode() + b"\n" + ) mocker.patch( "gitingest.utils.git_utils.run_command", new_callable=AsyncMock, - return_value=("\n".join(f"refs/heads/{b}" for b in branches).encode() + b"\n", b""), + return_value=(stdout, b""), ) mocker.patch( "gitingest.utils.git_utils.fetch_remote_branches_or_tags", @@ -168,11 +212,14 @@ def run_command_mock(mocker: MockerFixture) -> AsyncMock: The mocked function returns a dummy process whose ``communicate`` method yields generic ``stdout`` / ``stderr`` bytes. Tests can still access / tweak the mock via the fixture argument. """ - mock_exec = mocker.patch("gitingest.clone.run_command", new_callable=AsyncMock) + mock = AsyncMock(side_effect=_fake_run_command) + mocker.patch("gitingest.utils.git_utils.run_command", mock) + mocker.patch("gitingest.clone.run_command", mock) + return mock - # Provide a default dummy process so most tests don't have to create one. - dummy_process = AsyncMock() - dummy_process.communicate.return_value = (b"output", b"error") - mock_exec.return_value = dummy_process - return mock_exec +async def _fake_run_command(*args: str) -> tuple[bytes, bytes]: + if "ls-remote" in args: + # single match: refs/heads/main + return (f"{DEMO_COMMIT}\trefs/heads/main\n".encode(), b"") + return (b"output", b"error") diff --git a/tests/query_parser/test_git_host_agnostic.py b/tests/query_parser/test_git_host_agnostic.py index d3d2542a..342d9882 100644 --- a/tests/query_parser/test_git_host_agnostic.py +++ b/tests/query_parser/test_git_host_agnostic.py @@ -8,8 +8,9 @@ import pytest -from gitingest.query_parser import parse_query -from gitingest.utils.query_parser_utils import KNOWN_GIT_HOSTS +from gitingest.config import MAX_FILE_SIZE +from gitingest.query_parser import parse_remote_repo +from gitingest.utils.query_parser_utils import KNOWN_GIT_HOSTS, _is_valid_git_commit_hash # Repository matrix: (host, user, repo) _REPOS: list[tuple[str, str, str]] = [ @@ -33,7 +34,7 @@ async def test_parse_query_without_host( repo: str, variant: str, ) -> None: - """Verify that ``parse_query`` handles URLs, host-omitted URLs and raw slugs.""" + """Verify that ``parse_remote_repo`` handles URLs, host-omitted URLs and raw slugs.""" # Build the input URL based on the selected variant if variant == "full": url = f"https://{host}/{user}/{repo}" @@ -48,15 +49,20 @@ async def test_parse_query_without_host( # because the parser cannot guess which domain to use. if variant == "slug" and host not in KNOWN_GIT_HOSTS: with pytest.raises(ValueError, match="Could not find a valid repository host"): - await parse_query(url, max_file_size=50, from_web=True) + await parse_remote_repo(url) return - query = await parse_query(url, max_file_size=50, from_web=True) + query = await parse_remote_repo(url) # Compare against the canonical dict while ignoring unpredictable fields. actual = query.model_dump(exclude={"id", "local_path", "ignore_patterns"}) + assert "commit" in actual + assert _is_valid_git_commit_hash(actual["commit"]) + del actual["commit"] + expected = { + "host": host, "user_name": user, "repo_name": repo, "url": expected_url, @@ -65,8 +71,7 @@ async def test_parse_query_without_host( "type": None, "branch": None, "tag": None, - "commit": None, - "max_file_size": 50, + "max_file_size": MAX_FILE_SIZE, "include_patterns": None, "include_submodules": False, } diff --git a/tests/query_parser/test_query_parser.py b/tests/query_parser/test_query_parser.py index f6033352..65eb3764 100644 --- a/tests/query_parser/test_query_parser.py +++ b/tests/query_parser/test_query_parser.py @@ -4,6 +4,7 @@ paths. """ +# pylint: disable=too-many-arguments, too-many-positional-arguments from __future__ import annotations from pathlib import Path @@ -11,12 +12,14 @@ import pytest -from gitingest.query_parser import _parse_patterns, _parse_remote_repo, parse_query -from gitingest.utils.ignore_patterns import DEFAULT_IGNORE_PATTERNS +from gitingest.query_parser import parse_local_dir_path, parse_remote_repo +from gitingest.utils.query_parser_utils import _is_valid_git_commit_hash from tests.conftest import DEMO_URL if TYPE_CHECKING: - from gitingest.schemas.ingestion import IngestionQuery + from unittest.mock import AsyncMock + + from gitingest.schemas import IngestionQuery URLS_HTTPS: list[str] = [ @@ -36,106 +39,85 @@ @pytest.mark.parametrize("url", URLS_HTTPS, ids=lambda u: u) @pytest.mark.asyncio -async def test_parse_url_valid_https(url: str) -> None: +async def test_parse_url_valid_https(url: str, stub_resolve_sha: dict[str, AsyncMock]) -> None: """Valid HTTPS URLs parse correctly and ``query.url`` equals the input.""" - query = await _assert_basic_repo_fields(url) + query = await _assert_basic_repo_fields(url, stub_resolve_sha["head"]) assert query.url == url # HTTPS: canonical URL should equal input @pytest.mark.parametrize("url", URLS_HTTP, ids=lambda u: u) @pytest.mark.asyncio -async def test_parse_url_valid_http(url: str) -> None: +async def test_parse_url_valid_http(url: str, stub_resolve_sha: dict[str, AsyncMock]) -> None: """Valid HTTP URLs parse correctly (slug check only).""" - await _assert_basic_repo_fields(url) + await _assert_basic_repo_fields(url, stub_resolve_sha["head"]) @pytest.mark.asyncio -async def test_parse_url_invalid() -> None: - """Test ``_parse_remote_repo`` with an invalid URL. +async def test_parse_url_invalid(stub_resolve_sha: dict[str, AsyncMock]) -> None: + """Test ``parse_remote_repo`` with an invalid URL. Given an HTTPS URL lacking a repository structure (e.g., "https://github.com"), - When ``_parse_remote_repo`` is called, + When ``parse_remote_repo`` is called, Then a ValueError should be raised indicating an invalid repository URL. """ url = "https://github.com" with pytest.raises(ValueError, match="Invalid repository URL"): - await _parse_remote_repo(url) + await parse_remote_repo(url) + + stub_resolve_sha["head"].assert_not_awaited() @pytest.mark.asyncio @pytest.mark.parametrize("url", [DEMO_URL, "https://gitlab.com/user/repo"]) -async def test_parse_query_basic(url: str) -> None: - """Test ``parse_query`` with a basic valid repository URL. +async def test_parse_query_basic(url: str, stub_resolve_sha: dict[str, AsyncMock]) -> None: + """Test ``parse_remote_repo`` with a basic valid repository URL. - Given an HTTPS URL and ignore_patterns="*.txt": - When ``parse_query`` is called, - Then user/repo, URL, and ignore patterns should be parsed correctly. + Given an HTTPS URL: + When ``parse_remote_repo`` is called, + Then user/repo, URL should be parsed correctly. """ - query = await parse_query(source=url, max_file_size=50, from_web=True, ignore_patterns="*.txt") + query = await parse_remote_repo(url) + stub_resolve_sha["head"].assert_awaited_once() assert query.user_name == "user" assert query.repo_name == "repo" assert query.url == url - assert query.ignore_patterns - assert "*.txt" in query.ignore_patterns @pytest.mark.asyncio -async def test_parse_query_mixed_case() -> None: - """Test ``parse_query`` with mixed-case URLs. +async def test_parse_query_mixed_case(stub_resolve_sha: dict[str, AsyncMock]) -> None: + """Test ``parse_remote_repo`` with mixed-case URLs. Given a URL with mixed-case parts (e.g. "Https://GitHub.COM/UsEr/rEpO"): - When ``parse_query`` is called, + When ``parse_remote_repo`` is called, Then the user and repo names should be normalized to lowercase. """ url = "Https://GitHub.COM/UsEr/rEpO" - query = await parse_query(url, max_file_size=50, from_web=True) + query = await parse_remote_repo(url) + stub_resolve_sha["head"].assert_awaited_once() assert query.user_name == "user" assert query.repo_name == "repo" @pytest.mark.asyncio -async def test_parse_query_include_pattern() -> None: - """Test ``parse_query`` with a specified include pattern. - - Given a URL and include_patterns="*.py": - When ``parse_query`` is called, - Then the include pattern should be set, and default ignore patterns remain applied. - """ - query = await parse_query(DEMO_URL, max_file_size=50, from_web=True, include_patterns="*.py") - - assert query.include_patterns == {"*.py"} - assert query.ignore_patterns == DEFAULT_IGNORE_PATTERNS - - -@pytest.mark.asyncio -async def test_parse_query_invalid_pattern() -> None: - """Test ``parse_query`` with an invalid pattern. - - Given an include pattern containing special characters (e.g., "*.py;rm -rf"): - When ``parse_query`` is called, - Then a ValueError should be raised indicating invalid characters. - """ - with pytest.raises(ValueError, match="Pattern.*contains invalid characters"): - await parse_query(DEMO_URL, max_file_size=50, from_web=True, include_patterns="*.py;rm -rf") - - -@pytest.mark.asyncio -async def test_parse_url_with_subpaths(stub_branches: Callable[[list[str]], None]) -> None: - """Test ``_parse_remote_repo`` with a URL containing branch and subpath. +async def test_parse_url_with_subpaths( + stub_branches: Callable[[list[str]], None], + stub_resolve_sha: dict[str, AsyncMock], +) -> None: + """Test ``parse_remote_repo`` with a URL containing branch and subpath. Given a URL referencing a branch ("main") and a subdir ("subdir/file"): - When ``_parse_remote_repo`` is called with remote branch fetching, + When ``parse_remote_repo`` is called with remote branch fetching, Then user, repo, branch, and subpath should be identified correctly. """ url = DEMO_URL + "/tree/main/subdir/file" stub_branches(["main", "dev", "feature-branch"]) - query = await _assert_basic_repo_fields(url) + query = await _assert_basic_repo_fields(url, stub_resolve_sha["ref"]) assert query.user_name == "user" assert query.repo_name == "repo" @@ -144,105 +126,30 @@ async def test_parse_url_with_subpaths(stub_branches: Callable[[list[str]], None @pytest.mark.asyncio -async def test_parse_url_invalid_repo_structure() -> None: - """Test ``_parse_remote_repo`` with a URL missing a repository name. +async def test_parse_url_invalid_repo_structure(stub_resolve_sha: dict[str, AsyncMock]) -> None: + """Test ``parse_remote_repo`` with a URL missing a repository name. Given a URL like "https://github.com/user": - When ``_parse_remote_repo`` is called, + When ``parse_remote_repo`` is called, Then a ValueError should be raised indicating an invalid repository URL. """ url = "https://github.com/user" with pytest.raises(ValueError, match="Invalid repository URL"): - await _parse_remote_repo(url) - - -def test_parse_patterns_valid() -> None: - """Test ``_parse_patterns`` with valid comma-separated patterns. - - Given patterns like "*.py, *.md, docs/*": - When ``_parse_patterns`` is called, - Then it should return a set of parsed strings. - """ - patterns = "*.py, *.md, docs/*" - parsed_patterns = _parse_patterns(patterns) - - assert parsed_patterns == {"*.py", "*.md", "docs/*"} - - -def test_parse_patterns_invalid_characters() -> None: - """Test ``_parse_patterns`` with invalid characters. - - Given a pattern string containing special characters (e.g. "*.py;rm -rf"): - When ``_parse_patterns`` is called, - Then a ValueError should be raised indicating invalid pattern syntax. - """ - patterns = "*.py;rm -rf" - - with pytest.raises(ValueError, match="Pattern.*contains invalid characters"): - _parse_patterns(patterns) + await parse_remote_repo(url) + stub_resolve_sha["head"].assert_not_awaited() -@pytest.mark.asyncio -async def test_parse_query_with_large_file_size() -> None: - """Test ``parse_query`` with a very large file size limit. - - Given a URL and max_file_size=10**9: - When ``parse_query`` is called, - Then ``max_file_size`` should be set correctly and default ignore patterns remain unchanged. - """ - query = await parse_query(DEMO_URL, max_file_size=10**9, from_web=True) - - assert query.max_file_size == 10**9 - assert query.ignore_patterns == DEFAULT_IGNORE_PATTERNS +async def test_parse_local_dir_path_local_path() -> None: + """Test ``parse_local_dir_path``. -@pytest.mark.asyncio -async def test_parse_query_empty_patterns() -> None: - """Test ``parse_query`` with empty patterns. - - Given empty include_patterns and ignore_patterns: - When ``parse_query`` is called, - Then ``include_patterns`` becomes ``None`` and default ignore patterns apply. - """ - query = await parse_query(DEMO_URL, max_file_size=50, from_web=True, include_patterns="", ignore_patterns="") - - assert query.include_patterns is None - assert query.ignore_patterns == DEFAULT_IGNORE_PATTERNS - - -@pytest.mark.asyncio -async def test_parse_query_include_and_ignore_overlap() -> None: - """Test ``parse_query`` with overlapping patterns. - - Given include="*.py" and ignore={"*.py", "*.txt"}: - When ``parse_query`` is called, - Then "*.py" should be removed from ignore patterns. - """ - query = await parse_query( - DEMO_URL, - max_file_size=50, - from_web=True, - include_patterns="*.py", - ignore_patterns={"*.py", "*.txt"}, - ) - - assert query.include_patterns == {"*.py"} - assert query.ignore_patterns is not None - assert "*.py" not in query.ignore_patterns - assert "*.txt" in query.ignore_patterns - - -@pytest.mark.asyncio -async def test_parse_query_local_path() -> None: - """Test ``parse_query`` with a local file path. - - Given "/home/user/project" and from_web=False: - When ``parse_query`` is called, + Given "/home/user/project": + When ``parse_local_dir_path`` is called, Then the local path should be set, id generated, and slug formed accordingly. """ path = "/home/user/project" - query = await parse_query(path, max_file_size=100, from_web=False) + query = parse_local_dir_path(path) tail = Path("home/user/project") assert query.local_path.parts[-len(tail.parts) :] == tail.parts @@ -250,16 +157,15 @@ async def test_parse_query_local_path() -> None: assert query.slug == "home/user/project" -@pytest.mark.asyncio -async def test_parse_query_relative_path() -> None: - """Test ``parse_query`` with a relative path. +async def test_parse_local_dir_path_relative_path() -> None: + """Test ``parse_local_dir_path`` with a relative path. - Given "./project" and from_web=False: - When ``parse_query`` is called, + Given "./project": + When ``parse_local_dir_path`` is called, Then ``local_path`` resolves relatively, and ``slug`` ends with "project". """ path = "./project" - query = await parse_query(path, max_file_size=100, from_web=False) + query = parse_local_dir_path(path) tail = Path("project") assert query.local_path.parts[-len(tail.parts) :] == tail.parts @@ -267,103 +173,109 @@ async def test_parse_query_relative_path() -> None: @pytest.mark.asyncio -async def test_parse_query_empty_source() -> None: - """Test ``parse_query`` with an empty string. +async def test_parse_remote_repo_empty_source(stub_resolve_sha: dict[str, AsyncMock]) -> None: + """Test ``parse_remote_repo`` with an empty string. Given an empty source string: - When ``parse_query`` is called, + When ``parse_remote_repo`` is called, Then a ValueError should be raised indicating an invalid repository URL. """ url = "" with pytest.raises(ValueError, match="Invalid repository URL"): - await parse_query(url, max_file_size=100, from_web=True) + await parse_remote_repo(url) + + stub_resolve_sha["head"].assert_not_awaited() @pytest.mark.asyncio @pytest.mark.parametrize( - ("path", "expected_branch", "expected_commit"), + ("path", "expected_branch", "mock_name"), [ - ("/tree/main", "main", None), - ("/tree/abcd1234abcd1234abcd1234abcd1234abcd1234", None, "abcd1234abcd1234abcd1234abcd1234abcd1234"), + ("/tree/main", "main", "ref"), + ("/tree/abcd1234abcd1234abcd1234abcd1234abcd1234", None, "ref"), ], ) async def test_parse_url_branch_and_commit_distinction( path: str, expected_branch: str, - expected_commit: str, stub_branches: Callable[[list[str]], None], + stub_resolve_sha: dict[str, AsyncMock], + mock_name: str, ) -> None: - """Test ``_parse_remote_repo`` distinguishing branch vs. commit hash. + """Test ``parse_remote_repo`` distinguishing branch vs. commit hash. Given either a branch URL (e.g., ".../tree/main") or a 40-character commit URL: - When ``_parse_remote_repo`` is called with branch fetching, + When ``parse_remote_repo`` is called with branch fetching, Then the function should correctly set ``branch`` or ``commit`` based on the URL content. """ stub_branches(["main", "dev", "feature-branch"]) url = DEMO_URL + path - query = await _assert_basic_repo_fields(url) + query = await _assert_basic_repo_fields(url, stub_resolve_sha[mock_name]) assert query.branch == expected_branch - assert query.commit == expected_commit + assert query.commit is not None + assert _is_valid_git_commit_hash(query.commit) -@pytest.mark.asyncio -async def test_parse_query_uuid_uniqueness() -> None: - """Test ``parse_query`` for unique UUID generation. +async def test_parse_local_dir_path_uuid_uniqueness() -> None: + """Test ``parse_local_dir_path`` for unique UUID generation. Given the same path twice: - When ``parse_query`` is called repeatedly, + When ``parse_local_dir_path`` is called repeatedly, Then each call should produce a different query id. """ path = "/home/user/project" - query_1 = await parse_query(path, max_file_size=100, from_web=False) - query_2 = await parse_query(path, max_file_size=100, from_web=False) + query_1 = parse_local_dir_path(path) + query_2 = parse_local_dir_path(path) assert query_1.id != query_2.id @pytest.mark.asyncio -async def test_parse_url_with_query_and_fragment() -> None: - """Test ``_parse_remote_repo`` with query parameters and a fragment. +async def test_parse_url_with_query_and_fragment(stub_resolve_sha: dict[str, AsyncMock]) -> None: + """Test ``parse_remote_repo`` with query parameters and a fragment. Given a URL like "https://github.com/user/repo?arg=value#fragment": - When ``_parse_remote_repo`` is called, + When ``parse_remote_repo`` is called, Then those parts should be stripped, leaving a clean user/repo URL. """ url = DEMO_URL + "?arg=value#fragment" - query = await _parse_remote_repo(url) + query = await parse_remote_repo(url) + stub_resolve_sha["head"].assert_awaited_once() assert query.user_name == "user" assert query.repo_name == "repo" assert query.url == DEMO_URL # URL should be cleaned @pytest.mark.asyncio -async def test_parse_url_unsupported_host() -> None: - """Test ``_parse_remote_repo`` with an unsupported host. +async def test_parse_url_unsupported_host(stub_resolve_sha: dict[str, AsyncMock]) -> None: + """Test ``parse_remote_repo`` with an unsupported host. Given "https://only-domain.com": - When ``_parse_remote_repo`` is called, + When ``parse_remote_repo`` is called, Then a ValueError should be raised for the unknown domain. """ url = "https://only-domain.com" with pytest.raises(ValueError, match="Unknown domain 'only-domain.com' in URL"): - await _parse_remote_repo(url) + await parse_remote_repo(url) + + stub_resolve_sha["head"].assert_not_awaited() @pytest.mark.asyncio async def test_parse_query_with_branch() -> None: - """Test ``parse_query`` when a branch is specified in a blob path. + """Test ``parse_remote_repo`` when a branch is specified in a blob path. Given "https://github.com/pandas-dev/pandas/blob/2.2.x/...": - When ``parse_query`` is called, + When ``parse_remote_repo`` is called, Then the branch should be identified, subpath set, and commit remain None. """ url = "https://github.com/pandas-dev/pandas/blob/2.2.x/.github/ISSUE_TEMPLATE/documentation_improvement.yaml" - query = await parse_query(url, max_file_size=10**9, from_web=True) + query = await parse_remote_repo(url) assert query.user_name == "pandas-dev" assert query.repo_name == "pandas" @@ -372,20 +284,21 @@ async def test_parse_query_with_branch() -> None: assert query.id is not None assert query.subpath == "/.github/ISSUE_TEMPLATE/documentation_improvement.yaml" assert query.branch == "2.2.x" - assert query.commit is None + assert query.commit is not None + assert _is_valid_git_commit_hash(query.commit) assert query.type == "blob" @pytest.mark.asyncio @pytest.mark.parametrize( - ("path", "expected_branch", "expected_subpath"), + ("path", "expected_branch", "expected_subpath", "mock_name"), [ - ("/tree/feature/fix1/src", "feature/fix1", "/src"), - ("/tree/main/src", "main", "/src"), - ("", None, "/"), - ("/tree/nonexistent-branch/src", None, "/"), - ("/tree/fix", "fix", "/"), - ("/blob/fix/page.html", "fix", "/page.html"), + ("/tree/feature/fix1/src", "feature/fix1", "/src", "ref"), + ("/tree/main/src", "main", "/src", "ref"), + ("", None, "/", "head"), + ("/tree/nonexistent-branch/src", None, "/", "ref"), + ("/tree/fix", "fix", "/", "ref"), + ("/blob/fix/page.html", "fix", "/page.html", "ref"), ], ) async def test_parse_repo_source_with_various_url_patterns( @@ -393,11 +306,13 @@ async def test_parse_repo_source_with_various_url_patterns( expected_branch: str | None, expected_subpath: str, stub_branches: Callable[[list[str]], None], + stub_resolve_sha: dict[str, AsyncMock], + mock_name: str, ) -> None: - """Test ``_parse_remote_repo`` with various GitHub-style URL permutations. + """Test ``parse_remote_repo`` with various GitHub-style URL permutations. Given various GitHub-style URL permutations: - When ``_parse_remote_repo`` is called, + When ``parse_remote_repo`` is called, Then it should detect (or reject) a branch and resolve the sub-path. Branch discovery is stubbed so that only names passed to ``stub_branches`` are considered "remote". @@ -405,15 +320,24 @@ async def test_parse_repo_source_with_various_url_patterns( stub_branches(["feature/fix1", "main", "feature-branch", "fix"]) url = DEMO_URL + path - query = await _assert_basic_repo_fields(url) + query = await _assert_basic_repo_fields(url, stub_resolve_sha[mock_name]) assert query.branch == expected_branch assert query.subpath == expected_subpath -async def _assert_basic_repo_fields(url: str) -> IngestionQuery: - """Run ``_parse_remote_repo`` and assert user, repo and slug are parsed.""" - query = await _parse_remote_repo(url) +@pytest.mark.asyncio +async def _assert_basic_repo_fields(url: str, sha_mock: AsyncMock) -> IngestionQuery: + """Run ``parse_remote_repo`` and assert user, repo and slug are parsed.""" + query = await parse_remote_repo(url) + + assert query.commit is not None + assert _is_valid_git_commit_hash(query.commit) + + if query.commit in url: + sha_mock.assert_not_awaited() + else: + sha_mock.assert_awaited_once() assert query.user_name == "user" assert query.repo_name == "repo" diff --git a/tests/server/__init__.py b/tests/server/__init__.py new file mode 100644 index 00000000..967a4a7b --- /dev/null +++ b/tests/server/__init__.py @@ -0,0 +1 @@ +"""Tests for the server.""" diff --git a/tests/test_flow_integration.py b/tests/server/test_flow_integration.py similarity index 100% rename from tests/test_flow_integration.py rename to tests/server/test_flow_integration.py diff --git a/tests/test_clone.py b/tests/test_clone.py index 42ca1994..1d89c212 100644 --- a/tests/test_clone.py +++ b/tests/test_clone.py @@ -4,26 +4,35 @@ and handling edge cases such as nonexistent URLs, timeouts, redirects, and specific commits or branches. """ +from __future__ import annotations + import asyncio -import subprocess -from pathlib import Path +import sys +from typing import TYPE_CHECKING from unittest.mock import AsyncMock import httpx import pytest -from pytest_mock import MockerFixture from starlette.status import HTTP_200_OK, HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND from gitingest.clone import clone_repo from gitingest.schemas import CloneConfig from gitingest.utils.exceptions import AsyncTimeoutError from gitingest.utils.git_utils import check_repo_exists -from tests.conftest import DEMO_URL, LOCAL_REPO_PATH +from tests.conftest import DEMO_COMMIT, DEMO_URL, LOCAL_REPO_PATH + +if TYPE_CHECKING: + from pathlib import Path + + from pytest_mock import MockerFixture + # All cloning-related tests assume (unless explicitly overridden) that the repository exists. # Apply the check-repo patch automatically so individual tests don't need to repeat it. pytestmark = pytest.mark.usefixtures("repo_exists_true") +GIT_INSTALLED_CALLS = 2 if sys.platform == "win32" else 1 + @pytest.mark.asyncio async def test_clone_with_commit(repo_exists_true: AsyncMock, run_command_mock: AsyncMock) -> None: @@ -33,18 +42,20 @@ async def test_clone_with_commit(repo_exists_true: AsyncMock, run_command_mock: When ``clone_repo`` is called, Then the repository should be cloned and checked out at that commit. """ - expected_call_count = 2 + expected_call_count = GIT_INSTALLED_CALLS + 3 # ensure_git_installed + clone + fetch + checkout + commit_hash = "a" * 40 # Simulating a valid commit hash clone_config = CloneConfig( url=DEMO_URL, local_path=LOCAL_REPO_PATH, - commit="a" * 40, # Simulating a valid commit hash + commit=commit_hash, branch="main", ) await clone_repo(clone_config) - repo_exists_true.assert_called_once_with(clone_config.url, token=None) - assert run_command_mock.call_count == expected_call_count # Clone and checkout calls + repo_exists_true.assert_any_call(clone_config.url, token=None) + assert_standard_calls(run_command_mock, clone_config, commit=commit_hash) + assert run_command_mock.call_count == expected_call_count @pytest.mark.asyncio @@ -55,13 +66,14 @@ async def test_clone_without_commit(repo_exists_true: AsyncMock, run_command_moc When ``clone_repo`` is called, Then only the clone_repo operation should be performed (no checkout). """ - expected_call_count = 1 + expected_call_count = GIT_INSTALLED_CALLS + 4 # ensure_git_installed + resolve_commit + clone + fetch + checkout clone_config = CloneConfig(url=DEMO_URL, local_path=LOCAL_REPO_PATH, commit=None, branch="main") await clone_repo(clone_config) - repo_exists_true.assert_called_once_with(clone_config.url, token=None) - assert run_command_mock.call_count == expected_call_count # Only clone call + repo_exists_true.assert_any_call(clone_config.url, token=None) + assert_standard_calls(run_command_mock, clone_config, commit=DEMO_COMMIT) + assert run_command_mock.call_count == expected_call_count @pytest.mark.asyncio @@ -84,7 +96,7 @@ async def test_clone_nonexistent_repository(repo_exists_true: AsyncMock) -> None with pytest.raises(ValueError, match="Repository not found"): await clone_repo(clone_config) - repo_exists_true.assert_called_once_with(clone_config.url, token=None) + repo_exists_true.assert_any_call(clone_config.url, token=None) @pytest.mark.asyncio @@ -117,20 +129,13 @@ async def test_clone_with_custom_branch(run_command_mock: AsyncMock) -> None: When ``clone_repo`` is called, Then the repository should be cloned shallowly to that branch. """ + expected_call_count = GIT_INSTALLED_CALLS + 4 # ensure_git_installed + resolve_commit + clone + fetch + checkout clone_config = CloneConfig(url=DEMO_URL, local_path=LOCAL_REPO_PATH, branch="feature-branch") await clone_repo(clone_config) - run_command_mock.assert_called_once_with( - "git", - "clone", - "--single-branch", - "--depth=1", - "--branch", - "feature-branch", - clone_config.url, - clone_config.local_path, - ) + assert_standard_calls(run_command_mock, clone_config, commit=DEMO_COMMIT) + assert run_command_mock.call_count == expected_call_count @pytest.mark.asyncio @@ -143,9 +148,9 @@ async def test_git_command_failure(run_command_mock: AsyncMock) -> None: """ clone_config = CloneConfig(url=DEMO_URL, local_path=LOCAL_REPO_PATH) - run_command_mock.side_effect = RuntimeError("Git command failed") + run_command_mock.side_effect = RuntimeError("Git is not installed or not accessible. Please install Git first.") - with pytest.raises(RuntimeError, match="Git command failed"): + with pytest.raises(RuntimeError, match="Git is not installed or not accessible"): await clone_repo(clone_config) @@ -157,18 +162,13 @@ async def test_clone_default_shallow_clone(run_command_mock: AsyncMock) -> None: When ``clone_repo`` is called, Then the repository should be cloned with ``--depth=1`` and ``--single-branch``. """ + expected_call_count = GIT_INSTALLED_CALLS + 4 # ensure_git_installed + resolve_commit + clone + fetch + checkout clone_config = CloneConfig(url=DEMO_URL, local_path=LOCAL_REPO_PATH) await clone_repo(clone_config) - run_command_mock.assert_called_once_with( - "git", - "clone", - "--single-branch", - "--depth=1", - clone_config.url, - clone_config.local_path, - ) + assert_standard_calls(run_command_mock, clone_config, commit=DEMO_COMMIT) + assert run_command_mock.call_count == expected_call_count @pytest.mark.asyncio @@ -179,15 +179,14 @@ async def test_clone_commit(run_command_mock: AsyncMock) -> None: When ``clone_repo`` is called, Then the repository should be cloned and checked out at that commit. """ - expected_call_count = 2 - # Simulating a valid commit hash - clone_config = CloneConfig(url=DEMO_URL, local_path=LOCAL_REPO_PATH, commit="a" * 40) + expected_call_count = GIT_INSTALLED_CALLS + 3 # ensure_git_installed + clone + fetch + checkout + commit_hash = "a" * 40 # Simulating a valid commit hash + clone_config = CloneConfig(url=DEMO_URL, local_path=LOCAL_REPO_PATH, commit=commit_hash) await clone_repo(clone_config) - assert run_command_mock.call_count == expected_call_count # Clone and checkout calls - run_command_mock.assert_any_call("git", "clone", "--single-branch", clone_config.url, clone_config.local_path) - run_command_mock.assert_any_call("git", "-C", clone_config.local_path, "checkout", clone_config.commit) + assert_standard_calls(run_command_mock, clone_config, commit=commit_hash) + assert run_command_mock.call_count == expected_call_count @pytest.mark.asyncio @@ -225,40 +224,6 @@ async def test_clone_with_timeout(run_command_mock: AsyncMock) -> None: await clone_repo(clone_config) -@pytest.mark.asyncio -async def test_clone_specific_branch(tmp_path: Path) -> None: - """Test cloning a specific branch of a repository. - - Given a valid repository URL and a branch name: - When ``clone_repo`` is called, - Then the repository should be cloned and checked out at that branch. - """ - repo_url = "https://github.com/coderamp-labs/gitingest.git" - branch_name = "main" - local_path = tmp_path / "gitingest" - clone_config = CloneConfig(url=repo_url, local_path=str(local_path), branch=branch_name) - - await clone_repo(clone_config) - - assert local_path.exists(), "The repository was not cloned successfully." - assert local_path.is_dir(), "The cloned repository path is not a directory." - - loop = asyncio.get_running_loop() - current_branch = ( - ( - await loop.run_in_executor( - None, - subprocess.check_output, - ["git", "-C", str(local_path), "branch", "--show-current"], - ) - ) - .decode() - .strip() - ) - - assert current_branch == branch_name, f"Expected branch '{branch_name}', got '{current_branch}'." - - @pytest.mark.asyncio async def test_clone_branch_with_slashes(tmp_path: Path, run_command_mock: AsyncMock) -> None: """Test cloning a branch with slashes in the name. @@ -269,20 +234,13 @@ async def test_clone_branch_with_slashes(tmp_path: Path, run_command_mock: Async """ branch_name = "fix/in-operator" local_path = tmp_path / "gitingest" + expected_call_count = GIT_INSTALLED_CALLS + 4 # ensure_git_installed + resolve_commit + clone + fetch + checkout clone_config = CloneConfig(url=DEMO_URL, local_path=str(local_path), branch=branch_name) await clone_repo(clone_config) - run_command_mock.assert_called_once_with( - "git", - "clone", - "--single-branch", - "--depth=1", - "--branch", - "fix/in-operator", - clone_config.url, - clone_config.local_path, - ) + assert_standard_calls(run_command_mock, clone_config, commit=DEMO_COMMIT) + assert run_command_mock.call_count == expected_call_count @pytest.mark.asyncio @@ -293,20 +251,16 @@ async def test_clone_creates_parent_directory(tmp_path: Path, run_command_mock: When ``clone_repo`` is called, Then it should create the parent directories before attempting to clone. """ + expected_call_count = GIT_INSTALLED_CALLS + 4 # ensure_git_installed + resolve_commit + clone + fetch + checkout nested_path = tmp_path / "deep" / "nested" / "path" / "repo" + clone_config = CloneConfig(url=DEMO_URL, local_path=str(nested_path)) await clone_repo(clone_config) assert nested_path.parent.exists() - run_command_mock.assert_called_once_with( - "git", - "clone", - "--single-branch", - "--depth=1", - clone_config.url, - str(nested_path), - ) + assert_standard_calls(run_command_mock, clone_config, commit=DEMO_COMMIT) + assert run_command_mock.call_count == expected_call_count @pytest.mark.asyncio @@ -317,26 +271,15 @@ async def test_clone_with_specific_subpath(run_command_mock: AsyncMock) -> None: When ``clone_repo`` is called, Then the repository should be cloned with sparse checkout enabled and the specified subpath. """ - expected_call_count = 2 - clone_config = CloneConfig(url=DEMO_URL, local_path=LOCAL_REPO_PATH, subpath="src/docs") + # ensure_git_installed + resolve_commit + clone + sparse-checkout + fetch + checkout + subpath = "src/docs" + expected_call_count = GIT_INSTALLED_CALLS + 5 + clone_config = CloneConfig(url=DEMO_URL, local_path=LOCAL_REPO_PATH, subpath=subpath) await clone_repo(clone_config) # Verify the clone command includes sparse checkout flags - run_command_mock.assert_any_call( - "git", - "clone", - "--single-branch", - "--filter=blob:none", - "--sparse", - "--depth=1", - clone_config.url, - clone_config.local_path, - ) - - # Verify the sparse-checkout command sets the correct path - run_command_mock.assert_any_call("git", "-C", clone_config.local_path, "sparse-checkout", "set", "src/docs") - + assert_partial_clone_calls(run_command_mock, clone_config, commit=DEMO_COMMIT) assert run_command_mock.call_count == expected_call_count @@ -349,42 +292,14 @@ async def test_clone_with_commit_and_subpath(run_command_mock: AsyncMock) -> Non Then the repository should be cloned with sparse checkout enabled, checked out at the specific commit, and only include the specified subpath. """ - expected_call_count = 3 - # Simulating a valid commit hash - clone_config = CloneConfig(url=DEMO_URL, local_path=LOCAL_REPO_PATH, commit="a" * 40, subpath="src/docs") + subpath = "src/docs" + expected_call_count = GIT_INSTALLED_CALLS + 4 # ensure_git_installed + clone + sparse-checkout + fetch + checkout + commit_hash = "a" * 40 # Simulating a valid commit hash + clone_config = CloneConfig(url=DEMO_URL, local_path=LOCAL_REPO_PATH, commit=commit_hash, subpath=subpath) await clone_repo(clone_config) - # Verify the clone command includes sparse checkout flags - run_command_mock.assert_any_call( - "git", - "clone", - "--single-branch", - "--filter=blob:none", - "--sparse", - clone_config.url, - clone_config.local_path, - ) - - # Verify sparse-checkout set - run_command_mock.assert_any_call( - "git", - "-C", - clone_config.local_path, - "sparse-checkout", - "set", - "src/docs", - ) - - # Verify checkout commit - run_command_mock.assert_any_call( - "git", - "-C", - clone_config.local_path, - "checkout", - clone_config.commit, - ) - + assert_partial_clone_calls(run_command_mock, clone_config, commit=commit_hash) assert run_command_mock.call_count == expected_call_count @@ -396,18 +311,39 @@ async def test_clone_with_include_submodules(run_command_mock: AsyncMock) -> Non When ``clone_repo`` is called, Then the repository should be cloned with ``--recurse-submodules`` in the git command. """ - expected_call_count = 1 # No commit and no partial clone + # ensure_git_installed + resolve_commit + clone + fetch + checkout + checkout submodules + expected_call_count = GIT_INSTALLED_CALLS + 5 clone_config = CloneConfig(url=DEMO_URL, local_path=LOCAL_REPO_PATH, branch="main", include_submodules=True) await clone_repo(clone_config) + assert_standard_calls(run_command_mock, clone_config, commit=DEMO_COMMIT) + assert_submodule_calls(run_command_mock, clone_config) assert run_command_mock.call_count == expected_call_count - run_command_mock.assert_called_once_with( - "git", - "clone", - "--single-branch", - "--recurse-submodules", - "--depth=1", - clone_config.url, - clone_config.local_path, - ) + + +def assert_standard_calls(mock: AsyncMock, cfg: CloneConfig, commit: str, *, partial_clone: bool = False) -> None: + """Assert that the standard clone sequence of git commands was called.""" + mock.assert_any_call("git", "--version") + if sys.platform == "win32": + mock.assert_any_call("git", "config", "core.longpaths") + + # Clone + clone_cmd = ["git", "clone", "--single-branch", "--no-checkout", "--depth=1"] + if partial_clone: + clone_cmd += ["--filter=blob:none", "--sparse"] + mock.assert_any_call(*clone_cmd, cfg.url, cfg.local_path) + + mock.assert_any_call("git", "-C", cfg.local_path, "fetch", "--depth=1", "origin", commit) + mock.assert_any_call("git", "-C", cfg.local_path, "checkout", commit) + + +def assert_partial_clone_calls(mock: AsyncMock, cfg: CloneConfig, commit: str) -> None: + """Assert that the partial clone sequence of git commands was called.""" + assert_standard_calls(mock, cfg, commit=commit, partial_clone=True) + mock.assert_any_call("git", "-C", cfg.local_path, "sparse-checkout", "set", cfg.subpath) + + +def assert_submodule_calls(mock: AsyncMock, cfg: CloneConfig) -> None: + """Assert that submodule update commands were called.""" + mock.assert_any_call("git", "-C", cfg.local_path, "submodule", "update", "--init", "--recursive", "--depth=1") diff --git a/tests/test_pattern_utils.py b/tests/test_pattern_utils.py new file mode 100644 index 00000000..7a392fb0 --- /dev/null +++ b/tests/test_pattern_utils.py @@ -0,0 +1,45 @@ +"""Test pattern utilities.""" + +from gitingest.utils.ignore_patterns import DEFAULT_IGNORE_PATTERNS +from gitingest.utils.pattern_utils import _parse_patterns, process_patterns + + +def test_process_patterns_empty_patterns() -> None: + """Test ``process_patterns`` with empty patterns. + + Given empty ``include_patterns`` and ``exclude_patterns``: + When ``process_patterns`` is called, + Then ``include_patterns`` becomes ``None`` and ``DEFAULT_IGNORE_PATTERNS`` apply. + """ + exclude_patterns, include_patterns = process_patterns(exclude_patterns="", include_patterns="") + + assert include_patterns is None + assert exclude_patterns == DEFAULT_IGNORE_PATTERNS + + +def test_parse_patterns_valid() -> None: + """Test ``_parse_patterns`` with valid comma-separated patterns. + + Given patterns like "*.py, *.md, docs/*": + When ``_parse_patterns`` is called, + Then it should return a set of parsed strings. + """ + patterns = "*.py, *.md, docs/*" + parsed_patterns = _parse_patterns(patterns) + + assert parsed_patterns == {"*.py", "*.md", "docs/*"} + + +def test_process_patterns_include_and_ignore_overlap() -> None: + """Test ``process_patterns`` with overlapping patterns. + + Given include="*.py" and ignore={"*.py", "*.txt"}: + When ``process_patterns`` is called, + Then "*.py" should be removed from ignore patterns. + """ + exclude_patterns, include_patterns = process_patterns(exclude_patterns={"*.py", "*.txt"}, include_patterns="*.py") + + assert include_patterns == {"*.py"} + assert exclude_patterns is not None + assert "*.py" not in exclude_patterns + assert "*.txt" in exclude_patterns diff --git a/tests/test_summary.py b/tests/test_summary.py new file mode 100644 index 00000000..ac32394a --- /dev/null +++ b/tests/test_summary.py @@ -0,0 +1,111 @@ +"""Test that ``gitingest.ingest()`` emits a concise, 5-or-6-line summary.""" + +import re +from pathlib import Path + +import pytest + +from gitingest import ingest + +REPO = "pallets/flask" + +PATH_CASES = [ + ("tree", "/examples/celery"), + ("blob", "/examples/celery/make_celery.py"), + ("blob", "/.gitignore"), +] + +REF_CASES = [ + ("Branch", "main"), + ("Branch", "stable"), + ("Tag", "3.0.3"), + ("Commit", "e9741288637e0d9abe95311247b4842a017f7d5c"), +] + + +@pytest.mark.parametrize(("path_type", "path"), PATH_CASES) +@pytest.mark.parametrize(("ref_type", "ref"), REF_CASES) +def test_ingest_summary(path_type: str, path: str, ref_type: str, ref: str) -> None: + """Assert that ``gitingest.ingest()`` emits a concise, 5-or-6-line summary. + + - Non-'main” refs → 5 key/value pairs + blank line (6 total). + - 'main” branch → ref line omitted (5 total). + - Required keys: + - Repository + - ``ref_type`` (absent on 'main”) + - File│Subpath (chosen by ``path_type``) + - Lines│Files analyzed (chosen by ``path_type``) + - Estimated tokens (positive integer) + + Any missing key, wrong value, or incorrect line count should fail. + + Parameters + ---------- + path_type : {"tree", "blob"} + GitHub object type under test. + path : str + The repository sub-path or file path to feed into the URL. + ref_type : {"Branch", "Tag", "Commit"} + Label expected on line 2 of the summary (absent if `ref` is "main"). + ref : str + Actual branch name, tag, or commit hash. + + """ + is_main_branch = ref == "main" + is_blob = path_type == "blob" + expected_lines = _calculate_expected_lines(ref_type, is_main_branch=is_main_branch) + expected_non_empty_lines = expected_lines - 1 + + summary, _, _ = ingest(f"https://github.com/{REPO}/{path_type}/{ref}{path}") + lines = summary.splitlines() + parsed_lines = dict(line.split(": ", 1) for line in lines if ": " in line) + + assert parsed_lines["Repository"] == REPO + + if is_main_branch: + # We omit the 'Branch' line for 'main' branches. + assert ref_type not in parsed_lines + else: + assert parsed_lines[ref_type] == ref + + if is_blob: + assert parsed_lines["File"] == Path(path).name + assert "Lines" in parsed_lines + else: # 'tree' + assert parsed_lines["Subpath"] == path + assert "Files analyzed" in parsed_lines + + token_match = re.search(r"\d+", parsed_lines["Estimated tokens"]) + assert token_match, "'Estimated tokens' should contain a number" + assert int(token_match.group()) > 0 + + assert len(lines) == expected_lines + assert len(parsed_lines) == expected_non_empty_lines + + +def _calculate_expected_lines(ref_type: str, *, is_main_branch: bool) -> int: + """Calculate the expected number of lines in the summary. + + The total number of lines depends on the following: + - Commit type does not include the 'Branch'/'Tag' line, reducing the count by 1. + - The "main" branch omits the 'Branch' line, reducing the count by 1. + + Parameters + ---------- + ref_type : str + The type of reference, e.g., "Branch", "Tag", or "Commit". + is_main_branch : bool + True if the reference is the "main" branch, False otherwise. + + Returns + ------- + int + The expected number of lines in the summary. + + """ + base_lines = 7 + if is_main_branch: + base_lines -= 1 + if ref_type == "Commit": + base_lines -= 1 + return base_lines