Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 94 additions & 3 deletions _public/static/function/js/chat.js
Original file line number Diff line number Diff line change
Expand Up @@ -999,7 +999,7 @@
signal: abortController.signal
});
if (!res.ok) throw new Error(t('chat.requestFailedStatus', { status: res.status }));
await handleStream(res, assistantEntry, sendSessionId);
await handleAssistantResponse(res, payload, assistantEntry, sendSessionId);
setStatus('connected', t('common.done'));
} catch (e) {
if (e && e.name === 'AbortError') {
Expand Down Expand Up @@ -1293,6 +1293,45 @@
return payload;
}

function extractPromptText(content) {
if (typeof content === 'string') return content;
if (!Array.isArray(content)) return '';
return content
.filter((block) => block && block.type === 'text' && block.text)
.map((block) => String(block.text))
.join('\n')
.trim();
}

function isImageModelSelected(model) {
return /^grok-imagine/i.test(String(model || '').trim());
}

function isLikelyImagePrompt(prompt) {
const value = String(prompt || '').trim();
if (!value) return false;
const patterns = [
/generate\s+(?:a|an)?\s*(?:picture|image|photo|illustration|drawing)/i,
/\b(?:draw|illustrate|render|create|make)\b[\s\S]{0,32}\b(?:image|picture|photo|art|illustration)\b/i,
/生成[\s\S]{0,16}(?:图片|图像|照片|插画|绘图|一张图)/,
/画[\s\S]{0,16}(?:图|图片|插画|照片)/,
/来[\s\S]{0,8}(?:张|幅)(?:图|图片|插画)/,
];
return patterns.some((pattern) => pattern.test(value));
}

function shouldUseImageResultMode(history) {
if (isImageModelSelected(modelValue)) return true;
const items = Array.isArray(history) ? history : [];
for (let i = items.length - 1; i >= 0; i -= 1) {
const item = items[i];
if (!item || item.role !== 'user') continue;
const prompt = extractPromptText(item.content);
return isLikelyImagePrompt(prompt);
}
return false;
}

function buildPayload() {
const payload = {
model: modelValue || 'grok-3',
Expand All @@ -1301,6 +1340,13 @@
temperature: Number(tempRange ? tempRange.value : 0.8),
top_p: Number(topPRange ? topPRange.value : 0.95)
};
if (shouldUseImageResultMode(messageHistory)) {
payload.stream = false;
payload.image_config = {
response_format: 'url',
return_share_url: true
};
}
return payload;
}

Expand All @@ -1312,9 +1358,54 @@
temperature: Number(tempRange ? tempRange.value : 0.8),
top_p: Number(topPRange ? topPRange.value : 0.95)
};
if (shouldUseImageResultMode(history)) {
payload.stream = false;
payload.image_config = {
response_format: 'url',
return_share_url: true
};
}
return payload;
}

function appendShareLinks(content, payload) {
const base = String(content || '').trimEnd();
const extras = [];
const shareUrl = payload && payload.share_url ? String(payload.share_url).trim() : '';
const shareImageUrl = payload && payload.share_image_url ? String(payload.share_image_url).trim() : '';

if (shareUrl && !base.includes(shareUrl)) {
extras.push(`Share link: [${shareUrl}](${shareUrl})`);
}
if (shareImageUrl && !base.includes(shareImageUrl)) {
extras.push(`Direct link: [${shareImageUrl}](${shareImageUrl})`);
}

if (!extras.length) return base;
return base ? `${base}\n\n${extras.join('\n')}` : extras.join('\n');
}

async function handleJsonResponse(res, assistantEntry, targetSessionId) {
const payload = await res.json();
if (payload && payload.error) {
throw new Error(payload.error.message || t('common.failed'));
}
const content = payload && payload.choices && payload.choices[0] && payload.choices[0].message
? (payload.choices[0].message.content || '')
: '';
const assistantText = appendShareLinks(content, payload);
updateMessage(assistantEntry, assistantText, true);
assistantEntry.committed = true;
commitToSession(targetSessionId, assistantText);
}

async function handleAssistantResponse(res, requestPayload, assistantEntry, targetSessionId) {
if (requestPayload && requestPayload.stream === false) {
return handleJsonResponse(res, assistantEntry, targetSessionId);
}
return handleStream(res, assistantEntry, targetSessionId);
}

function selectModel(value) {
modelValue = value;
if (modelLabel) modelLabel.textContent = value;
Expand Down Expand Up @@ -1514,7 +1605,7 @@
throw new Error(t('chat.requestFailedStatus', { status: res.status }));
}

await handleStream(res, assistantEntry, retrySessionId);
await handleAssistantResponse(res, payload, assistantEntry, retrySessionId);
setStatus('connected', t('common.done'));
} catch (e) {
updateMessage(assistantEntry, t('chat.requestFailedStatus', { status: e.message || e }), true);
Expand Down Expand Up @@ -1597,7 +1688,7 @@
throw new Error(t('chat.requestFailedStatus', { status: res.status }));
}

