11from typing import Sequence
22
3- from sqlalchemy import func , select , RowMapping
3+ from sqlalchemy import func , select , RowMapping , or_ , and_
44from sqlalchemy .ext .asyncio import AsyncSession
55
6- from src .api .endpoints .annotate .all .get .models .agency import AgencyAnnotationAutoSuggestion , \
7- AgencyAnnotationUserSuggestion
8- from src .api .endpoints .annotate .all .get .queries .agency .suggestions_with_highest_confidence import \
9- SuggestionsWithHighestConfidenceCTE
6+ from src .api .endpoints .annotate .all .get .models .agency import AgencyAnnotationSuggestion
7+ from src .db .helpers .query import exists_url
108from src .db .helpers .session import session_helper as sh
119from src .db .models .impl .agency .sqlalchemy import Agency
12- from src .db .models .impl .link .agency_location .sqlalchemy import LinkAgencyLocation
1310from src .db .models .impl .link .user_suggestion_not_found .agency .sqlalchemy import LinkUserSuggestionAgencyNotFound
11+ from src .db .models .impl .url .suggestion .agency .subtask .sqlalchemy import URLAutoAgencyIDSubtask
12+ from src .db .models .impl .url .suggestion .agency .suggestion .sqlalchemy import AgencyIDSubtaskSuggestion
1413from src .db .models .impl .url .suggestion .agency .user import UserURLAgencySuggestion
1514from src .db .templates .requester import RequesterBase
1615
@@ -27,102 +26,97 @@ def __init__(
2726 self .url_id = url_id
2827 self .location_id = location_id
2928
30- async def get_user_agency_suggestions (self ) -> list [AgencyAnnotationUserSuggestion ]:
31- query = (
29+ async def get_agency_suggestions (self ) -> list [AgencyAnnotationSuggestion ]:
30+ # All agencies with either a user or robo annotation
31+ valid_agencies_cte = (
3232 select (
33- UserURLAgencySuggestion .agency_id ,
34- func .count (UserURLAgencySuggestion .user_id ).label ("count" ),
35- Agency .name .label ("agency_name" ),
36- )
37- .join (
38- Agency ,
39- Agency .id == UserURLAgencySuggestion .agency_id
33+ Agency .id ,
4034 )
41-
42- )
43-
44- if self .location_id is not None :
45- query = (
46- query .join (
47- LinkAgencyLocation ,
48- LinkAgencyLocation .agency_id == UserURLAgencySuggestion .agency_id
49- )
50- .where (
51- LinkAgencyLocation .location_id == self .location_id
35+ .where (
36+ or_ (
37+ exists_url (
38+ UserURLAgencySuggestion
39+ ),
40+ exists_url (
41+ URLAutoAgencyIDSubtask
42+ )
5243 )
5344 )
45+ .cte ("valid_agencies" )
46+ )
5447
55- query = (
56- query .where (
57- UserURLAgencySuggestion .url_id == self .url_id
48+ # Number of users who suggested each agency
49+ user_suggestions_cte = (
50+ select (
51+ UserURLAgencySuggestion .url_id ,
52+ UserURLAgencySuggestion .agency_id ,
53+ func .count (UserURLAgencySuggestion .user_id ).label ('user_count' )
5854 )
5955 .group_by (
6056 UserURLAgencySuggestion .agency_id ,
61- Agency . name
57+ UserURLAgencySuggestion . url_id ,
6258 )
63- .order_by (
64- func .count (UserURLAgencySuggestion .user_id ).desc ()
65- )
66- .limit (3 )
59+ .cte ("user_suggestions" )
6760 )
6861
69- results : Sequence [RowMapping ] = await sh .mappings (self .session , query = query )
70-
71- return [
72- AgencyAnnotationUserSuggestion (
73- agency_id = autosuggestion ["agency_id" ],
74- user_count = autosuggestion ["count" ],
75- agency_name = autosuggestion ["agency_name" ],
62+ # Maximum confidence of robo annotation, if any
63+ robo_suggestions_cte = (
64+ select (
65+ URLAutoAgencyIDSubtask .url_id ,
66+ Agency .id .label ("agency_id" ),
67+ func .max (AgencyIDSubtaskSuggestion .confidence ).label ('robo_confidence' )
7668 )
77- for autosuggestion in results
78- ]
79-
80-
81- async def get_auto_agency_suggestions (self ) -> list [AgencyAnnotationAutoSuggestion ]:
82- cte = SuggestionsWithHighestConfidenceCTE ()
83- query = (
69+ .join (
70+ AgencyIDSubtaskSuggestion ,
71+ AgencyIDSubtaskSuggestion .subtask_id == URLAutoAgencyIDSubtask .id
72+ )
73+ .join (
74+ Agency ,
75+ Agency .id == AgencyIDSubtaskSuggestion .agency_id
76+ )
77+ .group_by (
78+ URLAutoAgencyIDSubtask .url_id ,
79+ Agency .id
80+ )
81+ .cte ("robo_suggestions" )
82+ )
83+ # Join user and robo suggestions
84+ joined_suggestions_query = (
8485 select (
85- cte .agency_id ,
86- cte .confidence ,
86+ valid_agencies_cte .c .id .label ("agency_id" ),
8787 Agency .name .label ("agency_name" ),
88+ func .coalesce (user_suggestions_cte .c .user_count , 0 ).label ('user_count' ),
89+ func .coalesce (robo_suggestions_cte .c .robo_confidence , 0 ).label ('robo_confidence' ),
8890 )
8991 .join (
9092 Agency ,
91- Agency .id == cte . agency_id
93+ Agency .id == valid_agencies_cte . c . id
9294 )
93- )
94-
95- if self .location_id is not None :
96- query = (
97- query .join (
98- LinkAgencyLocation ,
99- LinkAgencyLocation .agency_id == cte .agency_id
100- )
101- .where (
102- LinkAgencyLocation .location_id == self .location_id
95+ .outerjoin (
96+ user_suggestions_cte ,
97+ and_ (
98+ user_suggestions_cte .c .url_id == self .url_id ,
99+ user_suggestions_cte .c .agency_id == Agency .id
103100 )
104101 )
105-
106- query = (
107- query .where (
108- cte .url_id == self .url_id
109- )
110- .order_by (
111- cte .confidence .desc ()
102+ .outerjoin (
103+ robo_suggestions_cte ,
104+ and_ (
105+ robo_suggestions_cte .c .url_id == self .url_id ,
106+ robo_suggestions_cte .c .agency_id == Agency .id
107+ )
112108 )
113- .limit (3 )
114109 )
115110
116- results : Sequence [RowMapping ] = await sh .mappings (self .session , query = query )
117-
118- return [
119- AgencyAnnotationAutoSuggestion (
120- agency_id = autosuggestion ["agency_id" ],
121- confidence = autosuggestion ["confidence" ],
122- agency_name = autosuggestion ["agency_name" ],
111+ # Return suggestions
112+ mappings : Sequence [RowMapping ] = await self .mappings (joined_suggestions_query )
113+ suggestions : list [AgencyAnnotationSuggestion ] = [
114+ AgencyAnnotationSuggestion (
115+ ** mapping
123116 )
124- for autosuggestion in results
117+ for mapping in mappings
125118 ]
119+ return suggestions
126120
127121 async def get_not_found_count (self ) -> int :
128122 query = (
0 commit comments