-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4d7e857
commit abb7642
Showing
8 changed files
with
300 additions
and
122 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
package main | ||
|
||
import ( | ||
"fmt" | ||
"github.com/yonidavidson/gophercon-israel-2024/prompt" | ||
"github.com/yonidavidson/gophercon-israel-2024/provider" | ||
"os" | ||
) | ||
|
||
func main() { | ||
// Retrieve the API key from the environment variable | ||
apiKey := os.Getenv("PRIVATE_OPENAI_KEY") | ||
if apiKey == "" { | ||
fmt.Println("Error: PRIVATE_OPENAI_KEY environment variable not set") | ||
return | ||
} | ||
p := provider.OpenAIProvider{APIKey: apiKey} | ||
messages, err := prompt.ParseMessages(`<system>You are a helpful assistant that provides concise and accurate information.</system> | ||
<user>Translate the following English text to French: 'Hello, how are you'</user>`) | ||
if err != nil { | ||
fmt.Println("Error:", err) | ||
return | ||
} | ||
r, err := p.ChatCompletion(messages) | ||
if err != nil { | ||
fmt.Println("Error:", err) | ||
return | ||
} | ||
fmt.Println(string(r)) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,2 @@ | ||
module github/yonidavidson/gophercon2024 | ||
|
||
module github.com/yonidavidson/gophercon-israel-2024 | ||
go 1.21.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
package prompt | ||
|
||
import ( | ||
"fmt" | ||
"regexp" | ||
"strings" | ||
) | ||
|
||
type Role string | ||
|
||
const ( | ||
RoleSystem Role = "system" | ||
RoleUser Role = "user" | ||
RoleAssistant Role = "assistant" | ||
) | ||
|
||
// Message represents a message with a role and content. | ||
type Message struct { | ||
Role Role | ||
Content string | ||
} | ||
|
||
// ParseMessages parses the input string into a slice of messages. | ||
func ParseMessages(input string) ([]Message, error) { | ||
// Validate tags before parsing | ||
if err := validate(input); err != nil { | ||
return nil, err | ||
} | ||
var messages []Message | ||
|
||
// Regular expression to match tags and their content | ||
re := regexp.MustCompile(`<(system|user|assistant)>([\s\S]*?)</(system|user|assistant)>`) | ||
|
||
// Find all matches in the input string | ||
matches := re.FindAllStringSubmatch(input, -1) | ||
|
||
for _, match := range matches { | ||
role := Role(match[1]) | ||
content := strings.TrimSpace(match[2]) | ||
|
||
message := Message{ | ||
Role: role, | ||
Content: content, | ||
} | ||
|
||
messages = append(messages, message) | ||
} | ||
|
||
return messages, nil | ||
} | ||
|
||
// validate checks if the input string has matching opening and closing tags for each role. | ||
func validate(input string) error { | ||
roles := []string{"system", "user", "assistant"} | ||
for _, role := range roles { | ||
openCount := strings.Count(input, "<"+role+">") | ||
closeCount := strings.Count(input, "</"+role+">") | ||
if openCount != closeCount { | ||
return fmt.Errorf("mismatched tags for role %s: %d opening, %d closing", role, openCount, closeCount) | ||
} | ||
} | ||
return nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
package prompt_test | ||
|
||
import ( | ||
"reflect" | ||
"testing" | ||
|
||
"github.com/yonidavidson/gophercon-israel-2024/prompt" | ||
) | ||
|
||
func TestParseMessages(t *testing.T) { | ||
tests := []struct { | ||
name string | ||
input string | ||
expected []prompt.Message | ||
}{ | ||
{ | ||
name: "System and User messages", | ||
input: `<system>You are a helpful assistant</system> | ||
<user>Hello, how are you?</user>`, | ||
expected: []prompt.Message{ | ||
{Role: prompt.RoleSystem, Content: "You are a helpful assistant"}, | ||
{Role: prompt.RoleUser, Content: "Hello, how are you?"}, | ||
}, | ||
}, | ||
{ | ||
name: "System, User, and Assistant messages", | ||
input: `<system>You are a helpful assistant</system> | ||
<user>What's the weather like?</user> | ||
<assistant>I'm sorry, I don't have real-time weather information. Could you please specify a location and I can provide general climate information?</assistant> | ||
<user>How about in New York?</user>`, | ||
expected: []prompt.Message{ | ||
{Role: prompt.RoleSystem, Content: "You are a helpful assistant"}, | ||
{Role: prompt.RoleUser, Content: "What's the weather like?"}, | ||
{Role: prompt.RoleAssistant, Content: "I'm sorry, I don't have real-time weather information. Could you please specify a location and I can provide general climate information?"}, | ||
{Role: prompt.RoleUser, Content: "How about in New York?"}, | ||
}, | ||
}, | ||
{ | ||
name: "Empty input", | ||
input: "", | ||
expected: nil, | ||
}, | ||
} | ||
|
||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
result, err := prompt.ParseMessages(tt.input) | ||
if err != nil { | ||
t.Errorf("ParseMessages() error = %v", err) | ||
return | ||
} | ||
if !reflect.DeepEqual(result, tt.expected) { | ||
t.Errorf("ParseMessages() = %v, want %v", result, tt.expected) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func TestParseMessagesError(t *testing.T) { | ||
input := `<system>Incomplete system message | ||
<user>User message without closing tag | ||
<assistant>Assistant message</assistant>` | ||
|
||
_, err := prompt.ParseMessages(input) | ||
if err == nil { | ||
t.Error("Expected an error, but got nil") | ||
} | ||
} |
Oops, something went wrong.