Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions todo/constants/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class ValidationErrors:
MISSING_EMAIL = "Email is required"
MISSING_NAME = "Name is required"
MISSING_PICTURE = "Picture is required"
TEAM_ID_REQUIRED_FOR_ASSIGNEE_FILTER = "teamId is required when filtering by assigneeId."
SEARCH_QUERY_EMPTY = "Search query cannot be empty"
TASK_ID_STRING_REQUIRED = "Task ID must be a string."
INVALID_IS_ACTIVE_VALUE = "Invalid value for is_active"
Expand Down
118 changes: 105 additions & 13 deletions todo/repositories/task_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,51 @@ def _get_team_task_ids(cls, team_id: str) -> List[ObjectId]:
team_task_ids = [ObjectId(task["task_id"]) for task in team_tasks]
return list(set(team_task_ids))

@classmethod
def _get_task_ids_for_assignees(cls, assignee_ids: List[str], team_id: str | None = None) -> List[ObjectId]:
"""
Resolve active task IDs for the provided assignee IDs, optionally scoped to a team.
"""
if not assignee_ids:
return []

assignee_lookup_values = set()
for assignee_id in assignee_ids:
assignee_lookup_values.add(assignee_id)
if ObjectId.is_valid(assignee_id):
assignee_lookup_values.add(ObjectId(assignee_id))

if not assignee_lookup_values:
return []

assignment_collection = TaskAssignmentRepository.get_collection()
assignment_filter: dict = {
"assignee_id": {"$in": list(assignee_lookup_values)},
"user_type": "user",
"is_active": True,
}

if team_id:
team_candidates = {team_id}
if ObjectId.is_valid(team_id):
team_candidates.add(ObjectId(team_id))
assignment_filter["team_id"] = {"$in": list(team_candidates)}

assignments = assignment_collection.find(
assignment_filter,
{"task_id": 1},
)

task_ids: set[ObjectId] = set()
for assignment in assignments:
task_identifier = assignment.get("task_id")
if isinstance(task_identifier, ObjectId):
task_ids.add(task_identifier)
elif isinstance(task_identifier, str) and ObjectId.is_valid(task_identifier):
task_ids.add(ObjectId(task_identifier))

return list(task_ids)

@classmethod
def _build_status_filter(cls, status_filter: str = None) -> dict:
now = datetime.now(timezone.utc)
Expand Down Expand Up @@ -72,19 +117,41 @@ def list(
user_id: str = None,
team_id: str = None,
status_filter: str = None,
assignee_ids: List[str] | None = None,
) -> List[TaskModel]:
tasks_collection = cls.get_collection()

base_filter = cls._build_status_filter(status_filter)

if team_id:
filters = [base_filter]

team_scope_applied = False

if assignee_ids:
assignee_task_ids = cls._get_task_ids_for_assignees(assignee_ids, team_id=team_id)
if not assignee_task_ids:
return []
filters.append({"_id": {"$in": assignee_task_ids}})
if team_id:
team_scope_applied = True
elif team_id:
all_team_task_ids = cls._get_team_task_ids(team_id)
query_filter = {"$and": [base_filter, {"_id": {"$in": all_team_task_ids}}]}
elif user_id:
if not all_team_task_ids:
return []
filters.append({"_id": {"$in": all_team_task_ids}})
team_scope_applied = True

if user_id and not team_scope_applied:
assigned_task_ids = cls._get_assigned_task_ids_for_user(user_id)
query_filter = {"$and": [base_filter, {"_id": {"$in": assigned_task_ids}}]}
user_filters = [{"createdBy": user_id}]
if assigned_task_ids:
user_filters.append({"_id": {"$in": assigned_task_ids}})
filters.append({"$or": user_filters})

if len(filters) == 1:
query_filter = filters[0]
else:
query_filter = base_filter
query_filter = {"$and": filters}

if sort_by == SORT_FIELD_UPDATED_AT:
sort_direction = -1 if order == SORT_ORDER_DESC else 1
Expand Down Expand Up @@ -149,22 +216,47 @@ def _get_assigned_task_ids_for_user(cls, user_id: str) -> List[ObjectId]:
return direct_task_ids + team_task_ids

@classmethod
def count(cls, user_id: str = None, team_id: str = None, status_filter: str = None) -> int:
def count(
cls,
user_id: str = None,
team_id: str = None,
status_filter: str = None,
assignee_ids: List[str] | None = None,
) -> int:
tasks_collection = cls.get_collection()

