Skip to content
37 changes: 37 additions & 0 deletions packages/ai/__tests__/chat-session-helpers.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,43 @@ describe('chat-session-helpers', () => {
],
isValid: false,
},
{
history: [
{ role: 'user', parts: [{ text: 'hi' }] },
{
role: 'model',
parts: [
{ text: 'hi' },
{
text: 'thought about hi',
thought: true,
thoughtSignature: 'thought signature',
},
],
},
],
isValid: true,
},
{
history: [
{
role: 'user',
parts: [{ text: 'hi', thought: true, thoughtSignature: 'sig' }],
},
{
role: 'model',
parts: [
{ text: 'hi' },
{
text: 'thought about hi',
thought: true,
thoughtSignature: 'thought signature',
},
],
},
],
isValid: false,
},
];

TCS.forEach(tc => {
Expand Down
118 changes: 116 additions & 2 deletions packages/ai/__tests__/response-helpers.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,21 @@
* limitations under the License.
*/
import { describe, expect, it, jest, afterEach } from '@jest/globals';
import { addHelpers, formatBlockErrorMessage } from '../lib/requests/response-helpers';
import {
addHelpers,
formatBlockErrorMessage,
handlePredictResponse,
} from '../lib/requests/response-helpers';

import { BlockReason, Content, FinishReason, GenerateContentResponse } from '../lib/types';
import {
BlockReason,
Content,
FinishReason,
GenerateContentResponse,
ImagenInlineImage,
ImagenGCSImage,
} from '../lib/types';
import { getMockResponse, BackendName } from './test-utils/mock-response';

const fakeResponseText: GenerateContentResponse = {
candidates: [
Expand All @@ -31,6 +43,18 @@ const fakeResponseText: GenerateContentResponse = {
],
};

const fakeResponseThoughts: GenerateContentResponse = {
candidates: [
{
index: 0,
content: {
role: 'model',
parts: [{ text: 'Some text' }, { text: 'and some thoughts', thought: true }],
},
},
],
};

const functionCallPart1 = {
functionCall: {
name: 'find_theaters',
Expand Down Expand Up @@ -129,12 +153,14 @@ describe('response-helpers methods', () => {
const enhancedResponse = addHelpers(fakeResponseText);
expect(enhancedResponse.text()).toBe('Some text and some more text');
expect(enhancedResponse.functionCalls()).toBeUndefined();
expect(enhancedResponse.thoughtSummary()).toBeUndefined();
});

it('good response functionCall', () => {
const enhancedResponse = addHelpers(fakeResponseFunctionCall);
expect(enhancedResponse.text()).toBe('');
expect(enhancedResponse.functionCalls()).toEqual([functionCallPart1.functionCall]);
expect(enhancedResponse.thoughtSummary()).toBeUndefined();
});

it('good response functionCalls', () => {
Expand All @@ -144,29 +170,41 @@ describe('response-helpers methods', () => {
functionCallPart1.functionCall,
functionCallPart2.functionCall,
]);
expect(enhancedResponse.thoughtSummary()).toBeUndefined();
});

it('good response text/functionCall', () => {
const enhancedResponse = addHelpers(fakeResponseMixed1);
expect(enhancedResponse.functionCalls()).toEqual([functionCallPart2.functionCall]);
expect(enhancedResponse.text()).toBe('some text');
expect(enhancedResponse.thoughtSummary()).toBeUndefined();
});

it('good response functionCall/text', () => {
const enhancedResponse = addHelpers(fakeResponseMixed2);
expect(enhancedResponse.functionCalls()).toEqual([functionCallPart1.functionCall]);
expect(enhancedResponse.text()).toBe('some text');
expect(enhancedResponse.thoughtSummary()).toBeUndefined();
});

it('good response text/functionCall/text', () => {
const enhancedResponse = addHelpers(fakeResponseMixed3);
expect(enhancedResponse.functionCalls()).toEqual([functionCallPart1.functionCall]);
expect(enhancedResponse.text()).toBe('some text and more text');
expect(enhancedResponse.thoughtSummary()).toBeUndefined();
});

it('good response text/thought', () => {
const enhancedResponse = addHelpers(fakeResponseThoughts);
expect(enhancedResponse.text()).toBe('Some text');
expect(enhancedResponse.thoughtSummary()).toBe('and some thoughts');
expect(enhancedResponse.functionCalls()).toBeUndefined();
});

it('bad response safety', () => {
const enhancedResponse = addHelpers(badFakeResponse);
expect(() => enhancedResponse.text()).toThrow('SAFETY');
expect(() => enhancedResponse.thoughtSummary()).toThrow('SAFETY');
});
});

Expand Down Expand Up @@ -233,4 +271,80 @@ describe('response-helpers methods', () => {
expect(message).toContain('Candidate was blocked due to SAFETY: unsafe candidate');
});
});

describe('handlePredictResponse', () => {
it('returns base64 images', async () => {
const mockResponse = getMockResponse(
BackendName.VertexAI,
'unary-success-generate-images-base64.json',
) as Response;
const res = await handlePredictResponse<ImagenInlineImage>(mockResponse);
expect(res.filteredReason).toBeUndefined();
expect(res.images.length).toBe(4);
res.images.forEach(image => {
expect(image.mimeType).toBe('image/png');
expect(image.bytesBase64Encoded.length).toBeGreaterThan(0);
});
});

it('returns GCS images', async () => {
const mockResponse = getMockResponse(
BackendName.VertexAI,
'unary-success-generate-images-gcs.json',
) as Response;
const res = await handlePredictResponse<ImagenGCSImage>(mockResponse);
expect(res.filteredReason).toBeUndefined();
expect(res.images.length).toBe(4);
res.images.forEach((image, i) => {
expect(image.mimeType).toBe('image/jpeg');
expect(image.gcsURI).toBe(
`gs://test-project-id-1234.firebasestorage.app/images/1234567890123/sample_${i}.jpg`,
);
});
});

it('has filtered reason and no images if all images were filtered', async () => {
const mockResponse = getMockResponse(
BackendName.VertexAI,
'unary-failure-generate-images-all-filtered.json',
) as Response;
const res = await handlePredictResponse<ImagenInlineImage>(mockResponse);
expect(res.filteredReason).toBe(
"Unable to show generated images. All images were filtered out because they violated Vertex AI's usage guidelines. You will not be charged for blocked images. Try rephrasing the prompt. If you think this was an error, send feedback. Support codes: 39322892, 29310472",
);
expect(res.images.length).toBe(0);
});

it('has filtered reason and no images if all base64 images were filtered', async () => {
const mockResponse = getMockResponse(
BackendName.VertexAI,
'unary-failure-generate-images-base64-some-filtered.json',
) as Response;
const res = await handlePredictResponse<ImagenInlineImage>(mockResponse);
expect(res.filteredReason).toBe(
'Your current safety filter threshold filtered out 2 generated images. You will not be charged for blocked images. Try rephrasing the prompt. If you think this was an error, send feedback.',
);
expect(res.images.length).toBe(2);
res.images.forEach(image => {
expect(image.mimeType).toBe('image/png');
expect(image.bytesBase64Encoded.length).toBeGreaterThan(0);
});
});

it('has filtered reason and no images if all GCS images were filtered', async () => {
const mockResponse = getMockResponse(
BackendName.VertexAI,
'unary-failure-generate-images-gcs-some-filtered.json',
) as Response;
const res = await handlePredictResponse<ImagenGCSImage>(mockResponse);
expect(res.filteredReason).toBe(
'Your current safety filter threshold filtered out 2 generated images. You will not be charged for blocked images. Try rephrasing the prompt. If you think this was an error, send feedback.',
);
expect(res.images.length).toBe(2);
res.images.forEach(image => {
expect(image.mimeType).toBe('image/jpeg');
expect(image.gcsURI.length).toBeGreaterThan(0);
});
});
});
});
5 changes: 1 addition & 4 deletions packages/ai/lib/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,12 @@ export const AI_TYPE = 'AI';

