diff --git a/alembic/versions/2025_05_13_0704-864107b703ae_create_url_checked_for_duplicate_table.py b/alembic/versions/2025_05_13_0704-864107b703ae_create_url_checked_for_duplicate_table.py new file mode 100644 index 00000000..2719d33c --- /dev/null +++ b/alembic/versions/2025_05_13_0704-864107b703ae_create_url_checked_for_duplicate_table.py @@ -0,0 +1,78 @@ +"""Create url_checked_for_duplicate table + +Revision ID: 864107b703ae +Revises: 9d4002437ebe +Create Date: 2025-05-13 07:04:22.592396 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +from util.alembic_helpers import switch_enum_type + +# revision identifiers, used by Alembic. +revision: str = '864107b703ae' +down_revision: Union[str, None] = '9d4002437ebe' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + 'url_checked_for_duplicate', + sa.Column( + 'id', + sa.Integer(), + primary_key=True + ), + sa.Column( + 'url_id', + sa.Integer(), + sa.ForeignKey( + 'urls.id', + ondelete='CASCADE' + ), + nullable=False + ), + sa.Column( + 'created_at', + sa.DateTime(), + nullable=False, + server_default=sa.text('now()') + ), + ) + + switch_enum_type( + table_name='tasks', + column_name='task_type', + enum_name='task_type', + new_enum_values=[ + "HTML", + "Relevancy", + "Record Type", + "Agency Identification", + "Misc Metadata", + "Submit Approved URLs", + "Duplicate Detection" + ] + ) + + +def downgrade() -> None: + op.drop_table('url_checked_for_duplicate') + + switch_enum_type( + table_name='tasks', + column_name='task_type', + enum_name='task_type', + new_enum_values=[ + "HTML", + "Relevancy", + "Record Type", + "Agency Identification", + "Misc Metadata", + "Submit Approved URLs", + ] + ) diff --git a/collector_db/AsyncDatabaseClient.py b/collector_db/AsyncDatabaseClient.py index 5d28f70f..03c652c9 100644 --- a/collector_db/AsyncDatabaseClient.py +++ b/collector_db/AsyncDatabaseClient.py @@ -29,7 +29,7 @@ RootURL, Task, TaskError, LinkTaskURL, Batch, Agency, AutomatedUrlAgencySuggestion, \ UserUrlAgencySuggestion, AutoRelevantSuggestion, AutoRecordTypeSuggestion, UserRelevantSuggestion, \ UserRecordTypeSuggestion, ReviewingUserURL, URLOptionalDataSourceMetadata, ConfirmedURLAgency, Duplicate, Log, \ - BacklogSnapshot, URLDataSource + BacklogSnapshot, URLDataSource, URLCheckedForDuplicate from collector_manager.enums import URLStatus, CollectorType from core.DTOs.AllAnnotationPostInfo import AllAnnotationPostInfo from core.DTOs.FinalReviewApprovalInfo import FinalReviewApprovalInfo @@ -60,6 +60,7 @@ from core.DTOs.URLAgencySuggestionInfo import URLAgencySuggestionInfo from core.DTOs.task_data_objects.AgencyIdentificationTDO import AgencyIdentificationTDO from core.DTOs.task_data_objects.SubmitApprovedURLTDO import SubmitApprovedURLTDO, SubmittedURLInfo +from core.DTOs.task_data_objects.URLDuplicateTDO import URLDuplicateTDO from core.DTOs.task_data_objects.URLMiscellaneousMetadataTDO import URLMiscellaneousMetadataTDO, URLHTMLMetadataInfo from core.EnvVarManager import EnvVarManager from core.enums import BatchStatus, SuggestionType, RecordType @@ -2224,4 +2225,48 @@ async def populate_backlog_snapshot( session.add(snapshot) + @session_manager + async def has_pending_urls_not_checked_for_duplicates(self, session: AsyncSession) -> bool: + query = (select( + URL.id + ).outerjoin( + URLCheckedForDuplicate, + URL.id == URLCheckedForDuplicate.url_id + ).where( + URL.outcome == URLStatus.PENDING.value, + URLCheckedForDuplicate.id == None + ).limit(1) + ) + raw_result = await session.execute(query) + result = raw_result.one_or_none() + return result is not None + + @session_manager + async def get_pending_urls_not_checked_for_duplicates(self, session: AsyncSession) -> List[URLDuplicateTDO]: + query = (select( + URL + ).outerjoin( + URLCheckedForDuplicate, + URL.id == URLCheckedForDuplicate.url_id + ).where( + URL.outcome == URLStatus.PENDING.value, + URLCheckedForDuplicate.id == None + ).limit(100) + ) + + raw_result = await session.execute(query) + urls = raw_result.scalars().all() + return [URLDuplicateTDO(url=url.url, url_id=url.id) for url in urls] + + + @session_manager + async def mark_all_as_duplicates(self, session: AsyncSession, url_ids: List[int]): + query = update(URL).where(URL.id.in_(url_ids)).values(outcome=URLStatus.DUPLICATE.value) + await session.execute(query) + + @session_manager + async def mark_as_checked_for_duplicates(self, session: AsyncSession, url_ids: list[int]): + for url_id in url_ids: + url_checked_for_duplicate = URLCheckedForDuplicate(url_id=url_id) + session.add(url_checked_for_duplicate) diff --git a/collector_db/enums.py b/collector_db/enums.py index b28b6091..d6b3ec0f 100644 --- a/collector_db/enums.py +++ b/collector_db/enums.py @@ -38,6 +38,7 @@ class TaskType(PyEnum): AGENCY_IDENTIFICATION = "Agency Identification" MISC_METADATA = "Misc Metadata" SUBMIT_APPROVED = "Submit Approved URLs" + DUPLICATE_DETECTION = "Duplicate Detection" IDLE = "Idle" class PGEnum(TypeDecorator): diff --git a/collector_db/models.py b/collector_db/models.py index b38243dd..b2a86e9c 100644 --- a/collector_db/models.py +++ b/collector_db/models.py @@ -141,7 +141,21 @@ class URL(Base): back_populates="url", uselist=False ) + checked_for_duplicate = relationship( + "URLCheckedForDuplicate", + uselist=False, + back_populates="url" + ) + +class URLCheckedForDuplicate(Base): + __tablename__ = 'url_checked_for_duplicate' + id = Column(Integer, primary_key=True) + url_id = Column(Integer, ForeignKey('urls.id'), nullable=False) + created_at = get_created_at_column() + + # Relationships + url = relationship("URL", uselist=False, back_populates="checked_for_duplicate") class URLOptionalDataSourceMetadata(Base): __tablename__ = 'url_optional_data_source_metadata' diff --git a/core/DTOs/task_data_objects/URLDuplicateTDO.py b/core/DTOs/task_data_objects/URLDuplicateTDO.py new file mode 100644 index 00000000..af00ce38 --- /dev/null +++ b/core/DTOs/task_data_objects/URLDuplicateTDO.py @@ -0,0 +1,9 @@ +from typing import Optional + +from pydantic import BaseModel + + +class URLDuplicateTDO(BaseModel): + url_id: int + url: str + is_duplicate: Optional[bool] = None diff --git a/core/TaskManager.py b/core/TaskManager.py index 052bdbc8..1dcc9bb5 100644 --- a/core/TaskManager.py +++ b/core/TaskManager.py @@ -1,5 +1,6 @@ import logging +from core.classes.task_operators.URLDuplicateTaskOperator import URLDuplicateTaskOperator from source_collectors.muckrock.MuckrockAPIInterface import MuckrockAPIInterface from collector_db.AsyncDatabaseClient import AsyncDatabaseClient from collector_db.DTOs.TaskInfo import TaskInfo @@ -96,9 +97,17 @@ async def get_url_miscellaneous_metadata_task_operator(self): ) return operator + async def get_url_duplicate_task_operator(self): + operator = URLDuplicateTaskOperator( + adb_client=self.adb_client, + pdap_client=self.pdap_client + ) + return operator + async def get_task_operators(self) -> list[TaskOperatorBase]: return [ await self.get_url_html_task_operator(), + await self.get_url_duplicate_task_operator(), # await self.get_url_relevance_huggingface_task_operator(), await self.get_url_record_type_task_operator(), await self.get_agency_identification_task_operator(), diff --git a/core/classes/task_operators/URLDuplicateTaskOperator.py b/core/classes/task_operators/URLDuplicateTaskOperator.py new file mode 100644 index 00000000..32cea432 --- /dev/null +++ b/core/classes/task_operators/URLDuplicateTaskOperator.py @@ -0,0 +1,33 @@ +from collector_db.AsyncDatabaseClient import AsyncDatabaseClient +from collector_db.enums import TaskType +from core.DTOs.task_data_objects.URLDuplicateTDO import URLDuplicateTDO +from core.classes.task_operators.TaskOperatorBase import TaskOperatorBase +from pdap_api_client.PDAPClient import PDAPClient + + +class URLDuplicateTaskOperator(TaskOperatorBase): + + def __init__( + self, + adb_client: AsyncDatabaseClient, + pdap_client: PDAPClient + ): + super().__init__(adb_client) + self.pdap_client = pdap_client + + @property + def task_type(self): + return TaskType.DUPLICATE_DETECTION + + async def meets_task_prerequisites(self): + return await self.adb_client.has_pending_urls_not_checked_for_duplicates() + + async def inner_task_logic(self): + tdos: list[URLDuplicateTDO] = await self.adb_client.get_pending_urls_not_checked_for_duplicates() + url_ids = [tdo.url_id for tdo in tdos] + await self.link_urls_to_task(url_ids=url_ids) + for tdo in tdos: + tdo.is_duplicate = await self.pdap_client.is_url_duplicate(tdo.url) + duplicate_url_ids = [tdo.url_id for tdo in tdos if tdo.is_duplicate] + await self.adb_client.mark_all_as_duplicates(duplicate_url_ids) + await self.adb_client.mark_as_checked_for_duplicates(url_ids) diff --git a/pdap_api_client/DTOs.py b/pdap_api_client/DTOs.py index 93f67839..342ad948 100644 --- a/pdap_api_client/DTOs.py +++ b/pdap_api_client/DTOs.py @@ -25,7 +25,7 @@ class ApprovalStatus(Enum): class UniqueURLDuplicateInfo(BaseModel): original_url: str approval_status: ApprovalStatus - rejection_note: str + rejection_note: Optional[str] = None class UniqueURLResponseInfo(BaseModel): is_unique: bool diff --git a/pdap_api_client/PDAPClient.py b/pdap_api_client/PDAPClient.py index 24b9d98c..ad3c74ea 100644 --- a/pdap_api_client/PDAPClient.py +++ b/pdap_api_client/PDAPClient.py @@ -59,10 +59,10 @@ async def match_agency( ) - async def is_url_unique( + async def is_url_duplicate( self, url_to_check: str - ) -> UniqueURLResponseInfo: + ) -> bool: """ Check if a URL is unique. Returns duplicate info otherwise """ @@ -79,11 +79,8 @@ async def is_url_unique( ) response_info = await self.access_manager.make_request(request_info) duplicates = [UniqueURLDuplicateInfo(**entry) for entry in response_info.data["duplicates"]] - is_unique = (len(duplicates) == 0) - return UniqueURLResponseInfo( - is_unique=is_unique, - duplicates=duplicates - ) + is_duplicate = (len(duplicates) != 0) + return is_duplicate async def submit_urls( self, diff --git a/tests/test_automated/integration/tasks/conftest.py b/tests/test_automated/integration/tasks/conftest.py new file mode 100644 index 00000000..6a925cc5 --- /dev/null +++ b/tests/test_automated/integration/tasks/conftest.py @@ -0,0 +1,20 @@ +from unittest.mock import MagicMock, AsyncMock + +import pytest + +from pdap_api_client.AccessManager import AccessManager +from pdap_api_client.PDAPClient import PDAPClient + + +@pytest.fixture +def mock_pdap_client() -> PDAPClient: + mock_access_manager = MagicMock( + spec=AccessManager + ) + mock_access_manager.jwt_header = AsyncMock( + return_value={"Authorization": "Bearer token"} + ) + pdap_client = PDAPClient( + access_manager=mock_access_manager + ) + return pdap_client \ No newline at end of file diff --git a/tests/test_automated/integration/tasks/test_submit_approved_url_task.py b/tests/test_automated/integration/tasks/test_submit_approved_url_task.py index 32dc765c..c8aa86eb 100644 --- a/tests/test_automated/integration/tasks/test_submit_approved_url_task.py +++ b/tests/test_automated/integration/tasks/test_submit_approved_url_task.py @@ -46,18 +46,7 @@ def mock_make_request(pdap_client: PDAPClient, urls: list[str]): ) ) -@pytest.fixture -def mock_pdap_client() -> PDAPClient: - mock_access_manager = MagicMock( - spec=AccessManager - ) - mock_access_manager.jwt_header = AsyncMock( - return_value={"Authorization": "Bearer token"} - ) - pdap_client = PDAPClient( - access_manager=mock_access_manager - ) - return pdap_client + async def setup_validated_urls(db_data_creator: DBDataCreator) -> list[str]: creation_info: BatchURLCreationInfo = await db_data_creator.batch_and_urls( diff --git a/tests/test_automated/integration/tasks/test_url_duplicate_task.py b/tests/test_automated/integration/tasks/test_url_duplicate_task.py new file mode 100644 index 00000000..1b3e77d8 --- /dev/null +++ b/tests/test_automated/integration/tasks/test_url_duplicate_task.py @@ -0,0 +1,98 @@ +from http import HTTPStatus +from unittest.mock import MagicMock + +import pytest + +from collector_db.DTOs.URLMapping import URLMapping +from collector_db.models import URL, URLCheckedForDuplicate +from collector_manager.enums import CollectorType, URLStatus +from core.DTOs.TaskOperatorRunInfo import TaskOperatorOutcome +from core.classes.task_operators.URLDuplicateTaskOperator import URLDuplicateTaskOperator +from tests.helpers.DBDataCreator import DBDataCreator +from tests.helpers.test_batch_creation_parameters import TestBatchCreationParameters, TestURLCreationParameters +from pdap_api_client.DTOs import ResponseInfo +from pdap_api_client.PDAPClient import PDAPClient + + +@pytest.mark.asyncio +async def test_url_duplicate_task( + db_data_creator: DBDataCreator, + mock_pdap_client: PDAPClient +): + + + operator = URLDuplicateTaskOperator( + adb_client=db_data_creator.adb_client, + pdap_client=mock_pdap_client + ) + + assert not await operator.meets_task_prerequisites() + make_request_mock: MagicMock = mock_pdap_client.access_manager.make_request + + make_request_mock.assert_not_called() + + # Add three URLs to the database, one of which is in error, the other two pending + creation_info = await db_data_creator.batch_v2( + parameters=TestBatchCreationParameters( + urls=[ + TestURLCreationParameters( + count=1, + status=URLStatus.ERROR + ), + TestURLCreationParameters( + count=2, + status=URLStatus.PENDING + ), + ] + ) + ) + pending_urls: list[URLMapping] = creation_info.url_creation_infos[URLStatus.PENDING].url_mappings + duplicate_url = pending_urls[0] + non_duplicate_url = pending_urls[1] + assert await operator.meets_task_prerequisites() + make_request_mock.assert_not_called() + + make_request_mock.side_effect = [ + ResponseInfo( + data={ + "duplicates": [ + { + "original_url": duplicate_url.url, + "approval_status": "approved" + } + ], + }, + status_code=HTTPStatus.OK + ), + ResponseInfo( + data={ + "duplicates": [], + }, + status_code=HTTPStatus.OK + ), + ] + run_info = await operator.run_task(1) + assert run_info.outcome == TaskOperatorOutcome.SUCCESS, run_info.message + assert make_request_mock.call_count == 2 + + adb_client = db_data_creator.adb_client + urls: list[URL] = await adb_client.get_all(URL) + assert len(urls) == 3 + url_ids = [url.id for url in urls] + assert duplicate_url.url_id in url_ids + for url in urls: + if url.id == duplicate_url.url_id: + assert url.outcome == URLStatus.DUPLICATE.value + + checked_for_duplicates: list[URLCheckedForDuplicate] = await adb_client.get_all(URLCheckedForDuplicate) + assert len(checked_for_duplicates) == 2 + checked_for_duplicate_url_ids = [url.url_id for url in checked_for_duplicates] + assert duplicate_url.url_id in checked_for_duplicate_url_ids + assert non_duplicate_url.url_id in checked_for_duplicate_url_ids + + assert not await operator.meets_task_prerequisites() + + + + +