await handleStream(res, assistantEntry, sendSessionId);
await handleAssistantResponse(res, payload, assistantEntry, sendSessionId);
setStatus('connected', t('common.done'));
} catch (e) {
if (e && e.name === 'AbortError') {
Expand Down
2 changes: 1 addition & 1 deletion _public/static/function/pages/chat.html
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ <h2 class="chat-title-text" data-i18n="chat.title">Chat 聊天</h2>
<script src="/static/common/js/function-header.js?v=1.6.2"></script>
<script src="/static/common/js/footer.js?v=1.6.2"></script>
<script src="/static/common/js/toast.js"></script>
<script src="/static/function/js/chat.js?v=1.6.2"></script>
<script src="/static/function/js/chat.js?v=1.6.3-share"></script>
</body>

</html>
43 changes: 35 additions & 8 deletions app/api/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from app.services.grok.services.image_edit import ImageEditService
from app.services.grok.services.model import ModelService
from app.services.grok.services.video import VideoService
from app.services.grok.utils.response import make_chat_response
from app.api.v1.image import append_share_payload
from app.services.grok.utils.response import make_chat_response, wrap_image_content
from app.services.token import get_token_manager
from app.core.config import get_config
from app.core.exceptions import ValidationException, AppException, ErrorType
Expand Down Expand Up @@ -49,6 +50,9 @@ class ImageConfig(BaseModel):
n: Optional[int] = Field(1, ge=1, le=10, description="生成数量 (1-10)")
size: Optional[str] = Field("1024x1024", description="图片尺寸")
response_format: Optional[str] = Field(None, description="响应格式")
return_share_url: Optional[bool] = Field(
False, description="是否返回 Grok 分享页链接"
)


class ChatCompletionRequest(BaseModel):
Expand Down Expand Up @@ -186,15 +190,22 @@ def _image_field(response_format: str) -> str:
return "b64_json"


def _imagine_fast_server_image_config() -> ImageConfig:
def _imagine_fast_server_image_config(
client_config: Optional[ImageConfig] = None,
) -> ImageConfig:
"""Load server-side image generation parameters for grok-imagine-1.0-fast."""
n = int(get_config("imagine_fast.n", 1) or 1)
size = str(get_config("imagine_fast.size", "1024x1024") or "1024x1024")
response_format = str(
get_config("imagine_fast.response_format", get_config("app.image_format") or "url")
or "url"
)
return ImageConfig(n=n, size=size, response_format=response_format)
return ImageConfig(
n=n,
size=size,
response_format=response_format,
return_share_url=bool(client_config and client_config.return_share_url),
)


async def _safe_sse_stream(stream: AsyncIterable[str]) -> AsyncGenerator[str, None]:
Expand Down Expand Up @@ -274,6 +285,12 @@ def _validate_image_config(image_conf: ImageConfig, *, stream: bool):
param="image_config.response_format",
code="invalid_response_format",
)
if stream and image_conf.return_share_url:
raise ValidationException(
message="return_share_url is only supported when stream=false",
param="image_config.return_share_url",
code="share_url_stream_not_supported",
)
if image_conf.size and image_conf.size not in ALLOWED_IMAGE_SIZES:
raise ValidationException(
message=f"size must be one of {sorted(ALLOWED_IMAGE_SIZES)}",
Expand Down Expand Up @@ -601,7 +618,11 @@ def validate_request(request: ChatCompletionRequest):
param="messages",
code="empty_prompt",
)
image_conf = _imagine_fast_server_image_config() if request.model == IMAGINE_FAST_MODEL_ID else (request.image_config or ImageConfig())
image_conf = (
_imagine_fast_server_image_config(request.image_config)
if request.model == IMAGINE_FAST_MODEL_ID
else (request.image_config or ImageConfig())
)
n = image_conf.n or 1
if not (1 <= n <= 10):
raise ValidationException(
Expand Down Expand Up @@ -765,6 +786,7 @@ async def chat_completions(request: ChatCompletionRequest):
)

content = result.data[0] if result.data else ""
content = wrap_image_content(content, response_format)
return JSONResponse(
content=make_chat_response(request.model, content)
)
Expand All @@ -775,7 +797,11 @@ async def chat_completions(request: ChatCompletionRequest):
is_stream = (
request.stream if request.stream is not None else get_config("app.stream")
)
image_conf = _imagine_fast_server_image_config() if request.model == IMAGINE_FAST_MODEL_ID else (request.image_config or ImageConfig())
image_conf = (
_imagine_fast_server_image_config(request.image_config)
if request.model == IMAGINE_FAST_MODEL_ID
else (request.image_config or ImageConfig())
)
_validate_image_config(image_conf, stream=bool(is_stream))
response_format = _resolve_image_format(image_conf.response_format)
response_field = _image_field(response_format)
Expand Down Expand Up @@ -818,6 +844,7 @@ async def chat_completions(request: ChatCompletionRequest):
aspect_ratio=aspect_ratio,
stream=bool(is_stream),
chat_format=True,
return_share_url=bool(image_conf.return_share_url),
)