export const DEFAULT_LOCATION = 'us-central1';

export const DEFAULT_BASE_URL = 'https://firebasevertexai.googleapis.com';
export const DEFAULT_DOMAIN = 'firebasevertexai.googleapis.com';

// This is the default API version for the VertexAI API. At some point, should be able to change when the feature becomes available.
// `v1beta` & `stable` available: https://cloud.google.com/vertex-ai/docs/reference#versions
export const DEFAULT_API_VERSION = 'v1beta';

export const PACKAGE_VERSION = version;

export const LANGUAGE_TAG = 'gl-rn';

// Timeout is 180s by default
export const DEFAULT_FETCH_TIMEOUT_MS = 180 * 1000;
29 changes: 18 additions & 11 deletions packages/ai/lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import './polyfills';
import { getApp, ReactNativeFirebase } from '@react-native-firebase/app';
import { GoogleAIBackend, VertexAIBackend } from './backend';
import { Backend, GoogleAIBackend, VertexAIBackend } from './backend';
import { AIErrorCode, ModelParams, RequestOptions } from './types';
import { AI, AIOptions, ImagenModelParams } from './public-types';
import { AIError } from './errors';
Expand All @@ -27,8 +27,9 @@ import { AIModel, ImagenModel } from './models';
export * from './public-types';
export { ChatSession } from './methods/chat-session';
export * from './requests/schema-builder';
export { GoogleAIBackend, VertexAIBackend } from './backend';
export { GenerativeModel, AIError, AIModel };
export { ImagenImageFormat } from './requests/imagen-image-format';
export { Backend, GoogleAIBackend, VertexAIBackend } from './backend';
export { GenerativeModel, AIError, AIModel, ImagenModel };

