diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..4740336 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,34 @@ +name: CI + +on: + pull_request: + branches: [ master ] + push: + branches: [ master ] + +jobs: + test: + strategy: + matrix: + os: [ubuntu-latest, windows-latest] + runs-on: ${{ matrix.os }} + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + submodules: recursive + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: "1.20" + + - name: Build + run: go build ./... + + - name: Test (race) + run: go test ./core/pool/ ./core/baseline/ ./core/ihttp/ ./pkg/ -v -race -count=1 -timeout=120s + + - name: Vet + run: go vet ./core/... ./pkg/... ./cmd/... + continue-on-error: true diff --git a/core/baseline/baseline.go b/core/baseline/baseline.go index eab3710..3b66d1e 100644 --- a/core/baseline/baseline.go +++ b/core/baseline/baseline.go @@ -3,6 +3,7 @@ package baseline import ( "bytes" "github.com/chainreactors/fingers/common" + "github.com/chainreactors/logs" "github.com/chainreactors/parsers" "github.com/chainreactors/spray/core/ihttp" "github.com/chainreactors/spray/pkg" @@ -53,15 +54,14 @@ func NewBaseline(u, host string, resp *ihttp.Response) *Baseline { bl.Raw = append(bl.Header, bl.Body...) bl.Response, err = pkg.ParseRawResponse(bl.Raw) if err != nil { - bl.IsValid = false - bl.Reason = pkg.ErrResponseError.Error() - bl.ErrString = err.Error() - return bl + // raw 重解析失败不影响 baseline 有效性,live response 已提供所有需要的数据 + logs.Log.Debugf("ParseRawResponse failed for %s: %s", u, err.Error()) } - if r := bl.Response.Header.Get("Location"); r != "" { + // 始终从 live response 读取 Location + if r := resp.GetHeader("Location"); r != "" { bl.RedirectURL = r } else { - bl.RedirectURL = bl.Response.Header.Get("location") + bl.RedirectURL = resp.GetHeader("location") } bl.Dir = bl.IsDir() diff --git a/core/pool/brutepool.go b/core/pool/brutepool.go index 5c1d955..552afd8 100644 --- a/core/pool/brutepool.go +++ b/core/pool/brutepool.go @@ -52,10 +52,11 @@ func NewBrutePool(ctx context.Context, config *Config) (*BrutePool, error) { Timeout: config.Timeout, ProxyClient: config.ProxyClient, }), - additionCh: make(chan *Unit, config.Thread*10), - closeCh: make(chan struct{}), - processCh: make(chan *baseline.Baseline, config.Thread*2), - wg: &sync.WaitGroup{}, + additionCh: make(chan *Unit, config.Thread*10), + closeCh: make(chan struct{}), + processCh: make(chan *baseline.Baseline, config.Thread*2), + wg: &sync.WaitGroup{}, + handlerDone: make(chan struct{}), }, base: u.Scheme + "://" + u.Host, isDir: strings.HasSuffix(u.Path, "/"), @@ -105,7 +106,6 @@ type BrutePool struct { urls sync.Map scopeurls map[string]struct{} uniques map[uint16]struct{} - analyzeDone bool limiter *rate.Limiter locker sync.Mutex scopeLocker sync.Mutex @@ -197,6 +197,11 @@ func (pool *BrutePool) Run(offset, limit int) { close(pool.closeCh) return } + select { + case <-pool.ctx.Done(): + return + default: + } time.Sleep(100 * time.Millisecond) } }() @@ -376,7 +381,7 @@ func (pool *BrutePool) Invoke(v interface{}) { case parsers.WordSource: // 异步进行性能消耗较大的深度对比 - pool.processCh <- bl + pool.sendProcess(bl) if int(pool.Statistor.ReqTotal)%pool.CheckPeriod == 0 { // 间歇插入check waf的探针 pool.doCheck() @@ -388,9 +393,9 @@ func (pool *BrutePool) Invoke(v interface{}) { pool.Bar.Done() case parsers.RedirectSource: bl.FrontURL = unit.frontUrl - pool.processCh <- bl + pool.sendProcess(bl) default: - pool.processCh <- bl + pool.sendProcess(bl) } } @@ -432,6 +437,7 @@ func (pool *BrutePool) NoScopeInvoke(v interface{}) { } func (pool *BrutePool) Handler() { + defer close(pool.handlerDone) for bl := range pool.processCh { if bl.IsValid { pool.addFuzzyBaseline(bl) @@ -516,8 +522,6 @@ func (pool *BrutePool) Handler() { } pool.wg.Done() } - - pool.analyzeDone = true } func (pool *BrutePool) checkRedirect(redirectURL string) bool { @@ -674,15 +678,12 @@ func (pool *BrutePool) fallback() { } func (pool *BrutePool) Close() { - for pool.analyzeDone { - // 等待缓存的待处理任务完成 - time.Sleep(time.Duration(100) * time.Millisecond) - } - close(pool.additionCh) // 关闭addition管道 - //close(pool.checkCh) // 关闭check管道 - pool.Statistor.EndTime = time.Now().Unix() + pool.Cancel() pool.reqPool.Release() pool.scopePool.Release() + close(pool.processCh) + <-pool.handlerDone + pool.Statistor.EndTime = time.Now().Unix() } func (pool *BrutePool) safePath(u string) string { @@ -713,9 +714,15 @@ func (pool *BrutePool) doCheck() { } if pool.Mod == HostSpray { - pool.checkCh <- struct{}{} + select { + case pool.checkCh <- struct{}{}: + case <-pool.ctx.Done(): + } } else if pool.Mod == PathSpray { - pool.checkCh <- struct{}{} + select { + case pool.checkCh <- struct{}{}: + case <-pool.ctx.Done(): + } } } @@ -755,6 +762,7 @@ func (pool *BrutePool) doCrawl(bl *baseline.Baseline) { pool.doScopeCrawl(bl) go func() { + defer pool.wg.Done() for _, u := range bl.URLs { if u = pkg.FormatURL(bl.Url.Path, u); u == "" { continue diff --git a/core/pool/checkpool.go b/core/pool/checkpool.go index 8e7ee67..400039a 100644 --- a/core/pool/checkpool.go +++ b/core/pool/checkpool.go @@ -34,6 +34,7 @@ func NewCheckPool(ctx context.Context, config *Config) (*CheckPool, error) { additionCh: make(chan *Unit, config.Thread*10), closeCh: make(chan struct{}), processCh: make(chan *baseline.Baseline, config.Thread*2), + handlerDone: make(chan struct{}), }, } pool.Request.Headers.Set("Connection", "close") @@ -105,6 +106,9 @@ Loop: pool.Close() } func (pool *CheckPool) Close() { + pool.Cancel() + close(pool.processCh) + <-pool.handlerDone pool.Bar.Close() pool.Pool.Release() } @@ -139,7 +143,7 @@ func (pool *CheckPool) Invoke(v interface{}) { ReqDepth: unit.depth, }, } - pool.processCh <- bl + pool.sendProcess(bl) return } start := time.Now() @@ -172,10 +176,11 @@ func (pool *CheckPool) Invoke(v interface{}) { if bl.RedirectURL != "" { pool.doRedirect(bl, bl.ReqDepth) } - pool.processCh <- bl + pool.sendProcess(bl) } func (pool *CheckPool) Handler() { + defer close(pool.handlerDone) for bl := range pool.processCh { if bl.IsValid { params := map[string]interface{}{ diff --git a/core/pool/pool.go b/core/pool/pool.go index baab583..3f73bbb 100644 --- a/core/pool/pool.go +++ b/core/pool/pool.go @@ -26,6 +26,7 @@ type BasePool struct { additionCh chan *Unit closeCh chan struct{} wg *sync.WaitGroup + handlerDone chan struct{} isFallback atomic.Bool } @@ -44,18 +45,21 @@ func (pool *BasePool) doRetry(bl *baseline.Baseline) { } func (pool *BasePool) addAddition(u *Unit) { + if pool.ctx.Err() != nil { + return + } pool.wg.Add(1) select { case pool.additionCh <- u: - default: - // 强行屏蔽报错, 防止goroutine泄露 - go func() { - select { - case pool.additionCh <- u: - case <-pool.ctx.Done(): - pool.wg.Done() - } - }() + case <-pool.ctx.Done(): + pool.wg.Done() + } +} + +func (pool *BasePool) sendProcess(bl *baseline.Baseline) { + select { + case pool.processCh <- bl: + case <-pool.ctx.Done(): } } @@ -64,11 +68,19 @@ func (pool *BasePool) putToOutput(bl *baseline.Baseline) { bl.Collect() } pool.Outwg.Add(1) - pool.OutputCh <- bl + select { + case pool.OutputCh <- bl: + case <-pool.ctx.Done(): + pool.Outwg.Done() + } } func (pool *BasePool) putToFuzzy(bl *baseline.Baseline) { pool.Outwg.Add(1) bl.IsFuzzy = true - pool.FuzzyCh <- bl + select { + case pool.FuzzyCh <- bl: + case <-pool.ctx.Done(): + pool.Outwg.Done() + } } diff --git a/core/pool/pool_test.go b/core/pool/pool_test.go new file mode 100644 index 0000000..fa69142 --- /dev/null +++ b/core/pool/pool_test.go @@ -0,0 +1,526 @@ +package pool + +import ( + "context" + "github.com/chainreactors/parsers" + "github.com/chainreactors/spray/core/baseline" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" +) + +// --------------------------------------------------------------------------- +// helpers +// --------------------------------------------------------------------------- + +func newTestBasePool(ctx context.Context, cancel context.CancelFunc) *BasePool { + return &BasePool{ + Config: &Config{ + Thread: 4, + OutputCh: make(chan *baseline.Baseline, 100), + FuzzyCh: make(chan *baseline.Baseline, 100), + Outwg: &sync.WaitGroup{}, + }, + ctx: ctx, + Cancel: cancel, + additionCh: make(chan *Unit, 40), // Thread*10 + closeCh: make(chan struct{}), + processCh: make(chan *baseline.Baseline, 8), // Thread*2 + wg: &sync.WaitGroup{}, + handlerDone: make(chan struct{}), + } +} + +func newTestBaseline() *baseline.Baseline { + return &baseline.Baseline{ + SprayResult: &parsers.SprayResult{ + UrlString: "http://example.com/test", + IsValid: true, + }, + } +} + +// mustFinish fails the test if fn does not return within d. +func mustFinish(t *testing.T, d time.Duration, msg string, fn func()) { + t.Helper() + done := make(chan struct{}) + go func() { fn(); close(done) }() + select { + case <-done: + case <-time.After(d): + t.Fatalf("timeout: %s", msg) + } +} + +// --------------------------------------------------------------------------- +// addAddition +// --------------------------------------------------------------------------- + +func TestAddAddition_Normal(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + pool := newTestBasePool(ctx, cancel) + + pool.addAddition(&Unit{path: "/a", source: parsers.WordSource}) + + select { + case u := <-pool.additionCh: + if u.path != "/a" { + t.Fatalf("path = %q, want /a", u.path) + } + case <-time.After(time.Second): + t.Fatal("addAddition: channel receive timed out") + } + pool.wg.Done() // balance the Add(1) + pool.wg.Wait() // must not hang +} + +func TestAddAddition_AfterCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + pool := newTestBasePool(ctx, cancel) + cancel() + + pool.addAddition(&Unit{path: "/a", source: parsers.WordSource}) + + select { + case <-pool.additionCh: + t.Fatal("should not send after cancel") + default: + } + // wg must be zero — the cancelled path must not leak a counter + pool.wg.Wait() +} + +// Regression: old code had a `default` branch that spawned an async goroutine. +// If the goroutine sent successfully, wg.Done() was never called — wg leak. +// This test verifies wg stays balanced after many addAddition calls. +func TestAddAddition_WgBalance(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + pool := newTestBasePool(ctx, cancel) + + const N = 100 + // drain concurrently so addAddition never blocks on a full buffer + var received int32 + go func() { + for range pool.additionCh { + atomic.AddInt32(&received, 1) + pool.wg.Done() + } + }() + + for i := 0; i < N; i++ { + pool.addAddition(&Unit{path: "/x", source: parsers.WordSource}) + } + + mustFinish(t, 2*time.Second, "wg.Wait hung — wg counter leaked", func() { + pool.wg.Wait() + }) + close(pool.additionCh) // stop drain goroutine + if r := atomic.LoadInt32(&received); r != N { + t.Fatalf("received %d items, want %d", r, N) + } +} + +func TestAddAddition_FullBufferUnblocksOnCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + pool := newTestBasePool(ctx, cancel) + + // fill buffer + for i := 0; i < cap(pool.additionCh); i++ { + pool.additionCh <- &Unit{path: "/fill"} + } + + mustFinish(t, 2*time.Second, "addAddition on full buffer not unblocked by cancel", func() { + go func() { + time.Sleep(50 * time.Millisecond) + cancel() + }() + pool.addAddition(&Unit{path: "/blocked", source: parsers.WordSource}) + }) +} + +func TestAddAddition_ConcurrentShutdown(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + pool := newTestBasePool(ctx, cancel) + + var ext sync.WaitGroup + for i := 0; i < 200; i++ { + ext.Add(1) + go func() { + defer ext.Done() + pool.addAddition(&Unit{path: "/c", source: parsers.WordSource}) + }() + } + + time.Sleep(2 * time.Millisecond) + cancel() + + // drain so senders can unblock + go func() { + for range pool.additionCh { + pool.wg.Done() + } + }() + + mustFinish(t, 5*time.Second, "concurrent addAddition+cancel hung", func() { + ext.Wait() + }) + close(pool.additionCh) +} + +// --------------------------------------------------------------------------- +// sendProcess +// --------------------------------------------------------------------------- + +func TestSendProcess_Normal(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + pool := newTestBasePool(ctx, cancel) + + bl := newTestBaseline() + pool.sendProcess(bl) + + select { + case got := <-pool.processCh: + if got.UrlString != bl.UrlString { + t.Fatalf("got %q, want %q", got.UrlString, bl.UrlString) + } + case <-time.After(time.Second): + t.Fatal("sendProcess: channel receive timed out") + } +} + +func TestSendProcess_AfterCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + pool := newTestBasePool(ctx, cancel) + cancel() + + mustFinish(t, 2*time.Second, "sendProcess blocked after cancel", func() { + pool.sendProcess(newTestBaseline()) + }) +} + +func TestSendProcess_FullBufferUnblocksOnCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + pool := newTestBasePool(ctx, cancel) + + for i := 0; i < cap(pool.processCh); i++ { + pool.processCh <- newTestBaseline() + } + + mustFinish(t, 2*time.Second, "sendProcess on full buffer not unblocked by cancel", func() { + go func() { + time.Sleep(50 * time.Millisecond) + cancel() + }() + pool.sendProcess(newTestBaseline()) + }) +} + +func TestSendProcess_ConcurrentShutdown(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + pool := newTestBasePool(ctx, cancel) + + var ext sync.WaitGroup + for i := 0; i < 200; i++ { + ext.Add(1) + go func() { + defer ext.Done() + pool.sendProcess(newTestBaseline()) + }() + } + + time.Sleep(2 * time.Millisecond) + cancel() + + go func() { + for range pool.processCh { + } + }() + + mustFinish(t, 5*time.Second, "concurrent sendProcess+cancel hung", func() { + ext.Wait() + }) + close(pool.processCh) +} + +// --------------------------------------------------------------------------- +// putToOutput +// --------------------------------------------------------------------------- + +func TestPutToOutput_Normal(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + pool := newTestBasePool(ctx, cancel) + + pool.putToOutput(newTestBaseline()) + + select { + case <-pool.OutputCh: + case <-time.After(time.Second): + t.Fatal("putToOutput: receive timed out") + } + pool.Outwg.Done() + pool.Outwg.Wait() +} + +func TestPutToOutput_AfterCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + pool := newTestBasePool(ctx, cancel) + cancel() + + mustFinish(t, 2*time.Second, "putToOutput blocked after cancel", func() { + pool.putToOutput(newTestBaseline()) + }) + // select may have picked OutputCh (buffer available) or ctx.Done(); + // drain if sent so Outwg stays balanced. + select { + case <-pool.OutputCh: + pool.Outwg.Done() + default: + } + pool.Outwg.Wait() +} + +func TestPutToOutput_OutwgBalance(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + pool := newTestBasePool(ctx, cancel) + + // fill OutputCh so next send will block + for i := 0; i < cap(pool.OutputCh); i++ { + pool.OutputCh <- newTestBaseline() + pool.Outwg.Add(1) // manual balance for the fills + } + + mustFinish(t, 2*time.Second, "putToOutput on full ch not unblocked by cancel", func() { + go func() { + time.Sleep(50 * time.Millisecond) + cancel() + }() + pool.putToOutput(newTestBaseline()) + }) + + // drain and done for fills + for i := 0; i < cap(pool.OutputCh); i++ { + <-pool.OutputCh + pool.Outwg.Done() + } + pool.Outwg.Wait() +} + +// --------------------------------------------------------------------------- +// putToFuzzy +// --------------------------------------------------------------------------- + +func TestPutToFuzzy_Normal(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + pool := newTestBasePool(ctx, cancel) + + pool.putToFuzzy(newTestBaseline()) + + select { + case bl := <-pool.FuzzyCh: + if !bl.IsFuzzy { + t.Fatal("IsFuzzy should be true") + } + case <-time.After(time.Second): + t.Fatal("putToFuzzy: receive timed out") + } + pool.Outwg.Done() + pool.Outwg.Wait() +} + +func TestPutToFuzzy_AfterCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + pool := newTestBasePool(ctx, cancel) + cancel() + + mustFinish(t, 2*time.Second, "putToFuzzy blocked after cancel", func() { + pool.putToFuzzy(newTestBaseline()) + }) + select { + case <-pool.FuzzyCh: + pool.Outwg.Done() + default: + } + pool.Outwg.Wait() +} + +// --------------------------------------------------------------------------- +// Handler lifecycle — close(processCh) must exit for-range +// --------------------------------------------------------------------------- + +// Regression: processCh was never closed → Handler goroutine leaked forever. +func TestHandlerDone_SignaledOnProcessChClose(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + pool := newTestBasePool(ctx, cancel) + + go func() { + defer close(pool.handlerDone) + for range pool.processCh { + } + }() + + close(pool.processCh) + + select { + case <-pool.handlerDone: + case <-time.After(2 * time.Second): + t.Fatal("handlerDone not signaled after processCh closed") + } +} + +func TestHandlerDone_ProcessesAllBeforeExit(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + pool := newTestBasePool(ctx, cancel) + + var count int + go func() { + defer close(pool.handlerDone) + for range pool.processCh { + count++ + } + }() + + const N = 20 + for i := 0; i < N; i++ { + pool.processCh <- newTestBaseline() + } + close(pool.processCh) + <-pool.handlerDone + + if count != N { + t.Fatalf("handler processed %d items, want %d", count, N) + } +} + +// --------------------------------------------------------------------------- +// Full shutdown sequence (integration-level) +// --------------------------------------------------------------------------- + +func TestShutdownSequence_NoDeadlock(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + pool := newTestBasePool(ctx, cancel) + + // simulate Handler + go func() { + defer close(pool.handlerDone) + for range pool.processCh { + } + }() + + // simulate some work + for i := 0; i < 10; i++ { + pool.addAddition(&Unit{path: "/w", source: parsers.WordSource}) + } + // simulate consumer + for i := 0; i < 10; i++ { + <-pool.additionCh + pool.sendProcess(newTestBaseline()) + pool.wg.Done() + } + + // shutdown: Cancel → wg.Wait → close(processCh) → <-handlerDone + mustFinish(t, 5*time.Second, "shutdown sequence deadlocked", func() { + cancel() + pool.wg.Wait() + close(pool.processCh) + <-pool.handlerDone + }) +} + +func TestShutdownSequence_CancelMidFlight(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + pool := newTestBasePool(ctx, cancel) + + go func() { + defer close(pool.handlerDone) + for range pool.processCh { + } + }() + + // producers still running when cancel fires + var ext sync.WaitGroup + for i := 0; i < 50; i++ { + ext.Add(1) + go func() { + defer ext.Done() + pool.addAddition(&Unit{path: "/m", source: parsers.WordSource}) + }() + } + for i := 0; i < 50; i++ { + ext.Add(1) + go func() { + defer ext.Done() + pool.sendProcess(newTestBaseline()) + }() + } + + time.Sleep(5 * time.Millisecond) + cancel() + + // drain additionCh + go func() { + for range pool.additionCh { + pool.wg.Done() + } + }() + + mustFinish(t, 5*time.Second, "mid-flight cancel deadlocked", func() { + ext.Wait() + pool.wg.Wait() + close(pool.processCh) + <-pool.handlerDone + close(pool.additionCh) + }) +} + +// --------------------------------------------------------------------------- +// Goroutine leak detection +// --------------------------------------------------------------------------- + +func TestNoGoroutineLeak(t *testing.T) { + // let background goroutines from prior tests settle + runtime.GC() + time.Sleep(100 * time.Millisecond) + before := runtime.NumGoroutine() + + ctx, cancel := context.WithCancel(context.Background()) + pool := newTestBasePool(ctx, cancel) + + go func() { + defer close(pool.handlerDone) + for range pool.processCh { + } + }() + + for i := 0; i < 10; i++ { + pool.addAddition(&Unit{path: "/l", source: parsers.WordSource}) + } + go func() { + for range pool.additionCh { + pool.wg.Done() + } + }() + + cancel() + pool.wg.Wait() + close(pool.processCh) + <-pool.handlerDone + close(pool.additionCh) + + time.Sleep(200 * time.Millisecond) + runtime.GC() + time.Sleep(100 * time.Millisecond) + + after := runtime.NumGoroutine() + if after > before+2 { + t.Errorf("goroutine leak: before=%d after=%d", before, after) + } +} diff --git a/pkg/utils_test.go b/pkg/utils_test.go new file mode 100644 index 0000000..7fa07b0 --- /dev/null +++ b/pkg/utils_test.go @@ -0,0 +1,122 @@ +package pkg + +import ( + "strings" + "testing" +) + +func TestParseRawResponse(t *testing.T) { + t.Run("valid complete response", func(t *testing.T) { + raw := "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: 5\r\n\r\nhello" + resp, err := ParseRawResponse([]byte(raw)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.StatusCode != 200 { + t.Fatalf("expected status 200, got %d", resp.StatusCode) + } + if ct := resp.Header.Get("Content-Type"); ct != "text/html" { + t.Fatalf("expected Content-Type text/html, got %s", ct) + } + }) + + t.Run("nil input", func(t *testing.T) { + _, err := ParseRawResponse(nil) + if err == nil { + t.Fatal("expected error for nil input") + } + }) + + t.Run("empty input", func(t *testing.T) { + _, err := ParseRawResponse([]byte{}) + if err == nil { + t.Fatal("expected error for empty input") + } + }) + + t.Run("status line only no headers", func(t *testing.T) { + raw := "HTTP/1.1 200 OK\r\n\r\n" + resp, err := ParseRawResponse([]byte(raw)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.StatusCode != 200 { + t.Fatalf("expected status 200, got %d", resp.StatusCode) + } + }) + + t.Run("truncated status line", func(t *testing.T) { + raw := "HTTP/1." + _, err := ParseRawResponse([]byte(raw)) + if err == nil { + t.Fatal("expected error for truncated status line") + } + }) + + t.Run("redirect response no body", func(t *testing.T) { + raw := "HTTP/1.1 302 Found\r\nLocation: https://example.com/new\r\n\r\n" + resp, err := ParseRawResponse([]byte(raw)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.StatusCode != 302 { + t.Fatalf("expected status 302, got %d", resp.StatusCode) + } + if loc := resp.Header.Get("Location"); loc != "https://example.com/new" { + t.Fatalf("expected Location https://example.com/new, got %s", loc) + } + }) + + t.Run("chunked transfer encoding header", func(t *testing.T) { + raw := "HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n" + resp, err := ParseRawResponse([]byte(raw)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if te := resp.Header.Get("Transfer-Encoding"); te == "" { + // Transfer-Encoding may be consumed by http.ReadResponse, just verify no panic + } + _ = resp + }) + + t.Run("large header value", func(t *testing.T) { + largeValue := strings.Repeat("A", 8192) + raw := "HTTP/1.1 200 OK\r\nX-Large: " + largeValue + "\r\n\r\n" + resp, err := ParseRawResponse([]byte(raw)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got := resp.Header.Get("X-Large"); got != largeValue { + t.Fatalf("large header value mismatch, got length %d", len(got)) + } + }) + + t.Run("invalid status code", func(t *testing.T) { + raw := "HTTP/1.1 xyz OK\r\n\r\n" + _, err := ParseRawResponse([]byte(raw)) + if err == nil { + t.Fatal("expected error for invalid status code") + } + }) + + t.Run("incomplete header no terminator", func(t *testing.T) { + raw := "HTTP/1.1 200 OK\r\nContent-Type: text/html" + _, err := ParseRawResponse([]byte(raw)) + // Should return error or at least not panic + // http.ReadResponse may or may not error on missing \r\n\r\n, + // the key requirement is no panic + _ = err + }) + + t.Run("binary body", func(t *testing.T) { + body := []byte{0x00, 0x01, 0x02, 0xff, 0xfe, 0xfd} + raw := append([]byte("HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\n"), body...) + resp, err := ParseRawResponse(raw) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.StatusCode != 200 { + t.Fatalf("expected status 200, got %d", resp.StatusCode) + } + }) +}