Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,24 @@
# Note the order is intentional to avoid multiple passes of the hooks
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.4
rev: v0.14.6
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix, --show-fixes]
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v3.0.3
rev: v4.0.0-alpha.8
hooks:
- id: prettier
exclude: front/package-lock\.json
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v6.0.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: debug-statements
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.6.1
rev: v1.18.2
hooks:
- id: mypy
additional_dependencies:
Expand Down
4 changes: 2 additions & 2 deletions src/graphql_sqlalchemy/graphql_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _get_array_item_type(column_type: TypeEngine[Any]) -> TypeEngine[Any] | None
and hasattr(column_type, "item_type")
):
return None
return cast(TypeEngine[Any], column_type.item_type)
return cast("TypeEngine[Any]", column_type.item_type)


def get_graphql_type_from_column(
Expand Down Expand Up @@ -139,7 +139,7 @@ def get_graphql_type_from_column(


def get_base_comparison_fields(
graphql_type: GraphQLScalarType | GraphQLEnumType | GraphQLList[Any]
graphql_type: GraphQLScalarType | GraphQLEnumType | GraphQLList[Any],
) -> dict[str, GraphQLInputField]:
return {
"_eq": GraphQLInputField(graphql_type),
Expand Down
14 changes: 5 additions & 9 deletions src/graphql_sqlalchemy/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ def all_scalars(
selection: W,
*,
execution_options: OrmExecuteOptionsParameter = MappingProxyType({}),
) -> Sequence[DeclarativeBase]:
...
) -> Sequence[DeclarativeBase]: ...


@overload
Expand All @@ -58,8 +57,7 @@ def all_scalars(
selection: W,
*,
execution_options: OrmExecuteOptionsParameter = MappingProxyType({}),
) -> Awaitable[Sequence[DeclarativeBase]]:
...
) -> Awaitable[Sequence[DeclarativeBase]]: ...


def all_scalars(
Expand Down Expand Up @@ -242,7 +240,7 @@ def resolver(
offset: int | None = None,
) -> AwaitableOrValue[Sequence[DeclarativeBase]]:
if all(f is None for f in [where, order, limit, offset]):
return cast(Sequence[Any], getattr(root, field_name))
return cast("Sequence[Any]", getattr(root, field_name))
session = info.context["session"]
relationship: InstrumentedAttribute[Any] = getattr(root.__class__, field_name)
field_model = relationship.prop.entity.class_
Expand Down Expand Up @@ -281,15 +279,13 @@ def resolver(_root: None, info: ResolveInfo, **kwargs: dict[str, Any]) -> Awaita
@overload
def session_add_object(
obj: dict[str, Any], model: type[DeclarativeBase], session: Session, *, on_conflict: dict[str, Any] | None
) -> DeclarativeBase:
...
) -> DeclarativeBase: ...


@overload
def session_add_object(
obj: dict[str, Any], model: type[DeclarativeBase], session: AsyncSession, *, on_conflict: dict[str, Any] | None
) -> Awaitable[DeclarativeBase]:
...
) -> Awaitable[DeclarativeBase]: ...


def session_add_object(
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
],
)
def is_async(request: pytest.FixtureRequest) -> bool:
return cast(bool, request.param)
return cast("bool", request.param)


@pytest.fixture(scope="session")
Expand All @@ -52,7 +52,7 @@ def db_session_factory(db_engine: Engine | AsyncEngine) -> scoped_session[Sessio
return scoped_session(sessionmaker(bind=db_engine))


@pytest.fixture()
@pytest.fixture
async def db_session(
db_session_factory: scoped_session[Session] | async_scoped_session[AsyncSession],
) -> AsyncGenerator[Session | AsyncSession, None]:
Expand Down
9 changes: 5 additions & 4 deletions tests/test_build_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
GraphQLScalarType,
GraphQLString,
)
from graphql_sqlalchemy import build_schema
from graphql_sqlalchemy.testing import JsonArray, assert_equal_gql_type
from sqlalchemy import Column, ForeignKey, Integer, Table
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, registry, relationship