base_filter = cls._build_status_filter(status_filter)

if team_id:
filters = [base_filter]

team_scope_applied = False

if assignee_ids:
assignee_task_ids = cls._get_task_ids_for_assignees(assignee_ids, team_id=team_id)
if not assignee_task_ids:
return 0
filters.append({"_id": {"$in": assignee_task_ids}})
if team_id:
team_scope_applied = True
elif team_id:
all_team_task_ids = cls._get_team_task_ids(team_id)
query_filter = {"$and": [base_filter, {"_id": {"$in": all_team_task_ids}}]}
if not all_team_task_ids:
return 0
filters.append({"_id": {"$in": all_team_task_ids}})
team_scope_applied = True

elif user_id:
if user_id and not team_scope_applied:
assigned_task_ids = cls._get_assigned_task_ids_for_user(user_id)
query_filter = {
"$and": [base_filter, {"$or": [{"createdBy": user_id}, {"_id": {"$in": assigned_task_ids}}]}]
}
user_filters = [{"createdBy": user_id}]
if assigned_task_ids:
user_filters.append({"_id": {"$in": assigned_task_ids}})
filters.append({"$or": user_filters})

if len(filters) == 1:
query_filter = filters[0]
else:
query_filter = base_filter
query_filter = {"$and": filters}

return tasks_collection.count_documents(query_filter)

@classmethod
Expand Down
33 changes: 32 additions & 1 deletion todo/serializers/get_tasks_serializer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from rest_framework import serializers
from django.conf import settings
from rest_framework import serializers
from bson import ObjectId

from todo.constants.task import SORT_FIELDS, SORT_ORDERS, SORT_FIELD_UPDATED_AT, SORT_FIELD_DEFAULT_ORDERS, TaskStatus
from todo.constants.messages import ValidationErrors


class CaseInsensitiveChoiceField(serializers.ChoiceField):
Expand All @@ -11,6 +13,19 @@ def to_internal_value(self, data):
return super().to_internal_value(data)


class QueryParameterListField(serializers.ListField):
"""
DRF list field that understands QueryDict inputs with repeated parameters.
"""

def get_value(self, dictionary):
if hasattr(dictionary, "getlist") and self.field_name in dictionary:
values = dictionary.getlist(self.field_name)
if values:
return values
return super().get_value(dictionary)


