Skip to content

Commit 35d68f8

Browse files
authored
feat(floware): image url base search and tag based document fetch (#221)
1 parent 444fff1 commit 35d68f8

3 files changed

Lines changed: 207 additions & 28 deletions

File tree

wavefront/server/modules/knowledge_base_module/knowledge_base_module/controllers/knowledge_base_document_controller.py

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@
2727
from flo_cloud.message_queue import MessageQueueManager
2828
from flo_cloud.cloud_storage import CloudStorageManager
2929
from pydantic import BaseModel
30-
from sqlalchemy import Result
31-
from sqlalchemy import select
30+
from knowledge_base_module.queries.generate_query import QueryGenerator
3231

3332
kb_document_router = APIRouter()
3433

@@ -182,11 +181,23 @@ async def upload_document(
182181
os.unlink(temp_file_path)
183182

184183

184+
def _document_row_to_dict(row: dict) -> dict:
185+
"""Convert a raw document row to the same format as KnowledgeBaseDocuments.to_dict()."""
186+
result = dict(row)
187+
for key, value in result.items():
188+
if isinstance(value, uuid.UUID):
189+
result[key] = str(value)
190+
elif isinstance(value, datetime):
191+
result[key] = value.isoformat()
192+
return result
193+
194+
185195
@kb_document_router.get('/v1/knowledge-bases/{kb_id}/documents')
186196
@inject
187197
async def get_documents(
188198
kb_id: uuid.UUID,
189199
file_type: Optional[str] = Query(None, description='Type of file to filter by'),
200+
query_filter: Optional[str] = Query(None, alias='$filter'),
190201
offset: int = Query(0, ge=0, description='The number of items to skip'),
191202
limit: int = Query(
192203
10, ge=1, le=100, description='The maximum number of items to return'
@@ -199,35 +210,28 @@ async def get_documents(
199210
] = Depends(Provide[KnowledgeBaseContainer.knowledge_base_documents_repository]),
200211
) -> JSONResponse:
201212
"""Get documents for a knowledge base with optional filtering and pagination."""
202-
# Validate knowledge base exists
203-
existing_document = await knowledge_base_documents_repository.find_one(
204-
knowledge_base_id=kb_id
205-
)
206-
if not existing_document:
207-
return JSONResponse(
208-
status_code=status.HTTP_200_OK,
209-
content=response_formatter.buildSuccessResponse(data={'resources': []}),
213+
try:
214+
query_generator = QueryGenerator()
215+
sql_query, query_params = query_generator.get_documents_list_query(
216+
kb_id=str(kb_id),
217+
file_type=file_type,
218+
filter=query_filter,
219+
offset=offset,
220+
limit=limit,
210221
)
211-
212-
# Fetch documents
213-
async with knowledge_base_documents_repository.session() as session:
214-
query = select(KnowledgeBaseDocuments).where(
215-
KnowledgeBaseDocuments.knowledge_base_id == kb_id
222+
rows = await knowledge_base_documents_repository.execute_query(
223+
sql_query, query_params
216224
)
217-
218-
if file_type:
219-
query = query.where(KnowledgeBaseDocuments.file_type == file_type)
220-
221-
query = query.slice(offset, limit)
222-
223-
results: Result = await session.execute(query)
224-
resources = results.scalars().all()
225-
data = [res.to_dict() for res in resources]
226-
225+
data = [_document_row_to_dict(row) for row in rows]
227226
return JSONResponse(
228227
status_code=status.HTTP_200_OK,
229228
content=response_formatter.buildSuccessResponse(data={'resources': data}),
230229
)
230+
except ValueError as e:
231+
return JSONResponse(
232+
status_code=status.HTTP_400_BAD_REQUEST,
233+
content=response_formatter.buildErrorResponse(str(e)),
234+
)
231235

232236

233237
@kb_document_router.delete('/v1/knowledge-bases/{kb_id}/documents/{document_id}')

wavefront/server/modules/knowledge_base_module/knowledge_base_module/controllers/rag_retreival_controller.py

Lines changed: 120 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import List, Optional
1+
import base64
2+
from typing import List, Optional, Tuple
23
import uuid
34

45
from common_module.common_container import CommonContainer
@@ -21,6 +22,7 @@
2122
)
2223
from knowledge_base_module.services.kb_rag_retrieve import KBRagResponse
2324
from knowledge_base_module.services.image_rag_retrieve import ImageRagRetrieve
25+
from flo_cloud.cloud_storage import CloudStorageManager
2426
from pydantic import BaseModel, Field
2527
from datetime import datetime
2628
from sqlalchemy import Result
@@ -71,9 +73,10 @@ class DocWiseEmbeddingSchema(BaseModel):
7173

7274

7375
class ImagePayload(BaseModel):
74-
"""Payload for Image embedding."""
76+
"""Payload for Image embedding. Use image_data (base64) or image_url (gs:// or s3://); image_url has priority if both are set."""
7577

7678
image_data: Optional[str] = None
79+
image_url: Optional[str] = None
7780

7881

7982
class DocumentPayload(BaseModel):
@@ -96,6 +99,112 @@ def convert_uuids_to_str(data):
9699
return data
97100

98101

102+
def _parse_cloud_image_url(url: str) -> Tuple[str, str, str]:
103+
"""
104+
Parse gs:// or s3:// URL into (scheme, bucket, key).
105+
Returns (scheme, bucket, key) or raises ValueError.
106+
"""
107+
url = (url or '').strip()
108+
if url.startswith('gs://'):
109+
rest = url[5:]
110+
if '/' not in rest:
111+
raise ValueError('Invalid gs:// URL: missing path after bucket')
112+
bucket, _, key = rest.partition('/')
113+
return ('gs', bucket, key)
114+
if url.startswith('s3://'):
115+
rest = url[5:]
116+
if '/' not in rest:
117+
raise ValueError('Invalid s3:// URL: missing path after bucket')
118+
bucket, _, key = rest.partition('/')
119+
return ('s3', bucket, key)
120+
raise ValueError('image_url must be in gs:// or s3:// format')
121+
122+
123+
async def _resolve_image_data(
124+
payload: ImagePayload,
125+
cloud_storage: CloudStorageManager,
126+
config: dict,
127+
response_formatter: ResponseFormatter,
128+
) -> Tuple[Optional[str], Optional[JSONResponse]]:
129+
"""
130+
Resolve image payload to a single image_data string (base64) for the inference API.
131+
When both are provided, image_url has priority; otherwise uses image_data or fetches from image_url (gs:// or s3://).
132+
Returns (image_data, None) on success, or (None, error_json_response) on validation/fetch error.
133+
"""
134+
if payload.image_url:
135+
pass
136+
elif payload.image_data:
137+
return (payload.image_data, None)
138+
else:
139+
return (
140+
None,
141+
JSONResponse(
142+
status_code=status.HTTP_400_BAD_REQUEST,
143+
content=response_formatter.buildErrorResponse(
144+
'Query or Image data should not be empty'
145+
),
146+
),
147+
)
148+
try:
149+
scheme, bucket, key = _parse_cloud_image_url(payload.image_url)
150+
except ValueError as e:
151+
return (
152+
None,
153+
JSONResponse(
154+
status_code=status.HTTP_400_BAD_REQUEST,
155+
content=response_formatter.buildErrorResponse(str(e)),
156+
),
157+
)
158+
cloud_provider = (
159+
(config.get('cloud_config') or {}).get('cloud_provider', '').lower()
160+
)
161+
if scheme == 'gs' and cloud_provider != 'gcp':
162+
return (
163+
None,
164+
JSONResponse(
165+
status_code=status.HTTP_400_BAD_REQUEST,
166+
content=response_formatter.buildErrorResponse(
167+
'image_url gs:// is only supported when cloud provider is GCP'
168+
),
169+
),
170+
)
171+
if scheme == 's3' and cloud_provider != 'aws':
172+
return (
173+
None,
174+
JSONResponse(
175+
status_code=status.HTTP_400_BAD_REQUEST,
176+
content=response_formatter.buildErrorResponse(
177+
'image_url s3:// is only supported when cloud provider is AWS'
178+
),
179+
),
180+
)
181+
try:
182+
content = cloud_storage.read_file(bucket, key)
183+
except Exception as e:
184+
return (
185+
None,
186+
JSONResponse(
187+
status_code=status.HTTP_400_BAD_REQUEST,
188+
content=response_formatter.buildErrorResponse(
189+
f'Failed to fetch image from storage: {e!s}'
190+
),
191+
),
192+
)
193+
image_bytes = content.read() if hasattr(content, 'read') else content
194+
if not image_bytes:
195+
return (
196+
None,
197+
JSONResponse(
198+
status_code=status.HTTP_400_BAD_REQUEST,
199+
content=response_formatter.buildErrorResponse(
200+
'Image from URL is empty'
201+
),
202+
),
203+
)
204+
image_data_b64 = base64.b64encode(image_bytes).decode('utf-8')
205+
return (image_data_b64, None)
206+
207+
99208
@rag_retrieval_router.post('/v1/knowledge-base/{kb_id}/retrieve')
100209
@inject
101210
async def retrieve_query(
@@ -128,6 +237,9 @@ async def retrieve_query(
128237
Provide[KnowledgeBaseContainer.image_knowledge_base_retrieve]
129238
),
130239
config: dict = Depends(Provide[KnowledgeBaseContainer.config]),
240+
cloud_storage: CloudStorageManager = Depends(
241+
Provide[KnowledgeBaseContainer.cloud_storage]
242+
),
131243
):
132244
if not query and not payload:
133245
return JSONResponse(
@@ -156,9 +268,14 @@ async def retrieve_query(
156268
limit,
157269
)
158270
else:
271+
image_data, error_response = await _resolve_image_data(
272+
payload, cloud_storage, config, response_formatter
273+
)
274+
if error_response is not None:
275+
return error_response
159276
inference_url = config['model']['inference_service_url']
160277
retrieved_docs = await image_rag_retrieval.retrieve_images(
161-
payload.image_data,
278+
image_data,
162279
inference_url,
163280
kb_id,
164281
threshold,

wavefront/server/modules/knowledge_base_module/knowledge_base_module/queries/generate_query.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,64 @@ def get_image_embedding_dino(
280280

281281
return sql_query, params
282282

283+
def get_documents_list_query(
284+
self,
285+
kb_id: str,
286+
file_type: Optional[str] = None,
287+
filter: Optional[str] = None,
288+
offset: int = 0,
289+
limit: int = 10,
290+
) -> Tuple[str, Dict[str, Any]]:
291+
"""
292+
Generate SQL query to list knowledge base documents with optional
293+
metadata filter (OData-style $filter) and file_type.
294+
295+
Returns:
296+
Tuple of (SQL query string, query parameters)
297+
"""
298+
params: Dict[str, Any] = {
299+
'kb_id': kb_id,
300+
'offset': offset,
301+
'limit': limit,
302+
}
303+
conditions = ['knowledge_base_id = :kb_id']
304+
if file_type:
305+
params['file_type'] = file_type
306+
conditions.append('file_type = :file_type')
307+
308+
metadata_filter_clause = ''
309+
if filter:
310+
where_clause, filter_params = self.odata_parser.prepare_odata_filter(filter)
311+
if where_clause and filter_params:
312+
metadata_filter_clause = self.build_metadata_clause(
313+
where_clause,
314+
filter_params,
315+
lambda field: f"(metadata_value ->> '{field}')",
316+
)
317+
params.update(filter_params)
318+
conditions.append(f'({metadata_filter_clause})')
319+
320+
where_sql = ' AND '.join(conditions)
321+
sql_query = f"""
322+
SELECT
323+
id,
324+
knowledge_base_id,
325+
file_path,
326+
file_name,
327+
file_type,
328+
file_size,
329+
created_at,
330+
updated_at,
331+
metadata_value
332+
FROM
333+
{KnowledgeBaseDocuments.__tablename__}
334+
WHERE
335+
{where_sql}
336+
ORDER BY created_at DESC
337+
LIMIT :limit OFFSET :offset
338+
"""
339+
return sql_query, params
340+
283341
@staticmethod
284342
def get_update_tokens_query() -> str:
285343
"""

0 commit comments

Comments
 (0)