diff --git a/backend/timed/reports/views.py b/backend/timed/reports/views.py index 1c01110d4..4032232b7 100644 --- a/backend/timed/reports/views.py +++ b/backend/timed/reports/views.py @@ -35,7 +35,52 @@ from timed.employment.models import User -class YearStatisticViewSet(AggregateQuerysetMixin, ReadOnlyModelViewSet): +class BaseStatisticQuerysetMixin: + """Base statistics queryset mixin. + + Build and filter the statistics queryset according to the following + principles: + + 0) For every statistic view (year, month, customer, project, task, user) + we use the same basic queryset and the same filterset. + 1) Build up a full queryset with annotations and everything we need + from a *task* perspective. + 2) Filter the queryset in the sxact same way in all the viewsets. + 3) Annotate the queryset in the viewset, according to their needs. This will + also cause the GROUP BY to happen as needed. + + For this to work, each viewset defines two properties: + + * The `qs_fields` define which fields are to be selected + * The `pk_field` is an expression that will be used as a primary key in the + REST sense (not really related to the database primary key, but serves as + a row identifier) + + And because we use the report queryset as our base, we can easily reuse + the report filterset as well. + """ + + def get_queryset(self): + return ( + Report.objects.all() + .select_related("user", "task", "task__project", "task__project__customer") + .annotate(year=ExtractYear("date")) + .annotate(month=ExtractYear("date") * 100 + ExtractMonth("date")) + ) + + def filter_queryset(self, queryset): + queryset = super().filter_queryset(queryset) + if isinstance(self.qs_fields, dict): + # qs fields need to be aliased + queryset = queryset.annotate(**self.qs_fields) + + queryset = queryset.values(*list(self.qs_fields)) + queryset = queryset.annotate(duration=Sum("duration")) + queryset = queryset.annotate(pk=F(self.pk_field)) + return queryset + + +class YearStatisticViewSet(BaseStatisticQuerysetMixin, ReadOnlyModelViewSet): """Year statistics calculates total reported time per year.""" serializer_class = serializers.YearStatisticSerializer @@ -52,14 +97,11 @@ class YearStatisticViewSet(AggregateQuerysetMixin, ReadOnlyModelViewSet): ), ) - def get_queryset(self): - queryset = Report.objects.all() - queryset = queryset.annotate(year=ExtractYear("date")).values("year") - queryset = queryset.annotate(duration=Sum("duration")) - return queryset.annotate(pk=F("year")) + qs_fields = ("year", "duration") + pk_field = "year" -class MonthStatisticViewSet(AggregateQuerysetMixin, ReadOnlyModelViewSet): +class MonthStatisticViewSet(BaseStatisticQuerysetMixin, ReadOnlyModelViewSet): """Month statistics calculates total reported time per month.""" serializer_class = serializers.MonthStatisticSerializer @@ -80,89 +122,17 @@ class MonthStatisticViewSet(AggregateQuerysetMixin, ReadOnlyModelViewSet): ), ) - def get_queryset(self): - queryset = Report.objects.all() - queryset = queryset.annotate( - year=ExtractYear("date"), month=ExtractMonth("date") - ) - queryset = queryset.values("year", "month") - queryset = queryset.annotate(duration=Sum("duration")) - return queryset.annotate(pk=F("year") * 100 + F("month")) - - -class StatisticQueryset(QuerySet): - def __init__(self, catch_prefixes, *args, base_qs=None, agg_filters=None, **kwargs): - super().__init__(*args, **kwargs) - if base_qs is None: - base_qs = self.model.objects.all() - self._base = base_qs - self._agg_filters = agg_filters - self._catch_prefixes = catch_prefixes - - def filter(self, /, **kwargs): - my_filters = { - k: v for k, v in kwargs.items() if not k.startswith(self._catch_prefixes) - } - - agg_filters = { - k: v for k, v in kwargs.items() if k.startswith(self._catch_prefixes) - } - - new_qs = self - if my_filters: - new_qs = self.filter_base(**my_filters) - if agg_filters: - new_qs = new_qs.filter_aggregate(**agg_filters) - - return new_qs - - def filter_base(self, *args, **kwargs): - filtered = ( - self.model.objects.filter(*args, **kwargs) - .values("pk") - .filter(pk=OuterRef("pk")) - ) - return StatisticQueryset( - model=self.model, - base_qs=self._base.filter(Exists(filtered)), - catch_prefixes=self._catch_prefixes, - agg_filters=self._agg_filters, - ) - - def _clone(self): - return StatisticQueryset( - model=self.model, - base_qs=self._base._clone(), # noqa: SLF001 - catch_prefixes=self._catch_prefixes, - agg_filters=self._agg_filters, - ) - - def __str__(self) -> str: - return f"StatisticQueryset({self._base!s} | {self._agg_filters!s})" - - def __repr__(self) -> str: - return f"StatisticQueryset({self._base!r} | {self._agg_filters!r})" + qs_fields = ("year", "month", "duration") + pk_field = "month" - def filter_aggregate(self, *args, **kwargs): - filter_q = Q(*args, **kwargs) - new_filters = self._agg_filters & filter_q if self._agg_filters else filter_q - - return StatisticQueryset( - model=self.model, - base_qs=self._base, - catch_prefixes=self._catch_prefixes, - agg_filters=new_filters, - ) - - -class CustomerStatisticViewSet(AggregateQuerysetMixin, ReadOnlyModelViewSet): +class CustomerStatisticViewSet(BaseStatisticQuerysetMixin, ReadOnlyModelViewSet): """Customer statistics calculates total reported time per customer.""" serializer_class = serializers.CustomerStatisticSerializer - filterset_class = filters.CustomerStatisticFilterSet + filterset_class = ReportFilterSet ordering_fields = ( - "name", + "task__project__customer__name", "duration", "estimated_time", "remaining_effort", @@ -175,23 +145,27 @@ class CustomerStatisticViewSet(AggregateQuerysetMixin, ReadOnlyModelViewSet): (IsInternal | IsSuperUser) & IsAuthenticated ), ) - - def get_queryset(self): - return StatisticQueryset(model=Customer, catch_prefixes="projects__") + qs_fields = { # noqa: RUF012 + "year": F("year"), + "month": F("month"), + "name": F("task__project__customer__name"), + "customer_id": F("task__project__customer_id"), + } + pk_field = "customer_id" class ProjectStatisticViewSet(AggregateQuerysetMixin, ReadOnlyModelViewSet): """Project statistics calculates total reported time per project.""" serializer_class = serializers.ProjectStatisticSerializer - filterset_class = filters.ProjectStatisticFilterSet + filterset_class = ReportFilterSet ordering_fields = ( - "name", + "task__project__name", "duration", "estimated_time", "remaining_effort", ) - ordering = ("name",) + ordering = ("task__project__name",) permission_classes = ( ( # internal employees or super users may read all customer statistics @@ -199,22 +173,27 @@ class ProjectStatisticViewSet(AggregateQuerysetMixin, ReadOnlyModelViewSet): ), ) - def get_queryset(self): - return StatisticQueryset(model=Project, catch_prefixes="tasks__") + qs_fields = { # noqa: RUF012 + "year": F("year"), + "month": F("month"), + "name": F("task__project__name"), + "project_id": F("task__project_id"), + } + pk_field = "project_id" -class TaskStatisticViewSet(AggregateQuerysetMixin, ReadOnlyModelViewSet): +class TaskStatisticViewSet(BaseStatisticQuerysetMixin, ReadOnlyModelViewSet): """Task statistics calculates total reported time per task.""" serializer_class = serializers.TaskStatisticSerializer - filterset_class = filters.TaskStatisticFilterSet + filterset_class = ReportFilterSet ordering_fields = ( - "name", + "task__name", "duration", "estimated_time", "remaining_effort", ) - ordering = ("name",) + ordering = ("task__name",) permission_classes = ( ( # internal employees or super users may read all customer statistics @@ -222,11 +201,15 @@ class TaskStatisticViewSet(AggregateQuerysetMixin, ReadOnlyModelViewSet): ), ) - def get_queryset(self): - return StatisticQueryset(model=Task, catch_prefixes="tasks__") + qs_fields = { # noqa: RUF012 + "year": F("year"), + "month": F("month"), + "name": F("task__name"), + } + pk_field = "task_id" -class UserStatisticViewSet(AggregateQuerysetMixin, ReadOnlyModelViewSet): +class UserStatisticViewSet(BaseStatisticQuerysetMixin, ReadOnlyModelViewSet): """User calculates total reported time per user.""" serializer_class = serializers.UserStatisticSerializer @@ -243,11 +226,7 @@ class UserStatisticViewSet(AggregateQuerysetMixin, ReadOnlyModelViewSet): ), ) - def get_queryset(self): - queryset = Report.objects.all() - queryset = queryset.values("user") - queryset = queryset.annotate(duration=Sum("duration")) - return queryset.annotate(pk=F("user")) + pk_field = "user" class WorkReportViewSet(GenericViewSet):