/**
* Returns the default {@link AI} instance that is associated with the provided
Expand Down Expand Up @@ -58,16 +59,22 @@ export { GenerativeModel, AIError, AIModel };
*
* @public
*/
export function getAI(
app: ReactNativeFirebase.FirebaseApp = getApp(),
options: AIOptions = { backend: new GoogleAIBackend() },
): AI {
export function getAI(app: ReactNativeFirebase.FirebaseApp = getApp(), options?: AIOptions): AI {
const backend: Backend = options?.backend ?? new GoogleAIBackend();

const finalOptions: Omit<AIOptions, 'backend'> = {
useLimitedUseAppCheckTokens: options?.useLimitedUseAppCheckTokens ?? false,
appCheck: options?.appCheck || null,
auth: options?.auth || null,
};

return {
app,
backend: options.backend,
location: (options.backend as VertexAIBackend)?.location || '',
appCheck: options.appCheck || null,
auth: options.auth || null,
backend,
options: finalOptions,
location: (backend as VertexAIBackend)?.location || '',
appCheck: options?.appCheck || null,
auth: options?.auth || null,
} as AI;
}

Expand Down
6 changes: 5 additions & 1 deletion packages/ai/lib/methods/chat-session-helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@ const VALID_PART_FIELDS: Array<keyof Part> = [
'inlineData',
'functionCall',
'functionResponse',
'thought',
'thoughtSignature',
];

const VALID_PARTS_PER_ROLE: { [key in Role]: Array<keyof Part> } = {
user: ['text', 'inlineData'],
function: ['functionResponse'],
model: ['text', 'functionCall'],
model: ['text', 'functionCall', 'thought', 'thoughtSignature'],
// System instructions shouldn't be in history anyway.
system: ['text'],
};
Expand Down Expand Up @@ -78,6 +80,8 @@ export function validateChatHistory(history: Content[]): void {
inlineData: 0,
functionCall: 0,
functionResponse: 0,
thought: 0,
thoughtSignature: 0,
};

for (const part of parts) {
Expand Down
Loading
Loading