Skip to content

Commit 16e37b1

Browse files
authored
Merge pull request chenyme#374 from JinchengGao-Infty/fix-image-gen-app-chat
fix: use app-chat REST as primary image generation, fallback to ws_imagine
2 parents a79cd22 + 6779da7 commit 16e37b1

5 files changed

Lines changed: 310 additions & 107 deletions

File tree

app/api/v1/chat.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class ChatCompletionRequest(BaseModel):
7272

7373
VALID_ROLES = {"developer", "system", "user", "assistant", "tool"}
7474
USER_CONTENT_TYPES = {"text", "image_url", "input_audio", "file"}
75+
MARKDOWN_IMAGE_RE = r"!\[[^\]]*\]\(([^)\s]+)(?:\s+\"[^\"]*\")?\)"
7576
ALLOWED_IMAGE_SIZES = {
7677
"1280x720",
7778
"720x1280",
@@ -118,11 +119,20 @@ def _extract_prompt_images(messages: List[MessageItem]) -> tuple[str, List[str]]
118119
last_text = ""
119120
image_urls: List[str] = []
120121

122+
def _collect_markdown_images(text: str):
123+
if not isinstance(text, str) or not text:
124+
return
125+
for match in __import__("re").finditer(MARKDOWN_IMAGE_RE, text):
126+
url = (match.group(1) or "").strip()
127+
if url:
128+
image_urls.append(url)
129+
121130
for msg in messages:
122131
role = msg.role or "user"
123132
content = msg.content
124133
if isinstance(content, str):
125134
text = content.strip()
135+
_collect_markdown_images(text)
126136
if text:
127137
last_text = text
128138
continue
@@ -137,14 +147,22 @@ def _extract_prompt_images(messages: List[MessageItem]) -> tuple[str, List[str]]
137147
if block_type == "text":
138148
text = block.get("text", "")
139149
if isinstance(text, str) and text.strip():
150+
_collect_markdown_images(text)
140151
last_text = text.strip()
141152
elif block_type == "image_url" and role == "user":
142153
image = block.get("image_url") or {}
143154
url = image.get("url", "")
144155
if isinstance(url, str) and url.strip():
145156
image_urls.append(url.strip())
146157

147-
return last_text, image_urls
158+
deduped_urls: List[str] = []
159+
seen = set()
160+
for url in image_urls:
161+
if url not in seen:
162+
seen.add(url)
163+
deduped_urls.append(url)
164+
165+
return last_text, deduped_urls
148166

149167

150168
def _resolve_image_format(value: Optional[str]) -> str:

app/services/grok/services/chat.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ async def chat(
263263
file_attachments: List[str] = None,
264264
tool_overrides: Dict[str, Any] = None,
265265
model_config_override: Dict[str, Any] = None,
266+
request_overrides: Dict[str, Any] = None,
266267
):
267268
"""发送聊天请求"""
268269
if stream is None:
@@ -286,6 +287,7 @@ async def chat(
286287
file_attachments=file_attachments,
287288
tool_overrides=tool_overrides,
288289
model_config_override=model_config_override,
290+
request_overrides=request_overrides,
289291
)
290292
logger.info(f"Chat connected: model={model}, stream={stream}")
291293
except Exception:

app/services/grok/services/image.py

Lines changed: 194 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@
2020
from app.services.grok.utils.retry import pick_token, rate_limited
2121
from app.services.grok.utils.response import make_response_id, make_chat_chunk, wrap_image_content
2222
from app.services.grok.utils.stream import wrap_stream_with_usage
23+
from app.services.grok.services.chat import GrokChatService
24+
from app.services.grok.services.image_edit import (
25+
ImageStreamProcessor as AppChatImageStreamProcessor,
26+
ImageCollectProcessor as AppChatImageCollectProcessor,
27+
)
2328
from app.services.token import EffortType
2429
from app.services.reverse.ws_imagine import ImagineWebSocketReverse
2530

@@ -37,6 +42,18 @@ class ImageGenerationResult:
3742
class ImageGenerationService:
3843
"""Image generation orchestration service."""
3944

45+
@staticmethod
46+
def _app_chat_request_overrides(
47+
count: int,
48+
enable_nsfw: Optional[bool],
49+
) -> Dict[str, Any]:
50+
overrides: Dict[str, Any] = {
51+
"imageGenerationCount": max(1, int(count or 1)),
52+
}
53+
if enable_nsfw is not None:
54+
overrides["enableNsfw"] = bool(enable_nsfw)
55+
return overrides
56+
4057
async def generate(
4158
self,
4259
*,
@@ -87,18 +104,36 @@ async def _stream_retry() -> AsyncGenerator[str, None]:
87104
tried_tokens.add(current_token)
88105
yielded = False
89106
try:
90-
result = await self._stream_ws(
91-
token_mgr=token_mgr,
92-
token=current_token,
93-
model_info=model_info,
94-
prompt=prompt,
95-
n=n,
96-
response_format=response_format,
97-
size=size,
98-
aspect_ratio=aspect_ratio,
99-
enable_nsfw=enable_nsfw,
100-
chat_format=chat_format,
101-
)
107+
try:
108+
result = await self._stream_app_chat(
109+
token_mgr=token_mgr,
110+
token=current_token,
111+
model_info=model_info,
112+
prompt=prompt,
113+
n=n,
114+
response_format=response_format,
115+
enable_nsfw=enable_nsfw,
116+
chat_format=chat_format,
117+
)
118+
except UpstreamException as app_chat_error:
119+
if rate_limited(app_chat_error):
120+
raise
121+
logger.warning(
122+
"App-chat image stream failed, falling back to ws_imagine: %s",
123+
app_chat_error,
124+
)
125+
result = await self._stream_ws(
126+
token_mgr=token_mgr,
127+
token=current_token,
128+
model_info=model_info,
129+
prompt=prompt,
130+
n=n,
131+
response_format=response_format,
132+
size=size,
133+
aspect_ratio=aspect_ratio,
134+
enable_nsfw=enable_nsfw,
135+
chat_format=chat_format,
136+
)
102137
async for chunk in result.data:
103138
yielded = True
104139
yield chunk
@@ -148,17 +183,34 @@ async def _stream_retry() -> AsyncGenerator[str, None]:
148183

149184
tried_tokens.add(current_token)
150185
try:
151-
return await self._collect_ws(
152-
token_mgr=token_mgr,
153-
token=current_token,
154-
model_info=model_info,
155-
tried_tokens=tried_tokens,
156-
prompt=prompt,
157-
n=n,
158-
response_format=response_format,
159-
aspect_ratio=aspect_ratio,
160-
enable_nsfw=enable_nsfw,
161-
)
186+
try:
187+
return await self._collect_app_chat(
188+
token_mgr=token_mgr,
189+
token=current_token,
190+
model_info=model_info,
191+
prompt=prompt,
192+
n=n,
193+
response_format=response_format,
194+
enable_nsfw=enable_nsfw,
195+
)
196+
except UpstreamException as app_chat_error:
197+
if rate_limited(app_chat_error):
198+
raise
199+
logger.warning(
200+
"App-chat image collect failed, falling back to ws_imagine: %s",
201+
app_chat_error,
202+
)
203+
return await self._collect_ws(
204+
token_mgr=token_mgr,
205+
token=current_token,
206+
model_info=model_info,
207+
tried_tokens=tried_tokens,
208+
prompt=prompt,
209+
n=n,
210+
response_format=response_format,
211+
aspect_ratio=aspect_ratio,
212+
enable_nsfw=enable_nsfw,
213+
)
162214
except UpstreamException as e:
163215
last_error = e
164216
if rate_limited(e):
@@ -221,6 +273,125 @@ async def _stream_ws(
221273
)
222274
return ImageGenerationResult(stream=True, data=stream)
223275

276+
async def _stream_app_chat(
277+
self,
278+
*,
279+
token_mgr: Any,
280+
token: str,
281+
model_info: Any,
282+
prompt: str,
283+
n: int,
284+
response_format: str,
285+
enable_nsfw: Optional[bool] = None,
286+
chat_format: bool = False,
287+
) -> ImageGenerationResult:
288+
response = await GrokChatService().chat(
289+
token=token,
290+
message=prompt,
291+
model=model_info.grok_model,
292+
mode=model_info.model_mode,
293+
stream=True,
294+
tool_overrides={"imageGen": True},
295+
request_overrides=self._app_chat_request_overrides(n, enable_nsfw),
296+
)
297+
processor = AppChatImageStreamProcessor(
298+
model_info.model_id,
299+
token,
300+
n=n,
301+
response_format=response_format,
302+
chat_format=chat_format,
303+
)
304+
stream = wrap_stream_with_usage(
305+
processor.process(response),
306+
token_mgr,
307+
token,
308+
model_info.model_id,
309+
)
310+
return ImageGenerationResult(stream=True, data=stream)
311+
312+
async def _collect_app_chat(
313+
self,
314+
*,
315+
token_mgr: Any,
316+
token: str,
317+
model_info: Any,
318+
prompt: str,
319+
n: int,
320+
response_format: str,
321+
enable_nsfw: Optional[bool] = None,
322+
) -> ImageGenerationResult:
323+
per_call = min(max(1, n), 2)
324+
calls_needed = max(1, int(math.ceil(n / per_call)))
325+
326+
async def _call_generate(call_target: int) -> List[str]:
327+
response = await GrokChatService().chat(
328+
token=token,
329+
message=prompt,
330+
model=model_info.grok_model,
331+
mode=model_info.model_mode,
332+
stream=True,
333+
tool_overrides={"imageGen": True},
334+
request_overrides=self._app_chat_request_overrides(
335+
call_target, enable_nsfw
336+
),
337+
)
338+
processor = AppChatImageCollectProcessor(
339+
model_info.model_id,
340+
token,
341+
response_format=response_format,
342+
)
343+
return await processor.process(response)
344+
345+
if calls_needed == 1:
346+
all_images = await _call_generate(n)
347+
else:
348+
tasks = []
349+
for i in range(calls_needed):
350+
remaining = n - (i * per_call)
351+
tasks.append(_call_generate(min(per_call, remaining)))
352+
results = await asyncio.gather(*tasks, return_exceptions=True)
353+
all_images: List[str] = []
354+
last_error: Optional[Exception] = None
355+
rate_limit_error: Optional[Exception] = None
356+
for result in results:
357+
if isinstance(result, Exception):
358+
logger.warning(f"Concurrent app-chat image call failed: {result}")
359+
last_error = result
360+
if rate_limited(result):
361+
rate_limit_error = result
362+
continue
363+
for image in result:
364+
if image not in all_images:
365+
all_images.append(image)
366+
367+
if not all_images:
368+
if rate_limit_error:
369+
raise rate_limit_error
370+
if last_error:
371+
raise last_error
372+
373+
if not all_images:
374+
raise UpstreamException(
375+
"Image generation returned no results",
376+
details={"error": "empty_result", "path": "app_chat"},
377+
)
378+
379+
try:
380+
await token_mgr.consume(token, self._get_effort(model_info))
381+
except Exception as e:
382+
logger.warning(f"Failed to consume token: {e}")
383+
384+
selected = self._select_images(all_images, n)
385+
usage_override = {
386+
"total_tokens": 0,
387+
"input_tokens": 0,
388+
"output_tokens": 0,
389+
"input_tokens_details": {"text_tokens": 0, "image_tokens": 0},
390+
}
391+
return ImageGenerationResult(
392+
stream=False, data=selected, usage_override=usage_override
393+
)
394+
224395
async def _collect_ws(
225396
self,
226397
*,

0 commit comments

Comments
 (0)