from graphql_sqlalchemy import build_schema
from graphql_sqlalchemy.testing import JsonArray, assert_equal_gql_type

# Tested types


Expand Down Expand Up @@ -77,15 +78,15 @@ class Project(Base):
)
def test_build_schema_simple(field: str, gql_type: GraphQLScalarType) -> None:
schema = build_schema(Base)
user = cast(Union[GraphQLObjectType, None], schema.get_type("user"))
user = cast("Union[GraphQLObjectType, None]", schema.get_type("user"))
assert user
f: GraphQLField = user.fields[field]
assert_equal_gql_type(f.type, GraphQLNonNull(gql_type))


def test_build_schema_rel() -> None:
schema = build_schema(Base)
user = cast(Union[GraphQLObjectType, None], schema.get_type("user"))
user = cast("Union[GraphQLObjectType, None]", schema.get_type("user"))
assert user
f: GraphQLField = user.fields["projects"]
assert isinstance(f.type, GraphQLNonNull)
Expand Down
13 changes: 7 additions & 6 deletions tests/test_graphql/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@

import pytest
from graphql import ExecutionResult, GraphQLSchema, graphql, graphql_sync
from graphql_sqlalchemy.schema import build_schema
from graphql_sqlalchemy.testing import JsonArray
from sqlalchemy import Column, Engine, ForeignKey, String, Table
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, registry, relationship

from graphql_sqlalchemy.schema import build_schema
from graphql_sqlalchemy.testing import JsonArray

if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Callable

Expand Down Expand Up @@ -96,7 +97,7 @@ def gql_schema() -> GraphQLSchema:
return build_schema(Base)


@pytest.fixture()
@pytest.fixture
async def example_session(
db_engine: Engine | AsyncEngine, db_session: Session | AsyncSession
) -> AsyncGenerator[Session | AsyncSession, None]:
Expand Down Expand Up @@ -129,7 +130,7 @@ def raise_if_errors(result: ExecutionResult) -> None:
raise result.errors[0] if len(result.errors) == 1 else ExceptionGroup("Invalid Query", result.errors)


@pytest.fixture()
@pytest.fixture
def graphql_example(
example_session: Session | AsyncSession, gql_schema: GraphQLSchema
) -> Callable[[str], dict[str, Any]]:
Expand All @@ -155,15 +156,15 @@ async def gql_async(session: AsyncSession) -> ExecutionResult:
return graphql_


@pytest.fixture()
@pytest.fixture
def query_example(graphql_example: Callable[[str], dict[str, Any]]) -> Callable[[str], dict[str, Any]]:
def query(source: str) -> dict[str, Any]:
return graphql_example(f"query {{\n{indent(source, ' ')}\n}}")

return query


@pytest.fixture()
@pytest.fixture
def mutation_example(graphql_example: Callable[[str], dict[str, Any]]) -> Callable[[str], dict[str, Any]]:
def mutation(source: str) -> dict[str, Any]:
return graphql_example(f"mutation {{\n{indent(source, ' ')}\n}}")
Expand Down
5 changes: 3 additions & 2 deletions tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
GraphQLScalarType,
GraphQLString,
)
from graphql_sqlalchemy.graphql_types import get_graphql_type_from_column, get_graphql_type_from_python
from graphql_sqlalchemy.testing import assert_equal_gql_type
from sqlalchemy import ARRAY, Boolean, Column, Float, Integer, String
from sqlalchemy import Enum as SqlaEnum

from graphql_sqlalchemy.graphql_types import get_graphql_type_from_column, get_graphql_type_from_python
from graphql_sqlalchemy.testing import assert_equal_gql_type

if sys.version_info >= (3, 10):
str_or_none = str | None
else:
Expand Down
1 change: 1 addition & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest
from graphql import GraphQLInt, GraphQLString

from graphql_sqlalchemy.testing import assert_equal_gql_type


Expand Down