Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: 0.2.0 #51

Merged
merged 13 commits into from
Feb 29, 2024
55 changes: 44 additions & 11 deletions examples/gemini/celest/functions/gemini.dart
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
// Cloud functions are top-level Dart functions defined in the `functions/`
// folder of your Celest project.

import 'dart:convert';

import 'package:celest/celest.dart';
import 'package:celest_backend/models.dart';
import 'package:google_generative_ai/google_generative_ai.dart';

import '../resources.dart';
Expand All @@ -24,8 +25,7 @@ const _availableModels = [
Future<String> generateContent({
required String modelName,
required String prompt,
ModelParameters parameters = const ModelParameters(),
@env.geminiApiKey required String apiKey,
@Env.geminiApiKey required String apiKey,
}) async {
if (!_availableModels.contains(modelName)) {
throw BadRequestException('Invalid model: $modelName');
Expand All @@ -35,19 +35,52 @@ Future<String> generateContent({
final request = [
Content.text(prompt),
];
final response = await model.generateContent(
request,
generationConfig: GenerationConfig(
maxOutputTokens: parameters.maxTokens,
temperature: parameters.temperature,
),
);
print('Sending prompt: $prompt');

final response = await model.generateContent(request);
print('Got response: ${_prettyJson(response.toJson())}');

switch (response.text) {
case final text?:
print('Gemini response: $text');
print('Selected answer: $text');
return text;
case _:
throw InternalServerException('Failed to generate content');
}
}

extension on GenerateContentResponse {
Map<String, Object?> toJson() => {
'promptFeedback': {
'blockReason': promptFeedback?.blockReason?.name,
'blockReasonMessage': promptFeedback?.blockReasonMessage,
'safetyRatings': [
for (final rating
in promptFeedback?.safetyRatings ?? <SafetyRating>[])
{
'category': rating.category.name,
'probability': rating.probability.name,
},
],
},
'candidates': [
for (final candidate in candidates)
{
'content': candidate.content.toJson(),
'safetyRatings': [
for (final rating
in candidate.safetyRatings ?? <SafetyRating>[])
{
'category': rating.category.name,
'probability': rating.probability.name,
},
],
'finishReason': candidate.finishReason?.name,
'finishMessage': candidate.finishMessage,
},
],
};
}

String _prettyJson(Object? json) =>
const JsonEncoder.withIndent(' ').convert(json);
53 changes: 42 additions & 11 deletions examples/gemini/celest/lib/client.dart
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,60 @@
// it can be checked into version control.
// ignore_for_file: type=lint, unused_local_variable, unnecessary_cast, unnecessary_import

library;
library; // ignore_for_file: no_leading_underscores_for_library_prefixes

import 'dart:io';
import 'dart:io' as _$io;

import 'package:celest/celest.dart';
import 'package:celest_core/src/util/globals.dart';
import 'package:http/http.dart' as http;
import 'package:http/http.dart' as _$http;

import 'src/client/functions.dart';
import 'src/client/serializers.dart';

final Celest celest = Celest();

enum CelestEnvironment {
local;

Uri get baseUri => switch (this) {
local => kIsWeb || !_$io.Platform.isAndroid
? Uri.parse('http://localhost:7781')
: Uri.parse('http://10.0.2.2:7781'),
};
}

class Celest {
late http.Client httpClient = http.Client();
var _initialized = false;

late CelestEnvironment _currentEnvironment;

late _$http.Client httpClient = _$http.Client();

late Uri _baseUri;

final _functions = CelestFunctions();

T _checkInitialized<T>(T Function() value) {
if (!_initialized) {
throw StateError(
'Celest has not been initialized. Make sure to call `celest.init()` at the start of your `main` method.');
}
return value();
}

CelestEnvironment get currentEnvironment =>
_checkInitialized(() => _currentEnvironment);

late final Uri baseUri = kIsWeb || !Platform.isAndroid
? Uri.parse('http://localhost:7777')
: Uri.parse('http://10.0.2.2:7777');
Uri get baseUri => _checkInitialized(() => _baseUri);

final functions = CelestFunctions();
CelestFunctions get functions => _checkInitialized(() => _functions);

void init() {
Serializers.instance.put(const ModelParametersSerializer());
void init({CelestEnvironment environment = CelestEnvironment.local}) {
_currentEnvironment = environment;
_baseUri = environment.baseUri;
if (!_initialized) {
initSerializers();
}
_initialized = true;
}
}
2 changes: 0 additions & 2 deletions examples/gemini/celest/lib/exceptions.dart

This file was deleted.

13 changes: 0 additions & 13 deletions examples/gemini/celest/lib/models.dart

This file was deleted.

95 changes: 55 additions & 40 deletions examples/gemini/celest/lib/src/client/functions.dart
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
// it can be checked into version control.
// ignore_for_file: type=lint, unused_local_variable, unnecessary_cast, unnecessary_import

library;
library; // ignore_for_file: no_leading_underscores_for_library_prefixes

import 'dart:convert';
import 'dart:convert' as _$convert;

import 'package:celest/celest.dart';
import 'package:celest_backend/models.dart';
import 'package:celest_core/src/exception/cloud_exception.dart';
import 'package:celest_core/src/exception/serialization_exception.dart';
import 'package:google_generative_ai/src/error.dart' as _$error;
import 'package:http/src/exception.dart' as _$exception;

import '../../client.dart';

Expand All @@ -17,18 +19,10 @@ class CelestFunctions {
}

class CelestFunctionsGemini {
/// Returns a list of available models.
Future<List<String>> availableModels() async {
final $response = await celest.httpClient.post(
celest.baseUri.resolve('/gemini/available-models'),
headers: const {'Content-Type': 'application/json; charset=utf-8'},
);
final $body = (jsonDecode($response.body) as Map<String, Object?>);
if ($response.statusCode == 200) {
return ($body['response'] as Iterable<Object?>)
.map((el) => (el as String))
.toList();
}
Never _throwError({
required int $statusCode,
required Map<String, Object?> $body,
}) {
final $error = ($body['error'] as Map<String, Object?>);
final $code = ($error['code'] as String);
final $details = ($error['details'] as Map<String, Object?>?);
Expand All @@ -38,8 +32,25 @@ class CelestFunctionsGemini {
case r'InternalServerException':
throw Serializers.instance
.deserialize<InternalServerException>($details);
case r'SerializationException':
throw Serializers.instance
.deserialize<SerializationException>($details);
case r'GenerativeAIException':
throw Serializers.instance
.deserialize<_$error.GenerativeAIException>($details);
case r'InvalidApiKey':
throw Serializers.instance.deserialize<_$error.InvalidApiKey>($details);
case r'UnsupportedUserLocation':
throw Serializers.instance
.deserialize<_$error.UnsupportedUserLocation>($details);
case r'ServerException':
throw Serializers.instance
.deserialize<_$error.ServerException>($details);
case r'ClientException':
throw Serializers.instance
.deserialize<_$exception.ClientException>($details);
case _:
switch ($response.statusCode) {
switch ($statusCode) {
case 400:
throw BadRequestException($code);
case _:
Expand All @@ -48,44 +59,48 @@ class CelestFunctionsGemini {
}
}

/// Returns a list of available models.
Future<List<String>> availableModels() async {
final $response = await celest.httpClient.post(
celest.baseUri.resolve('/gemini/available-models'),
headers: const {'Content-Type': 'application/json; charset=utf-8'},
);
final $body =
(_$convert.jsonDecode($response.body) as Map<String, Object?>);
if ($response.statusCode != 200) {
_throwError(
$statusCode: $response.statusCode,
$body: $body,
);
}
return ($body['response'] as Iterable<Object?>)
.map((el) => (el as String))
.toList();
}

/// Prompts the Gemini [modelName] with the given [prompt] and [parameters].
///
/// Returns the generated text.
Future<String> generateContent({
required String modelName,
required String prompt,
ModelParameters parameters = const ModelParameters(),
}) async {
final $response = await celest.httpClient.post(
celest.baseUri.resolve('/gemini/generate-content'),
headers: const {'Content-Type': 'application/json; charset=utf-8'},
body: jsonEncode({
body: _$convert.jsonEncode({
r'modelName': modelName,
r'prompt': prompt,
r'parameters':
Serializers.instance.serialize<ModelParameters>(parameters),
}),
);
final $body = (jsonDecode($response.body) as Map<String, Object?>);
if ($response.statusCode == 200) {
return ($body['response'] as String);
}
final $error = ($body['error'] as Map<String, Object?>);
final $code = ($error['code'] as String);
final $details = ($error['details'] as Map<String, Object?>?);
switch ($code) {
case r'BadRequestException':
throw Serializers.instance.deserialize<BadRequestException>($details);
case r'InternalServerException':
throw Serializers.instance
.deserialize<InternalServerException>($details);
case _:
switch ($response.statusCode) {
case 400:
throw BadRequestException($code);
case _:
throw InternalServerException($code);
}
final $body =
(_$convert.jsonDecode($response.body) as Map<String, Object?>);
if ($response.statusCode != 200) {
_throwError(
$statusCode: $response.statusCode,
$body: $body,
);
}
return ($body['response'] as String);
}
}
Loading
Loading