class GetTaskQueryParamsSerializer(serializers.Serializer):
page = serializers.IntegerField(
required=False,
Expand Down Expand Up @@ -44,6 +59,11 @@ class GetTaskQueryParamsSerializer(serializers.Serializer):

teamId = serializers.CharField(required=False, allow_blank=False, allow_null=True)

assigneeId = QueryParameterListField(
child=serializers.CharField(allow_blank=False),
required=False,
)

status = CaseInsensitiveChoiceField(
choices=[status.value for status in TaskStatus],
required=False,
Expand All @@ -57,4 +77,15 @@ def validate(self, attrs):
sort_by = validated_data.get("sort_by", SORT_FIELD_UPDATED_AT)
validated_data["order"] = SORT_FIELD_DEFAULT_ORDERS[sort_by]

assignee_ids = validated_data.pop("assigneeId", None)
if assignee_ids is not None:
normalized_ids = list(dict.fromkeys(assignee_ids))
invalid_ids = [assignee_id for assignee_id in normalized_ids if not ObjectId.is_valid(assignee_id)]
if invalid_ids:
raise serializers.ValidationError(
{"assigneeId": [ValidationErrors.INVALID_OBJECT_ID.format(invalid_ids[0])]}
)

validated_data["assignee_ids"] = normalized_ids

return validated_data
14 changes: 12 additions & 2 deletions todo/services/task_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def get_tasks(
user_id: str,
team_id: str = None,
status_filter: str = None,
assignee_ids: List[str] | None = None,
) -> GetTasksResponse:
try:
cls._validate_pagination_params(page, limit)
Expand All @@ -93,9 +94,18 @@ def get_tasks(
)

tasks = TaskRepository.list(
page, limit, sort_by, order, user_id, team_id=team_id, status_filter=status_filter
page,
limit,
sort_by,
order,
user_id,
team_id=team_id,
status_filter=status_filter,
assignee_ids=assignee_ids,
)
total_count = TaskRepository.count(
user_id, team_id=team_id, status_filter=status_filter, assignee_ids=assignee_ids
)
total_count = TaskRepository.count(user_id, team_id=team_id, status_filter=status_filter)

if not tasks:
return GetTasksResponse(tasks=[], links=None)
Expand Down
54 changes: 48 additions & 6 deletions todo/tests/integration/test_task_sorting_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,14 @@ def test_priority_sorting_integration(self, mock_list, mock_count):

self.assertEqual(response.status_code, status.HTTP_200_OK)
mock_list.assert_called_with(
1, 20, SORT_FIELD_PRIORITY, SORT_ORDER_DESC, str(self.user_id), team_id=None, status_filter=None
1,
20,
SORT_FIELD_PRIORITY,
SORT_ORDER_DESC,
str(self.user_id),
team_id=None,
status_filter=None,
assignee_ids=None,
)

@patch("todo.repositories.task_repository.TaskRepository.count")
Expand All @@ -40,7 +47,14 @@ def test_due_at_default_order_integration(self, mock_list, mock_count):
self.assertEqual(response.status_code, status.HTTP_200_OK)

mock_list.assert_called_with(
1, 20, SORT_FIELD_DUE_AT, SORT_ORDER_ASC, str(self.user_id), team_id=None, status_filter=None
1,
20,
SORT_FIELD_DUE_AT,
SORT_ORDER_ASC,
str(self.user_id),
team_id=None,
status_filter=None,
assignee_ids=None,
)

@patch("todo.repositories.task_repository.TaskRepository.count")
Expand All @@ -55,7 +69,14 @@ def test_assignee_sorting_uses_aggregation(self, mock_list, mock_count):

# Assignee sorting now falls back to createdAt sorting
mock_list.assert_called_once_with(
1, 20, SORT_FIELD_ASSIGNEE, SORT_ORDER_ASC, str(self.user_id), team_id=None, status_filter=None
1,
20,
SORT_FIELD_ASSIGNEE,
SORT_ORDER_ASC,
str(self.user_id),
team_id=None,
status_filter=None,
assignee_ids=None,
)

@patch("todo.repositories.task_repository.TaskRepository.count")
Expand All @@ -81,7 +102,14 @@ def test_field_specific_defaults_integration(self, mock_list, mock_count):

self.assertEqual(response.status_code, status.HTTP_200_OK)
mock_list.assert_called_with(
1, 20, sort_field, expected_order, str(self.user_id), team_id=None, status_filter=None
1,
20,
sort_field,
expected_order,
str(self.user_id),
team_id=None,
status_filter=None,
assignee_ids=None,
)

@patch("todo.repositories.task_repository.TaskRepository.count")
Expand All @@ -95,7 +123,14 @@ def test_pagination_with_sorting_integration(self, mock_list, mock_count):
self.assertEqual(response.status_code, status.HTTP_200_OK)

mock_list.assert_called_with(
3, 5, SORT_FIELD_CREATED_AT, SORT_ORDER_ASC, str(self.user_id), team_id=None, status_filter=None
3,
5,
SORT_FIELD_CREATED_AT,
SORT_ORDER_ASC,
str(self.user_id),
team_id=None,
status_filter=None,
assignee_ids=None,
)

def test_invalid_sort_parameters_integration(self):
Expand All @@ -116,7 +151,14 @@ def test_default_behavior_integration(self, mock_list, mock_count):
self.assertEqual(response.status_code, status.HTTP_200_OK)

mock_list.assert_called_with(
1, 20, SORT_FIELD_UPDATED_AT, SORT_ORDER_DESC, str(self.user_id), team_id=None, status_filter=None
1,
20,
SORT_FIELD_UPDATED_AT,
SORT_ORDER_DESC,
str(self.user_id),
team_id=None,
status_filter=None,
assignee_ids=None,
)

@patch("todo.repositories.user_repository.UserRepository.get_by_id")
Expand Down
2 changes: 2 additions & 0 deletions todo/tests/integration/test_tasks_pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def test_pagination_settings_integration(self, mock_get_tasks):
user_id=str(self.user_id),
team_id=None,
status_filter=None,
assignee_ids=None,
)

mock_get_tasks.reset_mock()
Expand All @@ -43,6 +44,7 @@ def test_pagination_settings_integration(self, mock_get_tasks):
user_id=str(self.user_id),
team_id=None,
status_filter=None,
assignee_ids=None,
)

# Verify API rejects values above max limit
Expand Down
Loading