Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: move to byorg framework #60

Merged
merged 10 commits into from
Nov 5, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
chore: move to byorg framework
Q1w1N committed Nov 4, 2024
commit abc68aced353a84a5e5b5004e5a692cdfddee05e
4 changes: 4 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
@@ -41,7 +41,11 @@
"registry": "https://registry.npmjs.org/"
},
"dependencies": {
"@ai-sdk/anthropic": "^0.0.54",
"@ai-sdk/mistral": "^0.0.46",
"@ai-sdk/openai": "^0.0.71",
"@anthropic-ai/sdk": "^0.26.1",
"@callstack/byorg-core": "0.1.2",
"@inkjs/ui": "^1.0.0",
"@mistralai/mistralai": "^1.0.2",
"chalk": "^5.3.0",
12 changes: 8 additions & 4 deletions src/commands/chat/providers.tsx
Original file line number Diff line number Diff line change
@@ -2,10 +2,14 @@ import openAi from '../../engine/providers/open-ai.js';
import anthropic from '../../engine/providers/anthropic.js';
import perplexity from '../../engine/providers/perplexity.js';
import mistral from '../../engine/providers/mistral.js';
import { getProvider, type Provider, type ProviderName } from '../../engine/providers/provider.js';
import {
getProvider,
type ProviderInfo,
type ProviderName,
} from '../../engine/providers/provider-info.js';
import type { ConfigFile } from '../../config-file.js';

export const providerOptionMapping: Record<string, Provider> = {
export const providerOptionMapping: Record<string, ProviderInfo> = {
openai: openAi,
anthropic,
anth: anthropic,
@@ -16,7 +20,7 @@ export const providerOptionMapping: Record<string, Provider> = {

export const providerOptions = Object.keys(providerOptionMapping);

export function resolveProviderFromOption(providerOption: string): Provider {
export function resolveProviderInfoFromOption(providerOption: string): ProviderInfo {
const provider = providerOptionMapping[providerOption];
if (!provider) {
throw new Error(`Provider not found: ${providerOption}.`);
@@ -25,7 +29,7 @@ export function resolveProviderFromOption(providerOption: string): Provider {
return provider;
}

export function getDefaultProvider(config: ConfigFile): Provider {
export function getDefaultProviderInfo(config: ConfigFile): ProviderInfo {
const providerNames = Object.keys(config.providers) as ProviderName[];
const providerName = providerNames ? providerNames[0] : undefined;

35 changes: 16 additions & 19 deletions src/commands/chat/state/actions.ts
Original file line number Diff line number Diff line change
@@ -1,39 +1,36 @@
import type { ModelResponse } from '../../../engine/inference.js';
import { type AssistantMessage, type AssistantResponse } from '@callstack/byorg-core';
import { useChatState, type MessageLevel } from './state.js';

export function addUserMessage(text: string) {
export function addUserMessage(content: string) {
useChatState.setState((state) => {
return {
activeView: null,
contextMessages: [...state.contextMessages, { role: 'user', content: text }],
chatMessages: [...state.chatMessages, { type: 'user', text }],
contextMessages: [...state.contextMessages, { role: 'user', content }],
chatMessages: [...state.chatMessages, { role: 'user', content }],
};
});
}

export function addAiResponse(response: ModelResponse) {
export function addAssistantResponse(response: AssistantResponse) {
useChatState.setState((state) => {
const outputMessages = {
type: 'ai',
text: response.message.content,
responseTime: response.responseTime,
usage: response.usage,
data: response.data,
} as const;
const message: AssistantMessage = {
role: 'assistant',
content: response.content,
};

return {
activeView: null,
contextMessages: [...state.contextMessages, response.message],
chatMessages: [...state.chatMessages, outputMessages],
contextMessages: [...state.contextMessages, message],
chatMessages: [...state.chatMessages, response],
};
});
}

export function addProgramMessage(text: string, level: MessageLevel = 'info') {
export function addProgramMessage(content: string, level: MessageLevel = 'info') {
useChatState.setState((state) => {
return {
activeView: null,
chatMessages: [...state.chatMessages, { type: 'program', level, text }],
chatMessages: [...state.chatMessages, { role: 'program', level, content }],
};
});
}
@@ -57,7 +54,7 @@ export function forgetContextMessages() {
contextMessages: [],
chatMessages: [
...state.chatMessages,
{ type: 'program', level: 'info', text: 'AI will forget previous messages.' },
{ role: 'program', level: 'info', content: 'AI will forget previous messages.' },
],
};
});
@@ -92,7 +89,7 @@ export function triggerExit() {
return {
activeView: null,
shouldExit: true,
chatMessages: [...state.chatMessages, { type: 'program', level: 'info', text: 'Bye! 👋' }],
chatMessages: [...state.chatMessages, { role: 'program', level: 'info', content: 'Bye! 👋' }],
};
});
}
@@ -103,7 +100,7 @@ export function setVerbose(verbose: boolean) {
verbose,
chatMessages: [
...state.chatMessages,
{ type: 'program', level: 'info', text: `Verbose mode: ${verbose ? 'on' : 'off'}` },
{ role: 'program', level: 'info', content: `Verbose mode: ${verbose ? 'on' : 'off'}` },
],
};
});
54 changes: 31 additions & 23 deletions src/commands/chat/state/init.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import { createApp, loggingPlugin, type Message } from '@callstack/byorg-core';
import { type ConfigFile } from '../../../config-file.js';
import { DEFAULT_SYSTEM_PROMPT } from '../../../default-config.js';
import type { ResponseStyle } from '../../../engine/providers/config.js';
import type { Message } from '../../../engine/inference.js';
import type { PromptOptions } from '../prompt-options.js';
import { getDefaultProvider, resolveProviderFromOption } from '../providers.js';
import { getDefaultProviderInfo, resolveProviderInfoFromOption } from '../providers.js';
import { filterOutApiKey, handleInputFile } from '../utils.js';
import { useChatState, type ChatMessage, type ChatState } from './state.js';

@@ -12,19 +12,19 @@ export function initChatState(
configFile: ConfigFile,
initialPrompt: string,
) {
const provider = options.provider
? resolveProviderFromOption(options.provider)
: getDefaultProvider(configFile);
const providerInfo = options.provider
? resolveProviderInfoFromOption(options.provider)
: getDefaultProviderInfo(configFile);

const providerFileConfig = configFile.providers[provider.name];
const providerFileConfig = configFile.providers[providerInfo.name];
if (!providerFileConfig) {
throw new Error(`Provider config not found: ${provider.name}.`);
throw new Error(`Provider config not found: ${providerInfo.name}.`);
}

const modelOrAlias = options.model ?? providerFileConfig.model;
const model = modelOrAlias
? (provider.modelAliases[modelOrAlias] ?? modelOrAlias)
: provider.defaultModel;
? (providerInfo.modelAliases[modelOrAlias] ?? modelOrAlias)
: providerInfo.defaultModel;

const systemPrompt = providerFileConfig.systemPrompt ?? DEFAULT_SYSTEM_PROMPT;

@@ -40,47 +40,55 @@ export function initChatState(

if (modelOrAlias != null && modelOrAlias !== model) {
outputMessages.push({
type: 'program',
role: 'program',
level: 'debug',
text: `Resolved model alias "${modelOrAlias}" to "${model}".`,
content: `Resolved model alias "${modelOrAlias}" to "${model}".`,
});
}

outputMessages.push({
type: 'program',
role: 'program',
level: 'debug',
text: `Loaded config: ${JSON.stringify(providerConfig, filterOutApiKey, 2)}`,
content: `Loaded config: ${JSON.stringify(providerConfig, filterOutApiKey, 2)}`,
});

if (options.file) {
const { systemMessage, costWarning, costInfo } = handleInputFile(
options.file,
providerConfig,
provider,
);
const {
systemPrompt: fileSystemPrompt,
costWarning,
costInfo,
} = handleInputFile(options.file, providerConfig, providerInfo);

providerConfig.systemPrompt += `\n\n${fileSystemPrompt}`;

contextMessages.push(systemMessage);
if (costWarning) {
outputMessages.push({ type: 'program', level: 'warning', text: costWarning });
outputMessages.push({ role: 'program', level: 'warning', content: costWarning });
} else if (costInfo) {
outputMessages.push({ type: 'program', level: 'info', text: costInfo });
outputMessages.push({ role: 'program', level: 'info', content: costInfo });
}
}

if (initialPrompt) {
contextMessages.push({ role: 'user', content: initialPrompt });
outputMessages.push({ type: 'user', text: initialPrompt });
outputMessages.push({ role: 'user', content: initialPrompt });
}

const app = createApp({
chatModel: providerInfo.getChatModel(providerConfig),
systemPrompt: () => providerConfig.systemPrompt,
plugins: [loggingPlugin],
});

const state: ChatState = {
activeView: null,
provider,
provider: providerInfo,
providerConfig,
contextMessages,
chatMessages: outputMessages,
verbose: options.verbose ?? false,
shouldExit: false,
stream: options.stream ?? true,
app,
};

useChatState.setState(state);
31 changes: 9 additions & 22 deletions src/commands/chat/state/state.ts
Original file line number Diff line number Diff line change
@@ -1,43 +1,30 @@
import { create } from 'zustand';
import type { Application, AssistantResponse, Message, UserMessage } from '@callstack/byorg-core';
import type { ProviderConfig } from '../../../engine/providers/config.js';
import type { Provider } from '../../../engine/providers/provider.js';
import type { Message, ModelUsage } from '../../../engine/inference.js';
import type { ProviderInfo } from '../../../engine/providers/provider-info.js';

export interface ChatState {
provider: Provider;
provider: ProviderInfo;
providerConfig: ProviderConfig;
contextMessages: Message[];
chatMessages: ChatMessage[];
activeView: ActiveView;
verbose: boolean;
shouldExit: boolean;
stream: boolean;
app: Application;
}

export type ChatMessage = UserChatMessage | AiChatMessage | ProgramChatMessage;
export type ChatMessage = UserMessage | AssistantResponse | ProgramOutput;

export interface UserChatMessage {
type: 'user';
text: string;
}

export interface AiChatMessage {
type: 'ai';
text: string;
responseTime: number;
usage: ModelUsage;
cost?: number;
data?: unknown;
export interface ProgramOutput {
role: 'program';
content: string;
level: MessageLevel;
}

export type MessageLevel = 'debug' | 'info' | 'warning' | 'error';

export interface ProgramChatMessage {
type: 'program';
level: MessageLevel;
text: string;
}

type ActiveView = 'info' | 'help' | null;

// @ts-expect-error lazy init
Original file line number Diff line number Diff line change
@@ -4,11 +4,11 @@ import { TextSpinner } from '../../../components/TextSpinner.js';
import { colors } from '../../../theme/colors.js';
import { texts } from '../texts.js';

interface AiResponseLoaderProps {
interface AssistantResponseLoaderProps {
text?: string;
}

export function AiResponseLoader({ text }: AiResponseLoaderProps) {
export function AssistantResponseLoader({ text }: AssistantResponseLoaderProps) {
return (
<Text color={colors.assistant}>
<Text color={colors.assistant}>{texts.assistantLabel}</Text>
32 changes: 17 additions & 15 deletions src/commands/chat/ui/ChatUi.tsx
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import React, { useEffect, useState } from 'react';
import { Box } from 'ink';
import { ExitApp } from '../../../components/ExitApp.js';
import type { ModelResponseUpdate } from '../../../engine/inference.js';
import { extractErrorMessage } from '../../../output.js';
import { processChatCommand } from '../chat-commands.js';
import {
addAiResponse,
addAssistantResponse,
addProgramMessage,
addUserMessage,
hideActiveView,
@@ -18,7 +17,7 @@ import { HelpOutput } from './HelpOutput.js';
import { StatusBar } from './StatusBar.js';
import { InfoOutput } from './InfoOutput.js';
import { ChatMessageList } from './list/ChatMessageList.js';
import { AiResponseLoader } from './AiResponseLoader.js';
import { AssistantResponseLoader } from './AssistantResponseLoader.js';

export function ChatUi() {
const contextMessages = useChatState((state) => state.contextMessages);
@@ -59,7 +58,7 @@ export function ChatUi() {
{activeView === 'help' && <HelpOutput />}
{activeView === 'info' && <InfoOutput />}

{isLoading ? <AiResponseLoader text={loadedResponse} /> : null}
{isLoading ? <AssistantResponseLoader text={loadedResponse} /> : null}
{showInput && <UserMessageInput onSubmit={handleSubmit} />}

<StatusBar />
@@ -69,8 +68,7 @@ export function ChatUi() {
}

function useAiResponse() {
const provider = useChatState((state) => state.provider);
const providerConfig = useChatState((state) => state.providerConfig);
const app = useChatState((state) => state.app);
const stream = useChatState((state) => state.stream);

const [isLoading, setLoading] = useState(false);
@@ -82,16 +80,20 @@ function useAiResponse() {
setLoadedResponse(undefined);

const messages = useChatState.getState().contextMessages;
if (stream && provider.getChatCompletionStream != null) {
const response = await provider.getChatCompletionStream(
providerConfig,
messages,
(update: ModelResponseUpdate) => setLoadedResponse(update.content),
);
addAiResponse(response);
if (stream) {
const { response } = await app.processMessages(messages, {
onPartialResponse: (update) => setLoadedResponse(update),
});

if (response.role == 'assistant') {
addAssistantResponse(response);
} else {
addProgramMessage(response.content);
}
} else {
const response = await provider.getChatCompletion(providerConfig, messages);
addAiResponse(response);
throw new Error('Non-Stream mode is not supported yet');
// const response = await provider.getChatCompletion(providerConfig, messages);
// addAiResponse(response);
}
} catch (error) {
// We cannot leave unanswered user message in context, as there is no AI response for it.
48 changes: 28 additions & 20 deletions src/commands/chat/ui/StatusBar.tsx
Original file line number Diff line number Diff line change
@@ -1,57 +1,65 @@
import React, { useMemo } from 'react';
import { Box, Text } from 'ink';
import type { ModelUsage } from '../../../engine/inference.js';
import { formatCost, formatSpeed, formatTokenCount } from '../../../format.js';
import { calculateUsageCost } from '../../../engine/session.js';
import { useChatState, type ChatMessage } from '../state/state.js';

type TotalUsage = {
inputTokens: number;
outputTokens: number;
requests: number;
responseTime: number;
model: string | undefined;
};

export function StatusBar() {
const verbose = useChatState((state) => state.verbose);
const items = useChatState((state) => state.chatMessages);
const provider = useChatState((state) => state.provider);
const providerConfig = useChatState((state) => state.providerConfig);

const totalUsage = useMemo(() => calculateTotalUsage(items), [items]);
const totalTime = useMemo(() => calculateTotalResponseTime(items), [items]);

const modelPricing = provider.modelPricing[providerConfig.model];
const model = totalUsage.model || providerConfig.model;
const modelPricing = provider.modelPricing[model] ?? provider.modelPricing[providerConfig.model];
const totalCost = calculateUsageCost(totalUsage, modelPricing) ?? 0;

return (
<Box flexDirection="row" marginTop={1}>
<Text color={'gray'}>
LLM: {provider.label}/{providerConfig.model} - Total Cost:{' '}
{verbose ? formatVerboseStats(totalCost, totalUsage, totalTime) : formatCost(totalCost)}
LLM: {provider.label}/{model} - Total Cost:{' '}
{verbose ? formatVerboseStats(totalCost, totalUsage) : formatCost(totalCost)}
</Text>
</Box>
);
}

function formatVerboseStats(cost: number, usage: ModelUsage, time: number) {
function formatVerboseStats(cost: number, usage: TotalUsage) {
const usageOutput = usage
? ` (tokens: ${formatTokenCount(usage.inputTokens)} in + ${formatTokenCount(usage.outputTokens)} out, requests: ${usage.requests}, speed: ${formatSpeed(usage.outputTokens, time)})`
? ` (tokens: ${formatTokenCount(usage.inputTokens)} in + ${formatTokenCount(
usage.outputTokens,
)} out, requests: ${usage.requests}, speed: ${formatSpeed(usage.outputTokens, usage.responseTime)})`
: '';
return `${formatCost(cost)}${usageOutput}`;
}

function calculateTotalUsage(messages: ChatMessage[]) {
const usage: ModelUsage = { inputTokens: 0, outputTokens: 0, requests: 0 };
function calculateTotalUsage(messages: ChatMessage[]): TotalUsage {
const usage: TotalUsage = {
inputTokens: 0,
outputTokens: 0,
requests: 0,
responseTime: 0,
model: undefined,
};
messages.forEach((message) => {
if (message.type === 'ai') {
if (message.role === 'assistant') {
usage.inputTokens += message.usage?.inputTokens ?? 0;
usage.outputTokens += message.usage?.outputTokens ?? 0;
usage.requests += message.usage?.requests ?? 0;
usage.responseTime += message.usage?.responseTime ?? 0;
usage.model = message.usage?.model ?? usage.model;
}
});
return usage;
}

function calculateTotalResponseTime(messages: ChatMessage[]) {
let total = 0;
messages.forEach((message) => {
if (message.type === 'ai') {
total += message.responseTime ?? 0;
}
});
return total;
return usage;
}
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
import React from 'react';
import { Text } from 'ink';
import type { AssistantResponse } from '@callstack/byorg-core';
import { formatSpeed, formatTime } from '../../../../format.js';
import { colors } from '../../../../theme/colors.js';
import { useChatState, type AiChatMessage } from '../../state/state.js';
import { useChatState } from '../../state/state.js';
import { texts } from '../../texts.js';

interface AiChatMessageItemProps {
message: AiChatMessage;
interface AssistantResponseItemProps {
message: AssistantResponse;
}

export function AiChatMessageItem({ message }: AiChatMessageItemProps) {
export function AssistantResponseItem({ message }: AssistantResponseItemProps) {
const verbose = useChatState((state) => state.verbose);

return (
<>
<Text color={colors.assistant}>
<Text>{texts.assistantLabel}</Text>
<Text>{message.text}</Text>
<Text>{message.content}</Text>

{verbose && message.responseTime != null ? (
{verbose && message.usage.responseTime != null ? (
<Text color={colors.info}>
{' '}
({formatTime(message.responseTime)},{' '}
{formatSpeed(message.usage?.outputTokens, message.responseTime)})
({formatTime(message.usage.responseTime)},{' '}
{formatSpeed(message.usage?.outputTokens, message.usage.responseTime)})
</Text>
) : null}
</Text>
{verbose ? <Text color={colors.debug}>{JSON.stringify(message.data, null, 2)}</Text> : null}
<Text> {/* Add a newline */}</Text>
</>
);
18 changes: 9 additions & 9 deletions src/commands/chat/ui/list/ChatMessageList.tsx
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import * as React from 'react';
import { Box, Static } from 'ink';
import { useChatState, type ChatMessage } from '../../state/state.js';
import { ProgramChatMessageItem } from './ProgramChatMessageItem.js';
import { UserChatMessageItem } from './UserChatMessageItem.js';
import { AiChatMessageItem } from './AiChatMessageItem.js';
import { UserMessageItem } from './UserChatMessageItem.js';
import { AssistantResponseItem } from './AssistantResponseItem.js';
import { ProgramOutputItem } from './ProgramChatMessageItem.js';

export function ChatMessageList() {
const messages = useChatState((state) => state.chatMessages);
@@ -21,16 +21,16 @@ export function ChatMessageList() {
}

function renderMessage(message: ChatMessage, index: number): React.ReactNode {
if (message.type === 'user') {
return <UserChatMessageItem key={index} message={message} />;
if (message.role === 'user') {
return <UserMessageItem key={index} message={message} />;
}

if (message.type === 'ai') {
return <AiChatMessageItem key={index} message={message} />;
if (message.role === 'assistant') {
return <AssistantResponseItem key={index} message={message} />;
}

if (message.type === 'program') {
return <ProgramChatMessageItem key={index} output={message} />;
if (message.role === 'program') {
return <ProgramOutputItem key={index} output={message} />;
}

return null;
16 changes: 8 additions & 8 deletions src/commands/chat/ui/list/ProgramChatMessageItem.tsx
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
import React from 'react';
import { Text } from 'ink';
import { colors } from '../../../../theme/colors.js';
import { useChatState, type ProgramChatMessage } from '../../state/state.js';
import { useChatState, type ProgramOutput } from '../../state/state.js';

interface ProgramChatMessageItemProps {
output: ProgramChatMessage;
interface ProgramOutputItemProps {
output: ProgramOutput;
}

export function ProgramChatMessageItem({ output }: ProgramChatMessageItemProps) {
export function ProgramOutputItem({ output }: ProgramOutputItemProps) {
const verbose = useChatState((state) => state.verbose);

if (output.level === 'error') {
return <Text color={colors.error}>{output.text}</Text>;
return <Text color={colors.error}>{output.content}</Text>;
}

if (output.level === 'warning') {
return <Text color={colors.warning}>{output.text}</Text>;
return <Text color={colors.warning}>{output.content}</Text>;
}

if (output.level === 'info') {
return <Text color={colors.info}>{output.text}</Text>;
return <Text color={colors.info}>{output.content}</Text>;
}

if (verbose && output.level === 'debug') {
return <Text color={colors.debug}>{output.text}</Text>;
return <Text color={colors.debug}>{output.content}</Text>;
}

return null;
8 changes: 4 additions & 4 deletions src/commands/chat/ui/list/UserChatMessageItem.tsx
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import React from 'react';
import { Text } from 'ink';
import type { UserMessage } from '@callstack/byorg-core';
import { colors } from '../../../../theme/colors.js';
import { type UserChatMessage } from '../../state/state.js';
import { texts } from '../../texts.js';

interface UserChatMessageItemProps {
message: UserChatMessage;
message: UserMessage;
}

export function UserChatMessageItem({ message }: UserChatMessageItemProps) {
export function UserMessageItem({ message }: UserChatMessageItemProps) {
return (
<Text color={colors.user}>
<Text>{texts.userLabel}</Text>
<Text>{message.text}</Text>
<Text>{message.content}</Text>
</Text>
);
}
17 changes: 8 additions & 9 deletions src/commands/chat/utils.ts
Original file line number Diff line number Diff line change
@@ -1,33 +1,34 @@
import * as fs from 'fs';
import * as path from 'path';
import * as os from 'os';
import type { Message } from '@callstack/byorg-core';
import {
DEFAULT_FILE_PROMPT,
FILE_COST_WARNING,
FILE_TOKEN_COUNT_WARNING,
} from '../../default-config.js';
import { calculateUsageCost } from '../../engine/session.js';
import { getTokensCount } from '../../engine/tokenizer.js';
import type { Message, SystemMessage } from '../../engine/inference.js';
import type { ProviderConfig } from '../../engine/providers/config.js';
import type { Provider } from '../../engine/providers/provider.js';
import type { ProviderInfo } from '../../engine/providers/provider-info.js';
import { formatCost, formatTokenCount } from '../../format.js';
import {
getConversationStoragePath,
getDefaultFilename,
getUniqueFilename,
} from '../../file-utils.js';
import { texts } from './texts.js';

interface HandleInputFileResult {
systemMessage: SystemMessage;
systemPrompt: string;
costWarning: string | null;
costInfo: string | null;
}

export function handleInputFile(
inputFile: string,
config: ProviderConfig,
provider: Provider,
provider: ProviderInfo,
): HandleInputFileResult {
const filePath = path.resolve(inputFile.replace('~', os.homedir()));

@@ -59,7 +60,7 @@ export function handleInputFile(
);

return {
systemMessage: { role: 'system', content },
systemPrompt: content,
costWarning,
costInfo,
};
@@ -92,11 +93,9 @@ export function saveConversation(messages: Message[]) {
function roleToLabel(role: Message['role']): string {
switch (role) {
case 'user':
return 'me';
return texts.userLabel;
case 'assistant':
return 'ai';
case 'system':
return 'system';
return texts.assistantLabel;
default:
return role;
}
6 changes: 3 additions & 3 deletions src/commands/init/ui/InitUi.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import React, { useState } from 'react';
import { Box, Newline, Text } from 'ink';
import { type Provider } from '../../../engine/providers/provider.js';
import { type ProviderInfo } from '../../../engine/providers/provider-info.js';
import { writeConfigFile } from '../../../config-file.js';
import { colors } from '../../../theme/colors.js';
import { ExitApp } from '../../../components/ExitApp.js';
@@ -17,10 +17,10 @@ export function InitUi({ hasConfig }: InitUiProps) {
const [step, setStep] = useState(0);

const [overwriteConfig, setOverwriteConfig] = useState(true);
const [provider, setProvider] = useState<Provider>();
const [provider, setProvider] = useState<ProviderInfo>();
const [hasKey, setHasKey] = useState<boolean>();

const writeConfig = (provider?: Provider, apiKey?: string) => {
const writeConfig = (provider?: ProviderInfo, apiKey?: string) => {
if (!provider || !apiKey) {
return;
}
8 changes: 4 additions & 4 deletions src/commands/init/ui/SelectProviderStep.tsx
Original file line number Diff line number Diff line change
@@ -4,20 +4,20 @@ import { Select } from '@inkjs/ui';
import {
getProvider,
providers,
type Provider,
type ProviderInfo,
type ProviderName,
} from '../../../engine/providers/provider.js';
} from '../../../engine/providers/provider-info.js';
import { colors } from '../../../theme/colors.js';

interface SelectProviderStepProps {
label: string;
onSelect: (provider: Provider) => void;
onSelect: (provider: ProviderInfo) => void;
}

const providerItems = providers.map((p) => ({ label: p.label, value: p.name }));

export function SelectProviderStep({ label, onSelect }: SelectProviderStepProps) {
const [value, setValue] = useState<Provider>();
const [value, setValue] = useState<ProviderInfo>();

const handleSelect = (name: ProviderName) => {
const provider = getProvider(name);
2 changes: 1 addition & 1 deletion src/config-file.ts
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@ import * as fs from 'fs';
import * as os from 'os';
import * as path from 'path';
import { z } from 'zod';
import { providerNames } from './engine/providers/provider.js';
import { providerNames } from './engine/providers/provider-info.js';

const LEGACY_CONFIG_FILENAME = '.airc';
const CONFIG_FILENAME = '.airc.json';
34 changes: 0 additions & 34 deletions src/engine/inference.ts

This file was deleted.

107 changes: 9 additions & 98 deletions src/engine/providers/anthropic.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
import AnthropicAPI from '@anthropic-ai/sdk';
import {
type AiMessage,
type Message,
type ModelResponseUpdate,
type UserMessage,
} from '../inference.js';
import { responseStyles, type ProviderConfig } from './config.js';
import type { Provider } from './provider.js';
import { createAnthropic } from '@ai-sdk/anthropic';
import { VercelChatModelAdapter } from '@callstack/byorg-core';
import { type ProviderConfig } from './config.js';
import type { ProviderInfo } from './provider-info.js';

type ModelMessage = UserMessage | AiMessage;

const Anthropic: Provider = {
const Anthropic: ProviderInfo = {
label: 'Anthropic',
name: 'anthropic',
apiKeyUrl: 'https://console.anthropic.com/settings/keys',
@@ -39,93 +32,11 @@ const Anthropic: Provider = {
opus: 'claude-3-opus-20240229',
},

getChatCompletion: async (config: ProviderConfig, messages: Message[]) => {
const api = new AnthropicAPI({
apiKey: config.apiKey,
});

const nonSystemMessages = messages.filter((m) => m.role !== 'system') as ModelMessage[];

const systemMessagesContent = messages.filter((m) => m.role === 'system').map((m) => m.content);
const systemPrompt = [config.systemPrompt, ...systemMessagesContent].join('\n\n');

const startTime = performance.now();
const response = await api.messages.create({
messages: nonSystemMessages,
model: config.model,
max_tokens: 1024,
system: systemPrompt,
...responseStyles[config.responseStyle],
getChatModel: (config: ProviderConfig) => {
const client = createAnthropic({ apiKey: config.apiKey });
return new VercelChatModelAdapter({
languageModel: client.languageModel(config.model),
});
const responseTime = performance.now() - startTime;

const firstContent = response.content[0];
if (firstContent?.type !== 'text') {
throw new Error(`Received unexpected content type from Anthropic API: ${firstContent?.type}`);
}

return {
message: {
role: 'assistant',
content: firstContent.text,
},
usage: {
inputTokens: response.usage.input_tokens,
outputTokens: response.usage.output_tokens,
requests: 1,
},
responseTime,
responseModel: response.model,
data: response,
};
},

getChatCompletionStream: async function (
config: ProviderConfig,
messages: Message[],
onResponseUpdate: (update: ModelResponseUpdate) => void,
) {
const api = new AnthropicAPI({
apiKey: config.apiKey,
});

const nonSystemMessages = messages.filter((m) => m.role !== 'system') as ModelMessage[];

const systemMessagesContent = messages.filter((m) => m.role === 'system').map((m) => m.content);
const systemPrompt = [config.systemPrompt, ...systemMessagesContent].join('\n\n');

const startTime = performance.now();
let content = '';
const stream = api.messages
.stream({
messages: nonSystemMessages,
model: config.model,
max_tokens: 1024,
system: systemPrompt,
...responseStyles[config.responseStyle],
})
.on('text', (text) => {
content += text;
onResponseUpdate({ content });
});

const response = await stream.finalMessage();
const responseTime = performance.now() - startTime;

return {
message: {
role: 'assistant',
content,
},
usage: {
inputTokens: response.usage.input_tokens,
outputTokens: response.usage.output_tokens,
requests: 1,
},
responseTime,
responseModel: response.model,
data: response,
};
},
};

112 changes: 8 additions & 104 deletions src/engine/providers/mistral.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import { Mistral as MistralClient } from '@mistralai/mistralai';
import type { CompletionEvent } from '@mistralai/mistralai/models/components/completionevent.js';
import { type Message, type ModelResponseUpdate } from '../inference.js';
import { estimateInputTokens, estimateOutputTokens } from '../tokenizer.js';
import { responseStyles, type ProviderConfig } from './config.js';
import type { Provider } from './provider.js';
import { createMistral } from '@ai-sdk/mistral';
import { VercelChatModelAdapter } from '@callstack/byorg-core';
import type { ProviderInfo } from './provider-info.js';

const Mistral: Provider = {
const Mistral: ProviderInfo = {
label: 'Mistral',
name: 'mistral',
apiKeyUrl: 'https://console.mistral.ai/api-keys/',
@@ -40,105 +37,12 @@ const Mistral: Provider = {
codestral: 'codestral-latest',
},

getChatCompletion: async (config: ProviderConfig, messages: Message[]) => {
const api = new MistralClient({ apiKey: config.apiKey });
const allMessages = getMessages(config, messages);

const startTime = performance.now();
const response = await api.chat.complete({
model: config.model,
messages: allMessages,
...getMistralResponseStyle(config),
getChatModel(config) {
const client = createMistral({ apiKey: config.apiKey });
return new VercelChatModelAdapter({
languageModel: client.languageModel(config.model),
});
const responseTime = performance.now() - startTime;

return {
message: {
role: 'assistant',
content: response.choices?.[0]?.message?.content ?? '',
},
usage: {
inputTokens: response.usage.promptTokens,
outputTokens: response.usage.completionTokens,
requests: 1,
},
responseTime,
responseModel: response.model,
data: response,
};
},

getChatCompletionStream: async function (
config: ProviderConfig,
messages: Message[],
onResponseUpdate: (update: ModelResponseUpdate) => void,
) {
const api = new MistralClient({ apiKey: config.apiKey });
const allMessages = getMessages(config, messages);

const startTime = performance.now();
const stream = await api.chat.stream({
messages: allMessages,
model: config.model,
...getMistralResponseStyle(config),
});

let lastChunk: CompletionEvent | null = null;
let content = '';
for await (const chunk of stream) {
lastChunk = chunk;
content += chunk.data.choices[0]?.delta?.content || '';
onResponseUpdate({ content });
}

const responseTime = performance.now() - startTime;

return {
message: {
role: 'assistant',
content,
},
usage: {
inputTokens: lastChunk?.data.usage?.promptTokens ?? estimateInputTokens(allMessages),
outputTokens: lastChunk?.data.usage?.completionTokens ?? estimateOutputTokens(content),
requests: 1,
},
responseTime,
responseModel: lastChunk?.data.model || 'unknown',
data: lastChunk,
};
},
};

function getMessages(config: ProviderConfig, messages: Message[]): Message[] {
if (!config.systemPrompt) {
return messages;
}

const systemMessage: Message = {
role: 'system',
content: config.systemPrompt,
};
return [systemMessage, ...messages];
}

interface MistralResponseStyle {
temperature?: number;
topP?: number;
}

function getMistralResponseStyle(config: ProviderConfig): MistralResponseStyle {
const style = responseStyles[config.responseStyle];

const result: MistralResponseStyle = {};
if ('temperature' in style) {
result.temperature = style.temperature;
}
if ('top_p' in style) {
result.topP = style.top_p;
}

return result;
}

export default Mistral;
30 changes: 8 additions & 22 deletions src/engine/providers/open-ai.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import OpenAI from 'openai';
import { type Message, type ModelResponseUpdate } from '../inference.js';
import { createOpenAI } from '@ai-sdk/openai';
import { VercelChatModelAdapter } from '@callstack/byorg-core';
import { type ProviderConfig } from './config.js';
import type { Provider } from './provider.js';
import { getChatCompletion, getChatCompletionStream } from './utils/open-ai-api.js';
import type { ProviderInfo } from './provider-info.js';

const OpenAi: Provider = {
const OpenAi: ProviderInfo = {
label: 'OpenAI',
name: 'openAi',
apiKeyUrl: 'https://platform.openai.com/api-keys',
@@ -31,24 +30,11 @@ const OpenAi: Provider = {

modelAliases: {},

getChatCompletion: async (config: ProviderConfig, messages: Message[]) => {
const api = new OpenAI({
apiKey: config.apiKey,
getChatModel: (config: ProviderConfig) => {
const client = createOpenAI({ apiKey: config.apiKey, compatibility: 'strict' });
return new VercelChatModelAdapter({
languageModel: client.languageModel(config.model),
});

return await getChatCompletion(api, config, messages);
},

getChatCompletionStream: async function (
config: ProviderConfig,
messages: Message[],
onResponseUpdate: (update: ModelResponseUpdate) => void,
) {
const api = new OpenAI({
apiKey: config.apiKey,
});

return await getChatCompletionStream(api, config, messages, onResponseUpdate);
},
};

32 changes: 9 additions & 23 deletions src/engine/providers/perplexity.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import OpenAI from 'openai';
import { type Message, type ModelResponseUpdate } from '../inference.js';
import { type ProviderConfig } from './config.js';
import type { Provider } from './provider.js';
import { getChatCompletion, getChatCompletionStream } from './utils/open-ai-api.js';
import { createOpenAI } from '@ai-sdk/openai';
import { VercelChatModelAdapter } from '@callstack/byorg-core';
import type { ProviderInfo } from './provider-info.js';

const Perplexity: Provider = {
const Perplexity: ProviderInfo = {
label: 'Perplexity',
name: 'perplexity',
apiKeyUrl: 'https://perplexity.ai/settings/api',
@@ -43,26 +41,14 @@ const Perplexity: Provider = {
huge: 'llama-3.1-sonar-huge-128k-online',
},

getChatCompletion: async (config: ProviderConfig, messages: Message[]) => {
const api = new OpenAI({
getChatModel(config) {
const client = createOpenAI({
apiKey: config.apiKey,
baseURL: 'https://api.perplexity.ai',
baseURL: 'https://api.perplexity.ai/',
});

return await getChatCompletion(api, config, messages);
},

getChatCompletionStream: async function (
config: ProviderConfig,
messages: Message[],
onResponseUpdate: (update: ModelResponseUpdate) => void,
) {
const api = new OpenAI({
apiKey: config.apiKey,
baseURL: 'https://api.perplexity.ai',
return new VercelChatModelAdapter({
languageModel: client.languageModel(config.model),
});

return await getChatCompletionStream(api, config, messages, onResponseUpdate);
},
};

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { Message, ModelResponse, ModelResponseUpdate } from '../inference.js';
import type { ChatModel } from '@callstack/byorg-core';
import type { ProviderConfig } from './config.js';
import openAi from './open-ai.js';
import perplexity from './perplexity.js';
@@ -8,7 +8,7 @@ import mistral from './mistral.js';
export const providerNames = ['openAi', 'anthropic', 'perplexity', 'mistral'] as const;
export type ProviderName = (typeof providerNames)[number];

export interface Provider {
export interface ProviderInfo {
name: ProviderName;
label: string;
apiKeyUrl: string;
@@ -17,12 +17,7 @@ export interface Provider {
modelPricing: Record<string, ModelPricing>;
modelAliases: Record<string, string>;

getChatCompletion: (config: ProviderConfig, messages: Message[]) => Promise<ModelResponse>;
getChatCompletionStream?: (
config: ProviderConfig,
messages: Message[],
onStreamUpdate: (update: ModelResponseUpdate) => void,
) => Promise<ModelResponse>;
getChatModel: (config: ProviderConfig) => ChatModel;
}

export interface ModelPricing {
@@ -36,7 +31,7 @@ export interface ModelPricing {
requestsCost?: number;
}

const providersMap: Record<ProviderName, Provider> = {
const providersMap: Record<ProviderName, ProviderInfo> = {
openAi,
anthropic,
perplexity,
@@ -45,7 +40,7 @@ const providersMap: Record<ProviderName, Provider> = {

export const providers = Object.values(providersMap);

export function getProvider(providerName: ProviderName): Provider {
export function getProvider(providerName: ProviderName): ProviderInfo {
const provider = providersMap[providerName];
if (!provider) {
throw new Error(`Provider not found: ${providerName}.`);
93 changes: 0 additions & 93 deletions src/engine/providers/utils/open-ai-api.ts

This file was deleted.

4 changes: 2 additions & 2 deletions src/engine/session.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import type { ModelUsage } from './inference.js';
import type { ModelPricing } from './providers/provider.js';
import type { ModelUsage } from '@callstack/byorg-core';
import type { ModelPricing } from './providers/provider-info.js';

export interface SessionUsage {
total: ModelUsage;
2 changes: 1 addition & 1 deletion src/engine/tokenizer.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { createRequire } from 'module';
import { Tiktoken } from 'tiktoken/lite';
import type { Message } from './inference.js';
import type { Message } from '@callstack/byorg-core';

// Workaround for JSON loading in ESM
// See: https://www.stefanjudis.com/snippets/how-to-import-json-files-in-es-modules-node-js/#option-2%3A-leverage-the-commonjs-%60require%60-function-to-load-json-files
2 changes: 1 addition & 1 deletion src/file-utils.ts
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@ import * as fs from 'fs';
import * as os from 'os';
import * as path from 'path';
import { format } from 'date-fns';
import type { Message } from './engine/inference.js';
import type { Message } from '@callstack/byorg-core';

export const CHATS_SAVE_DIRECTORY = '~/ai-chats';

5 changes: 2 additions & 3 deletions tsconfig.json
Original file line number Diff line number Diff line change
@@ -4,9 +4,8 @@
"ai-cli": ["./src/index.js"]
},
"target": "ES2020",
"lib": ["ES2020"],
"module": "Node16",
"moduleResolution": "Node16",
"module": "ESNext",
"moduleResolution": "Bundler",
"outDir": "build",
"jsx": "react",
"esModuleInterop": true,
369 changes: 368 additions & 1 deletion yarn.lock

Large diffs are not rendered by default.