Skip to content

feat (providers/xai): add xai image model support #5295

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/famous-flowers-flow.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@ai-sdk/xai': patch
---

feat (providers/xai): add xai image model support
1 change: 1 addition & 0 deletions content/docs/03-ai-sdk-core/35-image-generation.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ for (const file of result.files) {

| Provider | Model | Support sizes (`width x height`) or aspect ratios (`width : height`) |
| ------------------------------------------------------------------------- | ------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [xAI Grok](/providers/ai-sdk-providers/xai#image-models) | `grok-2-image` | 1024x768 (default) |
| [OpenAI](/providers/ai-sdk-providers/openai#image-models) | `dall-e-3` | 1024x1024, 1792x1024, 1024x1792 |
| [OpenAI](/providers/ai-sdk-providers/openai#image-models) | `dall-e-2` | 256x256, 512x512, 1024x1024 |
| [Amazon Bedrock](/providers/ai-sdk-providers/amazon-bedrock#image-models) | `amazon.nova-canvas-v1:0` | 320-4096 (multiples of 16), 1:4 to 4:1, max 4.2M pixels |
Expand Down
42 changes: 42 additions & 0 deletions content/providers/01-ai-sdk-providers/01-xai.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,45 @@ The following optional settings are available for xAI chat models:
table above lists popular models. You can also pass any available provider
model ID as a string if needed.
</Note>

## Image Models

You can create xAI image models using the `.imageModel()` factory method. For more on image generation with the AI SDK see [generateImage()](/docs/reference/ai-sdk-core/generate-image).

```ts
import { xai } from '@ai-sdk/xai';
import { experimental_generateImage as generateImage } from 'ai';

const { image } = await generateImage({
model: xai.image('grok-2-image'),
prompt: 'A futuristic cityscape at sunset',
});
```

<Note>
The xAI image model does not currently support the `aspectRatio` or `size`
parameters. Image size defaults to 1024x768.
</Note>

### Model-specific options

You can customize the image generation behavior with model-specific settings:

```ts
import { xai } from '@ai-sdk/xai';
import { experimental_generateImage as generateImage } from 'ai';

const { image } = await generateImage({
model: xai.image('grok-2-image', {
maxImagesPerCall: 5, // Default is 10
}),
prompt: 'A futuristic cityscape at sunset',
n: 2, // Generate 2 images
});
```

### Model Capabilities

| Model | Sizes | Notes |
| -------------- | ------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `grok-2-image` | 1024x768 (default) | xAI's text-to-image generation model, designed to create high-quality images from text prompts. It's trained on a diverse dataset and can generate images across various styles, subjects, and settings. |
16 changes: 16 additions & 0 deletions examples/ai-core/src/generate-image/xai-many.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import { xai } from '@ai-sdk/xai';
import { experimental_generateImage as generateImage } from 'ai';
import { presentImages } from '../lib/present-image';
import 'dotenv/config';

async function main() {
const { images } = await generateImage({
model: xai.image('grok-2-image'),
n: 3,
prompt: 'A chicken flying into the sunset in the style of anime.',
});

await presentImages(images);
}

main().catch(console.error);
15 changes: 15 additions & 0 deletions examples/ai-core/src/generate-image/xai.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import { xai } from '@ai-sdk/xai';
import { experimental_generateImage as generateImage } from 'ai';
import { presentImages } from '../lib/present-image';
import 'dotenv/config';

async function main() {
const { image } = await generateImage({
model: xai.image('grok-2-image'),
prompt: 'A salamander at dusk in a forest pond surrounded by fireflies.',
});

await presentImages([image]);
}

main().catch(console.error);
7 changes: 2 additions & 5 deletions packages/xai/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
export { createXai, xai } from './xai-provider';
export type {
XaiErrorData,
XaiProvider,
XaiProviderSettings,
} from './xai-provider';
export type { XaiProvider, XaiProviderSettings } from './xai-provider';
export type { XaiErrorData } from './xai-error';
9 changes: 9 additions & 0 deletions packages/xai/src/xai-error.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import { z } from 'zod';

// Add error schema and structure
export const xaiErrorSchema = z.object({
code: z.string(),
error: z.string(),
});

export type XaiErrorData = z.infer<typeof xaiErrorSchema>;
276 changes: 276 additions & 0 deletions packages/xai/src/xai-image-model.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
import { FetchFunction } from '@ai-sdk/provider-utils';
import { createTestServer } from '@ai-sdk/provider-utils/test';
import { describe, expect, it } from 'vitest';
import { XaiImageModel } from './xai-image-model';

const prompt = 'A photorealistic astronaut riding a horse';

function createBasicModel({
headers,
fetch,
currentDate,
settings,
}: {
headers?: () => Record<string, string | undefined>;
fetch?: FetchFunction;
currentDate?: () => Date;
settings?: any;
} = {}) {
return new XaiImageModel('grok-2-image', settings ?? {}, {
provider: 'xai',
headers: headers ?? (() => ({ Authorization: 'Bearer test-key' })),
url: ({ modelId, path }) => `https://api.example.com/${modelId}${path}`,
fetch,
_internal: {
currentDate,
},
});
}

describe('XaiImageModel', () => {
const server = createTestServer({
'https://api.example.com/grok-2-image/images/generations': {
response: {
type: 'json-value',
body: {
data: [
{
b64_json: '',
},
{
b64_json: '',
},
],
},
},
},
});

describe('constructor', () => {
it('should expose correct provider and model information', () => {
const model = createBasicModel();

expect(model.provider).toBe('xai');
expect(model.modelId).toBe('grok-2-image');
expect(model.specificationVersion).toBe('v1');
expect(model.maxImagesPerCall).toBe(10);
});

it('should use maxImagesPerCall from settings', () => {
const model = createBasicModel({
settings: {
maxImagesPerCall: 5,
},
});

expect(model.maxImagesPerCall).toBe(5);
});

it('should default maxImagesPerCall to 10 when not specified', () => {
const model = createBasicModel();

expect(model.maxImagesPerCall).toBe(10);
});
});

describe('doGenerate', () => {
it('should pass the correct parameters', async () => {
const model = createBasicModel();

await model.doGenerate({
prompt,
n: 2,
size: '1024x1024',
aspectRatio: undefined,
seed: undefined,
providerOptions: { openai: { quality: 'hd' } },
headers: {},
abortSignal: undefined,
});

expect(await server.calls[0].requestBody).toStrictEqual({
model: 'grok-2-image',
prompt,
n: 2,
size: '1024x1024',
quality: 'hd',
response_format: 'b64_json',
});
});

it('should add warnings for unsupported settings', async () => {
const model = createBasicModel();

const result = await model.doGenerate({
prompt,
n: 1,
size: '1024x1024',
aspectRatio: '16:9',
seed: 123,
providerOptions: {},
headers: {},
abortSignal: undefined,
});

expect(result.warnings).toHaveLength(2);
expect(result.warnings).toContainEqual({
type: 'unsupported-setting',
setting: 'aspectRatio',
details:
'This model does not support aspect ratio. Use `size` instead.',
});
expect(result.warnings).toContainEqual({
type: 'unsupported-setting',
setting: 'seed',
});
});

it('should pass headers', async () => {
const modelWithHeaders = createBasicModel({
headers: () => ({
'Custom-Provider-Header': 'provider-header-value',
}),
});

await modelWithHeaders.doGenerate({
prompt,
n: 1,
providerOptions: {},
headers: {
'Custom-Request-Header': 'request-header-value',
},
size: '1024x1024',
seed: undefined,
aspectRatio: undefined,
abortSignal: undefined,
});

expect(server.calls[0].requestHeaders).toStrictEqual({
'content-type': 'application/json',
'custom-provider-header': 'provider-header-value',
'custom-request-header': 'request-header-value',
});
});

it('should handle API errors', async () => {
server.urls[
'https://api.example.com/grok-2-image/images/generations'
].response = {
type: 'error',
status: 400,
body: JSON.stringify({
code: 'invalid_request_error',
error: 'Invalid prompt content',
}),
};

const model = createBasicModel();
await expect(
model.doGenerate({
prompt,
n: 1,
providerOptions: {},
size: '1024x1024',
seed: undefined,
aspectRatio: undefined,
headers: {},
abortSignal: undefined,
}),
).rejects.toMatchObject({
message: 'Invalid prompt content',
statusCode: 400,
url: 'https://api.example.com/grok-2-image/images/generations',
});
});

it('should strip data URI scheme prefix from b64 content', async () => {
const model = createBasicModel();
const result = await model.doGenerate({
prompt,
n: 2,
size: '1024x1024',
providerOptions: {},
headers: {},
abortSignal: undefined,
aspectRatio: undefined,
seed: undefined,
});

expect(result.images).toHaveLength(2);
expect(result.images[0]).toBe('test1234');
expect(result.images[1]).toBe('test5678');
});

describe('response metadata', () => {
it('should include timestamp, headers and modelId in response', async () => {
const testDate = new Date('2024-01-01T00:00:00Z');
const model = createBasicModel({
currentDate: () => testDate,
});

const result = await model.doGenerate({
prompt,
n: 1,
providerOptions: {},
size: '1024x1024',
seed: undefined,
aspectRatio: undefined,
headers: {},
abortSignal: undefined,
});

expect(result.response).toStrictEqual({
timestamp: testDate,
modelId: 'grok-2-image',
headers: expect.any(Object),
});
});
});

it('should respect maxImagesPerCall setting', async () => {
const customModel = createBasicModel({
settings: { maxImagesPerCall: 5 },
});
expect(customModel.maxImagesPerCall).toBe(5);

const defaultModel = createBasicModel();
expect(defaultModel.maxImagesPerCall).toBe(10); // default for XAI models
});

it('should use real date when no custom date provider is specified', async () => {
const beforeDate = new Date();

const model = new XaiImageModel(
'grok-2-image',
{},
{
provider: 'xai',
headers: () => ({ Authorization: 'Bearer test-key' }),
url: ({ modelId, path }) =>
`https://api.example.com/${modelId}${path}`,
},
);

const result = await model.doGenerate({
prompt,
n: 1,
size: '1024x1024',
aspectRatio: undefined,
seed: undefined,
providerOptions: {},
headers: {},
abortSignal: undefined,
});

const afterDate = new Date();

expect(result.response.timestamp.getTime()).toBeGreaterThanOrEqual(
beforeDate.getTime(),
);
expect(result.response.timestamp.getTime()).toBeLessThanOrEqual(
afterDate.getTime(),
);
expect(result.response.modelId).toBe('grok-2-image');
});
});
});
Loading
Loading