Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions cmd/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"os"
Expand Down Expand Up @@ -70,6 +71,11 @@ func parseAgentType(firstArg string, agentTypeVar string) (AgentType, error) {
return AgentTypeCustom, nil
}

// isStdinPiped checks if stdin is piped (not a terminal)
func isStdinPiped(stat os.FileInfo) bool {
return (stat.Mode() & os.ModeCharDevice) == 0
}

func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) error {
agent := argsToPass[0]
agentTypeValue := viper.GetString(FlagType)
Expand All @@ -88,6 +94,19 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er
return xerrors.Errorf("term height must be at least 10")
}

// Read stdin if it's piped, to be used as initial prompt
initialPrompt := viper.GetString(FlagInitialPrompt)
if initialPrompt == "" {
if stat, err := os.Stdin.Stat(); err == nil && isStdinPiped(stat) {
if stdinData, err := io.ReadAll(os.Stdin); err != nil {
return xerrors.Errorf("failed to read stdin: %w", err)
} else if len(stdinData) > 0 {
initialPrompt = string(stdinData)
logger.Info("Read initial prompt from stdin", "bytes", len(stdinData))
}
}
}

printOpenAPI := viper.GetBool(FlagPrintOpenAPI)
var process *termexec.Process
if printOpenAPI {
Expand All @@ -112,7 +131,7 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er
ChatBasePath: viper.GetString(FlagChatBasePath),
AllowedHosts: viper.GetStringSlice(FlagAllowedHosts),
AllowedOrigins: viper.GetStringSlice(FlagAllowedOrigins),
InitialPrompt: viper.GetString(FlagInitialPrompt),
InitialPrompt: initialPrompt,
})
if err != nil {
return xerrors.Errorf("failed to create server: %w", err)
Expand Down Expand Up @@ -213,7 +232,7 @@ func CreateServerCmd() *cobra.Command {
{FlagAllowedHosts, "a", []string{"localhost", "127.0.0.1", "[::1]"}, "HTTP allowed hosts (hostnames only, no ports). Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_HOSTS env var", "stringSlice"},
// localhost:3284 is the default origin when you open the chat interface in your browser. localhost:3000 and 3001 are used during development.
{FlagAllowedOrigins, "o", []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, "HTTP allowed origins. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_ORIGINS env var", "stringSlice"},
{FlagInitialPrompt, "I", "", "Initial prompt for the agent (recommended only if the agent doesn't support initial prompt in interaction mode)", "string"},
{FlagInitialPrompt, "I", "", "Initial prompt for the agent. Recommended only if the agent doesn't support initial prompt in interaction mode. Will be read from stdin if piped (e.g., echo 'prompt' | agentapi server -- my-agent)", "string"},
}

for _, spec := range flagSpecs {
Expand Down
44 changes: 44 additions & 0 deletions cmd/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"os"
"strings"
"testing"
"time"

"github.com/spf13/cobra"
"github.com/spf13/viper"
Expand Down Expand Up @@ -571,3 +572,46 @@ func TestServerCmd_AllowedOrigins(t *testing.T) {
})
}
}

func TestIsStdinPiped(t *testing.T) {
tests := []struct {
name string
fileInfo os.FileInfo
expected bool
}{
{
name: "regular file (piped)",
fileInfo: &mockFileInfo{mode: 0},
expected: true,
},
{
name: "character device (terminal)",
fileInfo: &mockFileInfo{mode: os.ModeCharDevice},
expected: false,
},
{
name: "named pipe",
fileInfo: &mockFileInfo{mode: os.ModeNamedPipe},
expected: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isStdinPiped(tt.fileInfo)
assert.Equal(t, tt.expected, result)
})
}
}

// mockFileInfo implements os.FileInfo for testing
type mockFileInfo struct {
mode os.FileMode
}

func (m *mockFileInfo) Name() string { return "stdin" }
func (m *mockFileInfo) Size() int64 { return 0 }
func (m *mockFileInfo) Mode() os.FileMode { return m.mode }
func (m *mockFileInfo) ModTime() time.Time { return time.Time{} }
func (m *mockFileInfo) IsDir() bool { return false }
func (m *mockFileInfo) Sys() interface{} { return nil }