diff --git a/genkit-tools/common/src/manager/manager.ts b/genkit-tools/common/src/manager/manager.ts index 0f664301fb..0eb86973cb 100644 --- a/genkit-tools/common/src/manager/manager.ts +++ b/genkit-tools/common/src/manager/manager.ts @@ -17,6 +17,7 @@ import axios, { type AxiosError } from 'axios'; import chokidar from 'chokidar'; import EventEmitter from 'events'; +import * as fsSync from 'fs'; import fs from 'fs/promises'; import path from 'path'; import { @@ -390,6 +391,10 @@ export class RuntimeManager { */ private async handleNewDevUi(filePath: string) { try { + if (!fsSync.existsSync(filePath)) { + // file already got deleted, ignore... + return; + } const { content, toolsInfo } = await retriable( async () => { const content = await fs.readFile(filePath, 'utf-8'); @@ -433,6 +438,10 @@ export class RuntimeManager { */ private async handleNewRuntime(filePath: string) { try { + if (!fsSync.existsSync(filePath)) { + // file already got deleted, ignore... + return; + } const { content, runtimeInfo } = await retriable( async () => { const content = await fs.readFile(filePath, 'utf-8'); @@ -448,7 +457,12 @@ export class RuntimeManager { runtimeInfo.name = runtimeInfo.id; } const fileName = path.basename(filePath); - if (await checkServerHealth(runtimeInfo.reflectionServerUrl)) { + if ( + await checkServerHealth( + runtimeInfo.reflectionServerUrl, + runtimeInfo.id + ) + ) { if ( runtimeInfo.reflectionApiSpecVersion != GENKIT_REFLECTION_API_SPEC_VERSION @@ -529,7 +543,9 @@ export class RuntimeManager { private async performHealthChecks() { const healthCheckPromises = Object.entries(this.filenameToRuntimeMap).map( async ([fileName, runtime]) => { - if (!(await checkServerHealth(runtime.reflectionServerUrl))) { + if ( + !(await checkServerHealth(runtime.reflectionServerUrl, runtime.id)) + ) { await this.removeRuntime(fileName); } } @@ -541,19 +557,14 @@ export class RuntimeManager { * Removes the runtime file which will trigger the removal watcher. */ private async removeRuntime(fileName: string) { - const runtime = this.filenameToRuntimeMap[fileName]; - if (runtime) { - try { - const runtimesDir = await findRuntimesDir(this.projectRoot); - const runtimeFilePath = path.join(runtimesDir, fileName); - await fs.unlink(runtimeFilePath); - } catch (error) { - logger.debug(`Failed to delete runtime file: ${error}`); - } - logger.debug( - `Removed unhealthy runtime with ID ${runtime.id} from manager.` - ); + try { + const runtimesDir = await findRuntimesDir(this.projectRoot); + const runtimeFilePath = path.join(runtimesDir, fileName); + await fs.unlink(runtimeFilePath); + } catch (error) { + logger.debug(`Failed to delete runtime file: ${error}`); } + logger.debug(`Removed unhealthy runtime ${fileName} from manager.`); } } diff --git a/genkit-tools/common/src/utils/utils.ts b/genkit-tools/common/src/utils/utils.ts index aeab3fb57e..d2a404a834 100644 --- a/genkit-tools/common/src/utils/utils.ts +++ b/genkit-tools/common/src/utils/utils.ts @@ -139,9 +139,12 @@ export async function detectRuntime(directory: string): Promise { /** * Checks the health of a server with a /api/__health endpoint. */ -export async function checkServerHealth(url: string): Promise { +export async function checkServerHealth( + url: string, + id?: string +): Promise { try { - const response = await fetch(`${url}/api/__health`); + const response = await fetch(`${url}/api/__health${id ? `?id=${id}` : ''}`); return response.status === 200; } catch (error) { if (isConnectionRefusedError(error)) { diff --git a/go/genkit/reflection.go b/go/genkit/reflection.go index afa24a3159..bb32c79bff 100644 --- a/go/genkit/reflection.go +++ b/go/genkit/reflection.go @@ -55,6 +55,15 @@ type reflectionServer struct { RuntimeFilePath string // Path to the runtime file that was written at startup. } +func (s *reflectionServer) runtimeID() string { + _, port, err := net.SplitHostPort(s.Addr) + if err != nil { + // This should not happen with a valid address. + return strconv.Itoa(os.Getpid()) + } + return fmt.Sprintf("%d-%s", os.Getpid(), port) +} + // findAvailablePort finds the next available port starting from the given port number. func findAvailablePort(startPort int) (string, error) { for port := startPort; port < startPort+100; port++ { @@ -91,10 +100,10 @@ func startReflectionServer(ctx context.Context, g *Genkit, errCh chan<- error, s s := &reflectionServer{ Server: &http.Server{ - Addr: addr, - Handler: serveMux(g), + Addr: addr, }, } + s.Handler = serveMux(g, s) slog.Debug("starting reflection server", "addr", s.Addr) @@ -159,7 +168,7 @@ func (s *reflectionServer) writeRuntimeFile(url string) error { runtimeID := os.Getenv("GENKIT_RUNTIME_ID") if runtimeID == "" { - runtimeID = strconv.Itoa(os.Getpid()) + runtimeID = s.runtimeID() } timestamp := time.Now().UTC().Format(time.RFC3339) @@ -238,10 +247,14 @@ func findProjectRoot() (string, error) { } // serveMux returns a new ServeMux configured for the required Reflection API endpoints. -func serveMux(g *Genkit) *http.ServeMux { +func serveMux(g *Genkit, s *reflectionServer) *http.ServeMux { mux := http.NewServeMux() // Skip wrapHandler here to avoid logging constant polling requests. - mux.HandleFunc("GET /api/__health", func(w http.ResponseWriter, _ *http.Request) { + mux.HandleFunc("GET /api/__health", func(w http.ResponseWriter, r *http.Request) { + if id := r.URL.Query().Get("id"); id != "" && id != s.runtimeID() { + http.Error(w, "Invalid runtime ID", http.StatusServiceUnavailable) + return + } w.WriteHeader(http.StatusOK) }) mux.HandleFunc("GET /api/actions", wrapReflectionHandler(handleListActions(g))) diff --git a/go/genkit/reflection_test.go b/go/genkit/reflection_test.go index 75d2278f5a..d47a10a027 100644 --- a/go/genkit/reflection_test.go +++ b/go/genkit/reflection_test.go @@ -90,7 +90,11 @@ func TestServeMux(t *testing.T) { core.DefineAction(g.reg, "test/inc", api.ActionTypeCustom, nil, nil, inc) core.DefineAction(g.reg, "test/dec", api.ActionTypeCustom, nil, nil, dec) - ts := httptest.NewServer(serveMux(g)) + s := &reflectionServer{ + Server: &http.Server{}, + } + ts := httptest.NewServer(serveMux(g, s)) + s.Addr = strings.TrimPrefix(ts.URL, "http://") defer ts.Close() t.Parallel() @@ -104,6 +108,26 @@ func TestServeMux(t *testing.T) { if res.StatusCode != http.StatusOK { t.Errorf("health check failed: got status %d, want %d", res.StatusCode, http.StatusOK) } + + // Test with correct runtime ID + res, err = http.Get(ts.URL + "/api/__health?id=" + s.runtimeID()) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + t.Errorf("health check with correct id failed: got status %d, want %d", res.StatusCode, http.StatusOK) + } + + // Test with incorrect runtime ID + res, err = http.Get(ts.URL + "/api/__health?id=invalid-id") + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if res.StatusCode != http.StatusServiceUnavailable { + t.Errorf("health check with incorrect id failed: got status %d, want %d", res.StatusCode, http.StatusServiceUnavailable) + } }) t.Run("list actions", func(t *testing.T) { diff --git a/js/core/src/reflection.ts b/js/core/src/reflection.ts index 659bd40b30..f5f01ee34b 100644 --- a/js/core/src/reflection.ts +++ b/js/core/src/reflection.ts @@ -82,6 +82,10 @@ export class ReflectionServer { }; } + get runtimeId() { + return `${process.pid}${this.port !== null ? `-${this.port}` : ''}`; + } + /** * Finds a free port to run the server on based on the original chosen port and environment. */ @@ -112,7 +116,11 @@ export class ReflectionServer { next(); }); - server.get('/api/__health', async (_, response) => { + server.get('/api/__health', async (req, response) => { + if (req.query['id'] && req.query['id'] !== this.runtimeId) { + response.status(503).send('Invalid runtime ID'); + return; + } await this.registry.listActions(); response.status(200).send('OK'); }); @@ -322,16 +330,13 @@ export class ReflectionServer { const date = new Date(); const time = date.getTime(); const timestamp = date.toISOString(); - const runtimeId = `${process.pid}${ - this.port !== null ? `-${this.port}` : '' - }`; this.runtimeFilePath = path.join( runtimesDir, - `${runtimeId}-${time}.json` + `${this.runtimeId}-${time}.json` ); const fileContent = JSON.stringify( { - id: process.env.GENKIT_RUNTIME_ID || runtimeId, + id: process.env.GENKIT_RUNTIME_ID || this.runtimeId, pid: process.pid, name: this.options.name, reflectionServerUrl: `http://localhost:${this.port}`, diff --git a/js/plugins/vertexai/src/rerankers/v2/index.ts b/js/plugins/vertexai/src/rerankers/v2/index.ts index bc27f4db25..8597878e6d 100644 --- a/js/plugins/vertexai/src/rerankers/v2/index.ts +++ b/js/plugins/vertexai/src/rerankers/v2/index.ts @@ -24,8 +24,8 @@ import { ActionType } from 'genkit/registry'; import { RerankerReference, z } from 'genkit'; import { getDerivedOptions } from '../../common/utils.js'; import * as reranker from './reranker.js'; -import { VertexRerankerPluginOptions } from './types.js'; -export { VertexRerankerPluginOptions }; +import { type VertexRerankerPluginOptions } from './types.js'; +export { type VertexRerankerPluginOptions }; async function initializer(pluginOptions: VertexRerankerPluginOptions) { const clientOptions = await getDerivedOptions(