Skip to content

Commit

Permalink
refactor auto decompression
Browse files Browse the repository at this point in the history
  • Loading branch information
imroc committed Jun 11, 2024
1 parent 74b6ad7 commit 2069ef9
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 85 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package http3
package compress

import (
"github.com/andybalholm/brotli"
"io"

"github.com/andybalholm/brotli"
)

type BrotliReader struct {
Expand All @@ -11,7 +12,7 @@ type BrotliReader struct {
berr error // sticky error
}

func newBrotliReader(body io.ReadCloser) io.ReadCloser {
func NewBrotliReader(body io.ReadCloser) *BrotliReader {
return &BrotliReader{Body: body}
}

Expand All @@ -28,3 +29,11 @@ func (br *BrotliReader) Read(p []byte) (n int, err error) {
func (br *BrotliReader) Close() error {
return br.Body.Close()
}

func (br *BrotliReader) GetUnderlyingBody() io.ReadCloser {
return br.Body
}

func (br *BrotliReader) SetUnderlyingBody(body io.ReadCloser) {
br.Body = body
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package http3
package compress

import (
"compress/flate"
Expand All @@ -11,7 +11,7 @@ type DeflateReader struct {
derr error // sticky error
}

func newDeflateReader(body io.ReadCloser) io.ReadCloser {
func NewDeflateReader(body io.ReadCloser) *DeflateReader {
return &DeflateReader{Body: body}
}

Expand All @@ -21,20 +21,21 @@ func (df *DeflateReader) Read(p []byte) (n int, err error) {
}
if df.dr == nil {
df.dr = flate.NewReader(df.Body)
if df.dr == nil {
df.derr = io.ErrUnexpectedEOF
return 0, df.derr
}
}
return df.dr.Read(p)
}

func (df *DeflateReader) Close() error {
if df.dr != nil {
err := df.dr.Close()
if err != nil {
return err
}
return df.dr.Close()
}
return df.Body.Close()
}

func (df *DeflateReader) GetUnderlyingBody() io.ReadCloser {
return df.Body
}

func (df *DeflateReader) SetUnderlyingBody(body io.ReadCloser) {
df.Body = body
}
24 changes: 17 additions & 7 deletions internal/http3/gzip_reader.go → internal/compress/gzip_reader.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
package http3
package compress

// copied from net/transport.go

// GzipReader wraps a response body so it can lazily
// call gzip.NewReader on the first call to Read
import (
"compress/gzip"
"io"
"io/fs"
)

// GzipReader wraps a response body so it can lazily
// call gzip.NewReader on the first call to Read
type GzipReader struct {
Body io.ReadCloser // underlying Response.Body
zr *gzip.Reader // lazily-initialized gzip reader
zerr error // sticky error
}

func newGzipReader(body io.ReadCloser) io.ReadCloser {
func NewGzipReader(body io.ReadCloser) *GzipReader {
return &GzipReader{Body: body}
}

Expand All @@ -35,5 +33,17 @@ func (gz *GzipReader) Read(p []byte) (n int, err error) {
}

func (gz *GzipReader) Close() error {
return gz.Body.Close()
if err := gz.Body.Close(); err != nil {
return err
}
gz.zerr = fs.ErrClosed
return nil
}

func (gz *GzipReader) GetUnderlyingBody() io.ReadCloser {
return gz.Body
}

func (gz *GzipReader) SetUnderlyingBody(body io.ReadCloser) {
gz.Body = body
}
23 changes: 23 additions & 0 deletions internal/compress/reader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package compress

import "io"

type CompressReader interface {
io.ReadCloser
GetUnderlyingBody() io.ReadCloser
SetUnderlyingBody(body io.ReadCloser)
}

func NewCompressReader(body io.ReadCloser, contentEncoding string) CompressReader {
switch contentEncoding {
case "gzip":
return NewGzipReader(body)
case "deflate":
return NewDeflateReader(body)
case "br":
return NewBrotliReader(body)
case "zstd":
return NewZstdReader(body)
}
return nil
}
15 changes: 12 additions & 3 deletions internal/http3/zstd_reader.go → internal/compress/zstd_reader.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package http3
package compress

import (
"github.com/klauspost/compress/zstd"
"io"

"github.com/klauspost/compress/zstd"
)

type ZstdReader struct {
Expand All @@ -11,7 +12,7 @@ type ZstdReader struct {
zerr error // sticky error
}

func newZstdReader(body io.ReadCloser) io.ReadCloser {
func NewZstdReader(body io.ReadCloser) *ZstdReader {
return &ZstdReader{Body: body}
}

Expand All @@ -35,3 +36,11 @@ func (zr *ZstdReader) Close() error {
}
return zr.Body.Close()
}

func (zr *ZstdReader) GetUnderlyingBody() io.ReadCloser {
return zr.Body
}

func (zr *ZstdReader) SetUnderlyingBody(body io.ReadCloser) {
zr.Body = body
}
45 changes: 11 additions & 34 deletions internal/http2/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,12 @@ package http2
import (
"bufio"
"bytes"
"compress/gzip"
"context"
"crypto/rand"
"crypto/tls"
"errors"
"fmt"
"io"
"io/fs"
"log"
"math"
"math/bits"
Expand All @@ -40,6 +38,7 @@ import (
"github.com/imroc/req/v3/http2"
"github.com/imroc/req/v3/internal/ascii"
"github.com/imroc/req/v3/internal/common"
"github.com/imroc/req/v3/internal/compress"
"github.com/imroc/req/v3/internal/dump"
"github.com/imroc/req/v3/internal/header"
"github.com/imroc/req/v3/internal/netutil"
Expand Down Expand Up @@ -2568,8 +2567,17 @@ func (rl *clientConnReadLoop) handleResponse(cs *clientStream, f *MetaHeadersFra
res.Header.Del("Content-Encoding")
res.Header.Del("Content-Length")
res.ContentLength = -1
res.Body = &GzipReader{Body: res.Body}
res.Body = compress.NewGzipReader(res.Body)
res.Uncompressed = true
} else if cs.cc.t.AutoDecompression {
contentEncoding := res.Header.Get("Content-Encoding")
if contentEncoding != "" {
res.Header.Del("Content-Encoding")
res.Header.Del("Content-Length")
res.ContentLength = -1
res.Uncompressed = true
res.Body = compress.NewCompressReader(res.Body, contentEncoding)
}
}

return res, nil
Expand Down Expand Up @@ -3145,37 +3153,6 @@ func (rt erringRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
return nil, rt.err
}

// GzipReader wraps a response body so it can lazily
// call gzip.NewReader on the first call to Read
type GzipReader struct {
_ incomparable
Body io.ReadCloser // underlying Response.Body
zr *gzip.Reader // lazily-initialized gzip reader
zerr error // sticky error
}

func (gz *GzipReader) Read(p []byte) (n int, err error) {
if gz.zerr != nil {
return 0, gz.zerr
}
if gz.zr == nil {
gz.zr, err = gzip.NewReader(gz.Body)
if err != nil {
gz.zerr = err
return 0, err
}
}
return gz.zr.Read(p)
}

func (gz *GzipReader) Close() error {
if err := gz.Body.Close(); err != nil {
return err
}
gz.zerr = fs.ErrClosed
return nil
}

// isConnectionCloseRequest reports whether req should use its own
// connection for a single request and then close the connection.
func isConnectionCloseRequest(req *http.Request) bool {
Expand Down
29 changes: 5 additions & 24 deletions internal/http3/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"sync/atomic"
"time"

"github.com/imroc/req/v3/internal/compress"
"github.com/imroc/req/v3/internal/dump"
"github.com/imroc/req/v3/internal/quic-go/quicvarint"
"github.com/imroc/req/v3/internal/transport"
Expand Down Expand Up @@ -498,36 +499,16 @@ func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str qui
res.Header.Del("Content-Encoding")
res.Header.Del("Content-Length")
res.ContentLength = -1
res.Body = newGzipReader(respBody)
res.Body = compress.NewGzipReader(respBody)
res.Uncompressed = true
} else if c.opt.AutoDecompression {
switch res.Header.Get("Content-Encoding") {
case "gzip":
contentEncoding := res.Header.Get("Content-Encoding")
if contentEncoding != "" {
res.Header.Del("Content-Encoding")
res.Header.Del("Content-Length")
res.ContentLength = -1
res.Body = newGzipReader(respBody)
res.Uncompressed = true
case "deflate":
res.Header.Del("Content-Encoding")
res.Header.Del("Content-Length")
res.ContentLength = -1
res.Body = newDeflateReader(respBody)
res.Uncompressed = true
case "br":
res.Header.Del("Content-Encoding")
res.Header.Del("Content-Length")
res.ContentLength = -1
res.Body = newBrotliReader(respBody)
res.Uncompressed = true
case "zstd":
res.Header.Del("Content-Encoding")
res.Header.Del("Content-Length")
res.ContentLength = -1
res.Body = newZstdReader(respBody)
res.Uncompressed = true
default:
res.Uncompressed = false
res.Body = compress.NewCompressReader(respBody, contentEncoding)
}
} else {
res.Body = respBody
Expand Down
16 changes: 12 additions & 4 deletions transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
"github.com/imroc/req/v3/internal/altsvcutil"
"github.com/imroc/req/v3/internal/ascii"
"github.com/imroc/req/v3/internal/common"
"github.com/imroc/req/v3/internal/compress"
"github.com/imroc/req/v3/internal/dump"
"github.com/imroc/req/v3/internal/header"
h2internal "github.com/imroc/req/v3/internal/http2"
Expand Down Expand Up @@ -681,10 +682,8 @@ func (t *Transport) wrapResponseBody(res *http.Response, wrap wrapResponseBodyFu
switch b := res.Body.(type) {
case *gzipReader:
b.body.body = wrap(b.body.body)
case *h2internal.GzipReader:
b.Body = wrap(b.Body)
case *http3.GzipReader:
b.Body = wrap(b.Body)
case compress.CompressReader:
b.SetUnderlyingBody(wrap(b.GetUnderlyingBody()))
default:
res.Body = wrap(res.Body)
}
Expand Down Expand Up @@ -2731,6 +2730,15 @@ func (pc *persistConn) readLoop() {
resp.Header.Del("Content-Length")
resp.ContentLength = -1
resp.Uncompressed = true
} else if pc.t.AutoDecompression {
contentEncoding := resp.Header.Get("Content-Encoding")
if contentEncoding != "" {
resp.Header.Del("Content-Encoding")
resp.Header.Del("Content-Length")
resp.ContentLength = -1
resp.Uncompressed = true
resp.Body = compress.NewCompressReader(resp.Body, contentEncoding)
}
}

select {
Expand Down

0 comments on commit 2069ef9

Please sign in to comment.