if result.stream:
Expand All @@ -828,10 +855,10 @@ async def chat_completions(request: ChatCompletionRequest):
)

content = result.data[0] if result.data else ""
content = wrap_image_content(content, response_format)
usage = result.usage_override
return JSONResponse(
content=make_chat_response(request.model, content, usage=usage)
)
payload = make_chat_response(request.model, content, usage=usage)
return JSONResponse(content=append_share_payload(payload, result))

if model_info and model_info.is_video:
# 提取视频配置 (默认值在 Pydantic 模型中处理)
Expand Down
73 changes: 58 additions & 15 deletions app/api/v1/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from app.services.grok.services.image import ImageGenerationService
from app.services.grok.services.image_edit import ImageEditService
from app.services.grok.services.model import ModelService
from app.services.grok.utils.share_resolver import resolve_grok_share_image
from app.services.token import get_token_manager
from app.core.exceptions import ValidationException, AppException, ErrorType
from app.core.config import get_config
Expand Down Expand Up @@ -53,6 +54,9 @@ class ImageGenerationRequest(BaseModel):
response_format: Optional[str] = Field(None, description="响应格式")
style: Optional[str] = Field(None, description="风格 (暂不支持)")
stream: Optional[bool] = Field(False, description="是否流式输出")
return_share_url: Optional[bool] = Field(
False, description="是否返回 Grok 分享页链接"
)


class ImageEditRequest(BaseModel):
Expand All @@ -72,6 +76,30 @@ class ImageEditRequest(BaseModel):
stream: Optional[bool] = Field(False, description="是否流式输出")


class ShareImageResolveRequest(BaseModel):
share_url: str = Field(..., description="Grok 分享页链接")


def append_share_payload(payload: dict, result) -> dict:
share_url = getattr(result, "share_url", "") or ""
if share_url:
payload["share_url"] = share_url

share_image_url = getattr(result, "share_image_url", "") or ""
if share_image_url:
payload["share_image_url"] = share_image_url

share_image_source = getattr(result, "share_image_source", "") or ""
if share_image_source:
payload["share_image_source"] = share_image_source

share_image_expires_at = getattr(result, "share_image_expires_at", "") or ""
if share_image_expires_at:
payload["share_image_expires_at"] = share_image_expires_at

return payload


def _validate_common_request(
request: Union[ImageGenerationRequest, ImageEditRequest],
*,
Expand Down Expand Up @@ -108,6 +136,13 @@ def _validate_common_request(
code="invalid_response_format",
)

if request.stream and getattr(request, "return_share_url", False):
raise ValidationException(
message="return_share_url is only supported when stream=false",
param="return_share_url",
code="share_url_stream_not_supported",
)

if request.response_format:
allowed_formats = {"b64_json", "base64", "url"}
if request.response_format not in allowed_formats:
Expand All @@ -127,16 +162,8 @@ def _validate_common_request(

def validate_generation_request(request: ImageGenerationRequest):
"""验证图片生成请求参数"""
if request.model != "grok-imagine-1.0":
raise ValidationException(
message="The model `grok-imagine-1.0` is required for image generation.",
param="model",
code="model_not_supported",
)
# 验证模型 - 通过 is_image 检查
model_info = ModelService.get(request.model)
if not model_info or not model_info.is_image:
# 获取支持的图片模型列表
image_models = [m.model_id for m in ModelService.MODELS if m.is_image]
raise ValidationException(
message=(
Expand Down Expand Up @@ -289,6 +316,7 @@ async def create_image(request: ImageGenerationRequest):
size=request.size,
aspect_ratio=aspect_ratio,
stream=bool(request.stream),
return_share_url=bool(request.return_share_url),
)

if result.stream:
Expand All @@ -306,13 +334,28 @@ async def create_image(request: ImageGenerationRequest):
"input_tokens_details": {"text_tokens": 0, "image_tokens": 0},
}

return JSONResponse(
content={
"created": int(time.time()),
"data": data,
"usage": usage,
}
)
payload = {
"created": int(time.time()),
"data": data,
"usage": usage,
}
return JSONResponse(content=append_share_payload(payload, result))


@router.post("/images/share/resolve")
async def resolve_share_image(request: ShareImageResolveRequest):
result = await resolve_grok_share_image(request.share_url)
payload = {
"share_url": result.share_url,
"resolved": bool(result.image_url),
}
if result.image_url:
payload["share_image_url"] = result.image_url
if result.source:
payload["share_image_source"] = result.source
if result.expires_at:
payload["share_image_expires_at"] = result.expires_at
return JSONResponse(content=payload)


@router.post("/images/edits")
Expand Down
Loading