diff --git a/app/api/v1/chat.py b/app/api/v1/chat.py index 8556b2eaf..e1a0740a9 100644 --- a/app/api/v1/chat.py +++ b/app/api/v1/chat.py @@ -72,6 +72,7 @@ class ChatCompletionRequest(BaseModel): VALID_ROLES = {"developer", "system", "user", "assistant", "tool"} USER_CONTENT_TYPES = {"text", "image_url", "input_audio", "file"} +MARKDOWN_IMAGE_RE = r"!\[[^\]]*\]\(([^)\s]+)(?:\s+\"[^\"]*\")?\)" ALLOWED_IMAGE_SIZES = { "1280x720", "720x1280", @@ -118,11 +119,20 @@ def _extract_prompt_images(messages: List[MessageItem]) -> tuple[str, List[str]] last_text = "" image_urls: List[str] = [] + def _collect_markdown_images(text: str): + if not isinstance(text, str) or not text: + return + for match in __import__("re").finditer(MARKDOWN_IMAGE_RE, text): + url = (match.group(1) or "").strip() + if url: + image_urls.append(url) + for msg in messages: role = msg.role or "user" content = msg.content if isinstance(content, str): text = content.strip() + _collect_markdown_images(text) if text: last_text = text continue @@ -137,6 +147,7 @@ def _extract_prompt_images(messages: List[MessageItem]) -> tuple[str, List[str]] if block_type == "text": text = block.get("text", "") if isinstance(text, str) and text.strip(): + _collect_markdown_images(text) last_text = text.strip() elif block_type == "image_url" and role == "user": image = block.get("image_url") or {} @@ -144,7 +155,14 @@ def _extract_prompt_images(messages: List[MessageItem]) -> tuple[str, List[str]] if isinstance(url, str) and url.strip(): image_urls.append(url.strip()) - return last_text, image_urls + deduped_urls: List[str] = [] + seen = set() + for url in image_urls: + if url not in seen: + seen.add(url) + deduped_urls.append(url) + + return last_text, deduped_urls def _resolve_image_format(value: Optional[str]) -> str: diff --git a/app/services/grok/services/chat.py b/app/services/grok/services/chat.py index b29585633..eec1e8ae8 100644 --- a/app/services/grok/services/chat.py +++ b/app/services/grok/services/chat.py @@ -263,6 +263,7 @@ async def chat( file_attachments: List[str] = None, tool_overrides: Dict[str, Any] = None, model_config_override: Dict[str, Any] = None, + request_overrides: Dict[str, Any] = None, ): """发送聊天请求""" if stream is None: @@ -286,6 +287,7 @@ async def chat( file_attachments=file_attachments, tool_overrides=tool_overrides, model_config_override=model_config_override, + request_overrides=request_overrides, ) logger.info(f"Chat connected: model={model}, stream={stream}") except Exception: diff --git a/app/services/grok/services/image.py b/app/services/grok/services/image.py index e60b5783e..28ea21fa6 100644 --- a/app/services/grok/services/image.py +++ b/app/services/grok/services/image.py @@ -20,6 +20,11 @@ from app.services.grok.utils.retry import pick_token, rate_limited from app.services.grok.utils.response import make_response_id, make_chat_chunk, wrap_image_content from app.services.grok.utils.stream import wrap_stream_with_usage +from app.services.grok.services.chat import GrokChatService +from app.services.grok.services.image_edit import ( + ImageStreamProcessor as AppChatImageStreamProcessor, + ImageCollectProcessor as AppChatImageCollectProcessor, +) from app.services.token import EffortType from app.services.reverse.ws_imagine import ImagineWebSocketReverse @@ -37,6 +42,18 @@ class ImageGenerationResult: class ImageGenerationService: """Image generation orchestration service.""" + @staticmethod + def _app_chat_request_overrides( + count: int, + enable_nsfw: Optional[bool], + ) -> Dict[str, Any]: + overrides: Dict[str, Any] = { + "imageGenerationCount": max(1, int(count or 1)), + } + if enable_nsfw is not None: + overrides["enableNsfw"] = bool(enable_nsfw) + return overrides + async def generate( self, *, @@ -87,18 +104,36 @@ async def _stream_retry() -> AsyncGenerator[str, None]: tried_tokens.add(current_token) yielded = False try: - result = await self._stream_ws( - token_mgr=token_mgr, - token=current_token, - model_info=model_info, - prompt=prompt, - n=n, - response_format=response_format, - size=size, - aspect_ratio=aspect_ratio, - enable_nsfw=enable_nsfw, - chat_format=chat_format, - ) + try: + result = await self._stream_app_chat( + token_mgr=token_mgr, + token=current_token, + model_info=model_info, + prompt=prompt, + n=n, + response_format=response_format, + enable_nsfw=enable_nsfw, + chat_format=chat_format, + ) + except UpstreamException as app_chat_error: + if rate_limited(app_chat_error): + raise + logger.warning( + "App-chat image stream failed, falling back to ws_imagine: %s", + app_chat_error, + ) + result = await self._stream_ws( + token_mgr=token_mgr, + token=current_token, + model_info=model_info, + prompt=prompt, + n=n, + response_format=response_format, + size=size, + aspect_ratio=aspect_ratio, + enable_nsfw=enable_nsfw, + chat_format=chat_format, + ) async for chunk in result.data: yielded = True yield chunk @@ -148,17 +183,34 @@ async def _stream_retry() -> AsyncGenerator[str, None]: tried_tokens.add(current_token) try: - return await self._collect_ws( - token_mgr=token_mgr, - token=current_token, - model_info=model_info, - tried_tokens=tried_tokens, - prompt=prompt, - n=n, - response_format=response_format, - aspect_ratio=aspect_ratio, - enable_nsfw=enable_nsfw, - ) + try: + return await self._collect_app_chat( + token_mgr=token_mgr, + token=current_token, + model_info=model_info, + prompt=prompt, + n=n, + response_format=response_format, + enable_nsfw=enable_nsfw, + ) + except UpstreamException as app_chat_error: + if rate_limited(app_chat_error): + raise + logger.warning( + "App-chat image collect failed, falling back to ws_imagine: %s", + app_chat_error, + ) + return await self._collect_ws( + token_mgr=token_mgr, + token=current_token, + model_info=model_info, + tried_tokens=tried_tokens, + prompt=prompt, + n=n, + response_format=response_format, + aspect_ratio=aspect_ratio, + enable_nsfw=enable_nsfw, + ) except UpstreamException as e: last_error = e if rate_limited(e): @@ -221,6 +273,125 @@ async def _stream_ws( ) return ImageGenerationResult(stream=True, data=stream) + async def _stream_app_chat( + self, + *, + token_mgr: Any, + token: str, + model_info: Any, + prompt: str, + n: int, + response_format: str, + enable_nsfw: Optional[bool] = None, + chat_format: bool = False, + ) -> ImageGenerationResult: + response = await GrokChatService().chat( + token=token, + message=prompt, + model=model_info.grok_model, + mode=model_info.model_mode, + stream=True, + tool_overrides={"imageGen": True}, + request_overrides=self._app_chat_request_overrides(n, enable_nsfw), + ) + processor = AppChatImageStreamProcessor( + model_info.model_id, + token, + n=n, + response_format=response_format, + chat_format=chat_format, + ) + stream = wrap_stream_with_usage( + processor.process(response), + token_mgr, + token, + model_info.model_id, + ) + return ImageGenerationResult(stream=True, data=stream) + + async def _collect_app_chat( + self, + *, + token_mgr: Any, + token: str, + model_info: Any, + prompt: str, + n: int, + response_format: str, + enable_nsfw: Optional[bool] = None, + ) -> ImageGenerationResult: + per_call = min(max(1, n), 2) + calls_needed = max(1, int(math.ceil(n / per_call))) + + async def _call_generate(call_target: int) -> List[str]: + response = await GrokChatService().chat( + token=token, + message=prompt, + model=model_info.grok_model, + mode=model_info.model_mode, + stream=True, + tool_overrides={"imageGen": True}, + request_overrides=self._app_chat_request_overrides( + call_target, enable_nsfw + ), + ) + processor = AppChatImageCollectProcessor( + model_info.model_id, + token, + response_format=response_format, + ) + return await processor.process(response) + + if calls_needed == 1: + all_images = await _call_generate(n) + else: + tasks = [] + for i in range(calls_needed): + remaining = n - (i * per_call) + tasks.append(_call_generate(min(per_call, remaining))) + results = await asyncio.gather(*tasks, return_exceptions=True) + all_images: List[str] = [] + last_error: Optional[Exception] = None + rate_limit_error: Optional[Exception] = None + for result in results: + if isinstance(result, Exception): + logger.warning(f"Concurrent app-chat image call failed: {result}") + last_error = result + if rate_limited(result): + rate_limit_error = result + continue + for image in result: + if image not in all_images: + all_images.append(image) + + if not all_images: + if rate_limit_error: + raise rate_limit_error + if last_error: + raise last_error + + if not all_images: + raise UpstreamException( + "Image generation returned no results", + details={"error": "empty_result", "path": "app_chat"}, + ) + + try: + await token_mgr.consume(token, self._get_effort(model_info)) + except Exception as e: + logger.warning(f"Failed to consume token: {e}") + + selected = self._select_images(all_images, n) + usage_override = { + "total_tokens": 0, + "input_tokens": 0, + "output_tokens": 0, + "input_tokens_details": {"text_tokens": 0, "image_tokens": 0}, + } + return ImageGenerationResult( + stream=False, data=selected, usage_override=usage_override + ) + async def _collect_ws( self, *, diff --git a/app/services/grok/services/image_edit.py b/app/services/grok/services/image_edit.py index 0684a557d..b5731e93b 100644 --- a/app/services/grok/services/image_edit.py +++ b/app/services/grok/services/image_edit.py @@ -8,7 +8,7 @@ import re import time from dataclasses import dataclass -from typing import AsyncGenerator, AsyncIterable, List, Union, Any +from typing import AsyncGenerator, AsyncIterable, Dict, List, Tuple, Union, Any import orjson from curl_cffi.requests.errors import RequestsError @@ -32,10 +32,12 @@ from app.services.grok.utils.retry import pick_token, rate_limited from app.services.grok.utils.response import make_response_id, make_chat_chunk, wrap_image_content from app.services.grok.services.chat import GrokChatService -from app.services.grok.services.video import VideoService from app.services.grok.utils.stream import wrap_stream_with_usage from app.services.token import EffortType +_EDIT_UPSTREAM_MODEL = "grok-4" +_EDIT_UPSTREAM_MODE = "MODEL_MODE_AUTO" + @dataclass class ImageEditResult: @@ -46,6 +48,10 @@ class ImageEditResult: class ImageEditService: """Image edit orchestration service.""" + @staticmethod + def _build_request_overrides(n: int) -> Dict[str, Any]: + return {"imageGenerationCount": max(1, int(n or 1))} + async def edit( self, *, @@ -87,35 +93,20 @@ async def edit( tried_tokens.add(current_token) try: - image_urls = await self._upload_images(images, current_token) - parent_post_id = await self._get_parent_post_id( - current_token, image_urls - ) - - model_config_override = { - "modelMap": { - "imageEditModel": "imagine", - "imageEditModelConfig": { - "imageReferences": image_urls, - }, - } - } - if parent_post_id: - model_config_override["modelMap"]["imageEditModelConfig"][ - "parentPostId" - ] = parent_post_id - - tool_overrides = {"imageGen": True} + file_attachments = await self._upload_images(images, current_token) + tool_overrides: Dict[str, Any] | None = None + request_overrides = self._build_request_overrides(n) if stream: response = await GrokChatService().chat( token=current_token, message=prompt, - model=model_info.grok_model, - mode=None, + model=_EDIT_UPSTREAM_MODEL, + mode=_EDIT_UPSTREAM_MODE, stream=True, + file_attachments=file_attachments, tool_overrides=tool_overrides, - model_config_override=model_config_override, + request_overrides=request_overrides, ) processor = ImageStreamProcessor( model_info.model_id, @@ -137,11 +128,10 @@ async def edit( images_out = await self._collect_images( token=current_token, prompt=prompt, - model_info=model_info, n=n, response_format=response_format, + file_attachments=file_attachments, tool_overrides=tool_overrides, - model_config_override=model_config_override, ) try: effort = ( @@ -177,82 +167,54 @@ async def edit( status_code=429, ) - async def _upload_images(self, images: List[str], token: str) -> List[str]: - image_urls: List[str] = [] + async def _upload_images( + self, images: List[str], token: str + ) -> List[str]: + file_attachments: List[str] = [] upload_service = UploadService() try: for image in images: - _, file_uri = await upload_service.upload_file(image, token) - if file_uri: - if file_uri.startswith("http"): - image_urls.append(file_uri) - else: - image_urls.append( - f"https://assets.grok.com/{file_uri.lstrip('/')}" - ) + file_id, _ = await upload_service.upload_file(image, token) + if file_id: + file_attachments.append(file_id) finally: await upload_service.close() - if not image_urls: + if not file_attachments: raise AppException( message="Image upload failed", error_type=ErrorType.SERVER.value, code="upload_failed", ) - return image_urls - - async def _get_parent_post_id(self, token: str, image_urls: List[str]) -> str: - parent_post_id = None - try: - media_service = VideoService() - parent_post_id = await media_service.create_image_post(token, image_urls[0]) - logger.debug(f"Parent post ID: {parent_post_id}") - except Exception as e: - logger.warning(f"Create image post failed: {e}") - - if parent_post_id: - return parent_post_id - - for url in image_urls: - match = re.search(r"/generated/([a-f0-9-]+)/", url) - if match: - parent_post_id = match.group(1) - logger.debug(f"Parent post ID: {parent_post_id}") - break - match = re.search(r"/users/[^/]+/([a-f0-9-]+)/content", url) - if match: - parent_post_id = match.group(1) - logger.debug(f"Parent post ID: {parent_post_id}") - break - - return parent_post_id or "" + return file_attachments async def _collect_images( self, *, token: str, prompt: str, - model_info: Any, n: int, response_format: str, + file_attachments: List[str], tool_overrides: dict, - model_config_override: dict, ) -> List[str]: - calls_needed = (n + 1) // 2 + per_call = 2 + calls_needed = max(1, (n + per_call - 1) // per_call) async def _call_edit(): response = await GrokChatService().chat( token=token, message=prompt, - model=model_info.grok_model, - mode=None, + model=_EDIT_UPSTREAM_MODEL, + mode=_EDIT_UPSTREAM_MODE, stream=True, + file_attachments=file_attachments, tool_overrides=tool_overrides, - model_config_override=model_config_override, + request_overrides=self._build_request_overrides(per_call), ) processor = ImageCollectProcessor( - model_info.model_id, token, response_format=response_format + "grok-imagine-1.0-edit", token, response_format=response_format ) return await processor.process(response) @@ -307,6 +269,7 @@ def __init__( self.chat_format = chat_format self._id_generated = False self._response_id = "" + self._image_ids: Dict[int, str] = {} # imageIndex → generated image_id if response_format == "url": self.response_field = "url" elif response_format == "base64": @@ -314,6 +277,12 @@ def __init__( else: self.response_field = "b64_json" + def _get_image_id(self, image_index: int) -> str: + """Get or create a stable image_id for a given image index.""" + if image_index not in self._image_ids: + self._image_ids[image_index] = f"app-chat-{int(time.time() * 1000)}-{image_index}" + return self._image_ids[image_index] + def _sse(self, event: str, data: dict) -> str: """Build SSE response.""" return f"event: {event}\ndata: {orjson.dumps(data).decode()}\n\n" @@ -349,6 +318,7 @@ async def process( out_index = 0 if self.n == 1 else image_index if not self.chat_format: + image_id = self._get_image_id(image_index) yield self._sse( "image_generation.partial_image", { @@ -356,6 +326,7 @@ async def process( self.response_field: "", "index": out_index, "progress": progress, + "image_id": image_id, }, ) continue @@ -421,12 +392,15 @@ async def process( ) else: # Original image_generation format + image_id = self._get_image_id(out_index) yield self._sse( "image_generation.completed", { "type": "image_generation.completed", self.response_field: img_data, "index": out_index, + "image_id": image_id, + "stage": "final", "usage": { "total_tokens": 0, "input_tokens": 0, diff --git a/app/services/reverse/app_chat.py b/app/services/reverse/app_chat.py index 40594781b..526059f0d 100644 --- a/app/services/reverse/app_chat.py +++ b/app/services/reverse/app_chat.py @@ -2,6 +2,7 @@ Reverse interface: app chat conversations. """ +import inspect import orjson from typing import Any, Dict, List, Optional from urllib.parse import urlparse @@ -52,6 +53,43 @@ def _log_proxy_state_once(base_proxy: str, normalized_proxy: str = "", scheme: s class AppChatReverse: """/rest/app-chat/conversations/new reverse interface.""" + @staticmethod + async def _read_error_body(response: Any) -> str: + """Best-effort read for non-200 upstream responses.""" + readers = ( + "text", + "atext", + "read", + "aread", + ) + for attr_name in readers: + attr = getattr(response, attr_name, None) + if attr is None: + continue + try: + value = attr() if callable(attr) else attr + if inspect.isawaitable(value): + value = await value + if value is None: + continue + if isinstance(value, bytes): + value = value.decode("utf-8", errors="ignore") + value = str(value) + if value: + return value + except Exception: + continue + + content = getattr(response, "content", None) + if content: + try: + if isinstance(content, bytes): + return content.decode("utf-8", errors="ignore") + return str(content) + except Exception: + pass + return "" + @staticmethod def _resolve_custom_personality() -> Optional[str]: """Resolve optional custom personality from app config.""" @@ -72,6 +110,7 @@ def build_payload( file_attachments: List[str] = None, tool_overrides: Dict[str, Any] = None, model_config_override: Dict[str, Any] = None, + request_overrides: Dict[str, Any] = None, ) -> Dict[str, Any]: """Build chat payload for Grok app-chat API.""" @@ -123,6 +162,9 @@ def build_payload( if model_config_override: payload["responseMetadata"]["modelConfigOverride"] = model_config_override + if request_overrides: + payload.update({k: v for k, v in request_overrides.items() if v is not None}) + import json logger.debug(f"AppChatReverse payload: {json.dumps(payload, indent=4, ensure_ascii=False)}") @@ -138,6 +180,7 @@ async def request( file_attachments: List[str] = None, tool_overrides: Dict[str, Any] = None, model_config_override: Dict[str, Any] = None, + request_overrides: Dict[str, Any] = None, ) -> Any: """Send app chat request to Grok. @@ -186,6 +229,7 @@ async def request( file_attachments=file_attachments, tool_overrides=tool_overrides, model_config_override=model_config_override, + request_overrides=request_overrides, ) payload_summary = { "model": payload.get("modelName"), @@ -237,20 +281,14 @@ async def _do_request(): ) if response.status_code != 200: + content = await AppChatReverse._read_error_body(response) + content_type = str(response.headers.get("content-type", "")) - # Get response content - content = "" - try: - content = await response.text() - except Exception: - pass - - logger.debug( - "AppChatReverse: Chat failed response body: %s", - content, - ) logger.error( - f"AppChatReverse: Chat failed, {response.status_code}", + "AppChatReverse: Chat failed, %s, content_type=%s, body=%s", + response.status_code, + content_type, + content[:500], extra={"error_type": "UpstreamException"}, ) raise UpstreamException(