diff --git a/_public/static/function/js/chat.js b/_public/static/function/js/chat.js index 77f427e0f..7bdd40906 100644 --- a/_public/static/function/js/chat.js +++ b/_public/static/function/js/chat.js @@ -999,7 +999,7 @@ signal: abortController.signal }); if (!res.ok) throw new Error(t('chat.requestFailedStatus', { status: res.status })); - await handleStream(res, assistantEntry, sendSessionId); + await handleAssistantResponse(res, payload, assistantEntry, sendSessionId); setStatus('connected', t('common.done')); } catch (e) { if (e && e.name === 'AbortError') { @@ -1293,6 +1293,45 @@ return payload; } + function extractPromptText(content) { + if (typeof content === 'string') return content; + if (!Array.isArray(content)) return ''; + return content + .filter((block) => block && block.type === 'text' && block.text) + .map((block) => String(block.text)) + .join('\n') + .trim(); + } + + function isImageModelSelected(model) { + return /^grok-imagine/i.test(String(model || '').trim()); + } + + function isLikelyImagePrompt(prompt) { + const value = String(prompt || '').trim(); + if (!value) return false; + const patterns = [ + /generate\s+(?:a|an)?\s*(?:picture|image|photo|illustration|drawing)/i, + /\b(?:draw|illustrate|render|create|make)\b[\s\S]{0,32}\b(?:image|picture|photo|art|illustration)\b/i, + /生成[\s\S]{0,16}(?:图片|图像|照片|插画|绘图|一张图)/, + /画[\s\S]{0,16}(?:图|图片|插画|照片)/, + /来[\s\S]{0,8}(?:张|幅)(?:图|图片|插画)/, + ]; + return patterns.some((pattern) => pattern.test(value)); + } + + function shouldUseImageResultMode(history) { + if (isImageModelSelected(modelValue)) return true; + const items = Array.isArray(history) ? history : []; + for (let i = items.length - 1; i >= 0; i -= 1) { + const item = items[i]; + if (!item || item.role !== 'user') continue; + const prompt = extractPromptText(item.content); + return isLikelyImagePrompt(prompt); + } + return false; + } + function buildPayload() { const payload = { model: modelValue || 'grok-3', @@ -1301,6 +1340,13 @@ temperature: Number(tempRange ? tempRange.value : 0.8), top_p: Number(topPRange ? topPRange.value : 0.95) }; + if (shouldUseImageResultMode(messageHistory)) { + payload.stream = false; + payload.image_config = { + response_format: 'url', + return_share_url: true + }; + } return payload; } @@ -1312,9 +1358,54 @@ temperature: Number(tempRange ? tempRange.value : 0.8), top_p: Number(topPRange ? topPRange.value : 0.95) }; + if (shouldUseImageResultMode(history)) { + payload.stream = false; + payload.image_config = { + response_format: 'url', + return_share_url: true + }; + } return payload; } + function appendShareLinks(content, payload) { + const base = String(content || '').trimEnd(); + const extras = []; + const shareUrl = payload && payload.share_url ? String(payload.share_url).trim() : ''; + const shareImageUrl = payload && payload.share_image_url ? String(payload.share_image_url).trim() : ''; + + if (shareUrl && !base.includes(shareUrl)) { + extras.push(`Share link: [${shareUrl}](${shareUrl})`); + } + if (shareImageUrl && !base.includes(shareImageUrl)) { + extras.push(`Direct link: [${shareImageUrl}](${shareImageUrl})`); + } + + if (!extras.length) return base; + return base ? `${base}\n\n${extras.join('\n')}` : extras.join('\n'); + } + + async function handleJsonResponse(res, assistantEntry, targetSessionId) { + const payload = await res.json(); + if (payload && payload.error) { + throw new Error(payload.error.message || t('common.failed')); + } + const content = payload && payload.choices && payload.choices[0] && payload.choices[0].message + ? (payload.choices[0].message.content || '') + : ''; + const assistantText = appendShareLinks(content, payload); + updateMessage(assistantEntry, assistantText, true); + assistantEntry.committed = true; + commitToSession(targetSessionId, assistantText); + } + + async function handleAssistantResponse(res, requestPayload, assistantEntry, targetSessionId) { + if (requestPayload && requestPayload.stream === false) { + return handleJsonResponse(res, assistantEntry, targetSessionId); + } + return handleStream(res, assistantEntry, targetSessionId); + } + function selectModel(value) { modelValue = value; if (modelLabel) modelLabel.textContent = value; @@ -1514,7 +1605,7 @@ throw new Error(t('chat.requestFailedStatus', { status: res.status })); } - await handleStream(res, assistantEntry, retrySessionId); + await handleAssistantResponse(res, payload, assistantEntry, retrySessionId); setStatus('connected', t('common.done')); } catch (e) { updateMessage(assistantEntry, t('chat.requestFailedStatus', { status: e.message || e }), true); @@ -1597,7 +1688,7 @@ throw new Error(t('chat.requestFailedStatus', { status: res.status })); } - await handleStream(res, assistantEntry, sendSessionId); + await handleAssistantResponse(res, payload, assistantEntry, sendSessionId); setStatus('connected', t('common.done')); } catch (e) { if (e && e.name === 'AbortError') { diff --git a/_public/static/function/pages/chat.html b/_public/static/function/pages/chat.html index f0b203517..285f29b35 100644 --- a/_public/static/function/pages/chat.html +++ b/_public/static/function/pages/chat.html @@ -152,7 +152,7 @@

Chat 聊天

- + diff --git a/app/api/v1/chat.py b/app/api/v1/chat.py index e1a0740a9..6fe63d25f 100644 --- a/app/api/v1/chat.py +++ b/app/api/v1/chat.py @@ -18,7 +18,8 @@ from app.services.grok.services.image_edit import ImageEditService from app.services.grok.services.model import ModelService from app.services.grok.services.video import VideoService -from app.services.grok.utils.response import make_chat_response +from app.api.v1.image import append_share_payload +from app.services.grok.utils.response import make_chat_response, wrap_image_content from app.services.token import get_token_manager from app.core.config import get_config from app.core.exceptions import ValidationException, AppException, ErrorType @@ -49,6 +50,9 @@ class ImageConfig(BaseModel): n: Optional[int] = Field(1, ge=1, le=10, description="生成数量 (1-10)") size: Optional[str] = Field("1024x1024", description="图片尺寸") response_format: Optional[str] = Field(None, description="响应格式") + return_share_url: Optional[bool] = Field( + False, description="是否返回 Grok 分享页链接" + ) class ChatCompletionRequest(BaseModel): @@ -186,7 +190,9 @@ def _image_field(response_format: str) -> str: return "b64_json" -def _imagine_fast_server_image_config() -> ImageConfig: +def _imagine_fast_server_image_config( + client_config: Optional[ImageConfig] = None, +) -> ImageConfig: """Load server-side image generation parameters for grok-imagine-1.0-fast.""" n = int(get_config("imagine_fast.n", 1) or 1) size = str(get_config("imagine_fast.size", "1024x1024") or "1024x1024") @@ -194,7 +200,12 @@ def _imagine_fast_server_image_config() -> ImageConfig: get_config("imagine_fast.response_format", get_config("app.image_format") or "url") or "url" ) - return ImageConfig(n=n, size=size, response_format=response_format) + return ImageConfig( + n=n, + size=size, + response_format=response_format, + return_share_url=bool(client_config and client_config.return_share_url), + ) async def _safe_sse_stream(stream: AsyncIterable[str]) -> AsyncGenerator[str, None]: @@ -274,6 +285,12 @@ def _validate_image_config(image_conf: ImageConfig, *, stream: bool): param="image_config.response_format", code="invalid_response_format", ) + if stream and image_conf.return_share_url: + raise ValidationException( + message="return_share_url is only supported when stream=false", + param="image_config.return_share_url", + code="share_url_stream_not_supported", + ) if image_conf.size and image_conf.size not in ALLOWED_IMAGE_SIZES: raise ValidationException( message=f"size must be one of {sorted(ALLOWED_IMAGE_SIZES)}", @@ -601,7 +618,11 @@ def validate_request(request: ChatCompletionRequest): param="messages", code="empty_prompt", ) - image_conf = _imagine_fast_server_image_config() if request.model == IMAGINE_FAST_MODEL_ID else (request.image_config or ImageConfig()) + image_conf = ( + _imagine_fast_server_image_config(request.image_config) + if request.model == IMAGINE_FAST_MODEL_ID + else (request.image_config or ImageConfig()) + ) n = image_conf.n or 1 if not (1 <= n <= 10): raise ValidationException( @@ -765,6 +786,7 @@ async def chat_completions(request: ChatCompletionRequest): ) content = result.data[0] if result.data else "" + content = wrap_image_content(content, response_format) return JSONResponse( content=make_chat_response(request.model, content) ) @@ -775,7 +797,11 @@ async def chat_completions(request: ChatCompletionRequest): is_stream = ( request.stream if request.stream is not None else get_config("app.stream") ) - image_conf = _imagine_fast_server_image_config() if request.model == IMAGINE_FAST_MODEL_ID else (request.image_config or ImageConfig()) + image_conf = ( + _imagine_fast_server_image_config(request.image_config) + if request.model == IMAGINE_FAST_MODEL_ID + else (request.image_config or ImageConfig()) + ) _validate_image_config(image_conf, stream=bool(is_stream)) response_format = _resolve_image_format(image_conf.response_format) response_field = _image_field(response_format) @@ -818,6 +844,7 @@ async def chat_completions(request: ChatCompletionRequest): aspect_ratio=aspect_ratio, stream=bool(is_stream), chat_format=True, + return_share_url=bool(image_conf.return_share_url), ) if result.stream: @@ -828,10 +855,10 @@ async def chat_completions(request: ChatCompletionRequest): ) content = result.data[0] if result.data else "" + content = wrap_image_content(content, response_format) usage = result.usage_override - return JSONResponse( - content=make_chat_response(request.model, content, usage=usage) - ) + payload = make_chat_response(request.model, content, usage=usage) + return JSONResponse(content=append_share_payload(payload, result)) if model_info and model_info.is_video: # 提取视频配置 (默认值在 Pydantic 模型中处理) diff --git a/app/api/v1/image.py b/app/api/v1/image.py index 88d643f02..70716f7fe 100644 --- a/app/api/v1/image.py +++ b/app/api/v1/image.py @@ -14,6 +14,7 @@ from app.services.grok.services.image import ImageGenerationService from app.services.grok.services.image_edit import ImageEditService from app.services.grok.services.model import ModelService +from app.services.grok.utils.share_resolver import resolve_grok_share_image from app.services.token import get_token_manager from app.core.exceptions import ValidationException, AppException, ErrorType from app.core.config import get_config @@ -53,6 +54,9 @@ class ImageGenerationRequest(BaseModel): response_format: Optional[str] = Field(None, description="响应格式") style: Optional[str] = Field(None, description="风格 (暂不支持)") stream: Optional[bool] = Field(False, description="是否流式输出") + return_share_url: Optional[bool] = Field( + False, description="是否返回 Grok 分享页链接" + ) class ImageEditRequest(BaseModel): @@ -72,6 +76,30 @@ class ImageEditRequest(BaseModel): stream: Optional[bool] = Field(False, description="是否流式输出") +class ShareImageResolveRequest(BaseModel): + share_url: str = Field(..., description="Grok 分享页链接") + + +def append_share_payload(payload: dict, result) -> dict: + share_url = getattr(result, "share_url", "") or "" + if share_url: + payload["share_url"] = share_url + + share_image_url = getattr(result, "share_image_url", "") or "" + if share_image_url: + payload["share_image_url"] = share_image_url + + share_image_source = getattr(result, "share_image_source", "") or "" + if share_image_source: + payload["share_image_source"] = share_image_source + + share_image_expires_at = getattr(result, "share_image_expires_at", "") or "" + if share_image_expires_at: + payload["share_image_expires_at"] = share_image_expires_at + + return payload + + def _validate_common_request( request: Union[ImageGenerationRequest, ImageEditRequest], *, @@ -108,6 +136,13 @@ def _validate_common_request( code="invalid_response_format", ) + if request.stream and getattr(request, "return_share_url", False): + raise ValidationException( + message="return_share_url is only supported when stream=false", + param="return_share_url", + code="share_url_stream_not_supported", + ) + if request.response_format: allowed_formats = {"b64_json", "base64", "url"} if request.response_format not in allowed_formats: @@ -127,16 +162,8 @@ def _validate_common_request( def validate_generation_request(request: ImageGenerationRequest): """验证图片生成请求参数""" - if request.model != "grok-imagine-1.0": - raise ValidationException( - message="The model `grok-imagine-1.0` is required for image generation.", - param="model", - code="model_not_supported", - ) - # 验证模型 - 通过 is_image 检查 model_info = ModelService.get(request.model) if not model_info or not model_info.is_image: - # 获取支持的图片模型列表 image_models = [m.model_id for m in ModelService.MODELS if m.is_image] raise ValidationException( message=( @@ -289,6 +316,7 @@ async def create_image(request: ImageGenerationRequest): size=request.size, aspect_ratio=aspect_ratio, stream=bool(request.stream), + return_share_url=bool(request.return_share_url), ) if result.stream: @@ -306,13 +334,28 @@ async def create_image(request: ImageGenerationRequest): "input_tokens_details": {"text_tokens": 0, "image_tokens": 0}, } - return JSONResponse( - content={ - "created": int(time.time()), - "data": data, - "usage": usage, - } - ) + payload = { + "created": int(time.time()), + "data": data, + "usage": usage, + } + return JSONResponse(content=append_share_payload(payload, result)) + + +@router.post("/images/share/resolve") +async def resolve_share_image(request: ShareImageResolveRequest): + result = await resolve_grok_share_image(request.share_url) + payload = { + "share_url": result.share_url, + "resolved": bool(result.image_url), + } + if result.image_url: + payload["share_image_url"] = result.image_url + if result.source: + payload["share_image_source"] = result.source + if result.expires_at: + payload["share_image_expires_at"] = result.expires_at + return JSONResponse(content=payload) @router.post("/images/edits") diff --git a/app/services/grok/services/chat.py b/app/services/grok/services/chat.py index eec1e8ae8..3e92ea9c7 100644 --- a/app/services/grok/services/chat.py +++ b/app/services/grok/services/chat.py @@ -5,7 +5,7 @@ import asyncio import re import uuid -from typing import Dict, List, Any, AsyncGenerator, AsyncIterable +from typing import Dict, List, Any, AsyncGenerator, AsyncIterable, Optional import orjson from curl_cffi.requests.errors import RequestsError @@ -24,7 +24,9 @@ from app.services.grok.utils import process as proc_base from app.services.grok.utils.retry import pick_token, rate_limited, transient_upstream from app.services.reverse.app_chat import AppChatReverse +from app.services.reverse.app_chat_share import AppChatShareReverse from app.services.reverse.utils.session import ResettableSession +from app.services.grok.utils.share_resolver import resolve_grok_share_image from app.services.grok.utils.stream import wrap_stream_with_usage from app.services.grok.utils.tool_call import ( build_tool_prompt, @@ -40,6 +42,111 @@ _CHAT_SEM_VALUE = None +def _pick_str(value: Any) -> str: + if isinstance(value, str): + return value.strip() + return "" + + +def _extract_app_chat_share_context(data: Dict[str, Any]) -> tuple[str, str]: + if not isinstance(data, dict): + return "", "" + + result = data.get("result") + if not isinstance(result, dict): + return "", "" + + response = result.get("response") + if not isinstance(response, dict): + response = result + + conversation_id = "" + for value in ( + result.get("conversationId"), + response.get("conversationId"), + ((result.get("conversation") or {}).get("conversationId")) + if isinstance(result.get("conversation"), dict) + else "", + ((response.get("conversation") or {}).get("conversationId")) + if isinstance(response.get("conversation"), dict) + else "", + ): + conversation_id = _pick_str(value) + if conversation_id: + break + + response_id = "" + model_resp = response.get("modelResponse") + if isinstance(model_resp, dict): + response_id = _pick_str(model_resp.get("responseId")) + + if not response_id: + for source in (response, result): + if not isinstance(source, dict): + continue + candidate = _pick_str(source.get("responseId")) + if candidate: + response_id = candidate + break + + return conversation_id, response_id + + +def _build_app_chat_share_url(payload: Any) -> str: + if not isinstance(payload, dict): + return "" + + for key in ("shareLink", "shareUrl", "shareURL"): + value = _pick_str(payload.get(key)) + if value: + return value + + share_link_id = _pick_str(payload.get("shareLinkId")) or _pick_str( + payload.get("publicId") + ) + if not share_link_id: + return "" + + return f"https://grok.com/share/{share_link_id}" + + +async def _create_chat_share_link( + token: str, + conversation_id: str, + response_id: str, +) -> str: + if not token or not conversation_id or not response_id: + return "" + + try: + async with ResettableSession() as session: + response = await AppChatShareReverse.request( + session, + token, + conversation_id, + response_id, + ) + payload = response.json() if response is not None else {} + return _build_app_chat_share_url(payload) + except Exception as e: + logger.warning(f"Chat share link failed: {e}") + return "" + + +async def _resolve_share_image_details( + share_url: str, +) -> tuple[str, str, str]: + if not share_url: + return "", "", "" + + try: + resolved = await resolve_grok_share_image(share_url) + return resolved.image_url, resolved.source, resolved.expires_at + except Exception as e: + logger.warning(f"Chat share image resolve failed: {e}") + return "", "", "" + + def extract_tool_text(raw: str, rollout_id: str = "") -> str: if not raw: return "" @@ -559,6 +666,8 @@ def __init__( self.prompt_tokens = max(0, int(prompt_tokens or 0)) self._completion_parts: list[str] = [] self._completion_tool_calls: list[dict[str, Any]] = [] + self._pending_image_candidates: dict[str, proc_base.ImageCandidate] = {} + self._emitted_image_keys: set[str] = set() def _record_content(self, content: str) -> None: if content: @@ -568,6 +677,46 @@ def _record_tool_call(self, tool_call: Any) -> None: if isinstance(tool_call, dict): self._completion_tool_calls.append(tool_call) + def _image_alt_text(self, url: str) -> str: + path = proc_base._image_candidate_path(url) + parts = path.split("/") + if len(parts) >= 2 and parts[-2]: + return parts[-2] + return "image" + + def _buffer_image_candidate(self, candidate: proc_base.ImageCandidate) -> None: + if not candidate or candidate.key in self._emitted_image_keys: + return + existing = self._pending_image_candidates.get(candidate.key) + self._pending_image_candidates[candidate.key] = ( + proc_base._pick_preferred_image_candidate(existing, candidate) + ) + + async def _render_image_candidate( + self, candidate: proc_base.ImageCandidate + ) -> str: + if not candidate or candidate.key in self._emitted_image_keys: + return "" + dl_service = self._get_dl() + rendered = await dl_service.render_image( + candidate.url, self.token, self._image_alt_text(candidate.url) + ) + if not rendered: + return "" + self._emitted_image_keys.add(candidate.key) + self._pending_image_candidates.pop(candidate.key, None) + return f"{rendered}\n" + + async def _flush_pending_images(self) -> list[str]: + rendered_images: list[str] = [] + for candidate in sorted( + self._pending_image_candidates.values(), key=lambda item: item.order + ): + rendered = await self._render_image_candidate(candidate) + if rendered: + rendered_images.append(rendered) + return rendered_images + def _with_tool_index(self, tool_call: Any) -> Any: if not isinstance(tool_call, dict): return tool_call @@ -818,15 +967,8 @@ async def process(self, response: AsyncIterable[bytes]) -> AsyncGenerator[str, N self.think_opened = False self.think_closed_once = True self.image_think_active = False - for url in proc_base._collect_images(mr): - parts = url.split("/") - img_id = parts[-2] if len(parts) >= 2 else "image" - dl_service = self._get_dl() - rendered = await dl_service.render_image( - url, self.token, img_id - ) - self._record_content(f"{rendered}\n") - yield self._sse(f"{rendered}\n") + for candidate in proc_base._collect_image_candidates(mr): + self._buffer_image_candidate(candidate) if ( (meta := mr.get("metadata", {})) @@ -837,24 +979,16 @@ async def process(self, response: AsyncIterable[bytes]) -> AsyncGenerator[str, N continue if card := resp.get("cardAttachment"): - json_data = card.get("jsonData") - if isinstance(json_data, str) and json_data.strip(): - try: - card_data = orjson.loads(json_data) - except orjson.JSONDecodeError: - card_data = None - if isinstance(card_data, dict): - image = card_data.get("image") or {} - original = image.get("original") - title = image.get("title") or "" - if original: - title_safe = title.replace("\n", " ").strip() - if title_safe: - self._record_content(f"![{title_safe}]({original})\n") - yield self._sse(f"![{title_safe}]({original})\n") - else: - self._record_content(f"![image]({original})\n") - yield self._sse(f"![image]({original})\n") + for candidate in proc_base._collect_image_candidates( + {"cardAttachment": card} + ): + self._buffer_image_candidate(candidate) + if proc_base._is_preview_image_url(candidate.url): + continue + rendered = await self._render_image_candidate(candidate) + if rendered: + self._record_content(rendered) + yield self._sse(rendered) continue if (token := resp.get("token")) is not None: @@ -903,6 +1037,10 @@ async def process(self, response: AsyncIterable[bytes]) -> AsyncGenerator[str, N yield self._sse("\n") self.think_closed_once = True + for rendered in await self._flush_pending_images(): + self._record_content(rendered) + yield self._sse(rendered) + if self._tool_stream_enabled: for kind, payload in self._flush_tool_stream(): if kind == "text": @@ -1020,8 +1158,11 @@ def _filter_content(self, content: str) -> str: async def process(self, response: AsyncIterable[bytes]) -> dict[str, Any]: """Process and collect full response.""" response_id = "" + conversation_id = "" fingerprint = "" content = "" + rendered_image_keys: set[str] = set() + has_image_output = False idle_timeout = get_config("chat.stream_timeout") try: @@ -1037,12 +1178,21 @@ async def process(self, response: AsyncIterable[bytes]) -> dict[str, Any]: continue resp = data.get("result", {}).get("response", {}) + current_conversation_id, current_response_id = ( + _extract_app_chat_share_context(data) + ) + if current_conversation_id: + conversation_id = current_conversation_id + if current_response_id: + response_id = current_response_id if (llm := resp.get("llmInfo")) and not fingerprint: fingerprint = llm.get("modelHash", "") if mr := resp.get("modelResponse"): - response_id = mr.get("responseId", "") + mr_response_id = _pick_str(mr.get("responseId")) + if mr_response_id: + response_id = mr_response_id content = mr.get("message", "") card_map: dict[str, tuple[str, str]] = {} @@ -1064,12 +1214,16 @@ async def process(self, response: AsyncIterable[bytes]) -> dict[str, Any]: card_map[card_id] = (title, original) if content and card_map: + has_image_output = True def _render_card(match: re.Match) -> str: card_id = match.group(1) item = card_map.get(card_id) if not item: return "" title, original = item + rendered_image_keys.add( + proc_base._image_candidate_key(original) + ) title_safe = title.replace("\n", " ").strip() or "image" prefix = "" if match.start() > 0: @@ -1085,16 +1239,22 @@ def _render_card(match: re.Match) -> str: flags=re.DOTALL, ) - if urls := proc_base._collect_images(mr): - content += "\n" - for url in urls: - parts = url.split("/") - img_id = parts[-2] if len(parts) >= 2 else "image" - dl_service = self._get_dl() - rendered = await dl_service.render_image( - url, self.token, img_id - ) - content += f"{rendered}\n" + extra_images: list[str] = [] + for candidate in proc_base._collect_image_candidates(mr): + if candidate.key in rendered_image_keys: + continue + has_image_output = True + dl_service = self._get_dl() + rendered = await dl_service.render_image( + candidate.url, + self.token, + candidate.key or "image", + ) + if rendered: + extra_images.append(rendered) + + if extra_images: + content += "\n" + "\n".join(extra_images) + "\n" if ( (meta := mr.get("metadata", {})) @@ -1161,7 +1321,7 @@ def _render_card(match: re.Match) -> str: if tool_calls_result: message_obj["tool_calls"] = tool_calls_result - return { + result = { "id": response_id, "object": "chat.completion", "created": self.created, @@ -1181,6 +1341,28 @@ def _render_card(match: re.Match) -> str: ), } + if has_image_output and conversation_id and response_id: + share_url = await _create_chat_share_link( + self.token, + conversation_id, + response_id, + ) + if share_url: + result["share_url"] = share_url + ( + share_image_url, + share_image_source, + share_image_expires_at, + ) = await _resolve_share_image_details(share_url) + if share_image_url: + result["share_image_url"] = share_image_url + if share_image_source: + result["share_image_source"] = share_image_source + if share_image_expires_at: + result["share_image_expires_at"] = share_image_expires_at + + return result + __all__ = [ "GrokChatService", diff --git a/app/services/grok/services/image.py b/app/services/grok/services/image.py index 28ea21fa6..c5673bcad 100644 --- a/app/services/grok/services/image.py +++ b/app/services/grok/services/image.py @@ -11,21 +11,31 @@ from typing import Any, AsyncGenerator, AsyncIterable, Dict, List, Optional, Union import orjson +from curl_cffi.requests.errors import RequestsError from app.core.config import get_config from app.core.logger import logger from app.core.storage import DATA_DIR from app.core.exceptions import AppException, ErrorType, UpstreamException -from app.services.grok.utils.process import BaseProcessor +from app.services.grok.utils.process import ( + BaseProcessor, + _collect_image_candidates, + _is_http2_error, + _normalize_line, + _pick_preferred_image_candidate, + _with_idle_timeout, +) from app.services.grok.utils.retry import pick_token, rate_limited from app.services.grok.utils.response import make_response_id, make_chat_chunk, wrap_image_content +from app.services.grok.utils.share_resolver import resolve_grok_share_image from app.services.grok.utils.stream import wrap_stream_with_usage from app.services.grok.services.chat import GrokChatService from app.services.grok.services.image_edit import ( ImageStreamProcessor as AppChatImageStreamProcessor, - ImageCollectProcessor as AppChatImageCollectProcessor, ) from app.services.token import EffortType +from app.services.reverse.app_chat_share import AppChatShareReverse +from app.services.reverse.utils.session import ResettableSession from app.services.reverse.ws_imagine import ImagineWebSocketReverse @@ -37,6 +47,386 @@ class ImageGenerationResult: stream: bool data: Union[AsyncGenerator[str, None], List[str]] usage_override: Optional[dict] = None + share_url: str = "" + share_image_url: str = "" + share_image_source: str = "" + share_image_expires_at: str = "" + + +@dataclass +class AppChatImageCollectPayload: + images: List[str] + post_id: str = "" + post_id_rank: int = 999 + conversation_id: str = "" + response_id: str = "" + + +def _pick_str(value: Any) -> str: + if isinstance(value, str): + return value.strip() + return "" + + +def _maybe_extract_uuid(value: Any) -> str: + text = _pick_str(value) + if not text: + return "" + + if ( + len(text) in (32, 36) + and text.replace("-", "").isalnum() + and text.count("-") in (0, 4) + ): + return text + + import re + + match = re.search(r"/generated/([0-9a-fA-F-]{32,36})/", text) + if match: + return match.group(1) + + match = re.search(r"\b([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})\b", text) + if match: + return match.group(1) + + return "" + + +def _append_post_id_candidate( + candidates: List[tuple[int, str]], + rank: int, + value: Any, +): + post_id = _maybe_extract_uuid(value) + if post_id: + candidates.append((rank, post_id)) + + +def _collect_post_id_candidates(resp: Dict[str, Any]) -> List[tuple[int, str]]: + candidates: List[tuple[int, str]] = [] + + post = resp.get("post") + if isinstance(post, dict): + _append_post_id_candidate(candidates, 1, post.get("id")) + + for key in ("postId", "post_id"): + _append_post_id_candidate(candidates, 2, resp.get(key)) + + for key in ("parentPostId", "parent_post_id", "originalPostId", "original_post_id"): + _append_post_id_candidate(candidates, 3, resp.get(key)) + + image_resp = resp.get("streamingImageGenerationResponse") + if isinstance(image_resp, dict): + for key in ("postId", "parentPostId", "originalPostId"): + _append_post_id_candidate(candidates, 4, image_resp.get(key)) + + model_resp = resp.get("modelResponse") + if isinstance(model_resp, dict): + file_attachments = model_resp.get("fileAttachments") + if isinstance(file_attachments, list): + for value in file_attachments: + _append_post_id_candidate(candidates, 5, value) + elif file_attachments: + _append_post_id_candidate(candidates, 5, file_attachments) + + for key in ("postId", "parentPostId", "originalPostId"): + _append_post_id_candidate(candidates, 5, model_resp.get(key)) + + metadata = model_resp.get("metadata") + if isinstance(metadata, dict): + for key in ("postId", "parentPostId", "originalPostId"): + _append_post_id_candidate(candidates, 6, metadata.get(key)) + + raw_cards = model_resp.get("cardAttachmentsJson") or [] + if isinstance(raw_cards, list): + for raw in raw_cards: + if not isinstance(raw, str) or not raw.strip(): + continue + try: + card = orjson.loads(raw) + except orjson.JSONDecodeError: + continue + if not isinstance(card, dict): + continue + for key in ("postId", "parentPostId", "originalPostId"): + _append_post_id_candidate(candidates, 6, card.get(key)) + + card_attachment = resp.get("cardAttachment") + if isinstance(card_attachment, dict): + raw = card_attachment.get("jsonData") + if isinstance(raw, str) and raw.strip(): + try: + card = orjson.loads(raw) + except orjson.JSONDecodeError: + card = None + if isinstance(card, dict): + for key in ("postId", "parentPostId", "originalPostId"): + _append_post_id_candidate(candidates, 6, card.get(key)) + + return candidates + + +def _pick_best_post_id(candidates: List[tuple[int, str]]) -> tuple[str, int]: + best_rank = 999 + best_post_id = "" + for rank, value in candidates: + if value and rank < best_rank: + best_rank = rank + best_post_id = value + return best_post_id, best_rank + + +def _new_session() -> ResettableSession: + browser = get_config("proxy.browser") + if browser: + return ResettableSession(impersonate=browser) + return ResettableSession() + + +def _extract_app_chat_result_payload(data: Dict[str, Any]) -> Dict[str, Any]: + if not isinstance(data, dict): + return {} + result = data.get("result") + if not isinstance(result, dict): + return {} + response = result.get("response") + if isinstance(response, dict): + return response + return result + + +def _extract_app_chat_share_context(data: Dict[str, Any]) -> tuple[str, str]: + if not isinstance(data, dict): + return "", "" + + result = data.get("result") + if not isinstance(result, dict): + return "", "" + + response = result.get("response") + if not isinstance(response, dict): + response = result + + conversation_id = "" + for value in ( + result.get("conversationId"), + response.get("conversationId"), + ((result.get("conversation") or {}).get("conversationId")) + if isinstance(result.get("conversation"), dict) + else "", + ((response.get("conversation") or {}).get("conversationId")) + if isinstance(response.get("conversation"), dict) + else "", + ): + conversation_id = _pick_str(value) + if conversation_id: + break + + response_id = "" + model_resp = response.get("modelResponse") + if isinstance(model_resp, dict): + response_id = _pick_str(model_resp.get("responseId")) + + if not response_id: + for source in (response, result): + if not isinstance(source, dict): + continue + candidate = _pick_str(source.get("responseId")) + if not candidate: + continue + if any( + key in source + for key in ( + "modelResponse", + "cardAttachment", + "token", + "finalMetadata", + "progressReport", + "uiLayout", + "llmInfo", + "streamingImageGenerationResponse", + ) + ): + response_id = candidate + break + + if not response_id: + user_response = result.get("userResponse") + if isinstance(user_response, dict): + response_id = _pick_str(user_response.get("responseId")) + + return conversation_id, response_id + + +def _build_app_chat_share_url(payload: Any) -> str: + if not isinstance(payload, dict): + return "" + + for key in ("shareLink", "shareUrl", "shareURL"): + value = _pick_str(payload.get(key)) + if value: + return value + + share_link_id = _pick_str(payload.get("shareLinkId")) or _pick_str( + payload.get("publicId") + ) + if not share_link_id: + return "" + + return f"https://grok.com/share/{share_link_id}" + + +async def _create_image_share_link( + token: str, + conversation_id: str, + response_id: str, +) -> str: + if not token or not conversation_id or not response_id: + return "" + + try: + async with _new_session() as session: + response = await AppChatShareReverse.request( + session, + token, + conversation_id, + response_id, + ) + payload = response.json() if response is not None else {} + share_link = _build_app_chat_share_url(payload) + if share_link: + logger.info(f"Image share link created: {share_link}") + return share_link + except Exception as e: + logger.warning(f"Image share link failed: {e}") + + return "" + + +async def _resolve_share_image_details( + share_url: str, +) -> tuple[str, str, str]: + if not share_url: + return "", "", "" + + try: + resolved = await resolve_grok_share_image(share_url) + return ( + resolved.image_url, + resolved.source, + resolved.expires_at, + ) + except Exception as e: + logger.warning(f"Share image resolve failed: {e}") + return "", "", "" + + +class ImageAppChatCollectProcessor(BaseProcessor): + """App-chat image non-stream processor with share-context collection.""" + + def __init__(self, model: str, token: str = "", response_format: str = "b64_json"): + if response_format == "base64": + response_format = "b64_json" + super().__init__(model, token) + self.response_format = response_format + + async def _process_image_url(self, url: str) -> str: + if self.response_format == "url": + return await self.process_url(url, "image") + + try: + dl_service = self._get_dl() + base64_data = await dl_service.parse_b64(url, self.token, "image") + if base64_data: + if "," in base64_data: + return base64_data.split(",", 1)[1] + return base64_data + except Exception as e: + logger.warning( + f"Failed to convert image to base64, falling back to URL: {e}" + ) + return await self.process_url(url, "image") + + return "" + + async def process(self, response: AsyncIterable[bytes]) -> AppChatImageCollectPayload: + best_candidates: dict[str, Any] = {} + post_id = "" + post_id_rank = 999 + conversation_id = "" + response_id = "" + idle_timeout = get_config("image.stream_timeout") + + try: + async for line in _with_idle_timeout(response, idle_timeout, self.model): + line = _normalize_line(line) + if not line: + continue + try: + data = orjson.loads(line) + except orjson.JSONDecodeError: + continue + + resp = _extract_app_chat_result_payload(data) + if not resp: + continue + + current_conversation_id, current_response_id = ( + _extract_app_chat_share_context(data) + ) + if current_conversation_id: + conversation_id = current_conversation_id + if current_response_id: + response_id = current_response_id + + current_post_id, current_rank = _pick_best_post_id( + _collect_post_id_candidates(resp) + ) + if current_post_id and current_rank < post_id_rank: + post_id = current_post_id + post_id_rank = current_rank + + if mr := resp.get("modelResponse"): + if candidates := _collect_image_candidates(mr): + for candidate in candidates: + fallback_post_id = _maybe_extract_uuid(candidate.url) + if fallback_post_id and post_id_rank > 7: + post_id = fallback_post_id + post_id_rank = 7 + existing = best_candidates.get(candidate.key) + best_candidates[candidate.key] = ( + _pick_preferred_image_candidate(existing, candidate) + ) + + except asyncio.CancelledError: + logger.debug("Image collect cancelled by client") + except RequestsError as e: + if _is_http2_error(e): + logger.warning(f"HTTP/2 stream error in image collect: {e}") + else: + logger.error(f"Image collect request error: {e}") + except Exception as e: + logger.error( + f"Image collect processing error: {e}", + extra={"error_type": type(e).__name__}, + ) + finally: + await self.close() + + images: List[str] = [] + for candidate in sorted(best_candidates.values(), key=lambda item: item.order): + processed = await self._process_image_url(candidate.url) + if processed: + images.append(processed) + + return AppChatImageCollectPayload( + images=images, + post_id=post_id, + post_id_rank=post_id_rank, + conversation_id=conversation_id, + response_id=response_id, + ) class ImageGenerationService: @@ -68,6 +458,7 @@ async def generate( stream: bool, enable_nsfw: Optional[bool] = None, chat_format: bool = False, + return_share_url: bool = False, ) -> ImageGenerationResult: max_token_retries = int(get_config("retry.max_retry") or 3) tried_tokens: set[str] = set() @@ -192,6 +583,7 @@ async def _stream_retry() -> AsyncGenerator[str, None]: n=n, response_format=response_format, enable_nsfw=enable_nsfw, + return_share_url=return_share_url, ) except UpstreamException as app_chat_error: if rate_limited(app_chat_error): @@ -319,11 +711,12 @@ async def _collect_app_chat( n: int, response_format: str, enable_nsfw: Optional[bool] = None, + return_share_url: bool = False, ) -> ImageGenerationResult: per_call = min(max(1, n), 2) calls_needed = max(1, int(math.ceil(n / per_call))) - async def _call_generate(call_target: int) -> List[str]: + async def _call_generate(call_target: int) -> AppChatImageCollectPayload: response = await GrokChatService().chat( token=token, message=prompt, @@ -335,7 +728,7 @@ async def _call_generate(call_target: int) -> List[str]: call_target, enable_nsfw ), ) - processor = AppChatImageCollectProcessor( + processor = ImageAppChatCollectProcessor( model_info.model_id, token, response_format=response_format, @@ -343,7 +736,10 @@ async def _call_generate(call_target: int) -> List[str]: return await processor.process(response) if calls_needed == 1: - all_images = await _call_generate(n) + payload = await _call_generate(n) + all_images = payload.images + share_conversation_id = payload.conversation_id + share_response_id = payload.response_id else: tasks = [] for i in range(calls_needed): @@ -351,6 +747,8 @@ async def _call_generate(call_target: int) -> List[str]: tasks.append(_call_generate(min(per_call, remaining))) results = await asyncio.gather(*tasks, return_exceptions=True) all_images: List[str] = [] + share_conversation_id = "" + share_response_id = "" last_error: Optional[Exception] = None rate_limit_error: Optional[Exception] = None for result in results: @@ -360,7 +758,14 @@ async def _call_generate(call_target: int) -> List[str]: if rate_limited(result): rate_limit_error = result continue - for image in result: + if ( + not share_conversation_id + and result.conversation_id + and result.response_id + ): + share_conversation_id = result.conversation_id + share_response_id = result.response_id + for image in result.images: if image not in all_images: all_images.append(image) @@ -382,6 +787,22 @@ async def _call_generate(call_target: int) -> List[str]: logger.warning(f"Failed to consume token: {e}") selected = self._select_images(all_images, n) + share_url = "" + share_image_url = "" + share_image_source = "" + share_image_expires_at = "" + if return_share_url and share_conversation_id and share_response_id: + share_url = await _create_image_share_link( + token, + share_conversation_id, + share_response_id, + ) + if share_url: + ( + share_image_url, + share_image_source, + share_image_expires_at, + ) = await _resolve_share_image_details(share_url) usage_override = { "total_tokens": 0, "input_tokens": 0, @@ -389,7 +810,13 @@ async def _call_generate(call_target: int) -> List[str]: "input_tokens_details": {"text_tokens": 0, "image_tokens": 0}, } return ImageGenerationResult( - stream=False, data=selected, usage_override=usage_override + stream=False, + data=selected, + usage_override=usage_override, + share_url=share_url, + share_image_url=share_image_url, + share_image_source=share_image_source, + share_image_expires_at=share_image_expires_at, ) async def _collect_ws( diff --git a/app/services/grok/utils/process.py b/app/services/grok/utils/process.py index 69353c651..11377fb7c 100644 --- a/app/services/grok/utils/process.py +++ b/app/services/grok/utils/process.py @@ -3,8 +3,12 @@ """ import asyncio +import json +import re import time +from dataclasses import dataclass from typing import Any, AsyncGenerator, Optional, AsyncIterable, List, TypeVar +from urllib.parse import urlparse from app.core.config import get_config from app.core.logger import logger @@ -14,6 +18,20 @@ T = TypeVar("T") +_GENERATED_PATH_RE = re.compile(r"/generated/([^/?#]+)/") +_GENERATED_PART_SUFFIX_RE = re.compile(r"-part-\d+$") + + +@dataclass +class ImageCandidate: + """Normalized image candidate extracted from Grok responses.""" + + url: str + key: str + priority: int + order: int + source: str + def _is_http2_error(e: Exception) -> bool: """检查是否为 HTTP/2 流错误""" @@ -39,16 +57,135 @@ def _normalize_line(line: Any) -> Optional[str]: return text -def _collect_images(obj: Any) -> List[str]: - """递归收集响应中的图片 URL""" - urls: List[str] = [] - seen = set() +def _image_candidate_path(url: str) -> str: + if not isinstance(url, str): + return "" + value = url.strip() + if not value: + return "" + if value.startswith("http://") or value.startswith("https://"): + parsed = urlparse(value) + path = parsed.path or "" + if parsed.query: + path = f"{path}?{parsed.query}" + return path or value + return value + + +def _image_candidate_key(url: str) -> str: + path = _image_candidate_path(url) + match = _GENERATED_PATH_RE.search(path) + if match: + return _GENERATED_PART_SUFFIX_RE.sub("", match.group(1)) + return path or url.strip() + + +def _is_preview_image_url(url: str) -> bool: + path = _image_candidate_path(url) + if not path: + return False + if re.search(r"/generated/[^/?#]+-part-\d+/", path): + return True + lowered = path.lower() + return any(flag in lowered for flag in ("/preview/", "/thumbnail/", "/thumb/")) + + +def _is_final_image_url(url: str) -> bool: + path = _image_candidate_path(url) + if not path: + return False + if _is_preview_image_url(url): + return False + return bool(_GENERATED_PATH_RE.search(path)) + + +def _image_candidate_priority(url: str, source: str) -> int: + priority = { + "card_original": 500, + "original": 450, + "generated_list": 300, + "generic_original": 280, + "generic_image_url": 220, + "image_chunk": 120, + }.get(source, 200) + + if _is_final_image_url(url): + priority += 40 + if _is_preview_image_url(url): + priority -= 120 + if "assets.grok.com" in url or "assets.grokusercontent.com" in url: + priority += 20 + return priority + + +def _pick_preferred_image_candidate( + existing: Optional[ImageCandidate], incoming: ImageCandidate +) -> ImageCandidate: + if existing is None: + return incoming + + if incoming.priority > existing.priority: + incoming.order = existing.order + return incoming + if incoming.priority < existing.priority: + return existing + + if not _is_preview_image_url(incoming.url) and _is_preview_image_url(existing.url): + incoming.order = existing.order + return incoming + return existing + + +def _collect_image_candidates(obj: Any) -> List[ImageCandidate]: + """Recursively collect image candidates and keep the best candidate per image.""" + best_by_key: dict[str, ImageCandidate] = {} + next_order = 0 - def add(url: str): - if not url or url in seen: + def add(url: str, source: str): + nonlocal next_order + if not isinstance(url, str): + return + normalized = url.strip() + if not normalized: + return + + candidate = ImageCandidate( + url=normalized, + key=_image_candidate_key(normalized), + priority=_image_candidate_priority(normalized, source), + order=next_order, + source=source, + ) + existing = best_by_key.get(candidate.key) + preferred = _pick_preferred_image_candidate(existing, candidate) + if existing is None: + best_by_key[candidate.key] = preferred + next_order += 1 + return + if preferred is not existing: + best_by_key[candidate.key] = preferred + + def collect_card_attachment(value: Any, *, from_json_data: bool = False): + if not isinstance(value, dict): + return + + image_chunk = value.get("image_chunk") or {} + if isinstance(image_chunk, dict): + image_url = image_chunk.get("imageUrl") + add(image_url, "image_chunk") + + image = value.get("image") or {} + if isinstance(image, dict): + original = image.get("original") + add(original, "card_original" if from_json_data else "original") + + def parse_card_json(raw: Any): + if not isinstance(raw, str) or not raw.strip(): + return + try: + collect_card_attachment(json.loads(raw), from_json_data=True) + except Exception: return - seen.add(url) - urls.append(url) def walk(value: Any): if isinstance(value, dict): @@ -56,10 +193,26 @@ def walk(value: Any): if key in {"generatedImageUrls", "imageUrls", "imageURLs"}: if isinstance(item, list): for url in item: - if isinstance(url, str): - add(url) - elif isinstance(item, str): - add(item) + add(url, "generated_list") + else: + add(item, "generated_list") + continue + if key == "cardAttachmentsJson" and isinstance(item, list): + for raw in item: + parse_card_json(raw) + continue + if key == "cardAttachment" and isinstance(item, dict): + json_data = item.get("jsonData") + if isinstance(json_data, str) and json_data.strip(): + parse_card_json(json_data) + else: + collect_card_attachment(item) + continue + if key == "imageUrl": + add(item, "generic_image_url") + continue + if key == "original": + add(item, "generic_original") continue walk(item) elif isinstance(value, list): @@ -67,7 +220,12 @@ def walk(value: Any): walk(item) walk(obj) - return urls + return sorted(best_by_key.values(), key=lambda item: item.order) + + +def _collect_images(obj: Any) -> List[str]: + """Collect best image URLs after candidate prioritization.""" + return [candidate.url for candidate in _collect_image_candidates(obj)] async def _with_idle_timeout( @@ -144,9 +302,15 @@ async def process_url(self, path: str, media_type: str = "image") -> str: __all__ = [ + "ImageCandidate", "BaseProcessor", "_with_idle_timeout", "_normalize_line", + "_collect_image_candidates", "_collect_images", + "_image_candidate_key", + "_is_final_image_url", + "_is_preview_image_url", "_is_http2_error", + "_pick_preferred_image_candidate", ] diff --git a/app/services/grok/utils/share_resolver.py b/app/services/grok/utils/share_resolver.py new file mode 100644 index 000000000..c68261f4f --- /dev/null +++ b/app/services/grok/utils/share_resolver.py @@ -0,0 +1,413 @@ +""" +Grok share-page image resolver. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from html import unescape +from html.parser import HTMLParser +import inspect +import re +from typing import Any, List, Optional +from urllib.parse import parse_qs, urljoin, urlparse + +from app.core.config import get_config +from app.core.exceptions import UpstreamException, ValidationException +from app.core.logger import logger +from app.core.proxy_pool import build_http_proxies, get_current_proxy_from +from app.services.grok.utils.process import _collect_images +from app.services.reverse.utils.session import ResettableSession + +_VALID_SHARE_HOSTS = {"grok.com", "www.grok.com"} +_ASSET_CANDIDATE_PATTERNS = ( + ("assets", re.compile(r"https://assets\.grok\.com/users/[^\s\"'<>]+", re.I)), + ( + "shared_assets", + re.compile(r"https://assets\.grokusercontent\.com/users/[^\s\"'<>]+", re.I), + ), +) +_PREVIEW_PATH_RE = re.compile(r"/(opengraph-image|twitter-image)/", re.I) + + +@dataclass +class ShareImageResolution: + share_url: str + image_url: str = "" + source: str = "" + expires_at: str = "" + + +class _ShareMetaParser(HTMLParser): + def __init__(self, base_url: str): + super().__init__(convert_charrefs=True) + self.base_url = base_url + self.meta: dict[str, str] = {} + self.preload_images: List[str] = [] + + def handle_starttag(self, tag: str, attrs): + attr_map = {str(k).lower(): str(v) for k, v in attrs if k and v is not None} + if tag.lower() == "meta": + key = (attr_map.get("property") or attr_map.get("name") or "").strip().lower() + content = (attr_map.get("content") or "").strip() + if key and content: + self.meta[key] = urljoin(self.base_url, content) + return + + if tag.lower() == "link": + rel = (attr_map.get("rel") or "").strip().lower() + as_type = (attr_map.get("as") or "").strip().lower() + href = (attr_map.get("href") or "").strip() + if rel == "preload" and as_type == "image" and href: + self.preload_images.append(urljoin(self.base_url, href)) + + +def normalize_grok_share_url(raw_url: str) -> str: + value = (raw_url or "").strip() + if not value: + raise ValidationException( + message="share_url cannot be empty", + param="share_url", + code="empty_share_url", + ) + + parsed = urlparse(value) + if parsed.scheme not in {"http", "https"} or parsed.hostname not in _VALID_SHARE_HOSTS: + raise ValidationException( + message="share_url must be a Grok share URL", + param="share_url", + code="invalid_share_url", + ) + + parts = [part for part in parsed.path.split("/") if part] + if len(parts) < 2 or parts[0] != "share": + raise ValidationException( + message="share_url must point to /share/", + param="share_url", + code="invalid_share_url", + ) + + share_id = parts[1].strip() + if not share_id: + raise ValidationException( + message="share_url is missing share id", + param="share_url", + code="invalid_share_url", + ) + + return f"https://grok.com/share/{share_id}" + + +def _to_rfc3339(value: datetime) -> str: + return value.astimezone(timezone.utc).isoformat().replace("+00:00", "Z") + + +def _extract_signed_expiry(raw_url: str) -> str: + parsed = urlparse((raw_url or "").strip()) + if not parsed.scheme or not parsed.netloc: + return "" + + query = parse_qs(parsed.query) + + se = (query.get("se") or [""])[0].strip() + if se: + try: + return _to_rfc3339(datetime.fromisoformat(se.replace("Z", "+00:00"))) + except ValueError: + pass + + for key in ("Expires", "expires", "exp"): + raw = (query.get(key) or [""])[0].strip() + if not raw: + continue + try: + unix = int(raw) + except ValueError: + continue + if unix > 0: + return _to_rfc3339(datetime.fromtimestamp(unix, tz=timezone.utc)) + + x_amz_date = (query.get("X-Amz-Date") or [""])[0].strip() + x_amz_expires = (query.get("X-Amz-Expires") or [""])[0].strip() + if x_amz_date and x_amz_expires: + try: + base = datetime.strptime(x_amz_date, "%Y%m%dT%H%M%SZ").replace( + tzinfo=timezone.utc + ) + ttl = int(x_amz_expires) + if ttl > 0: + return _to_rfc3339(base + timedelta(seconds=ttl)) + except ValueError: + pass + + return "" + + +def _decode_escaped_text(html: str) -> str: + return ( + unescape(html) + .replace("\\u002F", "/") + .replace("\\u002f", "/") + .replace("\\u002E", ".") + .replace("\\u002e", ".") + .replace("\\/", "/") + ) + + +def _share_id_from_url(share_url: str) -> str: + return share_url.rstrip("/").rsplit("/", 1)[-1].strip() + + +def _normalize_asset_url(raw_url: str) -> str: + value = (raw_url or "").strip() + if not value: + return "" + if value.startswith(("http://", "https://")): + return value + return f"https://assets.grok.com/{value.lstrip('/')}" + + +def _is_preview_url(url: str) -> bool: + return bool(_PREVIEW_PATH_RE.search((url or "").strip())) + + +def _candidate_priority(source: str, image_url: str) -> int: + url = (image_url or "").strip() + source_key = (source or "").strip().lower() + if not url: + return 999 + if "assets.grok.com/users/" in url.lower(): + return 0 + if "assets.grokusercontent.com/users/" in url.lower(): + return 1 + if source_key == "public_json": + return 2 + if source_key in {"assets", "shared_assets"}: + return 3 + if source_key == "og:image": + return 4 + if source_key == "twitter:image": + return 5 + if source_key == "preload": + return 6 + if not _is_preview_url(url): + return 7 + return 8 + + +def _pick_better_resolution( + first: Optional[ShareImageResolution], + second: Optional[ShareImageResolution], +) -> ShareImageResolution: + if first and first.image_url and not second: + return first + if second and second.image_url and not first: + return second + if not first: + return second or ShareImageResolution(share_url="") + if not second: + return first + first_rank = _candidate_priority(first.source, first.image_url) + second_rank = _candidate_priority(second.source, second.image_url) + if second_rank < first_rank: + return second + return first + + +def _pick_best_candidate(share_url: str, html: str) -> ShareImageResolution: + parser = _ShareMetaParser(share_url) + parser.feed(html) + + normalized = _decode_escaped_text(html) + seen: set[str] = set() + candidates: List[tuple[int, str, str]] = [] + + for source, pattern in _ASSET_CANDIDATE_PATTERNS: + for match in pattern.finditer(normalized): + url = match.group(0).rstrip("\\") + if url and url not in seen: + seen.add(url) + candidates.append((0, source, url)) + + for source, url in ( + ("og:image", parser.meta.get("og:image", "")), + ("twitter:image", parser.meta.get("twitter:image", "")), + ): + if url and url not in seen: + seen.add(url) + candidates.append((1 if source == "og:image" else 2, source, url)) + + for url in parser.preload_images: + if url and url not in seen: + seen.add(url) + candidates.append((3, "preload", url)) + + if not candidates: + return ShareImageResolution(share_url=share_url) + + candidates.sort(key=lambda item: item[0]) + _, source, image_url = candidates[0] + return ShareImageResolution( + share_url=share_url, + image_url=image_url, + source=source, + expires_at=_extract_signed_expiry(image_url), + ) + + +async def _fetch_public_share_payload(share_url: str) -> dict[str, Any]: + timeout = get_config("image.timeout") or get_config("chat.timeout") or 30 + browser = get_config("proxy.browser") + user_agent = str(get_config("proxy.user_agent") or "").strip() or ( + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) " + "AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36" + ) + share_id = _share_id_from_url(share_url) + api = f"https://grok.com/rest/app-chat/share_links/{share_id}" + headers = { + "Accept": "application/json, text/plain, */*", + "Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8", + "Cache-Control": "no-cache", + "Pragma": "no-cache", + "Origin": "https://grok.com", + "Referer": share_url, + "User-Agent": user_agent, + } + + async with ResettableSession(impersonate=browser) as session: + _, proxy_url = get_current_proxy_from("proxy.base_proxy_url") + proxies = build_http_proxies(proxy_url) + response = await session.get( + api, + headers=headers, + proxies=proxies, + timeout=timeout, + allow_redirects=True, + impersonate=browser, + ) + if response.status_code != 200: + body = "" + try: + text_value = getattr(response, "text", "") + if callable(text_value): + text_value = text_value() + if inspect.isawaitable(text_value): + text_value = await text_value + if isinstance(text_value, str): + body = text_value + except Exception: + pass + raise UpstreamException( + message=f"Share public API fetch failed, {response.status_code}", + details={"status": response.status_code, "body": body[:1000]}, + code="share_public_api_fetch_failed", + ) + payload = response.json() + if isinstance(payload, dict): + return payload + return {} + + +async def _resolve_share_image_via_public_api(share_url: str) -> ShareImageResolution: + payload = await _fetch_public_share_payload(share_url) + for image_url in _collect_images(payload): + direct_url = _normalize_asset_url(image_url) + if not direct_url: + continue + return ShareImageResolution( + share_url=share_url, + image_url=direct_url, + source="public_json", + expires_at=_extract_signed_expiry(direct_url), + ) + return ShareImageResolution(share_url=share_url, source="public_json") + + +async def _fetch_share_html(share_url: str) -> str: + timeout = get_config("image.timeout") or get_config("chat.timeout") or 30 + browser = get_config("proxy.browser") + user_agent = str(get_config("proxy.user_agent") or "").strip() or ( + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) " + "AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36" + ) + headers = { + "Accept": "text/html,application/xhtml+xml", + "Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8", + "Cache-Control": "no-cache", + "Pragma": "no-cache", + "Referer": "https://grok.com/", + "User-Agent": user_agent, + } + + async with ResettableSession(impersonate=browser) as session: + _, proxy_url = get_current_proxy_from("proxy.base_proxy_url") + proxies = build_http_proxies(proxy_url) + response = await session.get( + share_url, + headers=headers, + proxies=proxies, + timeout=timeout, + allow_redirects=True, + impersonate=browser, + ) + if response.status_code != 200: + body = "" + try: + text_value = getattr(response, "text", "") + if callable(text_value): + text_value = text_value() + if inspect.isawaitable(text_value): + text_value = await text_value + if isinstance(text_value, str): + body = text_value + except Exception: + pass + raise UpstreamException( + message=f"Share page fetch failed, {response.status_code}", + details={"status": response.status_code, "body": body[:1000]}, + code="share_page_fetch_failed", + ) + text_value = getattr(response, "text", "") + if callable(text_value): + text_value = text_value() + if inspect.isawaitable(text_value): + text_value = await text_value + if isinstance(text_value, str): + return text_value + + content = getattr(response, "content", b"") + if isinstance(content, (bytes, bytearray)): + return bytes(content).decode("utf-8", "ignore") + return str(content or "") + + +async def resolve_grok_share_image(raw_share_url: str) -> ShareImageResolution: + share_url = normalize_grok_share_url(raw_share_url) + public_result = ShareImageResolution(share_url=share_url) + try: + public_result = await _resolve_share_image_via_public_api(share_url) + if _candidate_priority(public_result.source, public_result.image_url) <= 1: + return public_result + except Exception as exc: + logger.debug("Share resolver public API path failed for {}: {}", share_url, exc) + + result = public_result + try: + html = await _fetch_share_html(share_url) + result = _pick_better_resolution(result, _pick_best_candidate(share_url, html)) + except Exception: + if result.image_url: + return result + raise + + if not result.image_url: + logger.warning("No share image candidate found for {}", share_url) + return result + + +__all__ = [ + "ShareImageResolution", + "normalize_grok_share_url", + "resolve_grok_share_image", +] diff --git a/app/services/reverse/app_chat_share.py b/app/services/reverse/app_chat_share.py new file mode 100644 index 000000000..596c1559d --- /dev/null +++ b/app/services/reverse/app_chat_share.py @@ -0,0 +1,131 @@ +""" +Reverse interface: app chat share link creation. +""" + +import orjson +from typing import Any + +from curl_cffi.requests import AsyncSession + +from app.core.config import get_config +from app.core.exceptions import UpstreamException +from app.core.logger import logger +from app.core.proxy_pool import ( + build_http_proxies, + get_current_proxy_from, + rotate_proxy, + should_rotate_proxy, +) +from app.services.reverse.utils.headers import build_headers +from app.services.reverse.utils.retry import retry_on_status +from app.services.token.service import TokenService + + +class AppChatShareReverse: + """/rest/app-chat/conversations/{conversationId}/share reverse interface.""" + + @staticmethod + async def request( + session: AsyncSession, + token: str, + conversation_id: str, + response_id: str, + allow_indexing: bool = True, + ) -> Any: + api = f"https://grok.com/rest/app-chat/conversations/{conversation_id}/share" + + try: + referer = f"https://grok.com/c/{conversation_id}" + if response_id: + referer = f"{referer}?rid={response_id}" + + headers = build_headers( + cookie_token=token, + content_type="application/json", + origin="https://grok.com", + referer=referer, + ) + + payload = { + "responseId": response_id, + "allowIndexing": bool(allow_indexing), + } + + timeout = get_config("image.timeout") or get_config("chat.timeout") + browser = get_config("proxy.browser") + active_proxy_key = None + + async def _do_request(): + nonlocal active_proxy_key + active_proxy_key, proxy_url = get_current_proxy_from( + "proxy.base_proxy_url" + ) + proxies = build_http_proxies(proxy_url) + response = await session.post( + api, + headers=headers, + data=orjson.dumps(payload), + timeout=timeout, + proxies=proxies, + impersonate=browser, + ) + + if response.status_code != 200: + content = "" + try: + content = await response.text() + except Exception: + pass + logger.error( + "AppChatShareReverse: Share create failed, %s", + response.status_code, + extra={"error_type": "UpstreamException"}, + ) + raise UpstreamException( + message=( + "AppChatShareReverse: Share create failed, " + f"{response.status_code}" + ), + details={"status": response.status_code, "body": content}, + ) + + return response + + async def _on_retry( + attempt: int, + status_code: int, + error: Exception, + delay: float, + ): + if active_proxy_key and should_rotate_proxy(status_code): + rotate_proxy(active_proxy_key) + + return await retry_on_status(_do_request, on_retry=_on_retry) + + except Exception as e: + if isinstance(e, UpstreamException): + status = None + if e.details and "status" in e.details: + status = e.details["status"] + else: + status = getattr(e, "status_code", None) + if status == 401: + try: + await TokenService.record_fail( + token, status, "app_chat_share_auth_failed" + ) + except Exception: + pass + raise + + logger.error( + f"AppChatShareReverse: Share create failed, {str(e)}", + extra={"error_type": type(e).__name__}, + ) + raise UpstreamException( + message=f"AppChatShareReverse: Share create failed, {str(e)}", + details={"status": 502, "error": str(e)}, + ) + + +__all__ = ["AppChatShareReverse"] diff --git a/tests/test_image_share.py b/tests/test_image_share.py new file mode 100644 index 000000000..4c59f7443 --- /dev/null +++ b/tests/test_image_share.py @@ -0,0 +1,410 @@ +import asyncio +import json +from types import SimpleNamespace + +import orjson + +from app.api.v1 import chat as chat_api +from app.api.v1.chat import ImageConfig, _imagine_fast_server_image_config +from app.api.v1 import image as image_api +from app.api.v1.image import ( + ImageGenerationRequest, + ShareImageResolveRequest, + append_share_payload, + validate_generation_request, +) +from app.core.exceptions import ValidationException +from app.services.grok.services.image import ( + _build_app_chat_share_url, + _collect_post_id_candidates, + _create_image_share_link, + _extract_app_chat_share_context, + _pick_best_post_id, +) +from app.services.grok.utils.share_resolver import ( + ShareImageResolution, + normalize_grok_share_url, + resolve_grok_share_image, +) + + +def test_validate_generation_request_allows_fast_model(): + request = ImageGenerationRequest( + prompt="draw a cat", + model="grok-imagine-1.0-fast", + n=1, + size="1024x1024", + stream=False, + ) + + validate_generation_request(request) + + +def test_validate_generation_request_rejects_stream_share_url(): + request = ImageGenerationRequest( + prompt="draw a cat", + model="grok-imagine-1.0-fast", + n=1, + size="1024x1024", + stream=True, + return_share_url=True, + ) + + try: + validate_generation_request(request) + except ValidationException as exc: + assert exc.code == "share_url_stream_not_supported" + else: + raise AssertionError("expected ValidationException") + + +def test_collect_post_id_candidates_prefers_response_post(): + payload = { + "post": {"id": "11111111-2222-3333-4444-555555555555"}, + "postId": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee", + "modelResponse": { + "fileAttachments": [ + "users/demo/generated/99999999-8888-7777-6666-555555555555/image.jpg" + ] + }, + } + + post_id, rank = _pick_best_post_id(_collect_post_id_candidates(payload)) + + assert post_id == "11111111-2222-3333-4444-555555555555" + assert rank == 1 + + +def test_imagine_fast_server_image_config_preserves_return_share_url(monkeypatch): + monkeypatch.setattr( + "app.api.v1.chat.get_config", + lambda key, default=None: { + "imagine_fast.n": 2, + "imagine_fast.size": "1024x1024", + "imagine_fast.response_format": "url", + }.get(key, default), + ) + + config = _imagine_fast_server_image_config( + ImageConfig(return_share_url=True, response_format="b64_json") + ) + + assert config.n == 2 + assert config.size == "1024x1024" + assert config.response_format == "url" + assert config.return_share_url is True + + +def test_extract_app_chat_share_context_supports_nested_result(): + payload = { + "result": { + "conversation": { + "conversationId": "conv-nested", + }, + "response": { + "responseId": "resp-nested", + "modelResponse": { + "responseId": "resp-model", + "cardAttachmentsJson": [ + '{"image_chunk":{"imageUrl":"users/demo/generated/test/image.jpg"}}' + ], + }, + }, + } + } + + conversation_id, response_id = _extract_app_chat_share_context(payload) + + assert conversation_id == "conv-nested" + assert response_id == "resp-model" + + +def test_extract_app_chat_share_context_supports_flat_result_events(): + payload = { + "result": { + "conversationId": "conv-flat", + "responseId": "resp-flat", + "cardAttachment": { + "jsonData": '{"image_chunk":{"imageUrl":"users/demo/generated/test/image.jpg"}}' + }, + } + } + + conversation_id, response_id = _extract_app_chat_share_context(payload) + + assert conversation_id == "conv-flat" + assert response_id == "resp-flat" + + +def test_build_app_chat_share_url_prefers_share_link_id(): + payload = { + "shareLinkId": "c2hhcmQtMg_demo", + "publicId": "should-not-win", + } + + assert _build_app_chat_share_url(payload) == "https://grok.com/share/c2hhcmQtMg_demo" + + +def test_normalize_grok_share_url_strips_extra_path(): + raw = ( + "https://grok.com/share/c2hhcmQtMg_demo/" + "opengraph-image/c2hhcmQtMg_demo?cache=1" + ) + + assert normalize_grok_share_url(raw) == "https://grok.com/share/c2hhcmQtMg_demo" + + +def test_resolve_grok_share_image_prefers_assets_candidate(monkeypatch): + html = """ + + + + + + + + + """ + + async def fake_public(share_url): + assert share_url == "https://grok.com/share/demo" + return {} + + async def fake_fetch(share_url): + assert share_url == "https://grok.com/share/demo" + return html + + monkeypatch.setattr( + "app.services.grok.utils.share_resolver._fetch_public_share_payload", + fake_public, + ) + monkeypatch.setattr( + "app.services.grok.utils.share_resolver._fetch_share_html", + fake_fetch, + ) + + resolved = asyncio.run(resolve_grok_share_image("https://grok.com/share/demo")) + + assert resolved.share_url == "https://grok.com/share/demo" + assert ( + resolved.image_url + == "https://assets.grok.com/users/demo/generated/final/image.jpg" + ) + assert resolved.source == "assets" + + +def test_resolve_grok_share_image_prefers_public_json_assets(monkeypatch): + payload = { + "responses": [ + { + "cardAttachmentsJson": [ + json.dumps( + { + "image_chunk": { + "imageUrl": "users/demo/generated/demo-image-part-0/image.jpg" + } + } + ), + json.dumps( + { + "image_chunk": { + "imageUrl": "users/demo/generated/demo-image/image.jpg" + } + } + ), + ] + } + ] + } + + async def fake_public(share_url): + assert share_url == "https://grok.com/share/demo-public" + return payload + + async def fake_fetch(share_url): + assert share_url == "https://grok.com/share/demo-public" + return """ + + + + + + """ + + monkeypatch.setattr( + "app.services.grok.utils.share_resolver._fetch_public_share_payload", + fake_public, + ) + monkeypatch.setattr( + "app.services.grok.utils.share_resolver._fetch_share_html", + fake_fetch, + ) + + resolved = asyncio.run(resolve_grok_share_image("https://grok.com/share/demo-public")) + + assert ( + resolved.image_url + == "https://assets.grok.com/users/demo/generated/demo-image/image.jpg" + ) + assert resolved.source == "public_json" + + +def test_append_share_payload_includes_resolved_fields(): + payload = {"created": 1} + result = SimpleNamespace( + share_url="https://grok.com/share/demo", + share_image_url="https://assets.grok.com/users/demo/generated/final/image.jpg", + share_image_source="assets", + share_image_expires_at="2026-03-29T00:00:00Z", + ) + + merged = append_share_payload(payload, result) + + assert merged["share_url"] == "https://grok.com/share/demo" + assert ( + merged["share_image_url"] + == "https://assets.grok.com/users/demo/generated/final/image.jpg" + ) + assert merged["share_image_source"] == "assets" + assert merged["share_image_expires_at"] == "2026-03-29T00:00:00Z" + + +def test_create_image_share_link_uses_app_chat_share(monkeypatch): + class DummyResponse: + def json(self): + return {"shareLinkId": "c2hhcmQtMg_demo"} + + class DummySession: + async def __aenter__(self): + return object() + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def fake_request(session, token, conversation_id, response_id): + assert token == "token-demo" + assert conversation_id == "conv-demo" + assert response_id == "resp-demo" + return DummyResponse() + + monkeypatch.setattr( + "app.services.grok.services.image._new_session", + lambda: DummySession(), + ) + monkeypatch.setattr( + "app.services.grok.services.image.AppChatShareReverse.request", + fake_request, + ) + + share_url = asyncio.run( + _create_image_share_link("token-demo", "conv-demo", "resp-demo") + ) + + assert share_url == "https://grok.com/share/c2hhcmQtMg_demo" + + +def test_chat_completions_imagine_fast_preserves_return_share_url(monkeypatch): + captured = {} + + class DummyTokenMgr: + async def reload_if_stale(self): + return None + + def get_token(self, pool_name): + return "token-demo" + + async def fake_get_token_manager(): + return DummyTokenMgr() + + async def fake_generate(self, **kwargs): + captured.update(kwargs) + return SimpleNamespace( + stream=False, + data=["http://localhost/image.jpg"], + usage_override={"total_tokens": 0}, + share_url="https://grok.com/share/demo-fast", + share_image_url="https://assets.grok.com/users/demo/generated/final/image.jpg", + share_image_source="assets", + share_image_expires_at="2026-03-29T00:00:00Z", + ) + + monkeypatch.setattr(chat_api, "get_config", lambda key, default=None: { + "imagine_fast.n": 2, + "imagine_fast.size": "1024x1024", + "imagine_fast.response_format": "url", + }.get(key, default)) + monkeypatch.setattr(chat_api, "get_token_manager", fake_get_token_manager) + monkeypatch.setattr(chat_api.ModelService, "valid", lambda model: True) + monkeypatch.setattr( + chat_api.ModelService, + "get", + lambda model: SimpleNamespace( + is_image=True, + is_image_edit=False, + is_video=False, + model_id=model, + grok_model="grok-4-image", + model_mode="auto", + ), + ) + monkeypatch.setattr( + chat_api.ModelService, + "pool_candidates_for_model", + lambda model: ["demo-pool"], + ) + monkeypatch.setattr( + "app.api.v1.chat.ImageGenerationService.generate", + fake_generate, + ) + + request = chat_api.ChatCompletionRequest( + model="grok-imagine-1.0-fast", + stream=False, + messages=[chat_api.MessageItem(role="user", content="draw a dog")], + image_config=chat_api.ImageConfig(return_share_url=True), + ) + + response = asyncio.run(chat_api.chat_completions(request)) + payload = orjson.loads(response.body) + + assert captured["return_share_url"] is True + assert payload["share_url"] == "https://grok.com/share/demo-fast" + assert payload["choices"][0]["message"]["content"] == "![image](http://localhost/image.jpg)" + assert ( + payload["share_image_url"] + == "https://assets.grok.com/users/demo/generated/final/image.jpg" + ) + assert payload["share_image_source"] == "assets" + assert payload["share_image_expires_at"] == "2026-03-29T00:00:00Z" + + +def test_resolve_share_image_endpoint_returns_current_direct_link(monkeypatch): + async def fake_resolve(share_url): + assert share_url == "https://grok.com/share/demo" + return ShareImageResolution( + share_url=share_url, + image_url="https://assets.grok.com/users/demo/generated/final/image.jpg", + source="assets", + expires_at="2026-03-29T00:00:00Z", + ) + + monkeypatch.setattr( + "app.api.v1.image.resolve_grok_share_image", + fake_resolve, + ) + + response = asyncio.run( + image_api.resolve_share_image( + ShareImageResolveRequest(share_url="https://grok.com/share/demo") + ) + ) + payload = orjson.loads(response.body) + + assert payload["resolved"] is True + assert payload["share_url"] == "https://grok.com/share/demo" + assert ( + payload["share_image_url"] + == "https://assets.grok.com/users/demo/generated/final/image.jpg" + ) diff --git a/tests/test_openai_usage.py b/tests/test_openai_usage.py index 27c122367..571574fb8 100644 --- a/tests/test_openai_usage.py +++ b/tests/test_openai_usage.py @@ -109,6 +109,232 @@ async def _run(): asyncio.run(_run()) +def test_stream_processor_prefers_final_image_over_preview(monkeypatch): + monkeypatch.setattr( + "app.services.grok.services.chat.get_config", + lambda key, default=None: 0 if key == "chat.stream_timeout" else [], + ) + + class DummyDownloadService: + async def render_image(self, url, token, image_id="image"): + return f"![{image_id}]({url})" + + preview = ( + "users/demo/generated/11111111-2222-3333-4444-555555555555-part-0/image.jpg" + ) + original = ( + "https://assets.grok.com/users/demo/generated/" + "11111111-2222-3333-4444-555555555555/image.jpg" + ) + + monkeypatch.setattr( + StreamProcessor, + "_get_dl", + lambda self: DummyDownloadService(), + ) + + async def _run(): + processor = StreamProcessor("grok-4.20-beta", prompt_tokens=9, show_think=True) + chunks = [] + async for chunk in processor.process( + _iter_lines( + [ + _json_line( + { + "result": { + "response": { + "responseId": "resp_stream_image", + "streamingImageGenerationResponse": { + "imageIndex": 0, + "progress": 42, + }, + } + } + } + ), + _json_line( + { + "result": { + "response": { + "responseId": "resp_stream_image", + "modelResponse": { + "cardAttachmentsJson": [ + orjson.dumps( + { + "id": "card-demo", + "image_chunk": { + "imageUrl": preview + }, + } + ).decode() + ] + }, + } + } + } + ), + _json_line( + { + "result": { + "response": { + "responseId": "resp_stream_image", + "cardAttachment": { + "jsonData": orjson.dumps( + { + "id": "card-demo", + "image": {"original": original}, + } + ).decode() + }, + } + } + } + ), + ] + ) + ): + chunks.append(chunk) + + combined = "".join(chunks) + assert preview not in combined + assert combined.count(original) == 1 + + asyncio.run(_run()) + + +def test_stream_processor_falls_back_to_preview_when_final_missing(monkeypatch): + monkeypatch.setattr( + "app.services.grok.services.chat.get_config", + lambda key, default=None: 0 if key == "chat.stream_timeout" else [], + ) + + class DummyDownloadService: + async def render_image(self, url, token, image_id="image"): + return f"![{image_id}]({url})" + + preview = ( + "users/demo/generated/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee-part-0/image.jpg" + ) + + monkeypatch.setattr( + StreamProcessor, + "_get_dl", + lambda self: DummyDownloadService(), + ) + + async def _run(): + processor = StreamProcessor("grok-4.20-beta", prompt_tokens=7, show_think=True) + chunks = [] + async for chunk in processor.process( + _iter_lines( + [ + _json_line( + { + "result": { + "response": { + "responseId": "resp_stream_preview", + "modelResponse": { + "cardAttachmentsJson": [ + orjson.dumps( + { + "id": "card-demo", + "image_chunk": { + "imageUrl": preview + }, + } + ).decode() + ] + }, + } + } + } + ) + ] + ) + ): + chunks.append(chunk) + + combined = "".join(chunks) + assert combined.count(preview) == 1 + + asyncio.run(_run()) + + +def test_collect_processor_image_result_includes_share_fields(monkeypatch): + monkeypatch.setattr( + "app.services.grok.services.chat.get_config", + lambda key, default=None: 0 if key == "chat.stream_timeout" else [], + ) + + class DummyDownloadService: + async def render_image(self, url, token, image_id="image"): + return f"![{image_id}]({url})" + + original = ( + "https://assets.grok.com/users/demo/generated/" + "11111111-2222-3333-4444-555555555555/image.jpg" + ) + + monkeypatch.setattr( + CollectProcessor, + "_get_dl", + lambda self: DummyDownloadService(), + ) + monkeypatch.setattr( + "app.services.grok.services.chat._create_chat_share_link", + lambda token, conversation_id, response_id: asyncio.sleep(0, result="https://grok.com/share/demo-share"), + ) + monkeypatch.setattr( + "app.services.grok.services.chat._resolve_share_image_details", + lambda share_url: asyncio.sleep( + 0, + result=( + "https://grok.com/share/demo-share/opengraph-image/demo-share?cache=1", + "og:image", + "", + ), + ), + ) + + async def _run(): + processor = CollectProcessor("grok-4.20-beta", token="token-demo", prompt_tokens=9) + result = await processor.process( + _iter_lines( + [ + _json_line( + { + "result": { + "conversationId": "conv-share", + "response": { + "modelResponse": { + "responseId": "resp-share", + "message": "", + "cardAttachmentsJson": [ + orjson.dumps( + { + "id": "card-demo", + "image": {"original": original}, + } + ).decode() + ], + } + }, + } + } + ) + ] + ) + ) + assert result["share_url"] == "https://grok.com/share/demo-share" + assert ( + result["share_image_url"] + == "https://grok.com/share/demo-share/opengraph-image/demo-share?cache=1" + ) + assert result["share_image_source"] == "og:image" + + asyncio.run(_run()) + + def test_responses_stream_completed_event_uses_chat_usage(monkeypatch): async def fake_chat_completions(**kwargs): async def _gen(): diff --git a/tests/test_process_images.py b/tests/test_process_images.py new file mode 100644 index 000000000..21009fdd7 --- /dev/null +++ b/tests/test_process_images.py @@ -0,0 +1,198 @@ +import asyncio + +import orjson + +from app.services.grok.services.image import ImageAppChatCollectProcessor +from app.services.grok.utils.process import _collect_images + + +def _json_line(payload: dict) -> bytes: + return orjson.dumps(payload) + + +async def _iter_lines(lines): + for line in lines: + yield line + + +def test_collect_images_reads_card_attachments_json(): + payload = { + "cardAttachmentsJson": [ + '{"id":"abc","image_chunk":{"imageUrl":"users/demo/generated/test/image.jpg"}}' + ] + } + + images = _collect_images(payload) + + assert images == ["users/demo/generated/test/image.jpg"] + + +def test_collect_images_reads_card_attachment_json_data(): + payload = { + "cardAttachment": { + "jsonData": '{"image":{"original":"https://assets.grok.com/users/demo/generated/test/image.jpg"}}' + } + } + + images = _collect_images(payload) + + assert images == ["https://assets.grok.com/users/demo/generated/test/image.jpg"] + + +def test_collect_images_prefers_original_over_preview_for_same_image(): + image_id = "11111111-2222-3333-4444-555555555555" + payload = { + "cardAttachmentsJson": [ + orjson.dumps( + { + "id": "abc", + "image_chunk": { + "imageUrl": f"users/demo/generated/{image_id}-part-0/image.jpg" + }, + "image": { + "original": f"https://assets.grok.com/users/demo/generated/{image_id}/image.jpg" + }, + } + ).decode() + ] + } + + images = _collect_images(payload) + + assert images == [ + f"https://assets.grok.com/users/demo/generated/{image_id}/image.jpg" + ] + + +def test_image_app_chat_collect_processor_prefers_final_across_events(monkeypatch): + image_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + preview = f"users/demo/generated/{image_id}-part-0/image.jpg" + original = f"https://assets.grok.com/users/demo/generated/{image_id}/image.jpg" + + monkeypatch.setattr( + "app.services.grok.services.image.get_config", + lambda key, default=None: 0 if key == "image.stream_timeout" else default, + ) + + async def fake_process_image_url(self, url): + return f"rendered:{url}" + + monkeypatch.setattr( + ImageAppChatCollectProcessor, + "_process_image_url", + fake_process_image_url, + ) + + async def _run(): + processor = ImageAppChatCollectProcessor( + "grok-imagine-1.0-fast", + token="token-demo", + response_format="url", + ) + payload = await processor.process( + _iter_lines( + [ + _json_line( + { + "result": { + "response": { + "modelResponse": { + "cardAttachmentsJson": [ + orjson.dumps( + { + "id": "card-demo", + "image_chunk": { + "imageUrl": preview + }, + } + ).decode() + ] + } + } + } + } + ), + _json_line( + { + "result": { + "response": { + "modelResponse": { + "cardAttachmentsJson": [ + orjson.dumps( + { + "id": "card-demo", + "image": { + "original": original + }, + } + ).decode() + ] + } + } + } + } + ), + ] + ) + ) + + assert payload.images == [f"rendered:{original}"] + + asyncio.run(_run()) + + +def test_image_app_chat_collect_processor_falls_back_to_preview(monkeypatch): + preview = ( + "users/demo/generated/ffffffff-1111-2222-3333-444444444444-part-0/image.jpg" + ) + + monkeypatch.setattr( + "app.services.grok.services.image.get_config", + lambda key, default=None: 0 if key == "image.stream_timeout" else default, + ) + + async def fake_process_image_url(self, url): + return f"rendered:{url}" + + monkeypatch.setattr( + ImageAppChatCollectProcessor, + "_process_image_url", + fake_process_image_url, + ) + + async def _run(): + processor = ImageAppChatCollectProcessor( + "grok-imagine-1.0-fast", + token="token-demo", + response_format="url", + ) + payload = await processor.process( + _iter_lines( + [ + _json_line( + { + "result": { + "response": { + "modelResponse": { + "cardAttachmentsJson": [ + orjson.dumps( + { + "id": "card-demo", + "image_chunk": { + "imageUrl": preview + }, + } + ).decode() + ] + } + } + } + } + ) + ] + ) + ) + + assert payload.images == [f"rendered:{preview}"] + + asyncio.run(_run())