diff --git a/internal/app/app.go b/internal/app/app.go index e12b1c1..6f5ed1a 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -24,8 +24,22 @@ type App struct { server *http.Server } -// New wires the HTTP server: routes, middleware, timeouts. +// New wires the HTTP server with the production usecases. This is the +// constructor production code (cmd/prism-api) uses. func New(cfg *config.Config, logger *slog.Logger) *App { + return newWithHandlers(cfg, logger, usecase.NewAnalyzer(), usecase.NewPrompter()) +} + +// newWithHandlers builds the full middleware chain and route table but +// accepts the analyze and prompt usecase dependencies as parameters so +// chain integration tests can substitute fakes without reaching for +// pkg/prism. Production code must go through New. +func newWithHandlers( + cfg *config.Config, + logger *slog.Logger, + analyzeUsecase handler.AnalyzeUsecase, + promptUsecase handler.PromptUsecase, +) *App { mux := http.NewServeMux() health := handler.NewHealthHandler() @@ -33,10 +47,10 @@ func New(cfg *config.Config, logger *slog.Logger) *App { mux.HandleFunc("GET /readyz", health.Ready) mux.HandleFunc("GET /version", health.Version) - analyzeHandler := handler.NewAnalyzeHandler(usecase.NewAnalyzer()) + analyzeHandler := handler.NewAnalyzeHandler(analyzeUsecase) mux.HandleFunc("POST /v1/analyze", analyzeHandler.Handle) - promptHandler := handler.NewPromptHandler(usecase.NewPrompter()) + promptHandler := handler.NewPromptHandler(promptUsecase) mux.HandleFunc("POST /v1/prompt", promptHandler.Handle) chain := middleware.Chain( diff --git a/internal/app/app_chain_test.go b/internal/app/app_chain_test.go index 2947101..d14d5fb 100644 --- a/internal/app/app_chain_test.go +++ b/internal/app/app_chain_test.go @@ -2,6 +2,7 @@ package app import ( "bytes" + "context" "encoding/json" "io" "log/slog" @@ -11,9 +12,57 @@ import ( "testing" "time" + "github.com/hidetzu/prism/pkg/prism" + "github.com/hidetzu/prism-api/internal/config" + "github.com/hidetzu/prism-api/internal/usecase" ) +// stubAnalyzeUsecase implements handler.AnalyzeUsecase by returning canned +// values. Tests wire it through newWithHandlers so chain integration can +// exercise the real middleware stack without reaching pkg/prism. +// +// When block is non-nil the stub parks on receive, simulating a long-running +// handler so concurrency_limit and timeout tests can orchestrate scenarios. +// When entered is non-nil the stub signals arrival before parking. +type stubAnalyzeUsecase struct { + result prism.Result + err error + block chan struct{} + entered chan struct{} +} + +func (s *stubAnalyzeUsecase) Analyze(ctx context.Context, _ usecase.AnalyzeInput) (prism.Result, error) { + if s.entered != nil { + s.entered <- struct{}{} + } + if s.block != nil { + select { + case <-s.block: + case <-ctx.Done(): + return prism.Result{}, ctx.Err() + } + } + return s.result, s.err +} + +// stubPromptUsecase implements handler.PromptUsecase by returning canned +// values. Mirrors stubAnalyzeUsecase. +type stubPromptUsecase struct { + prompt string + err error +} + +func (s *stubPromptUsecase) Prompt(_ context.Context, _ usecase.PromptInput) (string, error) { + return s.prompt, s.err +} + +// silentLogger returns a logger whose output goes to io.Discard, for tests +// that do not care about log capture. +func silentLogger() *slog.Logger { + return slog.New(slog.NewJSONHandler(io.Discard, nil)) +} + // chainConfig returns a Config suitable for chain integration tests. Callers // override specific fields to exercise particular defensive middleware. func chainConfig() *config.Config { @@ -175,3 +224,136 @@ func TestChain_RateLimitKeysByXForwardedFor(t *testing.T) { t.Errorf("B status = %d, want 200 (keyed by XFF)", code) } } + +func TestChain_AnalyzeEndpointSuccess(t *testing.T) { + stubA := &stubAnalyzeUsecase{ + result: prism.Result{ + PR: prism.PRInfo{ + Provider: "github", + Repository: "owner/repo", + ID: "1", + Title: "Example", + URL: "https://github.com/owner/repo/pull/1", + }, + Analysis: prism.AnalysisResult{ + ChangeType: "feature", + RiskLevel: "low", + }, + }, + } + a := newWithHandlers(chainConfig(), silentLogger(), stubA, &stubPromptUsecase{}) + + body := `{"pull_request_url":"https://github.com/owner/repo/pull/1"}` + req := httptest.NewRequest(http.MethodPost, "/v1/analyze", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := serveRequest(a, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + if rec.Header().Get("X-Request-Id") == "" { + t.Error("X-Request-Id must be set on success") + } + + var got struct { + Result struct { + PullRequest struct { + Repository string `json:"repository"` + ID string `json:"id"` + } `json:"pull_request"` + Analysis struct { + ChangeType string `json:"change_type"` + } `json:"analysis"` + } `json:"result"` + } + if err := json.NewDecoder(rec.Body).Decode(&got); err != nil { + t.Fatalf("decode: %v", err) + } + if got.Result.PullRequest.Repository != "owner/repo" { + t.Errorf("pull_request.repository = %q", got.Result.PullRequest.Repository) + } + if got.Result.Analysis.ChangeType != "feature" { + t.Errorf("analysis.change_type = %q", got.Result.Analysis.ChangeType) + } +} + +func TestChain_PromptEndpointSuccess(t *testing.T) { + stubP := &stubPromptUsecase{prompt: "Review this PR focusing on error handling."} + a := newWithHandlers(chainConfig(), silentLogger(), &stubAnalyzeUsecase{}, stubP) + + body := `{"pull_request_url":"https://github.com/owner/repo/pull/1","mode":"detailed","language":"ja"}` + req := httptest.NewRequest(http.MethodPost, "/v1/prompt", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := serveRequest(a, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + if rec.Header().Get("X-Request-Id") == "" { + t.Error("X-Request-Id must be set on success") + } + + var got struct { + Prompt string `json:"prompt"` + } + if err := json.NewDecoder(rec.Body).Decode(&got); err != nil { + t.Fatalf("decode: %v", err) + } + if got.Prompt != "Review this PR focusing on error handling." { + t.Errorf("prompt = %q", got.Prompt) + } +} + +func TestChain_ConcurrencyLimitRejectsBeyondCapacity(t *testing.T) { + // Cap the server at one in-flight request so the second request is + // rejected by concurrency_limit without relying on time.Sleep. + cfg := chainConfig() + cfg.MaxConcurrentRequests = 1 + + entered := make(chan struct{}, 1) + release := make(chan struct{}) + stubA := &stubAnalyzeUsecase{ + block: release, + entered: entered, + } + a := newWithHandlers(cfg, silentLogger(), stubA, &stubPromptUsecase{}) + + firstDone := make(chan int, 1) + go func() { + body := `{"pull_request_url":"https://github.com/owner/repo/pull/1"}` + req := httptest.NewRequest(http.MethodPost, "/v1/analyze", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + firstDone <- serveRequest(a, req).Code + }() + + // Wait until the first request is actually inside the stub — by this + // point the concurrency_limit middleware has already acquired the only + // slot on its behalf. + <-entered + + body := `{"pull_request_url":"https://github.com/owner/repo/pull/2"}` + req := httptest.NewRequest(http.MethodPost, "/v1/analyze", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := serveRequest(a, req) + + if rec.Code != http.StatusServiceUnavailable { + t.Errorf("second request status = %d, want 503", rec.Code) + } + var errBody struct { + Error struct { + Code string `json:"code"` + } `json:"error"` + } + if err := json.NewDecoder(rec.Body).Decode(&errBody); err != nil { + t.Fatalf("decode error body: %v", err) + } + if errBody.Error.Code != "service_unavailable" { + t.Errorf("error.code = %q, want service_unavailable", errBody.Error.Code) + } + + // Release the first request; it should complete normally. + close(release) + if code := <-firstDone; code != http.StatusOK { + t.Errorf("first request status = %d, want 200", code) + } +} diff --git a/internal/httpapi/middleware/clientip_test.go b/internal/httpapi/middleware/clientip_test.go new file mode 100644 index 0000000..c5426c6 --- /dev/null +++ b/internal/httpapi/middleware/clientip_test.go @@ -0,0 +1,96 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +// TestClientIP covers every extraction path the helper supports, with +// particular attention to IPv6 edge cases that were flagged as untested +// in the v0.2.0 release review. +func TestClientIP(t *testing.T) { + cases := []struct { + name string + xff string // X-Forwarded-For header; empty means header not set + remoteAddr string + want string + }{ + // X-Forwarded-For path (preferred when present) + { + name: "xff single ipv4", + xff: "203.0.113.5", + remoteAddr: "192.0.2.1:80", + want: "203.0.113.5", + }, + { + name: "xff single ipv6", + xff: "2001:db8::1", + remoteAddr: "192.0.2.1:80", + want: "2001:db8::1", + }, + { + name: "xff multi ipv4", + xff: "203.0.113.5, 10.0.0.1", + remoteAddr: "192.0.2.1:80", + want: "203.0.113.5", + }, + { + name: "xff multi ipv6", + xff: "2001:db8::1, 10.0.0.1", + remoteAddr: "192.0.2.1:80", + want: "2001:db8::1", + }, + { + name: "xff leading and trailing space", + xff: " 203.0.113.5 ", + remoteAddr: "192.0.2.1:80", + want: "203.0.113.5", + }, + { + name: "xff no space after comma", + xff: "203.0.113.5,10.0.0.1", + remoteAddr: "192.0.2.1:80", + want: "203.0.113.5", + }, + + // RemoteAddr fallback path (no XFF header) + { + name: "remoteaddr ipv4 with port", + remoteAddr: "203.0.113.5:1234", + want: "203.0.113.5", + }, + { + name: "remoteaddr ipv6 bracketed with port", + remoteAddr: "[2001:db8::1]:1234", + want: "2001:db8::1", + }, + { + name: "remoteaddr bare ipv4 without port falls through", + remoteAddr: "203.0.113.5", + want: "203.0.113.5", + }, + { + name: "remoteaddr bare ipv6 without port falls through", + remoteAddr: "2001:db8::1", + want: "2001:db8::1", + }, + { + name: "remoteaddr empty", + remoteAddr: "", + want: "", + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + if tc.xff != "" { + req.Header.Set("X-Forwarded-For", tc.xff) + } + req.RemoteAddr = tc.remoteAddr + if got := clientIP(req); got != tc.want { + t.Errorf("clientIP() = %q, want %q", got, tc.want) + } + }) + } +} diff --git a/internal/httpapi/middleware/requestid.go b/internal/httpapi/middleware/requestid.go index 780e7ce..7c33c24 100644 --- a/internal/httpapi/middleware/requestid.go +++ b/internal/httpapi/middleware/requestid.go @@ -4,6 +4,7 @@ import ( "context" "crypto/rand" "encoding/base32" + "io" "net/http" ) @@ -42,11 +43,17 @@ func RequestIDFrom(ctx context.Context) string { var requestIDEncoding = base32.StdEncoding.WithPadding(base32.NoPadding) +// randSource is the source of randomness used by newRequestID. Production +// code uses crypto/rand.Reader; tests may replace it via a defer-restored +// assignment to exercise the read-failure fallback path. rand.Reader's +// declared type is io.Reader, so type inference is sufficient. +var randSource = rand.Reader + // newRequestID generates a 26-character random identifier. // 16 random bytes encoded as unpadded base32 yields 26 ASCII characters. func newRequestID() string { var b [16]byte - if _, err := rand.Read(b[:]); err != nil { + if _, err := io.ReadFull(randSource, b[:]); err != nil { // Extremely unlikely in practice; return a fixed sentinel so the // request still has a (non-empty) identifier. return "00000000000000000000000000" diff --git a/internal/httpapi/middleware/requestid_test.go b/internal/httpapi/middleware/requestid_test.go index 4f069c4..6a39b7c 100644 --- a/internal/httpapi/middleware/requestid_test.go +++ b/internal/httpapi/middleware/requestid_test.go @@ -2,11 +2,20 @@ package middleware import ( "context" + "errors" "net/http" "net/http/httptest" "testing" ) +// errorReader always fails Read. Used to exercise the randSource failure +// path in newRequestID without actually exhausting /dev/urandom. +type errorReader struct{} + +func (errorReader) Read(_ []byte) (int, error) { + return 0, errors.New("simulated rand failure") +} + func TestRequestID_SetsHeaderAndContext(t *testing.T) { var contextID string h := RequestID()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -53,3 +62,19 @@ func TestRequestIDFrom_EmptyContext(t *testing.T) { t.Errorf("RequestIDFrom(empty) = %q, want empty", id) } } + +func TestNewRequestID_FallsBackToSentinelOnReadFailure(t *testing.T) { + // Swap the package-level randSource for one that always errors so the + // fallback path in newRequestID runs. t.Cleanup restores the original + // so other tests in the package are not affected. + orig := randSource + randSource = errorReader{} + t.Cleanup(func() { randSource = orig }) + + const want = "00000000000000000000000000" + for i := 0; i < 3; i++ { + if got := newRequestID(); got != want { + t.Errorf("newRequestID() = %q, want sentinel %q", got, want) + } + } +} diff --git a/internal/logging/logging_test.go b/internal/logging/logging_test.go new file mode 100644 index 0000000..d23b67e --- /dev/null +++ b/internal/logging/logging_test.go @@ -0,0 +1,77 @@ +package logging + +import ( + "context" + "log/slog" + "testing" + + "github.com/hidetzu/prism-api/internal/config" +) + +func TestParseLevel(t *testing.T) { + cases := []struct { + name string + input string + want slog.Level + }{ + {"debug", "debug", slog.LevelDebug}, + {"debug uppercase", "DEBUG", slog.LevelDebug}, + {"info", "info", slog.LevelInfo}, + {"info mixed case", "Info", slog.LevelInfo}, + {"warn", "warn", slog.LevelWarn}, + {"warning alias", "warning", slog.LevelWarn}, + {"warning uppercase", "WARNING", slog.LevelWarn}, + {"error", "error", slog.LevelError}, + // Unknown input returns LevelInfo as defense in depth. In + // production, config.Validate rejects unknown log levels before + // this function is called, so this branch is unreachable via the + // normal startup path. Keep the test anyway so a future refactor + // that bypasses Load() still gets a sane default instead of a + // panic or a bogus level. + {"unknown falls back to info", "bogus", slog.LevelInfo}, + {"empty string falls back to info", "", slog.LevelInfo}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if got := parseLevel(tc.input); got != tc.want { + t.Errorf("parseLevel(%q) = %v, want %v", tc.input, got, tc.want) + } + }) + } +} + +func TestNew_RespectsConfiguredLevel(t *testing.T) { + cases := []struct { + level string + wantDebug bool + wantInfo bool + wantWarn bool + wantError bool + }{ + {"debug", true, true, true, true}, + {"info", false, true, true, true}, + {"warn", false, false, true, true}, + {"error", false, false, false, true}, + } + for _, tc := range cases { + t.Run(tc.level, func(t *testing.T) { + logger := New(&config.Config{LogLevel: tc.level}) + if logger == nil { + t.Fatal("New() returned nil") + } + ctx := context.Background() + if got := logger.Enabled(ctx, slog.LevelDebug); got != tc.wantDebug { + t.Errorf("Enabled(Debug) = %v, want %v", got, tc.wantDebug) + } + if got := logger.Enabled(ctx, slog.LevelInfo); got != tc.wantInfo { + t.Errorf("Enabled(Info) = %v, want %v", got, tc.wantInfo) + } + if got := logger.Enabled(ctx, slog.LevelWarn); got != tc.wantWarn { + t.Errorf("Enabled(Warn) = %v, want %v", got, tc.wantWarn) + } + if got := logger.Enabled(ctx, slog.LevelError); got != tc.wantError { + t.Errorf("Enabled(Error) = %v, want %v", got, tc.wantError) + } + }) + } +}