Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions bytesconv.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,3 +355,23 @@ func appendQuotedPath(dst, src []byte) []byte {
}
return dst
}

// countHexDigits returns the number of hex digits required to represent n when using writeHexInt
func countHexDigits(n int) int {
if n < 0 {
// developer sanity-check
panic("BUG: int must be positive")
}

if n == 0 {
return 1
}

count := 0
for n > 0 {
n = n >> 4
count++
}

return count
}
3 changes: 3 additions & 0 deletions http.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ type Response struct {
raddr net.Addr
// Local TCPAddr from concurrently net.Conn
laddr net.Addr

headersWritten bool
}

// SetHost sets host for the request.
Expand Down Expand Up @@ -1122,6 +1124,7 @@ func (resp *Response) Reset() {
resp.laddr = nil
resp.ImmediateHeaderFlush = false
resp.StreamBody = false
resp.headersWritten = false
}

func (resp *Response) resetSkipHeader() {
Expand Down
70 changes: 70 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,48 @@ type RequestCtx struct {
hijackHandler HijackHandler
hijackNoResponse bool
formValueFunc FormValueFunc

disableBuffering bool // disables buffered response body
getUnbufferedWriter func(*RequestCtx) UnbufferedWriter // defines how to get unbuffered writer
unbufferedWriter UnbufferedWriter // writes directly to underlying connection
bytesSent int // number of bytes sent to client using unbuffered operations
}

// DisableBuffering modifies fasthttp to disable body buffering for this request.
// This is useful for requests that return large data or stream data.
//
// When buffering is disabled you must:
// 1. Set response status and header values before writing body
// 2. Set ContentLength is optional. If not set, the server will use chunked encoding.
// 3. Write body data using methods like ctx.Write or io.Copy(ctx,src), etc.
// 4. Optionally call CloseResponse to finalize the response.
//
// CLosing the response will finalize the response and send the last chunk.
// If the handler does not finish the response, it will be called automatically after handler returns.
// Closing the response will also set BytesSent with the correct number of total bytes sent.
func (ctx *RequestCtx) DisableBuffering() {
ctx.disableBuffering = true

// We need to create a new unbufferedWriter for each unbuffered request.
// This way we can allow different implementations and be compatible with http2 protocol
if ctx.unbufferedWriter == nil {
if ctx.getUnbufferedWriter != nil {
ctx.unbufferedWriter = ctx.getUnbufferedWriter(ctx)
} else {
ctx.unbufferedWriter = NewUnbufferedWriter(ctx)
}
}
}

// CloseResponse finalizes non-buffered response dispatch.
// This method must be called after performing non-buffered responses
// If the handler does not finish the response, it will be called automatically
// after the handler function returns.
func (ctx *RequestCtx) CloseResponse() error {
if !ctx.disableBuffering || ctx.unbufferedWriter == nil {
return ErrNotUnbuffered
}
return ctx.unbufferedWriter.Close()
}

// HijackHandler must process the hijacked connection c.
Expand Down Expand Up @@ -822,6 +864,11 @@ func (ctx *RequestCtx) reset() {

ctx.hijackHandler = nil
ctx.hijackNoResponse = false

ctx.disableBuffering = false
ctx.unbufferedWriter = nil
ctx.getUnbufferedWriter = nil
ctx.bytesSent = 0
}

type firstByteReader struct {
Expand Down Expand Up @@ -1443,10 +1490,28 @@ func (ctx *RequestCtx) NotFound() {

// Write writes p into response body.
func (ctx *RequestCtx) Write(p []byte) (int, error) {
if ctx.disableBuffering {
return ctx.writeDirect(p)
}

ctx.Response.AppendBody(p)
return len(p), nil
}

// writeDirect writes p to underlying connection bypassing any buffering.
func (ctx *RequestCtx) writeDirect(p []byte) (int, error) {
if ctx.unbufferedWriter == nil {
ctx.unbufferedWriter = NewUnbufferedWriter(ctx)
}
return ctx.unbufferedWriter.Write(p)
}

// BytesSent returns the number of bytes sent to the client after non buffered operation.
// Includes headers and body length.
func (ctx *RequestCtx) BytesSent() int {
return ctx.bytesSent
}

// WriteString appends s to response body.
func (ctx *RequestCtx) WriteString(s string) (int, error) {
ctx.Response.AppendBodyString(s)
Expand Down Expand Up @@ -2359,6 +2424,11 @@ func (s *Server) serveConn(c net.Conn) (err error) {
s.Handler(ctx)
}

if ctx.disableBuffering {
_ = ctx.CloseResponse()
break
}

timeoutResponse = ctx.timeoutResponse
if timeoutResponse != nil {
// Acquire a new ctx because the old one will still be in use by the timeout out handler.
Expand Down
70 changes: 70 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4237,6 +4237,76 @@ func TestServerChunkedResponse(t *testing.T) {
}
}

func TestServerDisableBuffering(t *testing.T) {
t.Parallel()

received := make(chan bool)
done := make(chan bool)

expectedBody := bytes.Repeat([]byte("a"), 4096)

s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.DisableBuffering()
ctx.SetStatusCode(StatusOK)
ctx.SetContentType("text/html; charset=utf-8")
reader := bytes.NewReader(expectedBody)
_, err := io.Copy(ctx, reader)
if err != nil {
t.Fatalf("Unexpected error when copying body: %v", err)
}
ctx.CloseResponse()
if len(ctx.Response.Body()) > 0 {
t.Fatalf("Body was populated when buffer was disabled")
}

// wait until body is received by the consumer or stop after 2 seconds timeout
select {
case <-received:
case <-time.After(2 * time.Second):
t.Fatal("Body not received by consumer after 2 seconds")
}

// The consumer received the body, so we can finish the test
done <- true
},
}

ln := fasthttputil.NewInmemoryListener()

go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
}()

conn, err := ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if _, err = conn.Write([]byte("GET /index.html HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Fatalf("unexpected error: %v", err)
}
br := bufio.NewReader(conn)

var resp Response
if err := resp.Read(br); err != nil {
t.Fatalf("Unexpected error when reading response: %v", err)
}
if resp.Header.ContentLength() != -1 {
t.Fatalf("Unexpected Content-Length %d. Expected %d", resp.Header.ContentLength(), -1)
}
if !bytes.Equal(resp.Body(), expectedBody) {
t.Fatalf("Unexpected body %q. Expected %q", resp.Body(), "foobar")
}

// Signal that the body was received correctly
received <- true

// Wait until the server has finished
<-done
}

func verifyResponse(t *testing.T, r *bufio.Reader, expectedStatusCode int, expectedContentType, expectedBody string) *Response {
var resp Response
if err := resp.Read(r); err != nil {
Expand Down
116 changes: 116 additions & 0 deletions unbuffered.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package fasthttp

import (
"bufio"
"errors"
"fmt"
)

type UnbufferedWriter interface {
Write(p []byte) (int, error)
WriteHeaders() (int, error)
Close() error
}

type UnbufferedWriterHttp1 struct {
writer *bufio.Writer
ctx *RequestCtx
bodyChunkStarted bool
bodyLastChunkSent bool
}

var ErrNotUnbuffered = errors.New("not unbuffered")
var ErrClosedUnbufferedWriter = errors.New("closed unbuffered writer")

// Ensure UnbufferedWriterHttp1 implements UnbufferedWriter.
var _ UnbufferedWriter = &UnbufferedWriterHttp1{}

// NewUnbufferedWriter
//
// Object must be discarded when request is finished
func NewUnbufferedWriter(ctx *RequestCtx) *UnbufferedWriterHttp1 {
writer := acquireWriter(ctx)
return &UnbufferedWriterHttp1{ctx: ctx, writer: writer}
}

func (uw *UnbufferedWriterHttp1) Write(p []byte) (int, error) {
if uw.writer == nil || uw.ctx == nil {
return 0, ErrClosedUnbufferedWriter
}

// Write headers if not already sent
if !uw.ctx.Response.headersWritten {
_, err := uw.WriteHeaders()
if err != nil {
return 0, fmt.Errorf("error writing headers: %w", err)
}
}

// Write body. In chunks if content length is not set.
if uw.ctx.Response.Header.contentLength == -1 && uw.ctx.Response.Header.IsHTTP11() {
uw.bodyChunkStarted = true
err := writeChunk(uw.writer, p)
if err != nil {
return 0, err
}
uw.ctx.bytesSent += len(p) + 4 + countHexDigits(len(p))
return len(p), nil
}

n, err := uw.writer.Write(p)
uw.ctx.bytesSent += n

return n, err
}

func (uw *UnbufferedWriterHttp1) WriteHeaders() (int, error) {
if uw.writer == nil || uw.ctx == nil {
return 0, ErrClosedUnbufferedWriter
}

if !uw.ctx.Response.headersWritten {
if uw.ctx.Response.Header.contentLength == 0 && uw.ctx.Response.Header.IsHTTP11() {
if uw.ctx.Response.SkipBody {
uw.ctx.Response.Header.SetContentLength(0)
} else {
uw.ctx.Response.Header.SetContentLength(-1) // means Transfer-Encoding = chunked
}
}
h := uw.ctx.Response.Header.Header()
n, err := uw.writer.Write(h)
if err != nil {
return 0, err
}
uw.ctx.bytesSent += n
uw.ctx.Response.headersWritten = true
}
return 0, nil
}

func (uw *UnbufferedWriterHttp1) Close() error {
if uw.writer == nil || uw.ctx == nil {
return ErrClosedUnbufferedWriter
}

// write headers if not already sent (e.g. if there is no body written)
if !uw.ctx.Response.headersWritten {
// skip body, as we are closing without writing body
uw.ctx.Response.SkipBody = true
_, err := uw.WriteHeaders()
if err != nil {
return fmt.Errorf("error writing headers: %w", err)
}
}

// finalize chunks
if uw.bodyChunkStarted && uw.ctx.Response.Header.IsHTTP11() && !uw.bodyLastChunkSent {
_, _ = uw.writer.Write([]byte("0\r\n\r\n"))
uw.ctx.bytesSent += 5
}
_ = uw.writer.Flush()
uw.bodyLastChunkSent = true
releaseWriter(uw.ctx.s, uw.writer)
uw.writer = nil
uw.ctx = nil
return nil
}