From b70b027fb1fdc145f7b1365b110a00c164689345 Mon Sep 17 00:00:00 2001 From: emailck Date: Tue, 24 Mar 2026 23:22:59 +0800 Subject: [PATCH] fix(image): restore app-chat image generation and parse generated image cards --- app/services/grok/services/image.py | 28 ++++++++-- app/services/grok/utils/process.py | 66 ++++++++++++++++++++--- tests/test_image_generation_app_chat.py | 69 +++++++++++++++++++++++++ 3 files changed, 152 insertions(+), 11 deletions(-) create mode 100644 tests/test_image_generation_app_chat.py diff --git a/app/services/grok/services/image.py b/app/services/grok/services/image.py index 28ea21fa6..879707e0d 100644 --- a/app/services/grok/services/image.py +++ b/app/services/grok/services/image.py @@ -5,6 +5,7 @@ import asyncio import base64 import math +import re import time from dataclasses import dataclass from pathlib import Path @@ -42,6 +43,11 @@ class ImageGenerationResult: class ImageGenerationService: """Image generation orchestration service.""" + _APP_CHAT_GENERATE_PREFIX_RE = re.compile( + r"^\s*(generate an image|create an image|draw an image|make an image)\s*:", + re.IGNORECASE, + ) + @staticmethod def _app_chat_request_overrides( count: int, @@ -49,11 +55,21 @@ def _app_chat_request_overrides( ) -> Dict[str, Any]: overrides: Dict[str, Any] = { "imageGenerationCount": max(1, int(count or 1)), + "disableSearch": True, } if enable_nsfw is not None: overrides["enableNsfw"] = bool(enable_nsfw) return overrides + @classmethod + def _build_app_chat_message(cls, prompt: str) -> str: + text = (prompt or "").strip() + if not text: + return prompt + if cls._APP_CHAT_GENERATE_PREFIX_RE.match(text): + return text + return f"Generate an image: {text}" + async def generate( self, *, @@ -196,9 +212,11 @@ async def _stream_retry() -> AsyncGenerator[str, None]: except UpstreamException as app_chat_error: if rate_limited(app_chat_error): raise + error_details = getattr(app_chat_error, "details", None) logger.warning( - "App-chat image collect failed, falling back to ws_imagine: %s", - app_chat_error, + "App-chat image collect failed, falling back to ws_imagine: " + f"{type(app_chat_error).__name__}: {app_chat_error}; " + f"details={error_details}" ) return await self._collect_ws( token_mgr=token_mgr, @@ -285,9 +303,10 @@ async def _stream_app_chat( enable_nsfw: Optional[bool] = None, chat_format: bool = False, ) -> ImageGenerationResult: + message = self._build_app_chat_message(prompt) response = await GrokChatService().chat( token=token, - message=prompt, + message=message, model=model_info.grok_model, mode=model_info.model_mode, stream=True, @@ -322,11 +341,12 @@ async def _collect_app_chat( ) -> ImageGenerationResult: per_call = min(max(1, n), 2) calls_needed = max(1, int(math.ceil(n / per_call))) + message = self._build_app_chat_message(prompt) async def _call_generate(call_target: int) -> List[str]: response = await GrokChatService().chat( token=token, - message=prompt, + message=message, model=model_info.grok_model, mode=model_info.model_mode, stream=True, diff --git a/app/services/grok/utils/process.py b/app/services/grok/utils/process.py index 69353c651..8198242fe 100644 --- a/app/services/grok/utils/process.py +++ b/app/services/grok/utils/process.py @@ -3,7 +3,9 @@ """ import asyncio +import re import time +import orjson from typing import Any, AsyncGenerator, Optional, AsyncIterable, List, TypeVar from app.core.config import get_config @@ -13,6 +15,35 @@ T = TypeVar("T") +_ASSET_URL_RE = re.compile(r"https://assets\.grok\.com[^\s\"'<>)]*") +_ASSET_PATH_RE = re.compile(r"(?P/?users/[^\s\"'<>)]*\.(?:png|jpe?g|webp|gif|bmp)(?:\?[^\s\"'<>)]*)?)") +_URLISH_KEYS = { + "url", + "uri", + "path", + "imageUrl", + "imageURI", + "imageUri", + "assetUrl", + "assetURI", + "assetUri", + "downloadUrl", + "downloadURI", + "downloadUri", + "fileUrl", + "fileURI", + "fileUri", + "contentUrl", + "contentURI", + "contentUri", +} +_IMAGE_COLLECTION_KEYS = { + "generatedImageUrls", + "imageUrls", + "imageURLs", + "fileUris", + "imageEditUris", +} def _is_http2_error(e: Exception) -> bool: @@ -51,15 +82,36 @@ def add(url: str): urls.append(url) def walk(value: Any): + if isinstance(value, str): + text = value.strip() + if text[:1] in {"{", "["}: + try: + parsed = orjson.loads(text) + except orjson.JSONDecodeError: + parsed = None + if parsed is not None: + walk(parsed) + return + for match in _ASSET_URL_RE.findall(text): + add(match) + for match in _ASSET_PATH_RE.findall(text): + add(match) + return + if isinstance(value, dict): + image_url = value.get("imageUrl") + progress = value.get("progress") + if isinstance(image_url, str) and image_url: + if progress is None or float(progress) >= 100: + add(image_url) for key, item in value.items(): - if key in {"generatedImageUrls", "imageUrls", "imageURLs"}: - if isinstance(item, list): - for url in item: - if isinstance(url, str): - add(url) - elif isinstance(item, str): - add(item) + if key in _IMAGE_COLLECTION_KEYS: + walk(item) + continue + if key == "imageUrl" and "progress" in value: + continue + if key in _URLISH_KEYS and isinstance(item, str): + add(item) continue walk(item) elif isinstance(value, list): diff --git a/tests/test_image_generation_app_chat.py b/tests/test_image_generation_app_chat.py new file mode 100644 index 000000000..4e7c5bae1 --- /dev/null +++ b/tests/test_image_generation_app_chat.py @@ -0,0 +1,69 @@ +import json + +from app.services.grok.services.image import ImageGenerationService +from app.services.grok.utils.process import _collect_images + + +def test_build_app_chat_message_prefixes_plain_prompt(): + assert ( + ImageGenerationService._build_app_chat_message("a red apple on a white table") + == "Generate an image: a red apple on a white table" + ) + + +def test_build_app_chat_message_keeps_existing_generate_prefix(): + assert ( + ImageGenerationService._build_app_chat_message( + "Generate an image: a red apple on a white table" + ) + == "Generate an image: a red apple on a white table" + ) + + +def test_collect_images_reads_final_generated_image_card_path(): + partial = { + "id": "abc", + "type": "render_generated_image", + "cardType": "generated_image_card", + "image_chunk": { + "imageUuid": "uuid-1", + "imageUrl": "users/example/generated/uuid-1-part-0/image.jpg", + "seq": 0, + "progress": 50, + }, + } + final = { + "id": "abc", + "type": "render_generated_image", + "cardType": "generated_image_card", + "image_chunk": { + "imageUuid": "uuid-1", + "imageUrl": "users/example/generated/uuid-1/image.jpg", + "seq": 1, + "progress": 100, + }, + } + + urls = _collect_images( + { + "generatedImageUrls": [], + "cardAttachmentsJson": [json.dumps(partial), json.dumps(final)], + } + ) + + assert urls == ["users/example/generated/uuid-1/image.jpg"] + + +def test_collect_images_ignores_search_result_cards(): + searched = { + "id": "xyz", + "type": "render_searched_image", + "cardType": "image_card", + "image": { + "thumbnail": "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcQ" + }, + } + + urls = _collect_images({"cardAttachmentsJson": [json.dumps(searched)]}) + + assert urls == []