diff --git a/genkit-tools/common/src/utils/utils.ts b/genkit-tools/common/src/utils/utils.ts index b36a3b58e..795715207 100644 --- a/genkit-tools/common/src/utils/utils.ts +++ b/genkit-tools/common/src/utils/utils.ts @@ -25,14 +25,27 @@ export async function findProjectRoot(): Promise { let currentDir = process.cwd(); while (currentDir !== path.parse(currentDir).root) { const packageJsonPath = path.join(currentDir, 'package.json'); + const goModPath = path.join(currentDir, 'go.mod'); try { - await fs.access(packageJsonPath); - return currentDir; + const [packageJsonExists, goModExists] = await Promise.all([ + fs + .access(packageJsonPath) + .then(() => true) + .catch(() => false), + fs + .access(goModPath) + .then(() => true) + .catch(() => false), + ]); + if (packageJsonExists || goModExists) { + return currentDir; + } } catch { - currentDir = path.dirname(currentDir); + // Continue searching if any errors occur } + currentDir = path.dirname(currentDir); } - throw new Error('Could not find project root (package.json not found)'); + throw new Error('Could not find project root'); } /** diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index c4bc48889..d2d75cd3c 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -70,7 +70,7 @@ func Init(ctx context.Context, opts *Options) error { wg.Add(1) go func() { defer wg.Done() - s := startReflectionServer(errCh) + s := startReflectionServer(ctx, errCh) mu.Lock() servers = append(servers, s) mu.Unlock() diff --git a/go/genkit/servers.go b/go/genkit/servers.go index 161e3d191..64f71d2b7 100644 --- a/go/genkit/servers.go +++ b/go/genkit/servers.go @@ -31,6 +31,7 @@ import ( "log/slog" "net/http" "os" + "path/filepath" "strconv" "sync" "sync/atomic" @@ -44,14 +45,110 @@ import ( "go.opentelemetry.io/otel/trace" ) +type runtimeFileData struct { + ID string `json:"id"` + PID int `json:"pid"` + ReflectionServerURL string `json:"reflectionServerUrl"` + Timestamp string `json:"timestamp"` +} + +type devServer struct { + reg *registry.Registry + runtimeFilePath string +} + // startReflectionServer starts the Reflection API server listening at the // value of the environment variable GENKIT_REFLECTION_PORT for the port, // or ":3100" if it is empty. -func startReflectionServer(errCh chan<- error) *http.Server { - slog.Info("starting reflection server") +func startReflectionServer(ctx context.Context, errCh chan<- error) *http.Server { + slog.Debug("starting reflection server") addr := serverAddress("", "GENKIT_REFLECTION_PORT", "127.0.0.1:3100") - mux := newDevServeMux(registry.Global) - return startServer(addr, mux, errCh) + s := &devServer{reg: registry.Global} + if err := s.writeRuntimeFile(addr); err != nil { + slog.Error("failed to write runtime file", "error", err) + } + mux := newDevServeMux(s) + server := startServer(addr, mux, errCh) + go func() { + <-ctx.Done() + if err := s.cleanupRuntimeFile(); err != nil { + slog.Error("failed to cleanup runtime file", "error", err) + } + }() + return server +} + +// writeRuntimeFile writes a file describing the runtime to the project root. +func (s *devServer) writeRuntimeFile(url string) error { + projectRoot, err := findProjectRoot() + if err != nil { + return fmt.Errorf("failed to find project root: %w", err) + } + runtimesDir := filepath.Join(projectRoot, ".genkit", "runtimes") + if err := os.MkdirAll(runtimesDir, 0755); err != nil { + return fmt.Errorf("failed to create runtimes directory: %w", err) + } + runtimeID := os.Getenv("GENKIT_RUNTIME_ID") + if runtimeID == "" { + runtimeID = strconv.Itoa(os.Getpid()) + } + timestamp := time.Now().UTC().Format(time.RFC3339) + s.runtimeFilePath = filepath.Join(runtimesDir, fmt.Sprintf("%d-%s.json", os.Getpid(), timestamp)) + data := runtimeFileData{ + ID: runtimeID, + PID: os.Getpid(), + ReflectionServerURL: fmt.Sprintf("http://%s", url), + Timestamp: timestamp, + } + fileContent, err := json.MarshalIndent(data, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal runtime data: %w", err) + } + if err := os.WriteFile(s.runtimeFilePath, fileContent, 0644); err != nil { + return fmt.Errorf("failed to write runtime file: %w", err) + } + slog.Debug("runtime file written", "path", s.runtimeFilePath) + return nil +} + +// cleanupRuntimeFile removes the runtime file associated with the dev server. +func (s *devServer) cleanupRuntimeFile() error { + if s.runtimeFilePath == "" { + return nil + } + content, err := os.ReadFile(s.runtimeFilePath) + if err != nil { + return fmt.Errorf("failed to read runtime file: %w", err) + } + var data runtimeFileData + if err := json.Unmarshal(content, &data); err != nil { + return fmt.Errorf("failed to unmarshal runtime data: %w", err) + } + if data.PID == os.Getpid() { + if err := os.Remove(s.runtimeFilePath); err != nil { + return fmt.Errorf("failed to remove runtime file: %w", err) + } + slog.Debug("runtime file cleaned up", "path", s.runtimeFilePath) + } + return nil +} + +// findProjectRoot finds the project root by looking for a go.mod file. +func findProjectRoot() (string, error) { + dir, err := os.Getwd() + if err != nil { + return "", err + } + for { + if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil { + return dir, nil + } + parent := filepath.Dir(dir) + if parent == dir { + return "", fmt.Errorf("could not find project root (go.mod not found)") + } + dir = parent + } } // startFlowServer starts a production server listening at the given address. @@ -129,13 +226,8 @@ func shutdownServers(servers []*http.Server) error { return nil } -type devServer struct { - reg *registry.Registry -} - -func newDevServeMux(r *registry.Registry) *http.ServeMux { +func newDevServeMux(s *devServer) *http.ServeMux { mux := http.NewServeMux() - s := &devServer{r} handle(mux, "GET /api/__health", func(w http.ResponseWriter, _ *http.Request) error { return nil }) diff --git a/go/genkit/servers_test.go b/go/genkit/servers_test.go index 61a3fb1db..f590e2986 100644 --- a/go/genkit/servers_test.go +++ b/go/genkit/servers_test.go @@ -53,7 +53,7 @@ func TestDevServer(t *testing.T) { core.DefineActionInRegistry(r, "devServer", "dec", atype.Custom, map[string]any{ "bar": "baz", }, nil, dec) - srv := httptest.NewServer(newDevServeMux(r)) + srv := httptest.NewServer(newDevServeMux(&devServer{reg: r})) defer srv.Close() t.Run("runAction", func(t *testing.T) {