Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion app/api/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class ChatCompletionRequest(BaseModel):

VALID_ROLES = {"developer", "system", "user", "assistant", "tool"}
USER_CONTENT_TYPES = {"text", "image_url", "input_audio", "file"}
MARKDOWN_IMAGE_RE = r"!\[[^\]]*\]\(([^)\s]+)(?:\s+\"[^\"]*\")?\)"
ALLOWED_IMAGE_SIZES = {
"1280x720",
"720x1280",
Expand Down Expand Up @@ -118,11 +119,20 @@ def _extract_prompt_images(messages: List[MessageItem]) -> tuple[str, List[str]]
last_text = ""
image_urls: List[str] = []

def _collect_markdown_images(text: str):
if not isinstance(text, str) or not text:
return
for match in __import__("re").finditer(MARKDOWN_IMAGE_RE, text):
url = (match.group(1) or "").strip()
if url:
image_urls.append(url)

for msg in messages:
role = msg.role or "user"
content = msg.content
if isinstance(content, str):
text = content.strip()
_collect_markdown_images(text)
if text:
last_text = text
continue
Expand All @@ -137,14 +147,22 @@ def _extract_prompt_images(messages: List[MessageItem]) -> tuple[str, List[str]]
if block_type == "text":
text = block.get("text", "")
if isinstance(text, str) and text.strip():
_collect_markdown_images(text)
last_text = text.strip()
elif block_type == "image_url" and role == "user":
image = block.get("image_url") or {}
url = image.get("url", "")
if isinstance(url, str) and url.strip():
image_urls.append(url.strip())

return last_text, image_urls
deduped_urls: List[str] = []
seen = set()
for url in image_urls:
if url not in seen:
seen.add(url)
deduped_urls.append(url)

return last_text, deduped_urls


def _resolve_image_format(value: Optional[str]) -> str:
Expand Down
2 changes: 2 additions & 0 deletions app/services/grok/services/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ async def chat(
file_attachments: List[str] = None,
tool_overrides: Dict[str, Any] = None,
model_config_override: Dict[str, Any] = None,
request_overrides: Dict[str, Any] = None,
):
"""发送聊天请求"""
if stream is None:
Expand All @@ -286,6 +287,7 @@ async def chat(
file_attachments=file_attachments,
tool_overrides=tool_overrides,
model_config_override=model_config_override,
request_overrides=request_overrides,
)
logger.info(f"Chat connected: model={model}, stream={stream}")
except Exception:
Expand Down
217 changes: 194 additions & 23 deletions app/services/grok/services/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
from app.services.grok.utils.retry import pick_token, rate_limited
from app.services.grok.utils.response import make_response_id, make_chat_chunk, wrap_image_content
from app.services.grok.utils.stream import wrap_stream_with_usage
from app.services.grok.services.chat import GrokChatService
from app.services.grok.services.image_edit import (
ImageStreamProcessor as AppChatImageStreamProcessor,
ImageCollectProcessor as AppChatImageCollectProcessor,
)
from app.services.token import EffortType
from app.services.reverse.ws_imagine import ImagineWebSocketReverse

Expand All @@ -37,6 +42,18 @@ class ImageGenerationResult:
class ImageGenerationService:
"""Image generation orchestration service."""

@staticmethod
def _app_chat_request_overrides(
count: int,
enable_nsfw: Optional[bool],
) -> Dict[str, Any]:
overrides: Dict[str, Any] = {
"imageGenerationCount": max(1, int(count or 1)),
}
if enable_nsfw is not None:
overrides["enableNsfw"] = bool(enable_nsfw)
return overrides

async def generate(
self,
*,
Expand Down Expand Up @@ -87,18 +104,36 @@ async def _stream_retry() -> AsyncGenerator[str, None]:
tried_tokens.add(current_token)
yielded = False
try:
result = await self._stream_ws(
token_mgr=token_mgr,
token=current_token,
model_info=model_info,
prompt=prompt,
n=n,
response_format=response_format,
size=size,
aspect_ratio=aspect_ratio,
enable_nsfw=enable_nsfw,
chat_format=chat_format,
)
try:
result = await self._stream_app_chat(
token_mgr=token_mgr,
token=current_token,
model_info=model_info,
prompt=prompt,
n=n,
response_format=response_format,
enable_nsfw=enable_nsfw,
chat_format=chat_format,
)
except UpstreamException as app_chat_error:
if rate_limited(app_chat_error):
raise
logger.warning(
"App-chat image stream failed, falling back to ws_imagine: %s",
app_chat_error,
)
result = await self._stream_ws(
token_mgr=token_mgr,
token=current_token,
model_info=model_info,
prompt=prompt,
n=n,
response_format=response_format,
size=size,
aspect_ratio=aspect_ratio,
enable_nsfw=enable_nsfw,
chat_format=chat_format,
)
async for chunk in result.data:
yielded = True
yield chunk
Expand Down Expand Up @@ -148,17 +183,34 @@ async def _stream_retry() -> AsyncGenerator[str, None]:

tried_tokens.add(current_token)
try:
return await self._collect_ws(
token_mgr=token_mgr,
token=current_token,
model_info=model_info,
tried_tokens=tried_tokens,
prompt=prompt,
n=n,
response_format=response_format,
aspect_ratio=aspect_ratio,
enable_nsfw=enable_nsfw,
)
try:
return await self._collect_app_chat(
token_mgr=token_mgr,
token=current_token,
model_info=model_info,
prompt=prompt,
n=n,
response_format=response_format,
enable_nsfw=enable_nsfw,
)
except UpstreamException as app_chat_error:
if rate_limited(app_chat_error):
raise
logger.warning(
"App-chat image collect failed, falling back to ws_imagine: %s",
app_chat_error,
)
return await self._collect_ws(
token_mgr=token_mgr,
token=current_token,
model_info=model_info,
tried_tokens=tried_tokens,
prompt=prompt,
n=n,
response_format=response_format,
aspect_ratio=aspect_ratio,
enable_nsfw=enable_nsfw,
)
except UpstreamException as e:
last_error = e
if rate_limited(e):
Expand Down Expand Up @@ -221,6 +273,125 @@ async def _stream_ws(
)
return ImageGenerationResult(stream=True, data=stream)

async def _stream_app_chat(
self,
*,
token_mgr: Any,
token: str,
model_info: Any,
prompt: str,
n: int,
response_format: str,
enable_nsfw: Optional[bool] = None,
chat_format: bool = False,
) -> ImageGenerationResult:
response = await GrokChatService().chat(
token=token,
message=prompt,
model=model_info.grok_model,
mode=model_info.model_mode,
stream=True,
tool_overrides={"imageGen": True},
request_overrides=self._app_chat_request_overrides(n, enable_nsfw),
)
processor = AppChatImageStreamProcessor(
model_info.model_id,
token,
n=n,
response_format=response_format,
chat_format=chat_format,
)
stream = wrap_stream_with_usage(
processor.process(response),
token_mgr,
token,
model_info.model_id,
)
return ImageGenerationResult(stream=True, data=stream)

async def _collect_app_chat(
self,
*,
token_mgr: Any,
token: str,
model_info: Any,
prompt: str,
n: int,
response_format: str,
enable_nsfw: Optional[bool] = None,
) -> ImageGenerationResult:
per_call = min(max(1, n), 2)
calls_needed = max(1, int(math.ceil(n / per_call)))

async def _call_generate(call_target: int) -> List[str]:
response = await GrokChatService().chat(
token=token,
message=prompt,
model=model_info.grok_model,
mode=model_info.model_mode,
stream=True,
tool_overrides={"imageGen": True},
request_overrides=self._app_chat_request_overrides(
call_target, enable_nsfw
),
)
processor = AppChatImageCollectProcessor(
model_info.model_id,
token,
response_format=response_format,
)
return await processor.process(response)

if calls_needed == 1:
all_images = await _call_generate(n)
else:
tasks = []
for i in range(calls_needed):
remaining = n - (i * per_call)
tasks.append(_call_generate(min(per_call, remaining)))
results = await asyncio.gather(*tasks, return_exceptions=True)
all_images: List[str] = []
last_error: Optional[Exception] = None
rate_limit_error: Optional[Exception] = None
for result in results:
if isinstance(result, Exception):
logger.warning(f"Concurrent app-chat image call failed: {result}")
last_error = result
if rate_limited(result):
rate_limit_error = result
continue
for image in result:
if image not in all_images:
all_images.append(image)

if not all_images:
if rate_limit_error:
raise rate_limit_error
if last_error:
raise last_error

if not all_images:
raise UpstreamException(
"Image generation returned no results",
details={"error": "empty_result", "path": "app_chat"},
)

try:
await token_mgr.consume(token, self._get_effort(model_info))
except Exception as e:
logger.warning(f"Failed to consume token: {e}")

selected = self._select_images(all_images, n)
usage_override = {
"total_tokens": 0,
"input_tokens": 0,
"output_tokens": 0,
"input_tokens_details": {"text_tokens": 0, "image_tokens": 0},
}
return ImageGenerationResult(
stream=False, data=selected, usage_override=usage_override
)

async def _collect_ws(
self,
*,
Expand Down
Loading
Loading