Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions app/services/grok/services/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import asyncio
import base64
import math
import re
import time
from dataclasses import dataclass
from pathlib import Path
Expand Down Expand Up @@ -42,18 +43,33 @@ class ImageGenerationResult:
class ImageGenerationService:
"""Image generation orchestration service."""

_APP_CHAT_GENERATE_PREFIX_RE = re.compile(
r"^\s*(generate an image|create an image|draw an image|make an image)\s*:",
re.IGNORECASE,
)

@staticmethod
def _app_chat_request_overrides(
count: int,
enable_nsfw: Optional[bool],
) -> Dict[str, Any]:
overrides: Dict[str, Any] = {
"imageGenerationCount": max(1, int(count or 1)),
"disableSearch": True,
}
if enable_nsfw is not None:
overrides["enableNsfw"] = bool(enable_nsfw)
return overrides

@classmethod
def _build_app_chat_message(cls, prompt: str) -> str:
text = (prompt or "").strip()
if not text:
return prompt
if cls._APP_CHAT_GENERATE_PREFIX_RE.match(text):
return text
return f"Generate an image: {text}"

async def generate(
self,
*,
Expand Down Expand Up @@ -196,9 +212,11 @@ async def _stream_retry() -> AsyncGenerator[str, None]:
except UpstreamException as app_chat_error:
if rate_limited(app_chat_error):
raise
error_details = getattr(app_chat_error, "details", None)
logger.warning(
"App-chat image collect failed, falling back to ws_imagine: %s",
app_chat_error,
"App-chat image collect failed, falling back to ws_imagine: "
f"{type(app_chat_error).__name__}: {app_chat_error}; "
f"details={error_details}"
)
return await self._collect_ws(
token_mgr=token_mgr,
Expand Down Expand Up @@ -285,9 +303,10 @@ async def _stream_app_chat(
enable_nsfw: Optional[bool] = None,
chat_format: bool = False,
) -> ImageGenerationResult:
message = self._build_app_chat_message(prompt)
response = await GrokChatService().chat(
token=token,
message=prompt,
message=message,
model=model_info.grok_model,
mode=model_info.model_mode,
stream=True,
Expand Down Expand Up @@ -322,11 +341,12 @@ async def _collect_app_chat(
) -> ImageGenerationResult:
per_call = min(max(1, n), 2)
calls_needed = max(1, int(math.ceil(n / per_call)))
message = self._build_app_chat_message(prompt)

async def _call_generate(call_target: int) -> List[str]:
response = await GrokChatService().chat(
token=token,
message=prompt,
message=message,
model=model_info.grok_model,
mode=model_info.model_mode,
stream=True,
Expand Down
66 changes: 59 additions & 7 deletions app/services/grok/utils/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
"""

import asyncio
import re
import time
import orjson
from typing import Any, AsyncGenerator, Optional, AsyncIterable, List, TypeVar

from app.core.config import get_config
Expand All @@ -13,6 +15,35 @@


T = TypeVar("T")
_ASSET_URL_RE = re.compile(r"https://assets\.grok\.com[^\s\"'<>)]*")
_ASSET_PATH_RE = re.compile(r"(?P<path>/?users/[^\s\"'<>)]*\.(?:png|jpe?g|webp|gif|bmp)(?:\?[^\s\"'<>)]*)?)")
_URLISH_KEYS = {
"url",
"uri",
"path",
"imageUrl",
"imageURI",
"imageUri",
"assetUrl",
"assetURI",
"assetUri",
"downloadUrl",
"downloadURI",
"downloadUri",
"fileUrl",
"fileURI",
"fileUri",
"contentUrl",
"contentURI",
"contentUri",
}
_IMAGE_COLLECTION_KEYS = {
"generatedImageUrls",
"imageUrls",
"imageURLs",
"fileUris",
"imageEditUris",
}


def _is_http2_error(e: Exception) -> bool:
Expand Down Expand Up @@ -51,15 +82,36 @@ def add(url: str):
urls.append(url)

def walk(value: Any):
if isinstance(value, str):
text = value.strip()
if text[:1] in {"{", "["}:
try:
parsed = orjson.loads(text)
except orjson.JSONDecodeError:
parsed = None
if parsed is not None:
walk(parsed)
return
for match in _ASSET_URL_RE.findall(text):
add(match)
for match in _ASSET_PATH_RE.findall(text):
add(match)
return

if isinstance(value, dict):
image_url = value.get("imageUrl")
progress = value.get("progress")
if isinstance(image_url, str) and image_url:
if progress is None or float(progress) >= 100:
add(image_url)
for key, item in value.items():
if key in {"generatedImageUrls", "imageUrls", "imageURLs"}:
if isinstance(item, list):
for url in item:
if isinstance(url, str):
add(url)
elif isinstance(item, str):
add(item)
if key in _IMAGE_COLLECTION_KEYS:
walk(item)
continue
if key == "imageUrl" and "progress" in value:
continue
if key in _URLISH_KEYS and isinstance(item, str):
add(item)
continue
walk(item)
elif isinstance(value, list):
Expand Down
69 changes: 69 additions & 0 deletions tests/test_image_generation_app_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import json

from app.services.grok.services.image import ImageGenerationService
from app.services.grok.utils.process import _collect_images


def test_build_app_chat_message_prefixes_plain_prompt():
assert (
ImageGenerationService._build_app_chat_message("a red apple on a white table")
== "Generate an image: a red apple on a white table"
)


def test_build_app_chat_message_keeps_existing_generate_prefix():
assert (
ImageGenerationService._build_app_chat_message(
"Generate an image: a red apple on a white table"
)
== "Generate an image: a red apple on a white table"
)


def test_collect_images_reads_final_generated_image_card_path():
partial = {
"id": "abc",
"type": "render_generated_image",
"cardType": "generated_image_card",
"image_chunk": {
"imageUuid": "uuid-1",
"imageUrl": "users/example/generated/uuid-1-part-0/image.jpg",
"seq": 0,
"progress": 50,
},
}
final = {
"id": "abc",
"type": "render_generated_image",
"cardType": "generated_image_card",
"image_chunk": {
"imageUuid": "uuid-1",
"imageUrl": "users/example/generated/uuid-1/image.jpg",
"seq": 1,
"progress": 100,
},
}

urls = _collect_images(
{
"generatedImageUrls": [],
"cardAttachmentsJson": [json.dumps(partial), json.dumps(final)],
}
)

assert urls == ["users/example/generated/uuid-1/image.jpg"]


def test_collect_images_ignores_search_result_cards():
searched = {
"id": "xyz",
"type": "render_searched_image",
"cardType": "image_card",
"image": {
"thumbnail": "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcQ"
},
}

urls = _collect_images({"cardAttachmentsJson": [json.dumps(searched)]})

assert urls == []