Skip to content

Commit

Permalink
split to packages (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
yonidavidson authored Aug 30, 2024
1 parent 4d7e857 commit abb7642
Show file tree
Hide file tree
Showing 8 changed files with 300 additions and 122 deletions.
42 changes: 30 additions & 12 deletions cmd/prompteng/cmd.go → cmd/prompt/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,20 @@ package main

import (
"fmt"
"github.com/yonidavidson/gophercon-israel-2024/prompt"
"github.com/yonidavidson/gophercon-israel-2024/provider"
"html/template"
"os"
"strings"
)

const promptTemplate = `System: {{.SystemPrompt}}
const promptTemplate = `<system>{{.SystemPrompt}}</system>
<user>
Chat History:
{{limitTokens .ChatHistory (multiply .MaxTokens 0.3)}}
Context: {{limitTokens .RAGContext (multiply .MaxTokens 0.1)}}
User: {{limitTokens .UserQuery (multiply .MaxTokens 0.2)}}`
User Query: {{limitTokens .UserQuery (multiply .MaxTokens 0.2)}}</user>`

type PromptData struct {
MaxTokens float64
Expand All @@ -24,7 +26,7 @@ type PromptData struct {
}

func generatePrompt(maxTokens int, ragContext, userQuery, chatHistory, systemPrompt string) (string, error) {
tmpl, err := template.New("prompt").Funcs(template.FuncMap{
tmpl, err := template.New("talk").Funcs(template.FuncMap{
"limitTokens": limitTokens,
"multiply": func(a, b float64) float64 {
return a * b
Expand Down Expand Up @@ -58,21 +60,37 @@ func limitTokens(s string, maxTokens float64) string {
if len(s) <= maxChars {
return s
}
return s[:maxChars] + "..."
return s[:maxChars]
}

func main() {
maxTokens := 100
ragContext := "Paris, the capital of France, is a major European city and a global center for art, fashion, gastronomy, and culture. Its 19th-century cityscape is crisscrossed by wide boulevards and the River Seine. Beyond such landmarks as the Eiffel Tower and the 12th-century, Gothic Notre-Dame cathedral, the city is known for its cafe culture and designer boutiques along the Rue du Faubourg Saint-Honoré."
userQuery := "Can you tell me about the history and main attractions of Paris? Also, what's the best time to visit and are there any local customs I should be aware of?"
chatHistory := "User: I'm planning a trip to Europe.\nAssistant: That's exciting! Europe has many wonderful destinations. Do you have any specific countries or cities in mind?\nUser: I'm thinking about visiting France.\nAssistant: France is a great choice! It offers a rich history, beautiful landscapes, and world-renowned cuisine. Are you interested in visiting Paris or exploring other regions as well?"
userQuery := "Can you tell me about the history and main attractions of Paris? Also, what`s the best time to visit and are there any local customs I should be aware of?"
chatHistory := "User: I`m planning a trip to Europe.\nAssistant: That`s exciting! Europe has many wonderful destinations. Do you have any specific countries or cities in mind?\nUser: I'm thinking about visiting France.\nAssistant: France is a great choice! It offers a rich history, beautiful landscapes, and world-renowned cuisine. Are you interested in visiting Paris or exploring other regions as well?"
systemPrompt := "You are a knowledgeable and helpful travel assistant. Provide accurate and concise information about destinations, attractions, local customs, and travel tips. When appropriate, suggest off-the-beaten-path experiences that tourists might not typically know about. Always prioritize the safety and cultural sensitivity of the traveler."

prompt, err := generatePrompt(maxTokens, ragContext, userQuery, chatHistory, systemPrompt)
prmt, err := generatePrompt(maxTokens, ragContext, userQuery, chatHistory, systemPrompt)
if err != nil {
fmt.Printf("Error generating prompt: %v\n", err)
fmt.Printf("Error generating talk: %v\n", err)
return
}

fmt.Println(prompt)
fmt.Println(prmt)
m, err := prompt.ParseMessages(prmt)
if err != nil {
fmt.Printf("Error parsing messages: %v\n", err)
return
}
apiKey := os.Getenv("PRIVATE_OPENAI_KEY")
if apiKey == "" {
fmt.Println("Error: PRIVATE_OPENAI_KEY environment variable not set")
return
}
p := provider.OpenAIProvider{APIKey: apiKey}
r, err := p.ChatCompletion(m)
if err != nil {
fmt.Printf("Error getting chat completion: %v\n", err)
return
}
fmt.Println("\n\n\n\n" + string(r))
}
108 changes: 0 additions & 108 deletions cmd/prompt/main.go

This file was deleted.

30 changes: 30 additions & 0 deletions cmd/talk/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package main

import (
"fmt"
"github.com/yonidavidson/gophercon-israel-2024/prompt"
"github.com/yonidavidson/gophercon-israel-2024/provider"
"os"
)

func main() {
// Retrieve the API key from the environment variable
apiKey := os.Getenv("PRIVATE_OPENAI_KEY")
if apiKey == "" {
fmt.Println("Error: PRIVATE_OPENAI_KEY environment variable not set")
return
}
p := provider.OpenAIProvider{APIKey: apiKey}
messages, err := prompt.ParseMessages(`<system>You are a helpful assistant that provides concise and accurate information.</system>
<user>Translate the following English text to French: 'Hello, how are you'</user>`)
if err != nil {
fmt.Println("Error:", err)
return
}
r, err := p.ChatCompletion(messages)
if err != nil {
fmt.Println("Error:", err)
return
}
fmt.Println(string(r))
}
3 changes: 1 addition & 2 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
module github/yonidavidson/gophercon2024

module github.com/yonidavidson/gophercon-israel-2024
go 1.21.0
Empty file added go.sum
Empty file.
63 changes: 63 additions & 0 deletions prompt/prompt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package prompt

import (
"fmt"
"regexp"
"strings"
)

type Role string

const (
RoleSystem Role = "system"
RoleUser Role = "user"
RoleAssistant Role = "assistant"
)

// Message represents a message with a role and content.
type Message struct {
Role Role
Content string
}

// ParseMessages parses the input string into a slice of messages.
func ParseMessages(input string) ([]Message, error) {
// Validate tags before parsing
if err := validate(input); err != nil {
return nil, err
}
var messages []Message

// Regular expression to match tags and their content
re := regexp.MustCompile(`<(system|user|assistant)>([\s\S]*?)</(system|user|assistant)>`)

// Find all matches in the input string
matches := re.FindAllStringSubmatch(input, -1)

for _, match := range matches {
role := Role(match[1])
content := strings.TrimSpace(match[2])

message := Message{
Role: role,
Content: content,
}

messages = append(messages, message)
}

return messages, nil
}

// validate checks if the input string has matching opening and closing tags for each role.
func validate(input string) error {
roles := []string{"system", "user", "assistant"}
for _, role := range roles {
openCount := strings.Count(input, "<"+role+">")
closeCount := strings.Count(input, "</"+role+">")
if openCount != closeCount {
return fmt.Errorf("mismatched tags for role %s: %d opening, %d closing", role, openCount, closeCount)
}
}
return nil
}
68 changes: 68 additions & 0 deletions prompt/prompt_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package prompt_test

import (
"reflect"
"testing"

"github.com/yonidavidson/gophercon-israel-2024/prompt"
)

func TestParseMessages(t *testing.T) {
tests := []struct {
name string
input string
expected []prompt.Message
}{
{
name: "System and User messages",
input: `<system>You are a helpful assistant</system>
<user>Hello, how are you?</user>`,
expected: []prompt.Message{
{Role: prompt.RoleSystem, Content: "You are a helpful assistant"},
{Role: prompt.RoleUser, Content: "Hello, how are you?"},
},
},
{
name: "System, User, and Assistant messages",
input: `<system>You are a helpful assistant</system>
<user>What's the weather like?</user>
<assistant>I'm sorry, I don't have real-time weather information. Could you please specify a location and I can provide general climate information?</assistant>
<user>How about in New York?</user>`,
expected: []prompt.Message{
{Role: prompt.RoleSystem, Content: "You are a helpful assistant"},
{Role: prompt.RoleUser, Content: "What's the weather like?"},
{Role: prompt.RoleAssistant, Content: "I'm sorry, I don't have real-time weather information. Could you please specify a location and I can provide general climate information?"},
{Role: prompt.RoleUser, Content: "How about in New York?"},
},
},
{
name: "Empty input",
input: "",
expected: nil,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := prompt.ParseMessages(tt.input)
if err != nil {
t.Errorf("ParseMessages() error = %v", err)
return
}
if !reflect.DeepEqual(result, tt.expected) {
t.Errorf("ParseMessages() = %v, want %v", result, tt.expected)
}
})
}
}

func TestParseMessagesError(t *testing.T) {
input := `<system>Incomplete system message
<user>User message without closing tag
<assistant>Assistant message</assistant>`

_, err := prompt.ParseMessages(input)
if err == nil {
t.Error("Expected an error, but got nil")
}
}
Loading

0 comments on commit abb7642

Please sign in to comment.