2020from app .services .grok .utils .retry import pick_token , rate_limited
2121from app .services .grok .utils .response import make_response_id , make_chat_chunk , wrap_image_content
2222from app .services .grok .utils .stream import wrap_stream_with_usage
23+ from app .services .grok .services .chat import GrokChatService
24+ from app .services .grok .services .image_edit import (
25+ ImageStreamProcessor as AppChatImageStreamProcessor ,
26+ ImageCollectProcessor as AppChatImageCollectProcessor ,
27+ )
2328from app .services .token import EffortType
2429from app .services .reverse .ws_imagine import ImagineWebSocketReverse
2530
@@ -37,6 +42,18 @@ class ImageGenerationResult:
3742class ImageGenerationService :
3843 """Image generation orchestration service."""
3944
45+ @staticmethod
46+ def _app_chat_request_overrides (
47+ count : int ,
48+ enable_nsfw : Optional [bool ],
49+ ) -> Dict [str , Any ]:
50+ overrides : Dict [str , Any ] = {
51+ "imageGenerationCount" : max (1 , int (count or 1 )),
52+ }
53+ if enable_nsfw is not None :
54+ overrides ["enableNsfw" ] = bool (enable_nsfw )
55+ return overrides
56+
4057 async def generate (
4158 self ,
4259 * ,
@@ -87,18 +104,36 @@ async def _stream_retry() -> AsyncGenerator[str, None]:
87104 tried_tokens .add (current_token )
88105 yielded = False
89106 try :
90- result = await self ._stream_ws (
91- token_mgr = token_mgr ,
92- token = current_token ,
93- model_info = model_info ,
94- prompt = prompt ,
95- n = n ,
96- response_format = response_format ,
97- size = size ,
98- aspect_ratio = aspect_ratio ,
99- enable_nsfw = enable_nsfw ,
100- chat_format = chat_format ,
101- )
107+ try :
108+ result = await self ._stream_app_chat (
109+ token_mgr = token_mgr ,
110+ token = current_token ,
111+ model_info = model_info ,
112+ prompt = prompt ,
113+ n = n ,
114+ response_format = response_format ,
115+ enable_nsfw = enable_nsfw ,
116+ chat_format = chat_format ,
117+ )
118+ except UpstreamException as app_chat_error :
119+ if rate_limited (app_chat_error ):
120+ raise
121+ logger .warning (
122+ "App-chat image stream failed, falling back to ws_imagine: %s" ,
123+ app_chat_error ,
124+ )
125+ result = await self ._stream_ws (
126+ token_mgr = token_mgr ,
127+ token = current_token ,
128+ model_info = model_info ,
129+ prompt = prompt ,
130+ n = n ,
131+ response_format = response_format ,
132+ size = size ,
133+ aspect_ratio = aspect_ratio ,
134+ enable_nsfw = enable_nsfw ,
135+ chat_format = chat_format ,
136+ )
102137 async for chunk in result .data :
103138 yielded = True
104139 yield chunk
@@ -148,17 +183,34 @@ async def _stream_retry() -> AsyncGenerator[str, None]:
148183
149184 tried_tokens .add (current_token )
150185 try :
151- return await self ._collect_ws (
152- token_mgr = token_mgr ,
153- token = current_token ,
154- model_info = model_info ,
155- tried_tokens = tried_tokens ,
156- prompt = prompt ,
157- n = n ,
158- response_format = response_format ,
159- aspect_ratio = aspect_ratio ,
160- enable_nsfw = enable_nsfw ,
161- )
186+ try :
187+ return await self ._collect_app_chat (
188+ token_mgr = token_mgr ,
189+ token = current_token ,
190+ model_info = model_info ,
191+ prompt = prompt ,
192+ n = n ,
193+ response_format = response_format ,
194+ enable_nsfw = enable_nsfw ,
195+ )
196+ except UpstreamException as app_chat_error :
197+ if rate_limited (app_chat_error ):
198+ raise
199+ logger .warning (
200+ "App-chat image collect failed, falling back to ws_imagine: %s" ,
201+ app_chat_error ,
202+ )
203+ return await self ._collect_ws (
204+ token_mgr = token_mgr ,
205+ token = current_token ,
206+ model_info = model_info ,
207+ tried_tokens = tried_tokens ,
208+ prompt = prompt ,
209+ n = n ,
210+ response_format = response_format ,
211+ aspect_ratio = aspect_ratio ,
212+ enable_nsfw = enable_nsfw ,
213+ )
162214 except UpstreamException as e :
163215 last_error = e
164216 if rate_limited (e ):
@@ -221,6 +273,125 @@ async def _stream_ws(
221273 )
222274 return ImageGenerationResult (stream = True , data = stream )
223275
276+ async def _stream_app_chat (
277+ self ,
278+ * ,
279+ token_mgr : Any ,
280+ token : str ,
281+ model_info : Any ,
282+ prompt : str ,
283+ n : int ,
284+ response_format : str ,
285+ enable_nsfw : Optional [bool ] = None ,
286+ chat_format : bool = False ,
287+ ) -> ImageGenerationResult :
288+ response = await GrokChatService ().chat (
289+ token = token ,
290+ message = prompt ,
291+ model = model_info .grok_model ,
292+ mode = model_info .model_mode ,
293+ stream = True ,
294+ tool_overrides = {"imageGen" : True },
295+ request_overrides = self ._app_chat_request_overrides (n , enable_nsfw ),
296+ )
297+ processor = AppChatImageStreamProcessor (
298+ model_info .model_id ,
299+ token ,
300+ n = n ,
301+ response_format = response_format ,
302+ chat_format = chat_format ,
303+ )
304+ stream = wrap_stream_with_usage (
305+ processor .process (response ),
306+ token_mgr ,
307+ token ,
308+ model_info .model_id ,
309+ )
310+ return ImageGenerationResult (stream = True , data = stream )
311+
312+ async def _collect_app_chat (
313+ self ,
314+ * ,
315+ token_mgr : Any ,
316+ token : str ,
317+ model_info : Any ,
318+ prompt : str ,
319+ n : int ,
320+ response_format : str ,
321+ enable_nsfw : Optional [bool ] = None ,
322+ ) -> ImageGenerationResult :
323+ per_call = min (max (1 , n ), 2 )
324+ calls_needed = max (1 , int (math .ceil (n / per_call )))
325+
326+ async def _call_generate (call_target : int ) -> List [str ]:
327+ response = await GrokChatService ().chat (
328+ token = token ,
329+ message = prompt ,
330+ model = model_info .grok_model ,
331+ mode = model_info .model_mode ,
332+ stream = True ,
333+ tool_overrides = {"imageGen" : True },
334+ request_overrides = self ._app_chat_request_overrides (
335+ call_target , enable_nsfw
336+ ),
337+ )
338+ processor = AppChatImageCollectProcessor (
339+ model_info .model_id ,
340+ token ,
341+ response_format = response_format ,
342+ )
343+ return await processor .process (response )
344+
345+ if calls_needed == 1 :
346+ all_images = await _call_generate (n )
347+ else :
348+ tasks = []
349+ for i in range (calls_needed ):
350+ remaining = n - (i * per_call )
351+ tasks .append (_call_generate (min (per_call , remaining )))
352+ results = await asyncio .gather (* tasks , return_exceptions = True )
353+ all_images : List [str ] = []
354+ last_error : Optional [Exception ] = None
355+ rate_limit_error : Optional [Exception ] = None
356+ for result in results :
357+ if isinstance (result , Exception ):
358+ logger .warning (f"Concurrent app-chat image call failed: { result } " )
359+ last_error = result
360+ if rate_limited (result ):
361+ rate_limit_error = result
362+ continue
363+ for image in result :
364+ if image not in all_images :
365+ all_images .append (image )
366+
367+ if not all_images :
368+ if rate_limit_error :
369+ raise rate_limit_error
370+ if last_error :
371+ raise last_error
372+
373+ if not all_images :
374+ raise UpstreamException (
375+ "Image generation returned no results" ,
376+ details = {"error" : "empty_result" , "path" : "app_chat" },
377+ )
378+
379+ try :
380+ await token_mgr .consume (token , self ._get_effort (model_info ))
381+ except Exception as e :
382+ logger .warning (f"Failed to consume token: { e } " )
383+
384+ selected = self ._select_images (all_images , n )
385+ usage_override = {
386+ "total_tokens" : 0 ,
387+ "input_tokens" : 0 ,
388+ "output_tokens" : 0 ,
389+ "input_tokens_details" : {"text_tokens" : 0 , "image_tokens" : 0 },
390+ }
391+ return ImageGenerationResult (
392+ stream = False , data = selected , usage_override = usage_override
393+ )
394+
224395 async def _collect_ws (
225396 self ,
226397 * ,
0 commit comments