Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changeset/feat-prefer-in-cloud.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"@firebase/ai": minor
"firebase": minor
---

Added a new `InferenceMode` option for the hybrid on-device capability: `prefer_in_cloud`. When this mode is selected, the SDK will attempt to use a cloud-hosted model first. If the call to the cloud-hosted model fails with a network-related error, the SDK will fall back to the on-device model, if it's available.
1 change: 1 addition & 0 deletions common/api-review/ai.api.md
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,7 @@ export const InferenceMode: {
readonly PREFER_ON_DEVICE: "prefer_on_device";
readonly ONLY_ON_DEVICE: "only_on_device";
readonly ONLY_IN_CLOUD: "only_in_cloud";
readonly PREFER_IN_CLOUD: "prefer_in_cloud";
};

// @public
Expand Down
3 changes: 3 additions & 0 deletions docs-devsite/ai.md
Original file line number Diff line number Diff line change
Expand Up @@ -624,13 +624,16 @@ ImagenSafetyFilterLevel: {

<b>(EXPERIMENTAL)</b> Determines whether inference happens on-device or in-cloud.

<b>PREFER\_ON\_DEVICE:</b> Attempt to make inference calls using an on-device model. If on-device inference is not available, the SDK will fall back to using a cloud-hosted model. <br/> <b>ONLY\_ON\_DEVICE:</b> Only attempt to make inference calls using an on-device model. The SDK will not fall back to a cloud-hosted model. If on-device inference is not available, inference methods will throw. <br/> <b>ONLY\_IN\_CLOUD:</b> Only attempt to make inference calls using a cloud-hosted model. The SDK will not fall back to an on-device model. <br/> <b>PREFER\_IN\_CLOUD:</b> Attempt to make inference calls to a cloud-hosted model. If not available, the SDK will fall back to an on-device model.

<b>Signature:</b>

```typescript
InferenceMode: {
readonly PREFER_ON_DEVICE: "prefer_on_device";
readonly ONLY_ON_DEVICE: "only_on_device";
readonly ONLY_IN_CLOUD: "only_in_cloud";
readonly PREFER_IN_CLOUD: "prefer_in_cloud";
}
```

Expand Down
26 changes: 9 additions & 17 deletions packages/ai/src/methods/count-tokens.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -196,24 +196,16 @@ describe('countTokens()', () => {
);
});
});
it('on-device', async () => {
const chromeAdapter = fakeChromeAdapter;
const isAvailableStub = stub(chromeAdapter, 'isAvailable').resolves(true);
const mockResponse = getMockResponse(
'vertexAI',
'unary-success-total-tokens.json'
);
const countTokensStub = stub(chromeAdapter, 'countTokens').resolves(
mockResponse as Response
it('throws if mode is ONLY_ON_DEVICE', async () => {
const chromeAdapter = new ChromeAdapterImpl(
// @ts-expect-error
undefined,
InferenceMode.ONLY_ON_DEVICE
);
const result = await countTokens(
fakeApiSettings,
'model',
fakeRequestParams,
chromeAdapter
await expect(
countTokens(fakeApiSettings, 'model', fakeRequestParams, chromeAdapter)
).to.be.rejectedWith(
/countTokens\(\) is not supported for on-device models/
);
expect(result.totalTokens).eq(6);
expect(isAvailableStub).to.be.called;
expect(countTokensStub).to.be.calledWith(fakeRequestParams);
});
});
16 changes: 12 additions & 4 deletions packages/ai/src/methods/count-tokens.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,20 @@
* limitations under the License.
*/

import { AIError } from '../errors';
import {
CountTokensRequest,
CountTokensResponse,
RequestOptions
InferenceMode,
RequestOptions,
AIErrorCode
} from '../types';
import { Task, makeRequest } from '../requests/request';
import { ApiSettings } from '../types/internal';
import * as GoogleAIMapper from '../googleai-mappers';
import { BackendType } from '../public-types';
import { ChromeAdapter } from '../types/chrome-adapter';
import { ChromeAdapterImpl } from './chrome-adapter';

export async function countTokensOnCloud(
apiSettings: ApiSettings,
Expand Down Expand Up @@ -57,9 +61,13 @@ export async function countTokens(
chromeAdapter?: ChromeAdapter,
requestOptions?: RequestOptions
): Promise<CountTokensResponse> {
if (chromeAdapter && (await chromeAdapter.isAvailable(params))) {
return (await chromeAdapter.countTokens(params)).json();
if (
(chromeAdapter as ChromeAdapterImpl)?.mode === InferenceMode.ONLY_ON_DEVICE
) {
throw new AIError(
AIErrorCode.UNSUPPORTED,
'countTokens() is not supported for on-device models.'
);
}

return countTokensOnCloud(apiSettings, model, params, requestOptions);
}
36 changes: 14 additions & 22 deletions packages/ai/src/methods/generate-content.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import { ApiSettings } from '../types/internal';
import * as GoogleAIMapper from '../googleai-mappers';
import { BackendType } from '../public-types';
import { ChromeAdapter } from '../types/chrome-adapter';
import { callCloudOrDevice } from '../requests/hybrid-helpers';

async function generateContentStreamOnCloud(
apiSettings: ApiSettings,
Expand Down Expand Up @@ -56,17 +57,13 @@ export async function generateContentStream(
chromeAdapter?: ChromeAdapter,
requestOptions?: RequestOptions
): Promise<GenerateContentStreamResult> {
let response;
if (chromeAdapter && (await chromeAdapter.isAvailable(params))) {
response = await chromeAdapter.generateContentStream(params);
} else {
response = await generateContentStreamOnCloud(
apiSettings,
model,
params,
requestOptions
);
}
const response = await callCloudOrDevice(
params,
chromeAdapter,
() => chromeAdapter!.generateContentStream(params),
() =>
generateContentStreamOnCloud(apiSettings, model, params, requestOptions)
);
return processStream(response, apiSettings); // TODO: Map streaming responses
}

Expand Down Expand Up @@ -96,17 +93,12 @@ export async function generateContent(
chromeAdapter?: ChromeAdapter,
requestOptions?: RequestOptions
): Promise<GenerateContentResult> {
let response;
if (chromeAdapter && (await chromeAdapter.isAvailable(params))) {
response = await chromeAdapter.generateContent(params);
} else {
response = await generateContentOnCloud(
apiSettings,
model,
params,
requestOptions
);
}
const response = await callCloudOrDevice(
params,
chromeAdapter,
() => chromeAdapter!.generateContent(params),
() => generateContentOnCloud(apiSettings, model, params, requestOptions)
);
const generateContentResponse = await processGenerateContentResponse(
response,
apiSettings
Expand Down
Loading
Loading