Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions internal/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,33 @@ type App struct {
server *http.Server
}

// New wires the HTTP server: routes, middleware, timeouts.
// New wires the HTTP server with the production usecases. This is the
// constructor production code (cmd/prism-api) uses.
func New(cfg *config.Config, logger *slog.Logger) *App {
return newWithHandlers(cfg, logger, usecase.NewAnalyzer(), usecase.NewPrompter())
}

// newWithHandlers builds the full middleware chain and route table but
// accepts the analyze and prompt usecase dependencies as parameters so
// chain integration tests can substitute fakes without reaching for
// pkg/prism. Production code must go through New.
func newWithHandlers(
cfg *config.Config,
logger *slog.Logger,
analyzeUsecase handler.AnalyzeUsecase,
promptUsecase handler.PromptUsecase,
) *App {
mux := http.NewServeMux()

health := handler.NewHealthHandler()
mux.HandleFunc("GET /healthz", health.Live)
mux.HandleFunc("GET /readyz", health.Ready)
mux.HandleFunc("GET /version", health.Version)

analyzeHandler := handler.NewAnalyzeHandler(usecase.NewAnalyzer())
analyzeHandler := handler.NewAnalyzeHandler(analyzeUsecase)
mux.HandleFunc("POST /v1/analyze", analyzeHandler.Handle)

promptHandler := handler.NewPromptHandler(usecase.NewPrompter())
promptHandler := handler.NewPromptHandler(promptUsecase)
mux.HandleFunc("POST /v1/prompt", promptHandler.Handle)

chain := middleware.Chain(
Expand Down
182 changes: 182 additions & 0 deletions internal/app/app_chain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package app

import (
"bytes"
"context"
"encoding/json"
"io"
"log/slog"
Expand All @@ -11,9 +12,57 @@ import (
"testing"
"time"

"github.com/hidetzu/prism/pkg/prism"

"github.com/hidetzu/prism-api/internal/config"
"github.com/hidetzu/prism-api/internal/usecase"
)

// stubAnalyzeUsecase implements handler.AnalyzeUsecase by returning canned
// values. Tests wire it through newWithHandlers so chain integration can
// exercise the real middleware stack without reaching pkg/prism.
//
// When block is non-nil the stub parks on receive, simulating a long-running
// handler so concurrency_limit and timeout tests can orchestrate scenarios.
// When entered is non-nil the stub signals arrival before parking.
type stubAnalyzeUsecase struct {
result prism.Result
err error
block chan struct{}
entered chan struct{}
}

func (s *stubAnalyzeUsecase) Analyze(ctx context.Context, _ usecase.AnalyzeInput) (prism.Result, error) {
if s.entered != nil {
s.entered <- struct{}{}
}
if s.block != nil {
select {
case <-s.block:
case <-ctx.Done():
return prism.Result{}, ctx.Err()
}
}
return s.result, s.err
}

// stubPromptUsecase implements handler.PromptUsecase by returning canned
// values. Mirrors stubAnalyzeUsecase.
type stubPromptUsecase struct {
prompt string
err error
}

func (s *stubPromptUsecase) Prompt(_ context.Context, _ usecase.PromptInput) (string, error) {
return s.prompt, s.err
}

// silentLogger returns a logger whose output goes to io.Discard, for tests
// that do not care about log capture.
func silentLogger() *slog.Logger {
return slog.New(slog.NewJSONHandler(io.Discard, nil))
}

// chainConfig returns a Config suitable for chain integration tests. Callers
// override specific fields to exercise particular defensive middleware.
func chainConfig() *config.Config {
Expand Down Expand Up @@ -175,3 +224,136 @@ func TestChain_RateLimitKeysByXForwardedFor(t *testing.T) {
t.Errorf("B status = %d, want 200 (keyed by XFF)", code)
}
}

func TestChain_AnalyzeEndpointSuccess(t *testing.T) {
stubA := &stubAnalyzeUsecase{
result: prism.Result{
PR: prism.PRInfo{
Provider: "github",
Repository: "owner/repo",
ID: "1",
Title: "Example",
URL: "https://github.com/owner/repo/pull/1",
},
Analysis: prism.AnalysisResult{
ChangeType: "feature",
RiskLevel: "low",
},
},
}
a := newWithHandlers(chainConfig(), silentLogger(), stubA, &stubPromptUsecase{})

body := `{"pull_request_url":"https://github.com/owner/repo/pull/1"}`
req := httptest.NewRequest(http.MethodPost, "/v1/analyze", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := serveRequest(a, req)

if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want 200", rec.Code)
}
if rec.Header().Get("X-Request-Id") == "" {
t.Error("X-Request-Id must be set on success")
}

var got struct {
Result struct {
PullRequest struct {
Repository string `json:"repository"`
ID string `json:"id"`
} `json:"pull_request"`
Analysis struct {
ChangeType string `json:"change_type"`
} `json:"analysis"`
} `json:"result"`
}
if err := json.NewDecoder(rec.Body).Decode(&got); err != nil {
t.Fatalf("decode: %v", err)
}
if got.Result.PullRequest.Repository != "owner/repo" {
t.Errorf("pull_request.repository = %q", got.Result.PullRequest.Repository)
}
if got.Result.Analysis.ChangeType != "feature" {
t.Errorf("analysis.change_type = %q", got.Result.Analysis.ChangeType)
}
}

func TestChain_PromptEndpointSuccess(t *testing.T) {
stubP := &stubPromptUsecase{prompt: "Review this PR focusing on error handling."}
a := newWithHandlers(chainConfig(), silentLogger(), &stubAnalyzeUsecase{}, stubP)

body := `{"pull_request_url":"https://github.com/owner/repo/pull/1","mode":"detailed","language":"ja"}`
req := httptest.NewRequest(http.MethodPost, "/v1/prompt", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := serveRequest(a, req)

if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want 200", rec.Code)
}
if rec.Header().Get("X-Request-Id") == "" {
t.Error("X-Request-Id must be set on success")
}

var got struct {
Prompt string `json:"prompt"`
}
if err := json.NewDecoder(rec.Body).Decode(&got); err != nil {
t.Fatalf("decode: %v", err)
}
if got.Prompt != "Review this PR focusing on error handling." {
t.Errorf("prompt = %q", got.Prompt)
}
}

func TestChain_ConcurrencyLimitRejectsBeyondCapacity(t *testing.T) {
// Cap the server at one in-flight request so the second request is
// rejected by concurrency_limit without relying on time.Sleep.
cfg := chainConfig()
cfg.MaxConcurrentRequests = 1

entered := make(chan struct{}, 1)
release := make(chan struct{})
stubA := &stubAnalyzeUsecase{
block: release,
entered: entered,
}
a := newWithHandlers(cfg, silentLogger(), stubA, &stubPromptUsecase{})

firstDone := make(chan int, 1)
go func() {
body := `{"pull_request_url":"https://github.com/owner/repo/pull/1"}`
req := httptest.NewRequest(http.MethodPost, "/v1/analyze", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
firstDone <- serveRequest(a, req).Code
}()

// Wait until the first request is actually inside the stub — by this
// point the concurrency_limit middleware has already acquired the only
// slot on its behalf.
<-entered

body := `{"pull_request_url":"https://github.com/owner/repo/pull/2"}`
req := httptest.NewRequest(http.MethodPost, "/v1/analyze", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := serveRequest(a, req)

if rec.Code != http.StatusServiceUnavailable {
t.Errorf("second request status = %d, want 503", rec.Code)
}
var errBody struct {
Error struct {
Code string `json:"code"`
} `json:"error"`
}
if err := json.NewDecoder(rec.Body).Decode(&errBody); err != nil {
t.Fatalf("decode error body: %v", err)
}
if errBody.Error.Code != "service_unavailable" {
t.Errorf("error.code = %q, want service_unavailable", errBody.Error.Code)
}

// Release the first request; it should complete normally.
close(release)
if code := <-firstDone; code != http.StatusOK {
t.Errorf("first request status = %d, want 200", code)
}
}
96 changes: 96 additions & 0 deletions internal/httpapi/middleware/clientip_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package middleware

import (
"net/http"
"net/http/httptest"
"testing"
)

// TestClientIP covers every extraction path the helper supports, with
// particular attention to IPv6 edge cases that were flagged as untested
// in the v0.2.0 release review.
func TestClientIP(t *testing.T) {
cases := []struct {
name string
xff string // X-Forwarded-For header; empty means header not set
remoteAddr string
want string
}{
// X-Forwarded-For path (preferred when present)
{
name: "xff single ipv4",
xff: "203.0.113.5",
remoteAddr: "192.0.2.1:80",
want: "203.0.113.5",
},
{
name: "xff single ipv6",
xff: "2001:db8::1",
remoteAddr: "192.0.2.1:80",
want: "2001:db8::1",
},
{
name: "xff multi ipv4",
xff: "203.0.113.5, 10.0.0.1",
remoteAddr: "192.0.2.1:80",
want: "203.0.113.5",
},
{
name: "xff multi ipv6",
xff: "2001:db8::1, 10.0.0.1",
remoteAddr: "192.0.2.1:80",
want: "2001:db8::1",
},
{
name: "xff leading and trailing space",
xff: " 203.0.113.5 ",
remoteAddr: "192.0.2.1:80",
want: "203.0.113.5",
},
{
name: "xff no space after comma",
xff: "203.0.113.5,10.0.0.1",
remoteAddr: "192.0.2.1:80",
want: "203.0.113.5",
},

// RemoteAddr fallback path (no XFF header)
{
name: "remoteaddr ipv4 with port",
remoteAddr: "203.0.113.5:1234",
want: "203.0.113.5",
},
{
name: "remoteaddr ipv6 bracketed with port",
remoteAddr: "[2001:db8::1]:1234",
want: "2001:db8::1",
},
{
name: "remoteaddr bare ipv4 without port falls through",
remoteAddr: "203.0.113.5",
want: "203.0.113.5",
},
{
name: "remoteaddr bare ipv6 without port falls through",
remoteAddr: "2001:db8::1",
want: "2001:db8::1",
},
{
name: "remoteaddr empty",
remoteAddr: "",
want: "",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
if tc.xff != "" {
req.Header.Set("X-Forwarded-For", tc.xff)
}
req.RemoteAddr = tc.remoteAddr
if got := clientIP(req); got != tc.want {
t.Errorf("clientIP() = %q, want %q", got, tc.want)
}
})
}
}
9 changes: 8 additions & 1 deletion internal/httpapi/middleware/requestid.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"crypto/rand"
"encoding/base32"
"io"
"net/http"
)

Expand Down Expand Up @@ -42,11 +43,17 @@ func RequestIDFrom(ctx context.Context) string {

var requestIDEncoding = base32.StdEncoding.WithPadding(base32.NoPadding)

// randSource is the source of randomness used by newRequestID. Production
// code uses crypto/rand.Reader; tests may replace it via a defer-restored
// assignment to exercise the read-failure fallback path. rand.Reader's
// declared type is io.Reader, so type inference is sufficient.
var randSource = rand.Reader

// newRequestID generates a 26-character random identifier.
// 16 random bytes encoded as unpadded base32 yields 26 ASCII characters.
func newRequestID() string {
var b [16]byte
if _, err := rand.Read(b[:]); err != nil {
if _, err := io.ReadFull(randSource, b[:]); err != nil {
// Extremely unlikely in practice; return a fixed sentinel so the
// request still has a (non-empty) identifier.
return "00000000000000000000000000"
Expand Down
Loading
Loading