1- from typing import List , Optional
1+ import base64
2+ from typing import List , Optional , Tuple
23import uuid
34
45from common_module .common_container import CommonContainer
2122)
2223from knowledge_base_module .services .kb_rag_retrieve import KBRagResponse
2324from knowledge_base_module .services .image_rag_retrieve import ImageRagRetrieve
25+ from flo_cloud .cloud_storage import CloudStorageManager
2426from pydantic import BaseModel , Field
2527from datetime import datetime
2628from sqlalchemy import Result
@@ -71,9 +73,10 @@ class DocWiseEmbeddingSchema(BaseModel):
7173
7274
7375class ImagePayload (BaseModel ):
74- """Payload for Image embedding."""
76+ """Payload for Image embedding. Use image_data (base64) or image_url (gs:// or s3://); image_url has priority if both are set. """
7577
7678 image_data : Optional [str ] = None
79+ image_url : Optional [str ] = None
7780
7881
7982class DocumentPayload (BaseModel ):
@@ -96,6 +99,112 @@ def convert_uuids_to_str(data):
9699 return data
97100
98101
102+ def _parse_cloud_image_url (url : str ) -> Tuple [str , str , str ]:
103+ """
104+ Parse gs:// or s3:// URL into (scheme, bucket, key).
105+ Returns (scheme, bucket, key) or raises ValueError.
106+ """
107+ url = (url or '' ).strip ()
108+ if url .startswith ('gs://' ):
109+ rest = url [5 :]
110+ if '/' not in rest :
111+ raise ValueError ('Invalid gs:// URL: missing path after bucket' )
112+ bucket , _ , key = rest .partition ('/' )
113+ return ('gs' , bucket , key )
114+ if url .startswith ('s3://' ):
115+ rest = url [5 :]
116+ if '/' not in rest :
117+ raise ValueError ('Invalid s3:// URL: missing path after bucket' )
118+ bucket , _ , key = rest .partition ('/' )
119+ return ('s3' , bucket , key )
120+ raise ValueError ('image_url must be in gs:// or s3:// format' )
121+
122+
123+ async def _resolve_image_data (
124+ payload : ImagePayload ,
125+ cloud_storage : CloudStorageManager ,
126+ config : dict ,
127+ response_formatter : ResponseFormatter ,
128+ ) -> Tuple [Optional [str ], Optional [JSONResponse ]]:
129+ """
130+ Resolve image payload to a single image_data string (base64) for the inference API.
131+ When both are provided, image_url has priority; otherwise uses image_data or fetches from image_url (gs:// or s3://).
132+ Returns (image_data, None) on success, or (None, error_json_response) on validation/fetch error.
133+ """
134+ if payload .image_url :
135+ pass
136+ elif payload .image_data :
137+ return (payload .image_data , None )
138+ else :
139+ return (
140+ None ,
141+ JSONResponse (
142+ status_code = status .HTTP_400_BAD_REQUEST ,
143+ content = response_formatter .buildErrorResponse (
144+ 'Query or Image data should not be empty'
145+ ),
146+ ),
147+ )
148+ try :
149+ scheme , bucket , key = _parse_cloud_image_url (payload .image_url )
150+ except ValueError as e :
151+ return (
152+ None ,
153+ JSONResponse (
154+ status_code = status .HTTP_400_BAD_REQUEST ,
155+ content = response_formatter .buildErrorResponse (str (e )),
156+ ),
157+ )
158+ cloud_provider = (
159+ (config .get ('cloud_config' ) or {}).get ('cloud_provider' , '' ).lower ()
160+ )
161+ if scheme == 'gs' and cloud_provider != 'gcp' :
162+ return (
163+ None ,
164+ JSONResponse (
165+ status_code = status .HTTP_400_BAD_REQUEST ,
166+ content = response_formatter .buildErrorResponse (
167+ 'image_url gs:// is only supported when cloud provider is GCP'
168+ ),
169+ ),
170+ )
171+ if scheme == 's3' and cloud_provider != 'aws' :
172+ return (
173+ None ,
174+ JSONResponse (
175+ status_code = status .HTTP_400_BAD_REQUEST ,
176+ content = response_formatter .buildErrorResponse (
177+ 'image_url s3:// is only supported when cloud provider is AWS'
178+ ),
179+ ),
180+ )
181+ try :
182+ content = cloud_storage .read_file (bucket , key )
183+ except Exception as e :
184+ return (
185+ None ,
186+ JSONResponse (
187+ status_code = status .HTTP_400_BAD_REQUEST ,
188+ content = response_formatter .buildErrorResponse (
189+ f'Failed to fetch image from storage: { e !s} '
190+ ),
191+ ),
192+ )
193+ image_bytes = content .read () if hasattr (content , 'read' ) else content
194+ if not image_bytes :
195+ return (
196+ None ,
197+ JSONResponse (
198+ status_code = status .HTTP_400_BAD_REQUEST ,
199+ content = response_formatter .buildErrorResponse (
200+ 'Image from URL is empty'
201+ ),
202+ ),
203+ )
204+ image_data_b64 = base64 .b64encode (image_bytes ).decode ('utf-8' )
205+ return (image_data_b64 , None )
206+
207+
99208@rag_retrieval_router .post ('/v1/knowledge-base/{kb_id}/retrieve' )
100209@inject
101210async def retrieve_query (
@@ -128,6 +237,9 @@ async def retrieve_query(
128237 Provide [KnowledgeBaseContainer .image_knowledge_base_retrieve ]
129238 ),
130239 config : dict = Depends (Provide [KnowledgeBaseContainer .config ]),
240+ cloud_storage : CloudStorageManager = Depends (
241+ Provide [KnowledgeBaseContainer .cloud_storage ]
242+ ),
131243):
132244 if not query and not payload :
133245 return JSONResponse (
@@ -156,9 +268,14 @@ async def retrieve_query(
156268 limit ,
157269 )
158270 else :
271+ image_data , error_response = await _resolve_image_data (
272+ payload , cloud_storage , config , response_formatter
273+ )
274+ if error_response is not None :
275+ return error_response
159276 inference_url = config ['model' ]['inference_service_url' ]
160277 retrieved_docs = await image_rag_retrieval .retrieve_images (
161- payload . image_data ,
278+ image_data ,
162279 inference_url ,
163280 kb_id ,
164281 threshold ,
0 commit comments