Skip to content

Commit 2b015cd

Browse files
authored
feat(go): added support for embedding prompts into the binary (#3973)
1 parent a209ac1 commit 2b015cd

File tree

5 files changed

+508
-134
lines changed

5 files changed

+508
-134
lines changed

go/ai/prompt.go

Lines changed: 54 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@ import (
1919
"encoding/json"
2020
"errors"
2121
"fmt"
22+
"io/fs"
2223
"iter"
2324
"log/slog"
2425
"maps"
2526
"os"
26-
"path/filepath"
27+
"path"
2728
"reflect"
2829
"slices"
2930
"strings"
@@ -601,87 +602,84 @@ func convertToPartPointers(parts []dotprompt.Part) ([]*Part, error) {
601602
return result, nil
602603
}
603604

604-
// LoadPromptDir loads prompts and partials from the input directory for the given namespace.
605-
func LoadPromptDir(r api.Registry, dir string, namespace string) {
606-
useDefaultDir := false
607-
if dir == "" {
608-
dir = "./prompts"
609-
useDefaultDir = true
605+
// LoadPromptDirFromFS loads prompts and partials from a filesystem for the given namespace.
606+
// The fsys parameter should be an fs.FS implementation (e.g., embed.FS or os.DirFS).
607+
// The dir parameter specifies the directory within the filesystem where prompts are located.
608+
func LoadPromptDirFromFS(r api.Registry, fsys fs.FS, dir, namespace string) {
609+
if fsys == nil {
610+
panic(errors.New("no prompt filesystem provided"))
610611
}
611612

612-
path, err := filepath.Abs(dir)
613-
if err != nil {
614-
if !useDefaultDir {
615-
panic(fmt.Errorf("failed to resolve prompt directory %q: %w", dir, err))
616-
}
617-
slog.Debug("default prompt directory not found, skipping loading .prompt files", "dir", dir)
618-
return
613+
if _, err := fs.Stat(fsys, dir); err != nil {
614+
panic(fmt.Errorf("failed to access prompt directory %q in filesystem: %w", dir, err))
619615
}
620616

621-
if _, err := os.Stat(path); os.IsNotExist(err) {
622-
if !useDefaultDir {
623-
panic(fmt.Errorf("failed to resolve prompt directory %q: %w", dir, err))
624-
}
625-
slog.Debug("Default prompt directory not found, skipping loading .prompt files", "dir", dir)
626-
return
627-
}
628-
629-
loadPromptDir(r, path, namespace)
630-
}
631-
632-
// loadPromptDir recursively loads prompts and partials from the directory.
633-
func loadPromptDir(r api.Registry, dir string, namespace string) {
634-
entries, err := os.ReadDir(dir)
617+
entries, err := fs.ReadDir(fsys, dir)
635618
if err != nil {
636619
panic(fmt.Errorf("failed to read prompt directory structure: %w", err))
637620
}
638621

639622
for _, entry := range entries {
640623
filename := entry.Name()
641-
path := filepath.Join(dir, filename)
624+
filePath := path.Join(dir, filename)
642625
if entry.IsDir() {
643-
loadPromptDir(r, path, namespace)
626+
LoadPromptDirFromFS(r, fsys, filePath, namespace)
644627
} else if strings.HasSuffix(filename, ".prompt") {
645628
if strings.HasPrefix(filename, "_") {
646629
partialName := strings.TrimSuffix(filename[1:], ".prompt")
647-
source, err := os.ReadFile(path)
630+
source, err := fs.ReadFile(fsys, filePath)
648631
if err != nil {
649632
slog.Error("Failed to read partial file", "error", err)
650633
continue
651634
}
652635
r.RegisterPartial(partialName, string(source))
653-
slog.Debug("Registered Dotprompt partial", "name", partialName, "file", path)
636+
slog.Debug("Registered Dotprompt partial", "name", partialName, "file", filePath)
654637
} else {
655-
LoadPrompt(r, dir, filename, namespace)
638+
LoadPromptFromFS(r, fsys, dir, filename, namespace)
656639
}
657640
}
658641
}
659642
}
660643

661-
// LoadPrompt loads a single prompt into the registry.
662-
func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt {
644+
// LoadPromptFromFS loads a single prompt from a filesystem into the registry.
645+
// The fsys parameter should be an fs.FS implementation (e.g., embed.FS or os.DirFS).
646+
// The dir parameter specifies the directory within the filesystem where the prompt is located.
647+
func LoadPromptFromFS(r api.Registry, fsys fs.FS, dir, filename, namespace string) Prompt {
663648
name := strings.TrimSuffix(filename, ".prompt")
664-
name, variant, _ := strings.Cut(name, ".")
665649

666-
sourceFile := filepath.Join(dir, filename)
667-
source, err := os.ReadFile(sourceFile)
650+
sourceFile := path.Join(dir, filename)
651+
source, err := fs.ReadFile(fsys, sourceFile)
668652
if err != nil {
669653
slog.Error("Failed to read prompt file", "file", sourceFile, "error", err)
670654
return nil
671655
}
672656

657+
p, err := LoadPromptFromSource(r, string(source), name, namespace)
658+
if err != nil {
659+
slog.Error("Failed to load prompt", "file", sourceFile, "error", err)
660+
return nil
661+
}
662+
663+
slog.Debug("Registered Dotprompt", "name", p.Name(), "file", sourceFile)
664+
return p
665+
}
666+
667+
// LoadPromptFromSource loads a prompt from raw .prompt file content.
668+
// The source parameter should contain the complete .prompt file text (frontmatter + template).
669+
// The name parameter is the prompt name (may include variant suffix like "myPrompt.variant").
670+
func LoadPromptFromSource(r api.Registry, source, name, namespace string) (Prompt, error) {
671+
name, variant, _ := strings.Cut(name, ".")
672+
673673
dp := r.Dotprompt()
674674

675-
parsedPrompt, err := dp.Parse(string(source))
675+
parsedPrompt, err := dp.Parse(source)
676676
if err != nil {
677-
slog.Error("Failed to parse file as dotprompt", "file", sourceFile, "error", err)
678-
return nil
677+
return nil, fmt.Errorf("failed to parse dotprompt: %w", err)
679678
}
680679

681-
metadata, err := dp.RenderMetadata(string(source), &parsedPrompt.PromptMetadata)
680+
metadata, err := dp.RenderMetadata(source, &parsedPrompt.PromptMetadata)
682681
if err != nil {
683-
slog.Error("Failed to render dotprompt metadata", "file", sourceFile, "error", err)
684-
return nil
682+
return nil, fmt.Errorf("failed to render dotprompt metadata: %w", err)
685683
}
686684

687685
toolRefs := make([]ToolRef, len(metadata.Tools))
@@ -765,17 +763,15 @@ func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt {
765763

766764
dpMessages, err := dotprompt.ToMessages(parsedPrompt.Template, &dotprompt.DataArgument{})
767765
if err != nil {
768-
slog.Error("Failed to convert prompt template to messages", "file", sourceFile, "error", err)
769-
return nil
766+
return nil, fmt.Errorf("failed to convert prompt template to messages: %w", err)
770767
}
771768

772769
var systemText string
773770
var nonSystemMessages []*Message
774771
for _, dpMsg := range dpMessages {
775772
parts, err := convertToPartPointers(dpMsg.Content)
776773
if err != nil {
777-
slog.Error("Failed to convert message parts", "file", sourceFile, "error", err)
778-
return nil
774+
return nil, fmt.Errorf("failed to convert message parts: %w", err)
779775
}
780776

781777
role := Role(dpMsg.Role)
@@ -809,9 +805,17 @@ func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt {
809805

810806
prompt := DefinePrompt(r, key, promptOpts...)
811807

812-
slog.Debug("Registered Dotprompt", "name", key, "file", sourceFile)
808+
return prompt, nil
809+
}
813810

814-
return prompt
811+
// LoadPromptDir loads prompts and partials from a directory on the local filesystem.
812+
func LoadPromptDir(r api.Registry, dir string, namespace string) {
813+
LoadPromptDirFromFS(r, os.DirFS(dir), ".", namespace)
814+
}
815+
816+
// LoadPrompt loads a single prompt from a directory on the local filesystem into the registry.
817+
func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt {
818+
return LoadPromptFromFS(r, os.DirFS(dir), ".", filename, namespace)
815819
}
816820

817821
// promptKey generates a unique key for the prompt in the registry.

0 commit comments

Comments
 (0)