Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 0 additions & 16 deletions fastcrud/core/config/crud_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import Any, Callable, Sequence, Optional, Annotated
from pydantic import BaseModel, Field
from pydantic.functional_validators import field_validator
from fastapi import Depends, Query

from ...types import ModelType

Expand Down Expand Up @@ -336,21 +335,6 @@ def __init__(self, **kwargs: Any) -> None:
filters.update(kwargs)
super().__init__(filters=filters)

def get_params(self) -> dict[str, Any]:
"""
Get FastAPI parameter definitions for the configured filters.

Returns:
Dictionary mapping parameter names to FastAPI parameter objects.
"""
params = {}
for key, value in self.filters.items():
if callable(value):
params[key] = Depends(value)
else:
params[key] = Query(value)
return params

def is_joined_filter(self, filter_key: str) -> bool:
"""
Check if a filter key represents a joined model filter (contains dot notation).
Expand Down
11 changes: 10 additions & 1 deletion fastcrud/core/filtering/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

FilterCallable = Callable[[Column[Any]], Callable[..., ColumnElement[bool]]]

COLLECTION_OPERATORS = {"in", "not_in", "between"}

SUPPORTED_FILTERS: dict[str, FilterCallable] = {
"eq": lambda column: column.__eq__,
"gt": lambda column: column.__gt__,
Expand Down Expand Up @@ -64,7 +66,7 @@ def get_sqlalchemy_filter(
>>> # This will raise ValueError
>>> get_sqlalchemy_filter('in', 'invalid') # Should be list/tuple/set
"""
if operator in {"in", "not_in", "between"}:
if operator in COLLECTION_OPERATORS:
if not isinstance(value, (tuple, list, set)):
raise ValueError(f"<{operator}> filter must be tuple, list or set")

