diff --git a/internal/dump/dump.go b/internal/dump/dump.go index 231c4beb..0f71da7e 100644 --- a/internal/dump/dump.go +++ b/internal/dump/dump.go @@ -138,7 +138,6 @@ func NewDumper(opt Options) *Dumper { func (d *Dumper) SetOptions(opt Options) { d.Options = opt - return } func (d *Dumper) Clone() *Dumper { diff --git a/internal/http3/body.go b/internal/http3/body.go index 63ff4366..fa023ce4 100644 --- a/internal/http3/body.go +++ b/internal/http3/body.go @@ -2,68 +2,68 @@ package http3 import ( "context" + "errors" "io" - "net" "github.com/quic-go/quic-go" ) -// The HTTPStreamer allows taking over a HTTP/3 stream. The interface is implemented by: -// * for the server: the http.Request.Body -// * for the client: the http.Response.Body -// On the client side, the stream will be closed for writing, unless the DontCloseRequestStream RoundTripOpt was set. -// When a stream is taken over, it's the caller's responsibility to close the stream. -type HTTPStreamer interface { - HTTPStream() Stream -} - -type StreamCreator interface { - // Context returns a context that is cancelled when the underlying connection is closed. - Context() context.Context - OpenStream() (quic.Stream, error) - OpenStreamSync(context.Context) (quic.Stream, error) - OpenUniStream() (quic.SendStream, error) - OpenUniStreamSync(context.Context) (quic.SendStream, error) - LocalAddr() net.Addr - RemoteAddr() net.Addr - ConnectionState() quic.ConnectionState -} - -var _ StreamCreator = quic.Connection(nil) - // A Hijacker allows hijacking of the stream creating part of a quic.Session from a http.Response.Body. // It is used by WebTransport to create WebTransport streams after a session has been established. type Hijacker interface { - StreamCreator() StreamCreator + Connection() Connection } -// The body of a http.Request or http.Response. +var errTooMuchData = errors.New("peer sent too much data") + +// The body is used in the requestBody (for a http.Request) and the responseBody (for a http.Response). type body struct { - str quic.Stream + str *stream - wasHijacked bool // set when HTTPStream is called + remainingContentLength int64 + violatedContentLength bool + hasContentLength bool } -var ( - _ io.ReadCloser = &body{} - _ HTTPStreamer = &body{} -) - -func newRequestBody(str Stream) *body { - return &body{str: str} +func newBody(str *stream, contentLength int64) *body { + b := &body{str: str} + if contentLength >= 0 { + b.hasContentLength = true + b.remainingContentLength = contentLength + } + return b } -func (r *body) HTTPStream() Stream { - r.wasHijacked = true - return r.str -} +func (r *body) StreamID() quic.StreamID { return r.str.StreamID() } -func (r *body) wasStreamHijacked() bool { - return r.wasHijacked +func (r *body) checkContentLengthViolation() error { + if !r.hasContentLength { + return nil + } + if r.remainingContentLength < 0 || r.remainingContentLength == 0 && r.str.hasMoreData() { + if !r.violatedContentLength { + r.str.CancelRead(quic.StreamErrorCode(ErrCodeMessageError)) + r.str.CancelWrite(quic.StreamErrorCode(ErrCodeMessageError)) + r.violatedContentLength = true + } + return errTooMuchData + } + return nil } func (r *body) Read(b []byte) (int, error) { - return r.str.Read(b) + if err := r.checkContentLengthViolation(); err != nil { + return 0, err + } + if r.hasContentLength { + b = b[:min(int64(len(b)), r.remainingContentLength)] + } + n, err := r.str.Read(b) + r.remainingContentLength -= int64(n) + if err := r.checkContentLengthViolation(); err != nil { + return n, err + } + return n, maybeReplaceError(err) } func (r *body) Close() error { @@ -71,9 +71,26 @@ func (r *body) Close() error { return nil } -type hijackableBody struct { +type requestBody struct { body - conn quic.Connection // only needed to implement Hijacker + connCtx context.Context + rcvdSettings <-chan struct{} + getSettings func() *Settings +} + +var _ io.ReadCloser = &requestBody{} + +func newRequestBody(str *stream, contentLength int64, connCtx context.Context, rcvdSettings <-chan struct{}, getSettings func() *Settings) *requestBody { + return &requestBody{ + body: *newBody(str, contentLength), + connCtx: connCtx, + rcvdSettings: rcvdSettings, + getSettings: getSettings, + } +} + +type hijackableBody struct { + body body // only set for the http.Response // The channel is closed when the user is done with this response: @@ -82,31 +99,21 @@ type hijackableBody struct { reqDoneClosed bool } -var ( - _ Hijacker = &hijackableBody{} - _ HTTPStreamer = &hijackableBody{} -) +var _ io.ReadCloser = &hijackableBody{} -func newResponseBody(str Stream, conn quic.Connection, done chan<- struct{}) *hijackableBody { +func newResponseBody(str *stream, contentLength int64, done chan<- struct{}) *hijackableBody { return &hijackableBody{ - body: body{ - str: str, - }, + body: *newBody(str, contentLength), reqDone: done, - conn: conn, } } -func (r *hijackableBody) StreamCreator() StreamCreator { - return r.conn -} - func (r *hijackableBody) Read(b []byte) (int, error) { - n, err := r.str.Read(b) + n, err := r.body.Read(b) if err != nil { r.requestDone() } - return n, err + return n, maybeReplaceError(err) } func (r *hijackableBody) requestDone() { @@ -119,17 +126,9 @@ func (r *hijackableBody) requestDone() { r.reqDoneClosed = true } -func (r *body) StreamID() quic.StreamID { - return r.str.StreamID() -} - func (r *hijackableBody) Close() error { r.requestDone() // If the EOF was read, CancelRead() is a no-op. - r.str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)) + r.body.str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)) return nil } - -func (r *hijackableBody) HTTPStream() Stream { - return r.str -} diff --git a/internal/http3/client.go b/internal/http3/client.go index 8c85cf63..664f0f76 100644 --- a/internal/http3/client.go +++ b/internal/http3/client.go @@ -2,38 +2,34 @@ package http3 import ( "context" - "crypto/tls" "errors" - "fmt" "io" - "net" + "log/slog" "net/http" - "strconv" + "net/http/httptrace" + "net/textproto" "sync" - "sync/atomic" "time" "github.com/quic-go/qpack" "github.com/quic-go/quic-go" - "github.com/imroc/req/v3/internal/compress" "github.com/imroc/req/v3/internal/dump" "github.com/imroc/req/v3/internal/quic-go/quicvarint" "github.com/imroc/req/v3/internal/transport" ) -// MethodGet0RTT allows a GET request to be sent using 0-RTT. -// Note that 0-RTT data doesn't provide replay protection. -const MethodGet0RTT = "GET_0RTT" - const ( - defaultMaxResponseHeaderBytes = 10 * 1 << 20 // 10 MB + // MethodGet0RTT allows a GET request to be sent using 0-RTT. + // Note that 0-RTT doesn't provide replay protection and should only be used for idempotent requests. + MethodGet0RTT = "GET_0RTT" + // MethodHead0RTT allows a HEAD request to be sent using 0-RTT. + // Note that 0-RTT doesn't provide replay protection and should only be used for idempotent requests. + MethodHead0RTT = "HEAD_0RTT" ) const ( - VersionDraft29 quic.Version = 0xff00001d - Version1 quic.Version = 0x1 - Version2 quic.Version = 0x6b3343cf + defaultMaxResponseHeaderBytes = 10 * 1 << 20 // 10 MB ) var defaultQuicConfig = &quic.Config{ @@ -41,119 +37,70 @@ var defaultQuicConfig = &quic.Config{ KeepAlivePeriod: 10 * time.Second, } -type dialFunc func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) +// SingleDestinationRoundTripper is an HTTP/3 client doing requests to a single remote server. +type SingleDestinationRoundTripper struct { + *transport.Options -var dialAddr dialFunc = quic.DialAddrEarly + Connection quic.Connection -type roundTripperOpts struct { - DisableCompression bool - EnableDatagram bool - MaxHeaderBytes int64 + // Enable support for HTTP/3 datagrams (RFC 9297). + // If a QUICConfig is set, datagram support also needs to be enabled on the QUIC layer by setting EnableDatagrams. + EnableDatagrams bool + + // Additional HTTP/3 settings. + // It is invalid to specify any settings defined by RFC 9114 (HTTP/3) and RFC 9297 (HTTP Datagrams). AdditionalSettings map[uint64]uint64 - StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error) - UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool) - dump *dump.Dumper -} + StreamHijacker func(FrameType, quic.ConnectionTracingID, quic.Stream, error) (hijacked bool, err error) + UniStreamHijacker func(ServerStreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool) -// client is a HTTP3 client doing requests -type client struct { - tlsConf *tls.Config - config *quic.Config - opts *roundTripperOpts + // MaxResponseHeaderBytes specifies a limit on how many response bytes are + // allowed in the server's response header. + // Zero means to use a default limit. + MaxResponseHeaderBytes int64 - dialOnce sync.Once - dialer dialFunc - handshakeErr error + Logger *slog.Logger + initOnce sync.Once + hconn *connection requestWriter *requestWriter - - decoder *qpack.Decoder - - hostname string - conn atomic.Pointer[quic.EarlyConnection] - - opt *transport.Options + decoder *qpack.Decoder } -var _ roundTripCloser = &client{} +var _ http.RoundTripper = &SingleDestinationRoundTripper{} -func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc, opt *transport.Options) (roundTripCloser, error) { - if conf == nil { - conf = defaultQuicConfig.Clone() - } - if len(conf.Versions) == 0 { - conf = conf.Clone() - conf.Versions = []quic.Version{Version1} - } - if len(conf.Versions) != 1 { - return nil, errors.New("can only use a single QUIC version for dialing a HTTP/3 connection") - } - if conf.MaxIncomingStreams == 0 { - conf.MaxIncomingStreams = -1 // don't allow any bidirectional streams - } - conf.EnableDatagrams = opts.EnableDatagram - var debugf func(format string, v ...interface{}) - if opt != nil && opt.Debugf != nil { - debugf = opt.Debugf - } - - if tlsConf == nil { - tlsConf = &tls.Config{} - } else { - tlsConf = tlsConf.Clone() - } - if tlsConf.ServerName == "" { - sni, _, err := net.SplitHostPort(hostname) - if err != nil { - // It's ok if net.SplitHostPort returns an error - it could be a hostname/IP address without a port. - sni = hostname - } - tlsConf.ServerName = sni - } - // Replace existing ALPNs by H3 - tlsConf.NextProtos = []string{versionToALPN(conf.Versions[0])} - - return &client{ - hostname: authorityAddr("https", hostname), - tlsConf: tlsConf, - requestWriter: newRequestWriter(debugf), - decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}), - config: conf, - opts: opts, - dialer: dialer, - opt: opt, - }, nil +func (c *SingleDestinationRoundTripper) Start() Connection { + c.initOnce.Do(func() { c.init() }) + return c.hconn } -func (c *client) dial(ctx context.Context) error { - var err error - var conn quic.EarlyConnection - if c.dialer != nil { - conn, err = c.dialer(ctx, c.hostname, c.tlsConf, c.config) - } else { - conn, err = dialAddr(ctx, c.hostname, c.tlsConf, c.config) - } - if err != nil { - return err - } - c.conn.Store(&conn) - +func (c *SingleDestinationRoundTripper) init() { + c.decoder = qpack.NewDecoder(func(hf qpack.HeaderField) {}) + c.requestWriter = newRequestWriter() + c.hconn = newConnection( + c.Connection.Context(), + c.Connection, + c.EnableDatagrams, + PerspectiveClient, + c.Logger, + 0, + c.Options, + ) // send the SETTINGs frame, using 0-RTT data, if possible go func() { - if err := c.setupConn(conn); err != nil { - c.opt.Debugf("setting up http3 connection failed: %s", err) - conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeInternalError), "") + if err := c.setupConn(c.hconn); err != nil { + if c.Logger != nil { + c.Logger.Debug("Setting up connection failed", "error", err) + } + c.hconn.CloseWithError(quic.ApplicationErrorCode(ErrCodeInternalError), "") } }() - - if c.opts.StreamHijacker != nil { - go c.handleBidirectionalStreams(conn) + if c.StreamHijacker != nil { + go c.handleBidirectionalStreams() } - go c.handleUnidirectionalStreams(conn) - return nil + go c.hconn.HandleUnidirectionalStreams(c.UniStreamHijacker) } -func (c *client) setupConn(conn quic.EarlyConnection) error { +func (c *SingleDestinationRoundTripper) setupConn(conn *connection) error { // open the control stream str, err := conn.OpenUniStream() if err != nil { @@ -162,108 +109,54 @@ func (c *client) setupConn(conn quic.EarlyConnection) error { b := make([]byte, 0, 64) b = quicvarint.Append(b, streamTypeControlStream) // send the SETTINGS frame - b = (&settingsFrame{Datagram: c.opts.EnableDatagram, Other: c.opts.AdditionalSettings}).Append(b) + b = (&settingsFrame{Datagram: c.EnableDatagrams, Other: c.AdditionalSettings}).Append(b) _, err = str.Write(b) return err } -func (c *client) handleBidirectionalStreams(conn quic.EarlyConnection) { +func (c *SingleDestinationRoundTripper) handleBidirectionalStreams() { for { - str, err := conn.AcceptStream(context.Background()) + str, err := c.hconn.AcceptStream(context.Background()) if err != nil { - c.opt.Debugf("accepting bidirectional stream failed: %s", err) - return - } - go func(str quic.Stream) { - _, err := parseNextFrame(str, func(ft FrameType, e error) (processed bool, err error) { - return c.opts.StreamHijacker(ft, conn, str, e) - }) - if err == errHijacked { - return + if c.Logger != nil { + c.Logger.Debug("accepting bidirectional stream failed", "error", err) } - if err != nil { - c.opt.Debugf("error handling stream: %s", err) - } - conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "received HTTP/3 frame on bidirectional stream") - }(str) - } -} - -func (c *client) handleUnidirectionalStreams(conn quic.EarlyConnection) { - for { - str, err := conn.AcceptUniStream(context.Background()) - if err != nil { - c.opt.Debugf("accepting unidirectional stream failed: %s", err) return } - - go func(str quic.ReceiveStream) { - streamType, err := quicvarint.Read(quicvarint.NewReader(str)) - if err != nil { - if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), conn, str, err) { - return - } - c.opt.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err) - return - } - // We're only interested in the control stream here. - switch streamType { - case streamTypeControlStream: - case streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream: - // Our QPACK implementation doesn't use the dynamic table yet. - // TODO: check that only one stream of each type is opened. - return - case streamTypePushStream: - // We never increased the Push ID, so we don't expect any push streams. - conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), "") - return - default: - if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), conn, str, nil) { - return - } - str.CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)) + fp := &frameParser{ + r: str, + conn: c.hconn, + unknownFrameHandler: func(ft FrameType, e error) (processed bool, err error) { + id := c.hconn.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID) + return c.StreamHijacker(ft, id, str, e) + }, + } + go func() { + if _, err := fp.ParseNext(); err == errHijacked { return } - f, err := parseNextFrame(str, nil) if err != nil { - conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "") - return - } - sf, ok := f.(*settingsFrame) - if !ok { - conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), "") - return - } - if !sf.Datagram { - return - } - // If datagram support was enabled on our side as well as on the server side, - // we can expect it to have been negotiated both on the transport and on the HTTP/3 layer. - // Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT). - if c.opts.EnableDatagram && !conn.ConnectionState().SupportsDatagrams { - conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support") + if c.Logger != nil { + c.Logger.Debug("error handling stream", "error", err) + } } - }(str) - } -} - -func (c *client) Close() error { - conn := c.conn.Load() - if conn == nil { - return nil + c.hconn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "received HTTP/3 frame on bidirectional stream") + }() } - return (*conn).CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "") } -func (c *client) maxHeaderBytes() uint64 { - if c.opts.MaxHeaderBytes <= 0 { +func (c *SingleDestinationRoundTripper) maxHeaderBytes() uint64 { + if c.MaxResponseHeaderBytes <= 0 { return defaultMaxResponseHeaderBytes } - return uint64(c.opts.MaxHeaderBytes) + return uint64(c.MaxResponseHeaderBytes) } -func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { - rsp, err := c.roundTripOpt(req, opt) +// RoundTrip executes a request and returns a response +func (c *SingleDestinationRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + c.initOnce.Do(func() { c.init() }) + + rsp, err := c.roundTrip(req) if err != nil && req.Context().Err() != nil { // if the context was canceled, return the context cancellation error err = req.Context().Err() @@ -271,35 +164,48 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon return rsp, err } -// RoundTripOpt executes a request and returns a response -func (c *client) roundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { - if authorityAddr("https", hostnameFromRequest(req)) != c.hostname { - return nil, fmt.Errorf("http3 client BUG: RoundTripOpt called for the wrong client (expected %s, got %s)", c.hostname, req.Host) - } - - c.dialOnce.Do(func() { - c.handshakeErr = c.dial(req.Context()) - }) - if c.handshakeErr != nil { - return nil, c.handshakeErr - } - - // At this point, c.conn is guaranteed to be set. - conn := *c.conn.Load() - +func (c *SingleDestinationRoundTripper) roundTrip(req *http.Request) (*http.Response, error) { // Immediately send out this request, if this is a 0-RTT request. - if req.Method == MethodGet0RTT { + switch req.Method { + case MethodGet0RTT: + // don't modify the original request + reqCopy := *req + req = &reqCopy req.Method = http.MethodGet - } else { + case MethodHead0RTT: + // don't modify the original request + reqCopy := *req + req = &reqCopy + req.Method = http.MethodHead + default: // wait for the handshake to complete + earlyConn, ok := c.Connection.(quic.EarlyConnection) + if ok { + select { + case <-earlyConn.HandshakeComplete(): + case <-req.Context().Done(): + return nil, req.Context().Err() + } + } + } + + // It is only possible to send an Extended CONNECT request once the SETTINGS were received. + // See section 3 of RFC 8441. + if isExtendedConnectRequest(req) { + connCtx := c.Connection.Context() + // wait for the server's SETTINGS frame to arrive select { - case <-conn.HandshakeComplete(): - case <-req.Context().Done(): - return nil, req.Context().Err() + case <-c.hconn.ReceivedSettings(): + case <-connCtx.Done(): + return nil, context.Cause(connCtx) + } + if !c.hconn.Settings().EnableExtendedConnect { + return nil, errors.New("http3: server didn't enable Extended CONNECT") } } - str, err := conn.OpenStreamSync(req.Context()) + reqDone := make(chan struct{}) + str, err := c.hconn.openRequestStream(req.Context(), c.requestWriter, reqDone, c.DisableCompression, c.maxHeaderBytes()) if err != nil { return nil, err } @@ -307,7 +213,6 @@ func (c *client) roundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon // Request Cancellation: // This go routine keeps running even after RoundTripOpt() returns. // It is shut down when the application is done processing the body. - reqDone := make(chan struct{}) done := make(chan struct{}) go func() { defer close(done) @@ -319,214 +224,110 @@ func (c *client) roundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon } }() - doneChan := reqDone - if opt.DontCloseRequestStream { - doneChan = nil - } - rsp, rerr := c.doRequest(req, conn, str, opt, doneChan) - if rerr.err != nil { // if any error occurred + rsp, err := c.doRequest(req, str) + if err != nil { // if any error occurred close(reqDone) <-done - if rerr.streamErr != 0 { // if it was a stream error - str.CancelWrite(quic.StreamErrorCode(rerr.streamErr)) - } - if rerr.connErr != 0 { // if it was a connection error - var reason string - if rerr.err != nil { - reason = rerr.err.Error() - } - conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason) - } - return nil, maybeReplaceError(rerr.err) - + return nil, maybeReplaceError(err) } - if opt.DontCloseRequestStream { - close(reqDone) - <-done + return rsp, maybeReplaceError(err) +} + +func (c *SingleDestinationRoundTripper) OpenRequestStream(ctx context.Context) (RequestStream, error) { + c.initOnce.Do(func() { c.init() }) + + return c.hconn.openRequestStream(ctx, c.requestWriter, nil, c.DisableCompression, c.maxHeaderBytes()) +} + +// cancelingReader reads from the io.Reader. +// It cancels writing on the stream if any error other than io.EOF occurs. +type cancelingReader struct { + r io.Reader + str Stream +} + +func (r *cancelingReader) Read(b []byte) (int, error) { + n, err := r.r.Read(b) + if err != nil && err != io.EOF { + r.str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)) } - return rsp, maybeReplaceError(rerr.err) + return n, err } -func (c *client) sendRequestBody(str Stream, body io.ReadCloser, dumps []*dump.Dumper) error { +func (c *SingleDestinationRoundTripper) sendRequestBody(str Stream, body io.ReadCloser, dumps []*dump.Dumper) error { defer body.Close() - b := make([]byte, bodyCopyBufferSize) - writeData := func(data []byte) error { - if _, err := str.Write(data); err != nil { - return err - } - return nil - } + buf := make([]byte, bodyCopyBufferSize) + sr := &cancelingReader{str: str, r: body} + var w io.Writer = str if len(dumps) > 0 { - writeData = func(data []byte) error { - for _, dump := range dumps { - dump.DumpRequestBody(data) - } - if _, err := str.Write(data); err != nil { - return err - } - return nil + for _, d := range dumps { + w = io.MultiWriter(w, d.RequestBodyOutput()) } } - for { - n, rerr := body.Read(b) - if n == 0 { - if rerr == nil { - continue - } - if rerr == io.EOF { - for _, dump := range dumps { - dump.DumpDefault([]byte("\r\n\r\n")) - } - break - } - } - if err := writeData(b[:n]); err != nil { - return err - } - if rerr != nil { - if rerr == io.EOF { - for _, dump := range dumps { - dump.DumpDefault([]byte("\r\n\r\n")) - } - break - } - str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)) - return rerr + writeTail := func() { + for _, d := range dumps { + d.Output().Write([]byte("\r\n\r\n")) } } - return nil + written, err := io.CopyBuffer(w, sr, buf) + if len(dumps) > 0 && err == nil && written > 0 { + writeTail() + } + + return err } -func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str quic.Stream, opt RoundTripOpt, reqDone chan<- struct{}) (*http.Response, requestError) { - var requestGzip bool - if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" { - requestGzip = true - } - dumps := dump.GetDumpers(req.Context(), c.opts.dump) - var headerDumps []*dump.Dumper - for _, dump := range dumps { - if dump.RequestHeader() { - headerDumps = append(headerDumps, dump) - } - } - if err := c.requestWriter.WriteRequestHeader(str, req, requestGzip, headerDumps); err != nil { - return nil, newStreamError(ErrCodeInternalError, err) +func (c *SingleDestinationRoundTripper) doRequest(req *http.Request, str *requestStream) (*http.Response, error) { + if err := str.SendRequestHeader(req); err != nil { + return nil, err } - - if req.Body == nil && !opt.DontCloseRequestStream { + if req.Body == nil { str.Close() - } - - hstr := newStream(str, func() { conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "") }) - if req.Body != nil { + } else { // send the request body asynchronously go func() { - var bodyDumps []*dump.Dumper - for _, dump := range dumps { - if dump.RequestBody() { - bodyDumps = append(bodyDumps, dump) + dumps := dump.GetDumpers(req.Context(), c.Dump) + if err := c.sendRequestBody(str, req.Body, dumps); err != nil { + if c.Logger != nil { + c.Logger.Debug("error writing request", "error", err) } } - if err := c.sendRequestBody(hstr, req.Body, bodyDumps); err != nil { - c.opt.Debugf("error writing request: %s", err) - } - if !opt.DontCloseRequestStream { - hstr.Close() - } + str.Close() }() } - frame, err := parseNextFrame(str, nil) - if err != nil { - return nil, newStreamError(ErrCodeFrameError, err) - } - hf, ok := frame.(*headersFrame) - if !ok { - return nil, newConnError(ErrCodeFrameUnexpected, errors.New("expected first frame to be a HEADERS frame")) - } - if hf.Length > c.maxHeaderBytes() { - return nil, newStreamError(ErrCodeFrameError, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", hf.Length, c.maxHeaderBytes())) - } - headerBlock := make([]byte, hf.Length) - if _, err := io.ReadFull(str, headerBlock); err != nil { - return nil, newStreamError(ErrCodeRequestIncomplete, err) - } - var respHeaderDumps []*dump.Dumper - for _, dump := range dumps { - if dump.ResponseHeader() { - respHeaderDumps = append(respHeaderDumps, dump) + // copy from net/http: support 1xx responses + trace := httptrace.ContextClientTrace(req.Context()) + num1xx := 0 // number of informational 1xx headers received + const max1xxResponses = 5 // arbitrary bound on number of informational responses + + var res *http.Response + for { + var err error + res, err = str.ReadResponse() + if err != nil { + return nil, err } - } - hfs, err := c.decoder.DecodeFull(headerBlock) - if len(respHeaderDumps) > 0 { - for _, hf := range hfs { - for _, dump := range respHeaderDumps { - dump.DumpResponseHeader([]byte(fmt.Sprintf("%s: %s\r\n", hf.Name, hf.Value))) + resCode := res.StatusCode + is1xx := 100 <= resCode && resCode <= 199 + // treat 101 as a terminal status, see https://github.com/golang/go/issues/26161 + is1xxNonTerminal := is1xx && resCode != http.StatusSwitchingProtocols + if is1xxNonTerminal { + num1xx++ + if num1xx > max1xxResponses { + return nil, errors.New("http: too many 1xx informational responses") } - } - for _, dump := range respHeaderDumps { - dump.DumpResponseHeader([]byte("\r\n")) - } - } - if err != nil { - // TODO: use the right error code - return nil, newConnError(ErrCodeGeneralProtocolError, err) - } - - res, err := responseFromHeaders(hfs) - if err != nil { - return nil, newStreamError(ErrCodeMessageError, err) - } - res.Request = req - connState := conn.ConnectionState().TLS - res.TLS = &connState - respBody := newResponseBody(hstr, conn, reqDone) - - // Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2. - _, hasTransferEncoding := res.Header["Transfer-Encoding"] - isInformational := res.StatusCode >= 100 && res.StatusCode < 200 - isNoContent := res.StatusCode == http.StatusNoContent - isSuccessfulConnect := req.Method == http.MethodConnect && res.StatusCode >= 200 && res.StatusCode < 300 - if !hasTransferEncoding && !isInformational && !isNoContent && !isSuccessfulConnect { - res.ContentLength = -1 - if clens, ok := res.Header["Content-Length"]; ok && len(clens) == 1 { - if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil { - res.ContentLength = clen64 + if trace != nil && trace.Got1xxResponse != nil { + if err := trace.Got1xxResponse(resCode, textproto.MIMEHeader(res.Header)); err != nil { + return nil, err + } } + continue } + break } - - if requestGzip && res.Header.Get("Content-Encoding") == "gzip" { - res.Header.Del("Content-Encoding") - res.Header.Del("Content-Length") - res.ContentLength = -1 - res.Body = compress.NewGzipReader(respBody) - res.Uncompressed = true - } else if c.opt.AutoDecompression { - contentEncoding := res.Header.Get("Content-Encoding") - if contentEncoding != "" { - res.Header.Del("Content-Encoding") - res.Header.Del("Content-Length") - res.ContentLength = -1 - res.Uncompressed = true - res.Body = compress.NewCompressReader(respBody, contentEncoding) - } - } else { - res.Body = respBody - } - - return res, requestError{} -} - -func (c *client) HandshakeComplete() bool { - conn := c.conn.Load() - if conn == nil { - return false - } - select { - case <-(*conn).HandshakeComplete(): - return true - default: - return false - } + connState := c.hconn.ConnectionState().TLS + res.TLS = &connState + res.Request = req + return res, nil } diff --git a/internal/http3/conn.go b/internal/http3/conn.go new file mode 100644 index 00000000..60f0e259 --- /dev/null +++ b/internal/http3/conn.go @@ -0,0 +1,328 @@ +package http3 + +import ( + "context" + "fmt" + "io" + "log/slog" + "net" + "net/http" + "sync" + "sync/atomic" + "time" + + "github.com/imroc/req/v3/internal/transport" + "github.com/quic-go/quic-go" + "github.com/quic-go/quic-go/quicvarint" + + "github.com/quic-go/qpack" +) + +// Connection is an HTTP/3 connection. +// It has all methods from the quic.Connection expect for AcceptStream, AcceptUniStream, +// SendDatagram and ReceiveDatagram. +type Connection interface { + OpenStream() (quic.Stream, error) + OpenStreamSync(context.Context) (quic.Stream, error) + OpenUniStream() (quic.SendStream, error) + OpenUniStreamSync(context.Context) (quic.SendStream, error) + LocalAddr() net.Addr + RemoteAddr() net.Addr + CloseWithError(quic.ApplicationErrorCode, string) error + Context() context.Context + ConnectionState() quic.ConnectionState + + // ReceivedSettings returns a channel that is closed once the client's SETTINGS frame was received. + ReceivedSettings() <-chan struct{} + // Settings returns the settings received on this connection. + Settings() *Settings +} + +type connection struct { + quic.Connection + *transport.Options + ctx context.Context + + perspective Perspective + logger *slog.Logger + + enableDatagrams bool + + decoder *qpack.Decoder + + streamMx sync.Mutex + streams map[quic.StreamID]*datagrammer + + settings *Settings + receivedSettings chan struct{} + + idleTimeout time.Duration + idleTimer *time.Timer +} + +func newConnection( + ctx context.Context, + quicConn quic.Connection, + enableDatagrams bool, + perspective Perspective, + logger *slog.Logger, + idleTimeout time.Duration, + options *transport.Options, +) *connection { + c := &connection{ + ctx: ctx, + Connection: quicConn, + Options: options, + perspective: perspective, + logger: logger, + idleTimeout: idleTimeout, + enableDatagrams: enableDatagrams, + decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}), + receivedSettings: make(chan struct{}), + streams: make(map[quic.StreamID]*datagrammer), + } + if idleTimeout > 0 { + c.idleTimer = time.AfterFunc(idleTimeout, c.onIdleTimer) + } + return c +} + +func (c *connection) onIdleTimer() { + c.CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "idle timeout") +} + +func (c *connection) clearStream(id quic.StreamID) { + c.streamMx.Lock() + defer c.streamMx.Unlock() + + delete(c.streams, id) + if c.idleTimeout > 0 && len(c.streams) == 0 { + c.idleTimer.Reset(c.idleTimeout) + } +} + +func (c *connection) openRequestStream( + ctx context.Context, + requestWriter *requestWriter, + reqDone chan<- struct{}, + disableCompression bool, + maxHeaderBytes uint64, +) (*requestStream, error) { + str, err := c.Connection.OpenStreamSync(ctx) + if err != nil { + return nil, err + } + datagrams := newDatagrammer(func(b []byte) error { return c.sendDatagram(str.StreamID(), b) }) + c.streamMx.Lock() + c.streams[str.StreamID()] = datagrams + c.streamMx.Unlock() + qstr := newStateTrackingStream(str, c, datagrams) + rsp := &http.Response{} + hstr := newStream(qstr, c, datagrams, func(r io.Reader, l uint64) error { + hdr, err := c.decodeTrailers(r, l, maxHeaderBytes) + if err != nil { + return err + } + rsp.Trailer = hdr + return nil + }) + return newRequestStream(ctx, c.Options, hstr, requestWriter, reqDone, c.decoder, disableCompression, maxHeaderBytes, rsp), nil +} + +func (c *connection) decodeTrailers(r io.Reader, l, maxHeaderBytes uint64) (http.Header, error) { + if l > maxHeaderBytes { + return nil, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", l, maxHeaderBytes) + } + + b := make([]byte, l) + if _, err := io.ReadFull(r, b); err != nil { + return nil, err + } + fields, err := c.decoder.DecodeFull(b) + if err != nil { + return nil, err + } + return parseTrailers(fields) +} + +func (c *connection) acceptStream(ctx context.Context) (quic.Stream, *datagrammer, error) { + str, err := c.AcceptStream(ctx) + if err != nil { + return nil, nil, err + } + datagrams := newDatagrammer(func(b []byte) error { return c.sendDatagram(str.StreamID(), b) }) + if c.perspective == PerspectiveServer { + strID := str.StreamID() + c.streamMx.Lock() + c.streams[strID] = datagrams + if c.idleTimeout > 0 { + if len(c.streams) == 1 { + c.idleTimer.Stop() + } + } + c.streamMx.Unlock() + str = newStateTrackingStream(str, c, datagrams) + } + return str, datagrams, nil +} + +func (c *connection) CloseWithError(code quic.ApplicationErrorCode, msg string) error { + if c.idleTimer != nil { + c.idleTimer.Stop() + } + return c.Connection.CloseWithError(code, msg) +} + +func (c *connection) HandleUnidirectionalStreams(hijack func(ServerStreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool)) { + var ( + rcvdControlStr atomic.Bool + rcvdQPACKEncoderStr atomic.Bool + rcvdQPACKDecoderStr atomic.Bool + ) + + for { + str, err := c.Connection.AcceptUniStream(context.Background()) + if err != nil { + if c.logger != nil { + c.logger.Debug("accepting unidirectional stream failed", "error", err) + } + return + } + + go func(str quic.ReceiveStream) { + streamType, err := quicvarint.Read(quicvarint.NewReader(str)) + if err != nil { + id := c.Connection.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID) + if hijack != nil && hijack(ServerStreamType(streamType), id, str, err) { + return + } + if c.logger != nil { + c.logger.Debug("reading stream type on stream failed", "stream ID", str.StreamID(), "error", err) + } + return + } + // We're only interested in the control stream here. + switch streamType { + case streamTypeControlStream: + case streamTypeQPACKEncoderStream: + if isFirst := rcvdQPACKEncoderStr.CompareAndSwap(false, true); !isFirst { + c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate QPACK encoder stream") + } + // Our QPACK implementation doesn't use the dynamic table yet. + return + case streamTypeQPACKDecoderStream: + if isFirst := rcvdQPACKDecoderStr.CompareAndSwap(false, true); !isFirst { + c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate QPACK decoder stream") + } + // Our QPACK implementation doesn't use the dynamic table yet. + return + case streamTypePushStream: + switch c.perspective { + case PerspectiveClient: + // we never increased the Push ID, so we don't expect any push streams + c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), "") + case PerspectiveServer: + // only the server can push + c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "") + } + return + default: + if hijack != nil { + if hijack( + ServerStreamType(streamType), + c.Connection.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID), + str, + nil, + ) { + return + } + } + str.CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)) + return + } + // Only a single control stream is allowed. + if isFirstControlStr := rcvdControlStr.CompareAndSwap(false, true); !isFirstControlStr { + c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream") + return + } + fp := &frameParser{conn: c.Connection, r: str} + f, err := fp.ParseNext() + if err != nil { + c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "") + return + } + sf, ok := f.(*settingsFrame) + if !ok { + c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), "") + return + } + c.settings = &Settings{ + EnableDatagrams: sf.Datagram, + EnableExtendedConnect: sf.ExtendedConnect, + Other: sf.Other, + } + close(c.receivedSettings) + if !sf.Datagram { + return + } + // If datagram support was enabled on our side as well as on the server side, + // we can expect it to have been negotiated both on the transport and on the HTTP/3 layer. + // Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT). + if c.enableDatagrams && !c.Connection.ConnectionState().SupportsDatagrams { + c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support") + return + } + go func() { + if err := c.receiveDatagrams(); err != nil { + if c.logger != nil { + c.logger.Debug("receiving datagrams failed", "error", err) + } + } + }() + }(str) + } +} + +func (c *connection) sendDatagram(streamID quic.StreamID, b []byte) error { + // TODO: this creates a lot of garbage and an additional copy + data := make([]byte, 0, len(b)+8) + data = quicvarint.Append(data, uint64(streamID/4)) + data = append(data, b...) + return c.Connection.SendDatagram(data) +} + +func (c *connection) receiveDatagrams() error { + for { + b, err := c.Connection.ReceiveDatagram(context.Background()) + if err != nil { + return err + } + quarterStreamID, n, err := quicvarint.Parse(b) + if err != nil { + c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeDatagramError), "") + return fmt.Errorf("could not read quarter stream id: %w", err) + } + if quarterStreamID > maxQuarterStreamID { + c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeDatagramError), "") + return fmt.Errorf("invalid quarter stream id: %w", err) + } + streamID := quic.StreamID(4 * quarterStreamID) + c.streamMx.Lock() + dg, ok := c.streams[streamID] + if !ok { + c.streamMx.Unlock() + return nil + } + c.streamMx.Unlock() + dg.enqueue(b[n:]) + } +} + +// ReceivedSettings returns a channel that is closed once the peer's SETTINGS frame was received. +func (c *connection) ReceivedSettings() <-chan struct{} { return c.receivedSettings } + +// Settings returns the settings received on this connection. +// It is only valid to call this function after the channel returned by ReceivedSettings was closed. +func (c *connection) Settings() *Settings { return c.settings } + +func (c *connection) Context() context.Context { return c.ctx } diff --git a/internal/http3/datagram.go b/internal/http3/datagram.go new file mode 100644 index 00000000..6d570e6b --- /dev/null +++ b/internal/http3/datagram.go @@ -0,0 +1,98 @@ +package http3 + +import ( + "context" + "sync" +) + +const maxQuarterStreamID = 1<<60 - 1 + +const streamDatagramQueueLen = 32 + +type datagrammer struct { + sendDatagram func([]byte) error + + hasData chan struct{} + queue [][]byte // TODO: use a ring buffer + + mx sync.Mutex + sendErr error + receiveErr error +} + +func newDatagrammer(sendDatagram func([]byte) error) *datagrammer { + return &datagrammer{ + sendDatagram: sendDatagram, + hasData: make(chan struct{}, 1), + } +} + +func (d *datagrammer) SetReceiveError(err error) { + d.mx.Lock() + defer d.mx.Unlock() + + d.receiveErr = err + d.signalHasData() +} + +func (d *datagrammer) SetSendError(err error) { + d.mx.Lock() + defer d.mx.Unlock() + + d.sendErr = err +} + +func (d *datagrammer) Send(b []byte) error { + d.mx.Lock() + sendErr := d.sendErr + d.mx.Unlock() + if sendErr != nil { + return sendErr + } + + return d.sendDatagram(b) +} + +func (d *datagrammer) signalHasData() { + select { + case d.hasData <- struct{}{}: + default: + } +} + +func (d *datagrammer) enqueue(data []byte) { + d.mx.Lock() + defer d.mx.Unlock() + + if d.receiveErr != nil { + return + } + if len(d.queue) >= streamDatagramQueueLen { + return + } + d.queue = append(d.queue, data) + d.signalHasData() +} + +func (d *datagrammer) Receive(ctx context.Context) ([]byte, error) { +start: + d.mx.Lock() + if len(d.queue) >= 1 { + data := d.queue[0] + d.queue = d.queue[1:] + d.mx.Unlock() + return data, nil + } + if receiveErr := d.receiveErr; receiveErr != nil { + d.mx.Unlock() + return nil, receiveErr + } + d.mx.Unlock() + + select { + case <-ctx.Done(): + return nil, context.Cause(ctx) + case <-d.hasData: + } + goto start +} diff --git a/internal/http3/frames.go b/internal/http3/frames.go index a3cd88ad..b2d59a52 100644 --- a/internal/http3/frames.go +++ b/internal/http3/frames.go @@ -7,6 +7,7 @@ import ( "io" "github.com/imroc/req/v3/internal/quic-go/quicvarint" + "github.com/quic-go/quic-go" ) // FrameType is the frame type of a HTTP/3 frame @@ -18,13 +19,19 @@ type frame interface{} var errHijacked = errors.New("hijacked") -func parseNextFrame(r io.Reader, unknownFrameHandler unknownFrameHandlerFunc) (frame, error) { - qr := quicvarint.NewReader(r) +type frameParser struct { + r io.Reader + conn quic.Connection + unknownFrameHandler unknownFrameHandlerFunc +} + +func (p *frameParser) ParseNext() (frame, error) { + qr := quicvarint.NewReader(p.r) for { t, err := quicvarint.Read(qr) if err != nil { - if unknownFrameHandler != nil { - hijacked, err := unknownFrameHandler(0, err) + if p.unknownFrameHandler != nil { + hijacked, err := p.unknownFrameHandler(0, err) if err != nil { return nil, err } @@ -35,8 +42,8 @@ func parseNextFrame(r io.Reader, unknownFrameHandler unknownFrameHandlerFunc) (f return nil, err } // Call the unknownFrameHandler for frames not defined in the HTTP/3 spec - if t > 0xd && unknownFrameHandler != nil { - hijacked, err := unknownFrameHandler(FrameType(t), nil) + if t > 0xd && p.unknownFrameHandler != nil { + hijacked, err := p.unknownFrameHandler(FrameType(t), nil) if err != nil { return nil, err } @@ -56,11 +63,14 @@ func parseNextFrame(r io.Reader, unknownFrameHandler unknownFrameHandlerFunc) (f case 0x1: return &headersFrame{Length: l}, nil case 0x4: - return parseSettingsFrame(r, l) + return parseSettingsFrame(p.r, l) case 0x3: // CANCEL_PUSH case 0x5: // PUSH_PROMISE case 0x7: // GOAWAY case 0xd: // MAX_PUSH_ID + case 0x2, 0x6, 0x8, 0x9: + p.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "") + return nil, fmt.Errorf("http3: reserved frame type: %d", t) } // skip over unknown frames if _, err := io.CopyN(io.Discard, qr, int64(l)); err != nil { @@ -87,11 +97,18 @@ func (f *headersFrame) Append(b []byte) []byte { return quicvarint.Append(b, f.Length) } -const settingDatagram = 0x33 +const ( + // Extended CONNECT, RFC 9220 + settingExtendedConnect = 0x8 + // HTTP Datagrams, RFC 9297 + settingDatagram = 0x33 +) type settingsFrame struct { - Datagram bool - Other map[uint64]uint64 // all settings that we don't explicitly recognize + Datagram bool // HTTP Datagrams, RFC 9297 + ExtendedConnect bool // Extended CONNECT, RFC 9220 + + Other map[uint64]uint64 // all settings that we don't explicitly recognize } func parseSettingsFrame(r io.Reader, l uint64) (*settingsFrame, error) { @@ -107,7 +124,7 @@ func parseSettingsFrame(r io.Reader, l uint64) (*settingsFrame, error) { } frame := &settingsFrame{} b := bytes.NewReader(buf) - var readDatagram bool + var readDatagram, readExtendedConnect bool for b.Len() > 0 { id, err := quicvarint.Read(b) if err != nil { // should not happen. We allocated the whole frame already. @@ -119,13 +136,22 @@ func parseSettingsFrame(r io.Reader, l uint64) (*settingsFrame, error) { } switch id { + case settingExtendedConnect: + if readExtendedConnect { + return nil, fmt.Errorf("duplicate setting: %d", id) + } + readExtendedConnect = true + if val != 0 && val != 1 { + return nil, fmt.Errorf("invalid value for SETTINGS_ENABLE_CONNECT_PROTOCOL: %d", val) + } + frame.ExtendedConnect = val == 1 case settingDatagram: if readDatagram { return nil, fmt.Errorf("duplicate setting: %d", id) } readDatagram = true if val != 0 && val != 1 { - return nil, fmt.Errorf("invalid value for H3_DATAGRAM: %d", val) + return nil, fmt.Errorf("invalid value for SETTINGS_H3_DATAGRAM: %d", val) } frame.Datagram = val == 1 default: @@ -143,18 +169,25 @@ func parseSettingsFrame(r io.Reader, l uint64) (*settingsFrame, error) { func (f *settingsFrame) Append(b []byte) []byte { b = quicvarint.Append(b, 0x4) - var l int64 + var l int for id, val := range f.Other { l += quicvarint.Len(id) + quicvarint.Len(val) } if f.Datagram { l += quicvarint.Len(settingDatagram) + quicvarint.Len(1) } + if f.ExtendedConnect { + l += quicvarint.Len(settingExtendedConnect) + quicvarint.Len(1) + } b = quicvarint.Append(b, uint64(l)) if f.Datagram { b = quicvarint.Append(b, settingDatagram) b = quicvarint.Append(b, 1) } + if f.ExtendedConnect { + b = quicvarint.Append(b, settingExtendedConnect) + b = quicvarint.Append(b, 1) + } for id, val := range f.Other { b = quicvarint.Append(b, id) b = quicvarint.Append(b, val) diff --git a/internal/http3/headers.go b/internal/http3/headers.go index 2eb5ca29..cbd79ecd 100644 --- a/internal/http3/headers.go +++ b/internal/http3/headers.go @@ -3,14 +3,17 @@ package http3 import ( "errors" "fmt" - "github.com/quic-go/qpack" - "golang.org/x/net/http/httpguts" "net/http" + "net/textproto" + "net/url" "strconv" "strings" + + "github.com/quic-go/qpack" + "golang.org/x/net/http/httpguts" ) -type Header struct { +type header struct { // Pseudo header fields defined in RFC 9114 Path string Method string @@ -19,28 +22,37 @@ type Header struct { Status string // for Extended connect Protocol string - // parsed and deduplicated + // parsed and deduplicated. -1 if no Content-Length header is sent ContentLength int64 // all non-pseudo headers Headers http.Header } -func parseHeaders(headers []qpack.HeaderField, isRequest bool) (Header, error) { - hdr := Header{Headers: make(http.Header, len(headers))} +// connection-specific header fields must not be sent on HTTP/3 +var invalidHeaderFields = [...]string{ + "connection", + "keep-alive", + "proxy-connection", + "transfer-encoding", + "upgrade", +} + +func parseHeaders(headers []qpack.HeaderField, isRequest bool) (header, error) { + hdr := header{Headers: make(http.Header, len(headers))} var readFirstRegularHeader, readContentLength bool var contentLengthStr string for _, h := range headers { // field names need to be lowercase, see section 4.2 of RFC 9114 if strings.ToLower(h.Name) != h.Name { - return Header{}, fmt.Errorf("header field is not lower-case: %s", h.Name) + return header{}, fmt.Errorf("header field is not lower-case: %s", h.Name) } if !httpguts.ValidHeaderFieldValue(h.Value) { - return Header{}, fmt.Errorf("invalid header field value for %s: %q", h.Name, h.Value) + return header{}, fmt.Errorf("invalid header field value for %s: %q", h.Name, h.Value) } if h.IsPseudo() { if readFirstRegularHeader { // all pseudo headers must appear before regular header fields, see section 4.3 of RFC 9114 - return Header{}, fmt.Errorf("received pseudo header %s after a regular header field", h.Name) + return header{}, fmt.Errorf("received pseudo header %s after a regular header field", h.Name) } var isResponsePseudoHeader bool // pseudo headers are either valid for requests or for responses switch h.Name { @@ -58,17 +70,25 @@ func parseHeaders(headers []qpack.HeaderField, isRequest bool) (Header, error) { hdr.Status = h.Value isResponsePseudoHeader = true default: - return Header{}, fmt.Errorf("unknown pseudo header: %s", h.Name) + return header{}, fmt.Errorf("unknown pseudo header: %s", h.Name) } if isRequest && isResponsePseudoHeader { - return Header{}, fmt.Errorf("invalid request pseudo header: %s", h.Name) + return header{}, fmt.Errorf("invalid request pseudo header: %s", h.Name) } if !isRequest && !isResponsePseudoHeader { - return Header{}, fmt.Errorf("invalid response pseudo header: %s", h.Name) + return header{}, fmt.Errorf("invalid response pseudo header: %s", h.Name) } } else { if !httpguts.ValidHeaderFieldName(h.Name) { - return Header{}, fmt.Errorf("invalid header field name: %q", h.Name) + return header{}, fmt.Errorf("invalid header field name: %q", h.Name) + } + for _, invalidField := range invalidHeaderFields { + if h.Name == invalidField { + return header{}, fmt.Errorf("invalid header field name: %q", h.Name) + } + } + if h.Name == "te" && h.Value != "trailers" { + return header{}, fmt.Errorf("invalid TE header field value: %q", h.Value) } readFirstRegularHeader = true switch h.Name { @@ -79,18 +99,19 @@ func parseHeaders(headers []qpack.HeaderField, isRequest bool) (Header, error) { readContentLength = true contentLengthStr = h.Value } else if contentLengthStr != h.Value { - return Header{}, fmt.Errorf("contradicting content lengths (%s and %s)", contentLengthStr, h.Value) + return header{}, fmt.Errorf("contradicting content lengths (%s and %s)", contentLengthStr, h.Value) } default: hdr.Headers.Add(h.Name, h.Value) } } } + hdr.ContentLength = -1 if len(contentLengthStr) > 0 { // use ParseUint instead of ParseInt, so that parsing fails on negative values cl, err := strconv.ParseUint(contentLengthStr, 10, 63) if err != nil { - return Header{}, fmt.Errorf("invalid content length: %w", err) + return header{}, fmt.Errorf("invalid content length: %w", err) } hdr.Headers.Set("Content-Length", contentLengthStr) hdr.ContentLength = int64(cl) @@ -98,32 +119,141 @@ func parseHeaders(headers []qpack.HeaderField, isRequest bool) (Header, error) { return hdr, nil } -func hostnameFromRequest(req *http.Request) string { - if req.URL != nil { - return req.URL.Host +func parseTrailers(headers []qpack.HeaderField) (http.Header, error) { + h := make(http.Header, len(headers)) + for _, field := range headers { + if field.IsPseudo() { + return nil, fmt.Errorf("http3: received pseudo header in trailer: %s", field.Name) + } + h.Add(field.Name, field.Value) } - return "" + return h, nil } -func responseFromHeaders(headerFields []qpack.HeaderField) (*http.Response, error) { - hdr, err := parseHeaders(headerFields, false) +func requestFromHeaders(headerFields []qpack.HeaderField) (*http.Request, error) { + hdr, err := parseHeaders(headerFields, true) if err != nil { return nil, err } - if hdr.Status == "" { - return nil, errors.New("missing status field") + // concatenate cookie headers, see https://tools.ietf.org/html/rfc6265#section-5.4 + if len(hdr.Headers["Cookie"]) > 0 { + hdr.Headers.Set("Cookie", strings.Join(hdr.Headers["Cookie"], "; ")) + } + + isConnect := hdr.Method == http.MethodConnect + // Extended CONNECT, see https://datatracker.ietf.org/doc/html/rfc8441#section-4 + isExtendedConnected := isConnect && hdr.Protocol != "" + if isExtendedConnected { + if hdr.Scheme == "" || hdr.Path == "" || hdr.Authority == "" { + return nil, errors.New("extended CONNECT: :scheme, :path and :authority must not be empty") + } + } else if isConnect { + if hdr.Path != "" || hdr.Authority == "" { // normal CONNECT + return nil, errors.New(":path must be empty and :authority must not be empty") + } + } else if len(hdr.Path) == 0 || len(hdr.Authority) == 0 || len(hdr.Method) == 0 { + return nil, errors.New(":path, :authority and :method must not be empty") } - rsp := &http.Response{ - Proto: "HTTP/3.0", + + if !isExtendedConnected && len(hdr.Protocol) > 0 { + return nil, errors.New(":protocol must be empty") + } + + var u *url.URL + var requestURI string + + protocol := "HTTP/3.0" + + if isConnect { + u = &url.URL{} + if isExtendedConnected { + u, err = url.ParseRequestURI(hdr.Path) + if err != nil { + return nil, err + } + protocol = hdr.Protocol + } else { + u.Path = hdr.Path + } + u.Scheme = hdr.Scheme + u.Host = hdr.Authority + requestURI = hdr.Authority + } else { + u, err = url.ParseRequestURI(hdr.Path) + if err != nil { + return nil, fmt.Errorf("invalid content length: %w", err) + } + requestURI = hdr.Path + } + + return &http.Request{ + Method: hdr.Method, + URL: u, + Proto: protocol, ProtoMajor: 3, + ProtoMinor: 0, Header: hdr.Headers, + Body: nil, ContentLength: hdr.ContentLength, + Host: hdr.Authority, + RequestURI: requestURI, + }, nil +} + +func hostnameFromURL(url *url.URL) string { + if url != nil { + return url.Host } + return "" +} + +// updateResponseFromHeaders sets up http.Response as an HTTP/3 response, +// using the decoded qpack header filed. +// It is only called for the HTTP header (and not the HTTP trailer). +// It takes an http.Response as an argument to allow the caller to set the trailer later on. +func updateResponseFromHeaders(rsp *http.Response, headerFields []qpack.HeaderField) error { + hdr, err := parseHeaders(headerFields, false) + if err != nil { + return err + } + if hdr.Status == "" { + return errors.New("missing status field") + } + rsp.Proto = "HTTP/3.0" + rsp.ProtoMajor = 3 + rsp.Header = hdr.Headers + processTrailers(rsp) + rsp.ContentLength = hdr.ContentLength + status, err := strconv.Atoi(hdr.Status) if err != nil { - return nil, fmt.Errorf("invalid status code: %w", err) + return fmt.Errorf("invalid status code: %w", err) } rsp.StatusCode = status rsp.Status = hdr.Status + " " + http.StatusText(status) - return rsp, nil + return nil +} + +// processTrailers initializes the rsp.Trailer map, and adds keys for every announced header value. +// The Trailer header is removed from the http.Response.Header map. +// It handles both duplicate as well as comma-separated values for the Trailer header. +// For example: +// +// Trailer: Trailer1, Trailer2 +// Trailer: Trailer3 +// +// Will result in a http.Response.Trailer map containing the keys "Trailer1", "Trailer2", "Trailer3". +func processTrailers(rsp *http.Response) { + rawTrailers, ok := rsp.Header["Trailer"] + if !ok { + return + } + + rsp.Trailer = make(http.Header) + for _, rawVal := range rawTrailers { + for _, val := range strings.Split(rawVal, ",") { + rsp.Trailer[http.CanonicalHeaderKey(textproto.TrimString(val))] = nil + } + } + delete(rsp.Header, "Trailer") } diff --git a/internal/http3/http_stream.go b/internal/http3/http_stream.go index bfaf4214..7c969090 100644 --- a/internal/http3/http_stream.go +++ b/internal/http3/http_stream.go @@ -1,54 +1,103 @@ package http3 import ( + "context" "errors" "fmt" + "io" + "net/http" + + "github.com/imroc/req/v3/internal/compress" + "github.com/imroc/req/v3/internal/dump" + "github.com/imroc/req/v3/internal/transport" "github.com/quic-go/quic-go" + + "github.com/quic-go/qpack" ) -// A Stream is a HTTP/3 stream. +// A Stream is an HTTP/3 request stream. // When writing to and reading from the stream, data is framed in HTTP/3 DATA frames. -type Stream quic.Stream +type Stream interface { + quic.Stream + + SendDatagram([]byte) error + ReceiveDatagram(context.Context) ([]byte, error) +} + +// A RequestStream is an HTTP/3 request stream. +// When writing to and reading from the stream, data is framed in HTTP/3 DATA frames. +type RequestStream interface { + Stream + + // SendRequestHeader sends the HTTP request. + // It is invalid to call it more than once. + // It is invalid to call it after Write has been called. + SendRequestHeader(req *http.Request) error + + // ReadResponse reads the HTTP response from the stream. + // It is invalid to call it more than once. + // It doesn't set Response.Request and Response.TLS. + // It is invalid to call it after Read has been called. + ReadResponse() (*http.Response, error) +} -// The stream conforms to the quic.Stream interface, but instead of writing to and reading directly -// from the QUIC stream, it writes to and reads from the HTTP stream. type stream struct { quic.Stream + conn *connection - buf []byte + buf []byte // used as a temporary buffer when writing the HTTP/3 frame headers - onFrameError func() bytesRemainingInFrame uint64 + + datagrams *datagrammer + + parseTrailer func(io.Reader, uint64) error + parsedTrailer bool } var _ Stream = &stream{} -func newStream(str quic.Stream, onFrameError func()) *stream { +func newStream(str quic.Stream, conn *connection, datagrams *datagrammer, parseTrailer func(io.Reader, uint64) error) *stream { return &stream{ Stream: str, - onFrameError: onFrameError, - buf: make([]byte, 0, 16), + conn: conn, + buf: make([]byte, 16), + datagrams: datagrams, + parseTrailer: parseTrailer, } } func (s *stream) Read(b []byte) (int, error) { + fp := &frameParser{ + r: s.Stream, + conn: s.conn, + } if s.bytesRemainingInFrame == 0 { parseLoop: for { - frame, err := parseNextFrame(s.Stream, nil) + frame, err := fp.ParseNext() if err != nil { return 0, err } switch f := frame.(type) { - case *headersFrame: - // skip HEADERS frames - continue case *dataFrame: + if s.parsedTrailer { + return 0, errors.New("DATA frame received after trailers") + } s.bytesRemainingInFrame = f.Length break parseLoop + case *headersFrame: + if s.conn.perspective == PerspectiveServer { + continue + } + if s.parsedTrailer { + return 0, errors.New("additional HEADERS frame received after trailers") + } + s.parsedTrailer = true + return 0, s.parseTrailer(s.Stream, f.Length) default: - s.onFrameError() + s.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "") // parseNextFrame skips over unknown frame types // Therefore, this condition is only entered when we parsed another known frame type. return 0, fmt.Errorf("peer sent an unexpected frame: %T", f) @@ -80,44 +129,175 @@ func (s *stream) Write(b []byte) (int, error) { return s.Stream.Write(b) } -var errTooMuchData = errors.New("peer sent too much data") +func (s *stream) writeUnframed(b []byte) (int, error) { + return s.Stream.Write(b) +} -type lengthLimitedStream struct { +func (s *stream) StreamID() quic.StreamID { + return s.Stream.StreamID() +} + +// The stream conforms to the quic.Stream interface, but instead of writing to and reading directly +// from the QUIC stream, it writes to and reads from the HTTP stream. +type requestStream struct { + ctx context.Context *stream - contentLength int64 - read int64 - resetStream bool + *transport.Options + + responseBody io.ReadCloser // set by ReadResponse + + decoder *qpack.Decoder + requestWriter *requestWriter + maxHeaderBytes uint64 + reqDone chan<- struct{} + disableCompression bool + response *http.Response + + sentRequest bool + requestedGzip bool + isConnect bool } -var _ Stream = &lengthLimitedStream{} +var _ RequestStream = &requestStream{} -func newLengthLimitedStream(str *stream, contentLength int64) *lengthLimitedStream { - return &lengthLimitedStream{ - stream: str, - contentLength: contentLength, +func newRequestStream( + ctx context.Context, + options *transport.Options, + str *stream, + requestWriter *requestWriter, + reqDone chan<- struct{}, + decoder *qpack.Decoder, + disableCompression bool, + maxHeaderBytes uint64, + rsp *http.Response, +) *requestStream { + return &requestStream{ + ctx: ctx, + Options: options, + stream: str, + requestWriter: requestWriter, + reqDone: reqDone, + decoder: decoder, + disableCompression: disableCompression, + maxHeaderBytes: maxHeaderBytes, + response: rsp, } } -func (s *lengthLimitedStream) checkContentLengthViolation() error { - if s.read > s.contentLength || s.read == s.contentLength && s.hasMoreData() { - if !s.resetStream { - s.CancelRead(quic.StreamErrorCode(ErrCodeMessageError)) - s.CancelWrite(quic.StreamErrorCode(ErrCodeMessageError)) - s.resetStream = true +func (s *requestStream) Read(b []byte) (int, error) { + if s.responseBody == nil { + return 0, errors.New("http3: invalid use of RequestStream.Read: need to call ReadResponse first") + } + return s.responseBody.Read(b) +} + +func (s *requestStream) SendRequestHeader(req *http.Request) error { + if s.sentRequest { + return errors.New("http3: invalid duplicate use of SendRequestHeader") + } + if !s.DisableCompression && !s.disableCompression && req.Method != http.MethodHead && + req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" { + s.requestedGzip = true + } + dumps := dump.GetDumpers(req.Context(), s.Dump) + var headerDumps []*dump.Dumper + for _, dump := range dumps { + if dump.RequestHeader() { + headerDumps = append(headerDumps, dump) } - return errTooMuchData } - return nil + + s.isConnect = req.Method == http.MethodConnect + s.sentRequest = true + return s.requestWriter.WriteRequestHeader(s.Stream, req, s.requestedGzip, headerDumps) } -func (s *lengthLimitedStream) Read(b []byte) (int, error) { - if err := s.checkContentLengthViolation(); err != nil { - return 0, err +func (s *requestStream) ReadResponse() (*http.Response, error) { + fp := &frameParser{ + r: s.Stream, + conn: s.conn, } - n, err := s.stream.Read(b[:min(int64(len(b)), s.contentLength-s.read)]) - s.read += int64(n) - if err := s.checkContentLengthViolation(); err != nil { - return n, err + frame, err := fp.ParseNext() + if err != nil { + s.Stream.CancelRead(quic.StreamErrorCode(ErrCodeFrameError)) + s.Stream.CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)) + return nil, fmt.Errorf("http3: parsing frame failed: %w", err) } - return n, err + hf, ok := frame.(*headersFrame) + if !ok { + s.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "expected first frame to be a HEADERS frame") + return nil, errors.New("http3: expected first frame to be a HEADERS frame") + } + if hf.Length > s.maxHeaderBytes { + s.Stream.CancelRead(quic.StreamErrorCode(ErrCodeFrameError)) + s.Stream.CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)) + return nil, fmt.Errorf("http3: HEADERS frame too large: %d bytes (max: %d)", hf.Length, s.maxHeaderBytes) + } + headerBlock := make([]byte, hf.Length) + if _, err := io.ReadFull(s.Stream, headerBlock); err != nil { + s.Stream.CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete)) + s.Stream.CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)) + return nil, fmt.Errorf("http3: failed to read response headers: %w", err) + } + hfs, err := s.decoder.DecodeFull(headerBlock) + if err != nil { + // TODO: use the right error code + s.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeGeneralProtocolError), "") + return nil, fmt.Errorf("http3: failed to decode response headers: %w", err) + } + ds := dump.GetResponseHeaderDumpers(s.ctx, s.Dump) + if ds.ShouldDump() { + for _, h := range hfs { + ds.DumpResponseHeader([]byte(fmt.Sprintf("%s: %s\r\n", h.Name, h.Value))) + } + ds.DumpResponseHeader([]byte("\r\n")) + } + res := s.response + if err := updateResponseFromHeaders(res, hfs); err != nil { + s.Stream.CancelRead(quic.StreamErrorCode(ErrCodeMessageError)) + s.Stream.CancelWrite(quic.StreamErrorCode(ErrCodeMessageError)) + return nil, fmt.Errorf("http3: invalid response: %w", err) + } + + // Check that the server doesn't send more data in DATA frames than indicated by the Content-Length header (if set). + // See section 4.1.2 of RFC 9114. + respBody := newResponseBody(s.stream, res.ContentLength, s.reqDone) + + // Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2. + isInformational := res.StatusCode >= 100 && res.StatusCode < 200 + isNoContent := res.StatusCode == http.StatusNoContent + isSuccessfulConnect := s.isConnect && res.StatusCode >= 200 && res.StatusCode < 300 + if (isInformational || isNoContent || isSuccessfulConnect) && res.ContentLength == -1 { + res.ContentLength = 0 + } + if s.requestedGzip && res.Header.Get("Content-Encoding") == "gzip" { + res.Header.Del("Content-Encoding") + res.Header.Del("Content-Length") + res.ContentLength = -1 + s.responseBody = compress.NewGzipReader(respBody) + res.Uncompressed = true + } else if s.AutoDecompression { + contentEncoding := res.Header.Get("Content-Encoding") + if contentEncoding != "" { + res.Header.Del("Content-Encoding") + res.Header.Del("Content-Length") + res.ContentLength = -1 + res.Uncompressed = true + res.Body = compress.NewCompressReader(respBody, contentEncoding) + } + } else { + s.responseBody = respBody + } + res.Body = s.responseBody + return res, nil +} + +func (s *stream) SendDatagram(b []byte) error { + // TODO: reject if datagrams are not negotiated (yet) + return s.datagrams.Send(b) +} + +func (s *stream) ReceiveDatagram(ctx context.Context) ([]byte, error) { + // TODO: reject if datagrams are not negotiated (yet) + return s.datagrams.Receive(ctx) } diff --git a/internal/http3/protocol.go b/internal/http3/protocol.go new file mode 100644 index 00000000..d5ba5bb6 --- /dev/null +++ b/internal/http3/protocol.go @@ -0,0 +1,119 @@ +package http3 + +import ( + "math" + + "github.com/quic-go/quic-go" +) + +// Perspective determines if we're acting as a server or a client +type Perspective int + +// the perspectives +const ( + PerspectiveServer Perspective = 1 + PerspectiveClient Perspective = 2 +) + +// Opposite returns the perspective of the peer +func (p Perspective) Opposite() Perspective { + return 3 - p +} + +func (p Perspective) String() string { + switch p { + case PerspectiveServer: + return "server" + case PerspectiveClient: + return "client" + default: + return "invalid perspective" + } +} + +// The version numbers, making grepping easier +const ( + VersionUnknown quic.Version = math.MaxUint32 + versionDraft29 quic.Version = 0xff00001d // draft-29 used to be a widely deployed version + Version1 quic.Version = 0x1 + Version2 quic.Version = 0x6b3343cf +) + +// SupportedVersions lists the versions that the server supports +// must be in sorted descending order +var SupportedVersions = []quic.Version{Version1, Version2} + +// StreamType encodes if this is a unidirectional or bidirectional stream +type StreamType uint8 + +const ( + // StreamTypeUni is a unidirectional stream + StreamTypeUni StreamType = iota + // StreamTypeBidi is a bidirectional stream + StreamTypeBidi +) + +// A StreamID in QUIC +type StreamID int64 + +// InitiatedBy says if the stream was initiated by the client or by the server +func (s StreamID) InitiatedBy() Perspective { + if s%2 == 0 { + return PerspectiveClient + } + return PerspectiveServer +} + +// Type says if this is a unidirectional or bidirectional stream +func (s StreamID) Type() StreamType { + if s%4 >= 2 { + return StreamTypeUni + } + return StreamTypeBidi +} + +// StreamNum returns how many streams in total are below this +// Example: for stream 9 it returns 3 (i.e. streams 1, 5 and 9) +func (s StreamID) StreamNum() StreamNum { + return StreamNum(s/4) + 1 +} + +// InvalidPacketNumber is a stream ID that is invalid. +// The first valid stream ID in QUIC is 0. +const InvalidStreamID StreamID = -1 + +// StreamNum is the stream number +type StreamNum int64 + +const ( + // InvalidStreamNum is an invalid stream number. + InvalidStreamNum = -1 + // MaxStreamCount is the maximum stream count value that can be sent in MAX_STREAMS frames + // and as the stream count in the transport parameters + MaxStreamCount StreamNum = 1 << 60 +) + +// StreamID calculates the stream ID. +func (s StreamNum) StreamID(stype StreamType, pers Perspective) StreamID { + if s == 0 { + return InvalidStreamID + } + var first StreamID + switch stype { + case StreamTypeBidi: + switch pers { + case PerspectiveClient: + first = 0 + case PerspectiveServer: + first = 1 + } + case StreamTypeUni: + switch pers { + case PerspectiveClient: + first = 2 + case PerspectiveServer: + first = 3 + } + } + return first + 4*StreamID(s-1) +} diff --git a/internal/http3/request_writer.go b/internal/http3/request_writer.go index 443b34c2..2af1ce3f 100644 --- a/internal/http3/request_writer.go +++ b/internal/http3/request_writer.go @@ -13,7 +13,7 @@ import ( "sync" "github.com/imroc/req/v3/internal/dump" - "github.com/imroc/req/v3/internal/header" + reqheader "github.com/imroc/req/v3/internal/header" "github.com/quic-go/qpack" "github.com/quic-go/quic-go" @@ -28,17 +28,14 @@ type requestWriter struct { mutex sync.Mutex encoder *qpack.Encoder headerBuf *bytes.Buffer - - debugf func(format string, v ...interface{}) } -func newRequestWriter(debugf func(format string, v ...interface{})) *requestWriter { +func newRequestWriter() *requestWriter { headerBuf := &bytes.Buffer{} encoder := qpack.NewEncoder(headerBuf) return &requestWriter{ encoder: encoder, headerBuf: headerBuf, - debugf: debugf, } } @@ -71,6 +68,10 @@ func (w *requestWriter) writeHeaders(wr io.Writer, req *http.Request, gzip bool, return err } +func isExtendedConnectRequest(req *http.Request) bool { + return req.Method == http.MethodConnect && req.Proto != "" && req.Proto != "HTTP/1.1" +} + // copied from net/transport.go // Modified to support Extended CONNECT: // Contrary to what the godoc for the http.Request says, @@ -89,7 +90,7 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra } // http.NewRequest sets this field to HTTP/1.1 - isExtendedConnect := req.Method == http.MethodConnect && req.Proto != "" && req.Proto != "HTTP/1.1" + isExtendedConnect := isExtendedConnectRequest(req) var path string if req.Method != http.MethodConnect || isExtendedConnect { @@ -123,11 +124,11 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra enumerateHeaders := func(f func(name, value string)) { var writeHeader func(name string, value ...string) - var kvs []header.KeyValues + var kvs []reqheader.KeyValues sort := false - if req.Header != nil && len(req.Header[header.PseudoHeaderOderKey]) > 0 { + if req.Header != nil && len(req.Header[reqheader.PseudoHeaderOderKey]) > 0 { writeHeader = func(name string, value ...string) { - kvs = append(kvs, header.KeyValues{ + kvs = append(kvs, reqheader.KeyValues{ Key: name, Values: value, }) @@ -156,7 +157,7 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra } if sort { - header.SortKeyValues(kvs, req.Header[header.PseudoHeaderOderKey]) + reqheader.SortKeyValues(kvs, req.Header[reqheader.PseudoHeaderOderKey]) for _, kv := range kvs { for _, v := range kv.Values { f(kv.Key, v) @@ -164,11 +165,11 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra } } - if req.Header != nil && len(req.Header[header.HeaderOderKey]) > 0 { + if req.Header != nil && len(req.Header[reqheader.HeaderOderKey]) > 0 { sort = true kvs = nil writeHeader = func(name string, value ...string) { - kvs = append(kvs, header.KeyValues{ + kvs = append(kvs, reqheader.KeyValues{ Key: name, Values: value, }) @@ -188,7 +189,7 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra var didUA bool for k, vv := range req.Header { - if header.IsExcluded(k) { + if reqheader.IsExcluded(k) { continue } else if strings.EqualFold(k, "user-agent") { // Match Go's http1 behavior: at most one @@ -217,11 +218,11 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra writeHeader("accept-encoding", "gzip") } if !didUA { - writeHeader("user-agent", header.DefaultUserAgent) + writeHeader("user-agent", reqheader.DefaultUserAgent) } if sort { - header.SortKeyValues(kvs, req.Header[header.HeaderOderKey]) + reqheader.SortKeyValues(kvs, req.Header[reqheader.HeaderOderKey]) for _, kv := range kvs { for _, v := range kv.Values { f(kv.Key, v) @@ -269,13 +270,10 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra // authorityAddr returns a given authority (a host/IP, or host:port / ip:port) // and returns a host:port. The port 443 is added if needed. -func authorityAddr(scheme string, authority string) (addr string) { +func authorityAddr(authority string) (addr string) { host, port, err := net.SplitHostPort(authority) if err != nil { // authority didn't have a port port = "443" - if scheme == "http" { - port = "80" - } host = authority } if a, err := idna.ToASCII(host); err == nil { diff --git a/internal/http3/roundtrip.go b/internal/http3/roundtrip.go index 89624094..afef7807 100644 --- a/internal/http3/roundtrip.go +++ b/internal/http3/roundtrip.go @@ -11,7 +11,6 @@ import ( "strings" "sync" "sync/atomic" - "time" "github.com/imroc/req/v3/internal/transport" @@ -20,71 +19,88 @@ import ( "golang.org/x/net/http/httpguts" ) -type roundTripCloser interface { - RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error) - HandshakeComplete() bool - io.Closer +// Settings are HTTP/3 settings that apply to the underlying connection. +type Settings struct { + // Support for HTTP/3 datagrams (RFC 9297) + EnableDatagrams bool + // Extended CONNECT, RFC 9220 + EnableExtendedConnect bool + // Other settings, defined by the application + Other map[uint64]uint64 } -type roundTripCloserWithCount struct { - roundTripCloser +// RoundTripOpt are options for the Transport.RoundTripOpt method. +type RoundTripOpt struct { + // OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection. + // If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn. + OnlyCachedConn bool +} + +type singleRoundTripper interface { + OpenRequestStream(context.Context) (RequestStream, error) + RoundTrip(*http.Request) (*http.Response, error) +} + +type roundTripperWithCount struct { + cancel context.CancelFunc + dialing chan struct{} // closed as soon as quic.Dial(Early) returned + dialErr error + conn quic.EarlyConnection + rt singleRoundTripper + useCount atomic.Int64 } +func (r *roundTripperWithCount) Close() error { + r.cancel() + <-r.dialing + if r.conn != nil { + return r.conn.CloseWithError(0, "") + } + return nil +} + // RoundTripper implements the http.RoundTripper interface type RoundTripper struct { *transport.Options mutex sync.Mutex - // QuicConfig is the quic.Config used for dialing new connections. + // TLSClientConfig specifies the TLS configuration to use with + // tls.Client. If nil, the default configuration is used. + TLSClientConfig *tls.Config + + // QUICConfig is the quic.Config used for dialing new connections. // If nil, reasonable default values will be used. - QuicConfig *quic.Config + QUICConfig *quic.Config + + // Dial specifies an optional dial function for creating QUIC + // connections for requests. + // If Dial is nil, a UDPConn will be created at the first request + // and will be reused for subsequent connections to other servers. + Dial func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) - // Enable support for HTTP/3 datagrams. - // If set to true, QuicConfig.EnableDatagram will be set. - // - // See https://datatracker.ietf.org/doc/html/rfc9297. + // Enable support for HTTP/3 datagrams (RFC 9297). + // If a QUICConfig is set, datagram support also needs to be enabled on the QUIC layer by setting EnableDatagrams. EnableDatagrams bool // Additional HTTP/3 settings. - // It is invalid to specify any settings defined by the HTTP/3 draft and the datagram draft. + // It is invalid to specify any settings defined by RFC 9114 (HTTP/3) and RFC 9297 (HTTP Datagrams). AdditionalSettings map[uint64]uint64 - // When set, this callback is called for the first unknown frame parsed on a bidirectional stream. - // It is called right after parsing the frame type. - // If parsing the frame type fails, the error is passed to the callback. - // In that case, the frame type will not be set. - // Callers can either ignore the frame and return control of the stream back to HTTP/3 - // (by returning hijacked false). - // Alternatively, callers can take over the QUIC stream (by returning hijacked true). - StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error) + // MaxResponseHeaderBytes specifies a limit on how many response bytes are + // allowed in the server's response header. + // Zero means to use a default limit. + MaxResponseHeaderBytes int64 - // When set, this callback is called for unknown unidirectional stream of unknown stream type. - // If parsing the stream type fails, the error is passed to the callback. - // In that case, the stream type will not be set. - UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool) + initOnce sync.Once + initErr error - // Dial specifies an optional dial function for creating QUIC - // connections for requests. - // If Dial is nil, a UDPConn will be created at the first request - // and will be reused for subsequent connections to other servers. - Dial func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) + newClient func(quic.EarlyConnection) singleRoundTripper - newClient func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc, opt *transport.Options) (roundTripCloser, error) // so we can mock it in tests - clients map[string]*roundTripCloserWithCount + clients map[string]*roundTripperWithCount transport *quic.Transport } -// RoundTripOpt are options for the Transport.RoundTripOpt method. -type RoundTripOpt struct { - // OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection. - // If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn. - OnlyCachedConn bool - // DontCloseRequestStream controls whether the request stream is closed after sending the request. - // If set, context cancellations have no effect after the response headers are received. - DontCloseRequestStream bool -} - var ( _ http.RoundTripper = &RoundTripper{} _ io.Closer = &RoundTripper{} @@ -95,6 +111,11 @@ var ErrNoCachedConn = errors.New("http3: no cached connection was available") // RoundTripOpt is like RoundTrip, but takes options. func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { + r.initOnce.Do(func() { r.initErr = r.init() }) + if r.initErr != nil { + return nil, r.initErr + } + if req.URL == nil { closeRequestBody(req) return nil, errors.New("http3: nil Request.URL") @@ -111,21 +132,15 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http. closeRequestBody(req) return nil, errors.New("http3: nil Request.Header") } - - if req.URL.Scheme == "https" { - for k, vv := range req.Header { - if !httpguts.ValidHeaderFieldName(k) { - return nil, fmt.Errorf("http3: invalid http header field name %q", k) - } - for _, v := range vv { - if !httpguts.ValidHeaderFieldValue(v) { - return nil, fmt.Errorf("http3: invalid http header field value %q for key %v", v, k) - } + for k, vv := range req.Header { + if !httpguts.ValidHeaderFieldName(k) { + return nil, fmt.Errorf("http3: invalid http header field name %q", k) + } + for _, v := range vv { + if !httpguts.ValidHeaderFieldValue(v) { + return nil, fmt.Errorf("http3: invalid http header field value %q for key %v", v, k) } } - } else { - closeRequestBody(req) - return nil, fmt.Errorf("http3: unsupported protocol scheme: %s", req.URL.Scheme) } if req.Method != "" && !validMethod(req.Method) { @@ -133,8 +148,8 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http. return nil, fmt.Errorf("http3: invalid method %q", req.Method) } - hostname := authorityAddr("https", hostnameFromRequest(req)) - cl, isReused, err := r.getClient(hostname, opt.OnlyCachedConn) + hostname := authorityAddr(hostnameFromURL(req.URL)) + cl, isReused, err := r.getClient(req.Context(), hostname, opt.OnlyCachedConn) if err != ErrNoCachedConn { if debugf := r.Debugf; debugf != nil { debugf("HTTP/3 %s %s", req.Method, req.URL.String()) @@ -143,10 +158,27 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http. if err != nil { return nil, err } + + select { + case <-cl.dialing: + case <-req.Context().Done(): + return nil, context.Cause(req.Context()) + } + + if cl.dialErr != nil { + r.removeClient(hostname) + return nil, cl.dialErr + } defer cl.useCount.Add(-1) - rsp, err := cl.RoundTripOpt(req, opt) + rsp, err := cl.rt.RoundTrip(req) if err != nil { - r.removeClient(hostname) + // non-nil errors on roundtrip are likely due to a problem with the connection + // so we remove the client from the cache so that subsequent trips reconnect + // context cancelation is excluded as is does not signify a connection error + if !errors.Is(err, context.Canceled) { + r.removeClient(hostname) + } + if isReused { if nerr, ok := err.(net.Error); ok && nerr.Timeout() { return r.RoundTripOpt(req, opt) @@ -161,82 +193,142 @@ func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { return r.RoundTripOpt(req, RoundTripOpt{}) } +func (r *RoundTripper) init() error { + if r.newClient == nil { + r.newClient = func(conn quic.EarlyConnection) singleRoundTripper { + return &SingleDestinationRoundTripper{ + Options: r.Options, + Connection: conn, + EnableDatagrams: r.EnableDatagrams, + AdditionalSettings: r.AdditionalSettings, + MaxResponseHeaderBytes: r.MaxResponseHeaderBytes, + } + } + } + if r.QUICConfig == nil { + r.QUICConfig = defaultQuicConfig.Clone() + r.QUICConfig.EnableDatagrams = r.EnableDatagrams + } + if r.EnableDatagrams && !r.QUICConfig.EnableDatagrams { + return errors.New("HTTP Datagrams enabled, but QUIC Datagrams disabled") + } + if len(r.QUICConfig.Versions) == 0 { + r.QUICConfig = r.QUICConfig.Clone() + r.QUICConfig.Versions = []quic.Version{SupportedVersions[0]} + } + if len(r.QUICConfig.Versions) != 1 { + return errors.New("can only use a single QUIC version for dialing a HTTP/3 connection") + } + if r.QUICConfig.MaxIncomingStreams == 0 { + r.QUICConfig.MaxIncomingStreams = -1 // don't allow any bidirectional streams + } + return nil +} + // RoundTripOnlyCachedConn round trip only cached conn. func (r *RoundTripper) RoundTripOnlyCachedConn(req *http.Request) (*http.Response, error) { return r.RoundTripOpt(req, RoundTripOpt{OnlyCachedConn: true}) } // AddConn add a http3 connection, dial new conn if not exists. -func (r *RoundTripper) AddConn(addr string) error { - addr = authorityAddr("https", addr) - c, _, err := r.getClient(addr, false) - if err != nil { - return err - } - client, ok := c.roundTripCloser.(*client) - if !ok { - return errors.New("bad client type") +func (r *RoundTripper) AddConn(ctx context.Context, addr string) error { + addr = authorityAddr(addr) + cl, _, err := r.getClient(ctx, addr, false) + if err == nil { + cl.useCount.Add(-1) } - client.dialOnce.Do(func() { - ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) - client.handshakeErr = client.dial(ctx) - }) - return client.handshakeErr + return err } -func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc *roundTripCloserWithCount, isReused bool, err error) { +func (r *RoundTripper) getClient(ctx context.Context, hostname string, onlyCached bool) (rtc *roundTripperWithCount, isReused bool, err error) { r.mutex.Lock() defer r.mutex.Unlock() if r.clients == nil { - r.clients = make(map[string]*roundTripCloserWithCount) + r.clients = make(map[string]*roundTripperWithCount) } - client, ok := r.clients[hostname] + cl, ok := r.clients[hostname] if !ok { if onlyCached { return nil, false, ErrNoCachedConn } - var err error - newCl := newClient - if r.newClient != nil { - newCl = r.newClient + ctx, cancel := context.WithCancel(ctx) + cl = &roundTripperWithCount{ + dialing: make(chan struct{}), + cancel: cancel, } - dial := r.Dial - if dial == nil { - if r.transport == nil { - udpConn, err := net.ListenUDP("udp", nil) - if err != nil { - return nil, false, err - } - r.transport = &quic.Transport{Conn: udpConn} + go func() { + defer close(cl.dialing) + defer cancel() + conn, rt, err := r.dial(ctx, hostname) + if err != nil { + cl.dialErr = err + return } - dial = r.makeDialer() + cl.conn = conn + cl.rt = rt + }() + r.clients[hostname] = cl + } + select { + case <-cl.dialing: + if cl.dialErr != nil { + delete(r.clients, hostname) + return nil, false, cl.dialErr + } + select { + case <-cl.conn.HandshakeComplete(): + isReused = true + default: } - c, err := newCl( - hostname, - r.TLSClientConfig, - &roundTripperOpts{ - EnableDatagram: r.EnableDatagrams, - DisableCompression: r.DisableCompression, - MaxHeaderBytes: r.MaxResponseHeaderBytes, - StreamHijacker: r.StreamHijacker, - UniStreamHijacker: r.UniStreamHijacker, - dump: r.Dump, - AdditionalSettings: r.AdditionalSettings, - }, - r.QuicConfig, - dial, - r.Options, - ) + default: + } + cl.useCount.Add(1) + return cl, isReused, nil +} + +func (r *RoundTripper) dial(ctx context.Context, hostname string) (quic.EarlyConnection, singleRoundTripper, error) { + var tlsConf *tls.Config + if r.TLSClientConfig == nil { + tlsConf = &tls.Config{} + } else { + tlsConf = r.TLSClientConfig.Clone() + } + if tlsConf.ServerName == "" { + sni, _, err := net.SplitHostPort(hostname) if err != nil { - return nil, false, err + // It's ok if net.SplitHostPort returns an error - it could be a hostname/IP address without a port. + sni = hostname } - client = &roundTripCloserWithCount{roundTripCloser: c} - r.clients[hostname] = client + tlsConf.ServerName = sni } - client.useCount.Add(1) - return client, isReused, nil + // Replace existing ALPNs by H3 + tlsConf.NextProtos = []string{versionToALPN(r.QUICConfig.Versions[0])} + + dial := r.Dial + if dial == nil { + if r.transport == nil { + udpConn, err := net.ListenUDP("udp", nil) + if err != nil { + return nil, nil, err + } + r.transport = &quic.Transport{Conn: udpConn} + } + dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + return r.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg) + } + } + + conn, err := dial(ctx, hostname, tlsConf, r.QUICConfig) + if err != nil { + return nil, nil, err + } + return conn, r.newClient(conn), nil } func (r *RoundTripper) removeClient(hostname string) { @@ -253,8 +345,8 @@ func (r *RoundTripper) removeClient(hostname string) { func (r *RoundTripper) Close() error { r.mutex.Lock() defer r.mutex.Unlock() - for _, client := range r.clients { - if err := client.Close(); err != nil { + for _, cl := range r.clients { + if err := cl.Close(); err != nil { return err } } @@ -299,23 +391,12 @@ func isNotToken(r rune) bool { return !httpguts.IsTokenRune(r) } -// makeDialer makes a QUIC dialer using r.udpConn. -func (r *RoundTripper) makeDialer() func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { - return func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { - udpAddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - return nil, err - } - return r.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg) - } -} - func (r *RoundTripper) CloseIdleConnections() { r.mutex.Lock() defer r.mutex.Unlock() - for hostname, client := range r.clients { - if client.useCount.Load() == 0 { - client.Close() + for hostname, cl := range r.clients { + if cl.useCount.Load() == 0 { + cl.Close() delete(r.clients, hostname) } } diff --git a/internal/http3/server.go b/internal/http3/server.go index 9f94e7b5..4c36f0f0 100644 --- a/internal/http3/server.go +++ b/internal/http3/server.go @@ -1,16 +1,12 @@ package http3 -import ( - "github.com/quic-go/quic-go" -) +import "github.com/quic-go/quic-go" -const ( - nextProtoH3Draft29 = "h3-29" - nextProtoH3 = "h3" -) +// NextProtoH3 is the ALPN protocol negotiated during the TLS handshake, for QUIC v1 and v2. +const NextProtoH3 = "h3" // StreamType is the stream type of a unidirectional stream. -type StreamType uint64 +type ServerStreamType uint64 const ( streamTypeControlStream = 0 @@ -20,25 +16,11 @@ const ( ) func versionToALPN(v quic.Version) string { + //nolint:exhaustive // These are all the versions we care about. switch v { case Version1, Version2: - return nextProtoH3 - case VersionDraft29: - return nextProtoH3Draft29 + return NextProtoH3 + default: + return "" } - return "" -} - -type requestError struct { - err error - streamErr ErrCode - connErr ErrCode -} - -func newStreamError(code ErrCode, err error) requestError { - return requestError{err: err, streamErr: code} -} - -func newConnError(code ErrCode, err error) requestError { - return requestError{err: err, connErr: code} } diff --git a/internal/http3/state_tracking_stream.go b/internal/http3/state_tracking_stream.go new file mode 100644 index 00000000..9cf17f5e --- /dev/null +++ b/internal/http3/state_tracking_stream.go @@ -0,0 +1,116 @@ +package http3 + +import ( + "context" + "errors" + "os" + "sync" + + "github.com/quic-go/quic-go" +) + +var _ quic.Stream = &stateTrackingStream{} + +// stateTrackingStream is an implementation of quic.Stream that delegates +// to an underlying stream +// it takes care of proxying send and receive errors onto an implementation of +// the errorSetter interface (intended to be occupied by a datagrammer) +// it is also responsible for clearing the stream based on its ID from its +// parent connection, this is done through the streamClearer interface when +// both the send and receive sides are closed +type stateTrackingStream struct { + quic.Stream + + mx sync.Mutex + sendErr error + recvErr error + + clearer streamClearer + setter errorSetter +} + +type streamClearer interface { + clearStream(quic.StreamID) +} + +type errorSetter interface { + SetSendError(error) + SetReceiveError(error) +} + +func newStateTrackingStream(s quic.Stream, clearer streamClearer, setter errorSetter) *stateTrackingStream { + t := &stateTrackingStream{ + Stream: s, + clearer: clearer, + setter: setter, + } + + context.AfterFunc(s.Context(), func() { + t.closeSend(context.Cause(s.Context())) + }) + + return t +} + +func (s *stateTrackingStream) closeSend(e error) { + s.mx.Lock() + defer s.mx.Unlock() + + // clear the stream the first time both the send + // and receive are finished + if s.sendErr == nil { + if s.recvErr != nil { + s.clearer.clearStream(s.StreamID()) + } + + s.setter.SetSendError(e) + s.sendErr = e + } +} + +func (s *stateTrackingStream) closeReceive(e error) { + s.mx.Lock() + defer s.mx.Unlock() + + // clear the stream the first time both the send + // and receive are finished + if s.recvErr == nil { + if s.sendErr != nil { + s.clearer.clearStream(s.StreamID()) + } + + s.setter.SetReceiveError(e) + s.recvErr = e + } +} + +func (s *stateTrackingStream) Close() error { + s.closeSend(errors.New("write on closed stream")) + return s.Stream.Close() +} + +func (s *stateTrackingStream) CancelWrite(e quic.StreamErrorCode) { + s.closeSend(&quic.StreamError{StreamID: s.Stream.StreamID(), ErrorCode: e}) + s.Stream.CancelWrite(e) +} + +func (s *stateTrackingStream) Write(b []byte) (int, error) { + n, err := s.Stream.Write(b) + if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) { + s.closeSend(err) + } + return n, err +} + +func (s *stateTrackingStream) CancelRead(e quic.StreamErrorCode) { + s.closeReceive(&quic.StreamError{StreamID: s.Stream.StreamID(), ErrorCode: e}) + s.Stream.CancelRead(e) +} + +func (s *stateTrackingStream) Read(b []byte) (int, error) { + n, err := s.Stream.Read(b) + if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) { + s.closeReceive(err) + } + return n, err +} diff --git a/internal/quic-go/quicvarint/varint.go b/internal/quic-go/quicvarint/varint.go index 60d17e3d..9a22e334 100644 --- a/internal/quic-go/quicvarint/varint.go +++ b/internal/quic-go/quicvarint/varint.go @@ -26,16 +26,16 @@ func Read(r io.ByteReader) (uint64, error) { return 0, err } // the first two bits of the first byte encode the length - len := 1 << ((firstByte & 0xc0) >> 6) + l := 1 << ((firstByte & 0xc0) >> 6) b1 := firstByte & (0xff - 0xc0) - if len == 1 { + if l == 1 { return uint64(b1), nil } b2, err := r.ReadByte() if err != nil { return 0, err } - if len == 2 { + if l == 2 { return uint64(b2) + uint64(b1)<<8, nil } b3, err := r.ReadByte() @@ -46,7 +46,7 @@ func Read(r io.ByteReader) (uint64, error) { if err != nil { return 0, err } - if len == 4 { + if l == 4 { return uint64(b4) + uint64(b3)<<8 + uint64(b2)<<16 + uint64(b1)<<24, nil } b5, err := r.ReadByte() @@ -68,6 +68,31 @@ func Read(r io.ByteReader) (uint64, error) { return uint64(b8) + uint64(b7)<<8 + uint64(b6)<<16 + uint64(b5)<<24 + uint64(b4)<<32 + uint64(b3)<<40 + uint64(b2)<<48 + uint64(b1)<<56, nil } +// Parse reads a number in the QUIC varint format. +// It returns the number of bytes consumed. +func Parse(b []byte) (uint64 /* value */, int /* bytes consumed */, error) { + if len(b) == 0 { + return 0, 0, io.EOF + } + firstByte := b[0] + // the first two bits of the first byte encode the length + l := 1 << ((firstByte & 0xc0) >> 6) + if len(b) < l { + return 0, 0, io.ErrUnexpectedEOF + } + b0 := firstByte & (0xff - 0xc0) + if l == 1 { + return uint64(b0), 1, nil + } + if l == 2 { + return uint64(b[1]) + uint64(b0)<<8, 2, nil + } + if l == 4 { + return uint64(b[3]) + uint64(b[2])<<8 + uint64(b[1])<<16 + uint64(b0)<<24, 4, nil + } + return uint64(b[7]) + uint64(b[6])<<8 + uint64(b[5])<<16 + uint64(b[4])<<24 + uint64(b[3])<<32 + uint64(b[2])<<40 + uint64(b[1])<<48 + uint64(b0)<<56, 8, nil +} + // Append appends i in the QUIC varint format. func Append(b []byte, i uint64) []byte { if i <= maxVarInt1 { @@ -89,7 +114,7 @@ func Append(b []byte, i uint64) []byte { } // AppendWithLen append i in the QUIC varint format with the desired length. -func AppendWithLen(b []byte, i uint64, length int64) []byte { +func AppendWithLen(b []byte, i uint64, length int) []byte { if length != 1 && length != 2 && length != 4 && length != 8 { panic("invalid varint length") } @@ -107,17 +132,17 @@ func AppendWithLen(b []byte, i uint64, length int64) []byte { } else if length == 8 { b = append(b, 0b11000000) } - for j := int64(1); j < length-l; j++ { + for j := 1; j < length-l; j++ { b = append(b, 0) } - for j := int64(0); j < l; j++ { + for j := 0; j < l; j++ { b = append(b, uint8(i>>(8*(l-1-j)))) } return b } // Len determines the number of bytes that will be needed to write the number i. -func Len(i uint64) int64 { +func Len(i uint64) int { if i <= maxVarInt1 { return 1 } diff --git a/transport.go b/transport.go index 69c8932f..4dd2bceb 100644 --- a/transport.go +++ b/transport.go @@ -661,7 +661,7 @@ func (t *Transport) handlePendingAltSvc(u *url.URL, pas *pendingAltSvc) { case "h3": // only support h3 in alt-svc for now u2 := altsvcutil.ConvertURL(pas.Entries[i], u) hostname := u2.Host - err := t.t3.AddConn(hostname) + err := t.t3.AddConn(context.Background(), hostname) if err != nil { if t.Debugf != nil { t.Debugf("failed to get http3 connection: %s", err.Error())