Expand All @@ -76,3 +78,10 @@ def get_sqlalchemy_filter(
raise ValueError("Between operator requires exactly 2 values")

return SUPPORTED_FILTERS.get(operator)


def get_operator_wrap_type(operator: str) -> Optional[type]:
if operator in COLLECTION_OPERATORS:
return list

return None
76 changes: 74 additions & 2 deletions fastcrud/core/filtering/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,38 @@
"""

from typing import Any, Optional, Union
from pydantic import BaseModel, ConfigDict, SkipValidation, computed_field
from sqlalchemy import Column, or_, not_, and_
from sqlalchemy.orm.util import AliasedClass
from sqlalchemy.sql.elements import ColumnElement

from ..introspection import get_model_column
from ..config.crud_configs import FilterConfig
from ..introspection import get_column_types, get_model_column
from ...types import ModelType, FilterValueType
from .operators import get_sqlalchemy_filter
from .operators import get_operator_wrap_type, get_sqlalchemy_filter
from .validators import validate_joined_filter_format


class Filter(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

definition: str
param_name: str
default_value: Any
operator: Optional[str]
wrap_type: Optional[type]
joined_model: Optional[type]
column: SkipValidation[Column[Any]]
value_type: type

@computed_field
def type(self) -> type:
if self.wrap_type:
return self.wrap_type[self.value_type]

return self.value_type


class FilterProcessor:
"""
Processes filter arguments into SQLAlchemy filter conditions.
Expand Down Expand Up @@ -349,6 +371,56 @@ def _handle_joined_filter(
else:
return self._handle_standard_filter(target_column, operator, value)

def interpret_filters(self, filter_config: FilterConfig) -> list[Filter]:
filters: list[Filter] = []

for filter_definition, default_value in filter_config.filters.items():
field_and_operator = filter_definition.rsplit("__", 1)
field_definition = field_and_operator[0]

param_name = filter_definition.replace(".", "_")

operator = field_and_operator[1] if len(field_and_operator) > 1 else None
wrap_type = get_operator_wrap_type(operator) if operator else None

is_joined_model = filter_config.is_joined_filter(filter_definition)

if is_joined_model:
validate_joined_filter_format(filter_definition)

relationship_name, column_name = field_definition.split(".", 1)
relationship_column = get_model_column(self.model, relationship_name)

if not hasattr(relationship_column.property, "mapper"):
raise ValueError(
f"Invalid relationship '{relationship_name}' in model '{self.model.__name__}'"
)

joined_model = relationship_column.property.mapper.class_
model_column_types = dict(get_column_types(joined_model))
column = get_model_column(joined_model, column_name)
value_type = model_column_types[column_name]
else:
joined_model = None
column_name = field_definition
column = get_model_column(self.model, column_name)
value_type = dict(get_column_types(self.model))[column_name]

filters.append(
Filter(
definition=filter_definition,
param_name=param_name,
default_value=default_value,
operator=operator,
wrap_type=wrap_type,
joined_model=joined_model,
value_type=value_type,
column=column,
)
)

return filters

def separate_joined_filters(
self, **kwargs: Any
) -> tuple[dict[str, Any], dict[str, Any]]:
Expand Down
2 changes: 1 addition & 1 deletion fastcrud/endpoint/endpoint_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def _read_items(
The query parameters are encapsulated in PaginatedRequestQuery schema,
which can be reused in custom endpoints.
"""
dynamic_filters = create_dynamic_filters(self.filter_config, self.column_types)
dynamic_filters = create_dynamic_filters(self.filter_config, self.model)

async def endpoint(
db: AsyncSession = Depends(self.session),
Expand Down
51 changes: 22 additions & 29 deletions fastcrud/fastapi_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@

from fastapi import Depends, Query, Path, params

from fastcrud.core.filtering.processor import FilterProcessor
from fastcrud.types import ModelType

if TYPE_CHECKING:
from .core.config import CreateConfig, UpdateConfig, DeleteConfig, FilterConfig

Expand Down Expand Up @@ -81,7 +84,8 @@ def auto_fields_resolver(**kwargs: Any) -> dict[str, Any]:


def create_dynamic_filters(
filter_config: Optional["FilterConfig"], column_types: dict[str, type]
filter_config: Optional["FilterConfig"],
model: ModelType,
) -> Callable[..., dict[str, Any]]:
"""
Create dynamic filter function for handling query parameters.
Expand All @@ -105,54 +109,43 @@ def create_dynamic_filters(
if filter_config is None:
return lambda: {}

param_to_filter_key = {}
for original_key in filter_config.filters.keys():
param_name = original_key.replace(".", "_")
param_to_filter_key[param_name] = original_key
filter_processor = FilterProcessor(model)
filters = filter_processor.interpret_filters(filter_config)

def filters(
def dependency_function(
**kwargs: Any,
) -> dict[str, Any]:
filtered_params = {}
for param_name, value in kwargs.items():
if value is not None:
original_key = param_to_filter_key.get(param_name, param_name)
key_without_op = original_key.rsplit("__", 1)[0]
parse_func = column_types.get(key_without_op)
if parse_func:
try:
filtered_params[original_key] = parse_func(value)
except (ValueError, TypeError):
filtered_params[original_key] = value
else:
filtered_params[original_key] = value
return filtered_params
return {
filter.definition: filter.default_value
for filter in filters
if filter.default_value
}

params = []
for key, value in filter_config.filters.items():
param_name = key.replace(".", "_")

if callable(value):
for filter in filters:
if callable(filter.default_value):
params.append(
inspect.Parameter(
param_name,
filter.param_name,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=Depends(value),
default=Depends(filter.default_value),
)
)
else:
params.append(
inspect.Parameter(
param_name,
filter.param_name,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=Query(value, alias=key),
annotation=filter.type,
default=Query(filter.default_value, alias=filter.definition),
)
)

sig = inspect.Signature(params)
setattr(filters, "__signature__", sig)
setattr(dependency_function, "__signature__", sig)

return filters
return dependency_function


def inject_dependencies(
Expand Down
15 changes: 15 additions & 0 deletions tests/sqlalchemy/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest_asyncio
from sqlalchemy import (
Column,
Float,
Integer,
String,
ForeignKey,
Expand Down Expand Up @@ -213,6 +214,15 @@ class ModelWithOrgTest(Base):
deleted_at = Column(DateTime, nullable=True, default=None)


class ModelWithTypes(Base):
__tablename__ = "test_with_types"
id = Column(Integer, primary_key=True)
str_param = Column(String(32))
int_param = Column(Integer)
float_param = Column(Float)
bool_param = Column(Boolean)


# Models for testing joined model filtering
class Company(Base):
__tablename__ = "company"
Expand Down Expand Up @@ -551,6 +561,11 @@ def test_model_with_org():
return ModelWithOrgTest


@pytest.fixture
def test_model_with_types():
return ModelWithTypes


async def test_read_dep():
pass

Expand Down
47 changes: 47 additions & 0 deletions tests/sqlalchemy/endpoint/test_filter_param_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from pydantic import BaseModel
import pytest
from fastapi import FastAPI

from fastcrud import crud_router, FilterConfig


class NullSchema(BaseModel):
pass


@pytest.fixture
def app_with_filter_params(test_model_with_types, async_session):
app = FastAPI()

app.include_router(
crud_router(
session=lambda: async_session,
model=test_model_with_types,
create_schema=NullSchema,
update_schema=NullSchema,
filter_config=FilterConfig(
int_param=None, float_param=None, str_param=None, bool_param=None
),
path="/test",
tags=["test"],
)
)

return app


@pytest.mark.asyncio
async def test_dependency_filtered_endpoint(app_with_filter_params):
"""Test that filter query parameters are correctly typed in the OpenAPI schema."""

schema = app_with_filter_params.openapi()

def get_type(param_name: str):
params = schema["paths"]["/test"]["get"]["parameters"]
param = next(item for item in params if item["name"] == param_name)
return param["schema"]["type"]

assert get_type("int_param") == "integer"
assert get_type("float_param") == "number"
assert get_type("str_param") == "string"
assert get_type("bool_param") == "boolean"
Loading