Skip to content

Commit

Permalink
WIP: http proxy mode for consistent hashing
Browse files Browse the repository at this point in the history
This adds a `pget proxy` command that runs pget as an http server that proxies
connections upstream to cache hosts via the consistent hashing strategy.

For now we ONLY support consistent hashing since that is the motivating use
case.

This is WIP. Still to do:

- support Range requests from the client itself
- dynamically respond to SRV record changes
- testing!
- documentation (eg longDesc!)
- DRY up the duplicated code around configuration
  • Loading branch information
philandstuff committed Dec 21, 2023
1 parent 296c401 commit 94c4eb4
Show file tree
Hide file tree
Showing 9 changed files with 229 additions and 49 deletions.
1 change: 1 addition & 0 deletions cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
func GetRootCommand() *cobra.Command {
rootCMD := root.GetCommand()
rootCMD.AddCommand(multifile.GetCommand())
rootCMD.AddCommand(GetProxyCommand())
rootCMD.AddCommand(version.VersionCMD)
return rootCMD
}
99 changes: 99 additions & 0 deletions cmd/proxy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package cmd

import (
"fmt"
"os"

"github.com/dustin/go-humanize"
"github.com/spf13/cobra"
"github.com/spf13/viper"

"github.com/replicate/pget/pkg/cli"
"github.com/replicate/pget/pkg/client"
"github.com/replicate/pget/pkg/config"
"github.com/replicate/pget/pkg/download"
"github.com/replicate/pget/pkg/proxy"
)

const longDesc = `
TODO
`

func GetProxyCommand() *cobra.Command {
cmd := &cobra.Command{
Use: "proxy [flags] <url> <dest>",
Short: "run as an http proxy server",
Long: longDesc,
PreRunE: proxyPreRunE,
RunE: runProxyCMD,
Args: cobra.ExactArgs(0),
Example: ` pget proxy`,
}
cmd.Flags().String(config.OptListenAddress, "127.0.0.1:9512", "address to listen on")
err := viper.BindPFlags(cmd.PersistentFlags())
if err != nil {
fmt.Println(err)
os.Exit(1)
}
cmd.SetUsageTemplate(cli.UsageTemplate)
return cmd
}

func proxyPreRunE(cmd *cobra.Command, args []string) error {
if viper.GetBool(config.OptExtract) {
return fmt.Errorf("cannot use --extract with proxy mode")
}
if viper.GetString(config.OptOutputConsumer) == config.ConsumerTarExtractor {
return fmt.Errorf("cannot use --output-consumer tar-extractor with proxy mode")
}
return nil
}

func runProxyCMD(cmd *cobra.Command, args []string) error {
minChunkSize, err := humanize.ParseBytes(viper.GetString(config.OptMinimumChunkSize))
if err != nil {
return err
}
clientOpts := client.Options{
MaxConnPerHost: viper.GetInt(config.OptMaxConnPerHost),
ForceHTTP2: viper.GetBool(config.OptForceHTTP2),
MaxRetries: viper.GetInt(config.OptRetries),
ConnectTimeout: viper.GetDuration(config.OptConnTimeout),
}
downloadOpts := download.Options{
MaxConcurrency: viper.GetInt(config.OptConcurrency),
MinChunkSize: int64(minChunkSize),
Client: clientOpts,
}

// TODO DRY this
srvName := config.GetCacheSRV()

if srvName == "" {
return fmt.Errorf("Option %s MUST be specified in proxy mode", config.OptCacheNodesSRVName)
}

downloadOpts.SliceSize = 500 * humanize.MiByte
// FIXME: make this a config option
downloadOpts.DomainsToCache = []string{"weights.replicate.delivery"}
// TODO: dynamically respond to SRV updates rather than just looking up
// once at startup
downloadOpts.CacheHosts, err = cli.LookupCacheHosts(srvName)
if err != nil {
return err
}
chMode, err := download.GetConsistentHashingMode(downloadOpts)
if err != nil {
return err
}

proxy, err := proxy.New(
chMode,
&proxy.Options{
Address: viper.GetString(config.OptListenAddress),
})
if err != nil {
return err
}
return proxy.Start()
}
1 change: 1 addition & 0 deletions pkg/config/optnames.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ const (
OptExtract = "extract"
OptForce = "force"
OptForceHTTP2 = "force-http2"
OptListenAddress = "listen-address"
OptLoggingLevel = "log-level"
OptMaxChunks = "max-chunks"
OptMaxConnPerHost = "max-conn-per-host"
Expand Down
23 changes: 14 additions & 9 deletions pkg/download/buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,19 @@ type firstReqResult struct {
func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, error) {
logger := logging.GetLogger()

baseReq, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, 0, err
}

br := newBufferedReader(m.minChunkSize())

firstReqResultCh := make(chan firstReqResult)
m.queue.submit(func() {
m.sem.Go(func() error {
defer close(firstReqResultCh)
defer br.done()
firstChunkResp, err := m.DoRequest(ctx, 0, m.minChunkSize()-1, url)
firstChunkResp, err := m.DoRequest(baseReq, 0, m.minChunkSize()-1)
if err != nil {
firstReqResultCh <- firstReqResult{err: err}
return err
Expand Down Expand Up @@ -109,7 +114,10 @@ func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, e
}

fileSize := firstReqResult.fileSize
trueURL := firstReqResult.trueURL
trueURLReq, err := http.NewRequestWithContext(ctx, http.MethodGet, firstReqResult.trueURL, nil)
if err != nil {
return nil, 0, err
}

if fileSize <= m.minChunkSize() {
// we only need a single chunk: just download it and finish
Expand Down Expand Up @@ -157,7 +165,7 @@ func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, e

m.sem.Go(func() error {
defer br.done()
resp, err := m.DoRequest(ctx, start, end, trueURL)
resp, err := m.DoRequest(trueURLReq, start, end)
if err != nil {
return err
}
Expand All @@ -170,18 +178,15 @@ func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, e
return newChanMultiReader(readersCh), fileSize, nil
}

func (m *BufferMode) DoRequest(ctx context.Context, start, end int64, trueURL string) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, "GET", trueURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to download %s: %w", trueURL, err)
}
func (m *BufferMode) DoRequest(origReq *http.Request, start, end int64) (*http.Response, error) {
req := origReq.Clone(origReq.Context())
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end))
resp, err := m.Client.Do(req)
if err != nil {
return nil, fmt.Errorf("error executing request for %s: %w", req.URL.String(), err)
}
if resp.StatusCode == 0 || resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("%w %s: %s", ErrUnexpectedHTTPStatus, req.URL.String(), resp.Status)
return nil, fmt.Errorf("%w %s", ErrUnexpectedHTTPStatus(resp.StatusCode), req.URL.String())
}

return resp, nil
Expand Down
79 changes: 51 additions & 28 deletions pkg/download/consistent_hashing.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,27 +78,51 @@ func (m *ConsistentHashingMode) getFileSizeFromContentRange(contentRange string)
return strconv.ParseInt(groups[1], 10, 64)
}

var _ http.Handler = &ConsistentHashingMode{}

func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io.Reader, int64, error) {
logger := logging.GetLogger()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, urlString, nil)
if err != nil {
return nil, 0, err
}
return m.fetch(req)
}

parsed, err := url.Parse(urlString)
func (m *ConsistentHashingMode) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
reader, size, err := m.fetch(req)
if err != nil {
return nil, -1, err
var httpErr HttpStatusError
if errors.As(err, &httpErr) {
resp.WriteHeader(httpErr.StatusCode)
} else {
resp.WriteHeader(http.StatusInternalServerError)
}
return
}
// TODO: http.StatusPartialContent and Content-Range if it was a range request
resp.Header().Set("Content-Length", fmt.Sprint(size))
resp.WriteHeader(http.StatusOK)
// we ignore errors as it's too late to change status code
_, _ = io.Copy(resp, reader)
}

func (m *ConsistentHashingMode) fetch(req *http.Request) (io.Reader, int64, error) {
logger := logging.GetLogger()

shouldContinue := false
for _, host := range m.DomainsToCache {
if host == parsed.Host {
if host == req.Host {
shouldContinue = true
break
}
}
// Use our fallback mode if we're not downloading from a consistent-hashing enabled domain
if !shouldContinue {
logger.Debug().
Str("url", urlString).
Str("reason", fmt.Sprintf("consistent hashing not enabled for %s", parsed.Host)).
Str("url", req.URL.String()).
Str("reason", fmt.Sprintf("consistent hashing not enabled for %s", req.Host)).
Msg("fallback strategy")
return m.FallbackStrategy.Fetch(ctx, urlString)
return m.FallbackStrategy.Fetch(req.Context(), req.URL.String())
}

br := newBufferedReader(m.minChunkSize())
Expand All @@ -107,7 +131,8 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io
m.sem.Go(func() error {
defer close(firstReqResultCh)
defer br.done()
firstChunkResp, err := m.DoRequest(ctx, 0, m.minChunkSize()-1, urlString)
// TODO: respect Range header in the original request
firstChunkResp, err := m.DoRequest(req, 0, m.minChunkSize()-1)
if err != nil {
firstReqResultCh <- firstReqResult{err: err}
return err
Expand Down Expand Up @@ -135,11 +160,11 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io
if errors.Is(firstReqResult.err, client.ErrStrategyFallback) {
// TODO(morgan): we should indicate the fallback strategy we're using in the logs
logger.Info().
Str("url", urlString).
Str("url", req.URL.String()).
Str("type", "file").
Err(err).
Err(firstReqResult.err).
Msg("consistent hash fallback")
return m.FallbackStrategy.Fetch(ctx, urlString)
return m.FallbackStrategy.Fetch(req.Context(), req.URL.String())
}
return nil, -1, firstReqResult.err
}
Expand Down Expand Up @@ -172,7 +197,7 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io
readersCh := make(chan io.Reader, m.maxConcurrency()+1)
readersCh <- br

logger.Debug().Str("url", urlString).
logger.Debug().Str("url", req.URL.String()).
Int64("size", fileSize).
Int("concurrency", m.maxConcurrency()).
Ints64("chunks_per_slice", chunksPerSlice).
Expand Down Expand Up @@ -214,19 +239,19 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io
m.sem.Go(func() error {
defer br.done()
logger.Debug().Int64("start", chunkStart).Int64("end", chunkEnd).Msg("starting request")
resp, err := m.DoRequest(ctx, chunkStart, chunkEnd, urlString)
resp, err := m.DoRequest(req, chunkStart, chunkEnd)
if err != nil {
// in the case that an error indicating an issue with the cache server, networking, etc is returned,
// this will use the fallback strategy. This is a case where the whole file will perform the fall-back
// for the specified chunk instead of the whole file.
if errors.Is(err, client.ErrStrategyFallback) {
// TODO(morgan): we should indicate the fallback strategy we're using in the logs
logger.Info().
Str("url", urlString).
Str("url", req.URL.String()).
Str("type", "chunk").
Err(err).
Msg("consistent hash fallback")
resp, err = m.FallbackStrategy.DoRequest(ctx, chunkStart, chunkEnd, urlString)
resp, err = m.FallbackStrategy.DoRequest(req, chunkStart, chunkEnd)
}
if err != nil {
return err
Expand All @@ -244,36 +269,30 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io
return newChanMultiReader(readersCh), fileSize, nil
}

func (m *ConsistentHashingMode) DoRequest(ctx context.Context, start, end int64, urlString string) (*http.Response, error) {
func (m *ConsistentHashingMode) DoRequest(origReq *http.Request, start, end int64) (*http.Response, error) {
logger := logging.GetLogger()
chContext := context.WithValue(ctx, config.ConsistentHashingStrategyKey, true)
req, err := http.NewRequestWithContext(chContext, "GET", urlString, nil)
if err != nil {
return nil, fmt.Errorf("failed to download %s: %w", req.URL.String(), err)
}
chContext := context.WithValue(origReq.Context(), config.ConsistentHashingStrategyKey, true)
req := origReq.Clone(chContext)
cachePodIndex, err := m.rewriteRequestToCacheHost(req, start, end)
if err != nil {
return nil, err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end))

logger.Debug().Str("url", urlString).Str("munged_url", req.URL.String()).Str("host", req.Host).Int64("start", start).Int64("end", end).Msg("request")
logger.Debug().Str("url", req.URL.String()).Str("munged_url", req.URL.String()).Str("host", req.Host).Int64("start", start).Int64("end", end).Msg("request")

resp, err := m.Client.Do(req)
if err != nil {
if errors.Is(err, client.ErrStrategyFallback) {
origErr := err
req, err := http.NewRequestWithContext(chContext, "GET", urlString, nil)
if err != nil {
return nil, fmt.Errorf("failed to download %s: %w", req.URL.String(), err)
}
req = origReq.Clone(chContext)
_, err = m.rewriteRequestToCacheHost(req, start, end, cachePodIndex)
if err != nil {
// return origErr so that we can use our regular fallback strategy
return nil, origErr
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end))
logger.Debug().Str("url", urlString).Str("munged_url", req.URL.String()).Str("host", req.Host).Int64("start", start).Int64("end", end).Msg("retry request")
logger.Debug().Str("url", origReq.URL.String()).Str("munged_url", req.URL.String()).Str("host", req.Host).Int64("start", start).Int64("end", end).Msg("retry request")

resp, err = m.Client.Do(req)
if err != nil {
Expand All @@ -285,7 +304,11 @@ func (m *ConsistentHashingMode) DoRequest(ctx context.Context, start, end int64,
}
}
if resp.StatusCode == 0 || resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("%w %s: %s", ErrUnexpectedHTTPStatus, req.URL.String(), resp.Status)
if resp.StatusCode >= 400 {
return nil, HttpStatusError{StatusCode: resp.StatusCode}
}

return nil, fmt.Errorf("%w %s", ErrUnexpectedHTTPStatus(resp.StatusCode), req.URL.String())
}

return resp, nil
Expand Down
8 changes: 2 additions & 6 deletions pkg/download/consistent_hashing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,14 +326,10 @@ func (s *testStrategy) Fetch(ctx context.Context, url string) (io.Reader, int64,
return io.NopCloser(strings.NewReader("00")), -1, nil
}

func (s *testStrategy) DoRequest(ctx context.Context, start, end int64, url string) (*http.Response, error) {
func (s *testStrategy) DoRequest(req *http.Request, start, end int64) (*http.Response, error) {
s.mut.Lock()
s.doRequestCalledCount++
s.mut.Unlock()
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp := &http.Response{
Request: req,
Body: io.NopCloser(strings.NewReader("00")),
Expand Down Expand Up @@ -362,7 +358,7 @@ func TestConsistentHashingFileFallback(t *testing.T) {
responseStatus: http.StatusNotFound,
fetchCalledCount: 0,
doRequestCalledCount: 0,
expectedError: download.ErrUnexpectedHTTPStatus,
expectedError: download.ErrUnexpectedHTTPStatus(http.StatusNotFound),
},
}

Expand Down
19 changes: 19 additions & 0 deletions pkg/download/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package download

import (
"fmt"
)

type HttpStatusError struct {
StatusCode int
}

func ErrUnexpectedHTTPStatus(statusCode int) error {
return HttpStatusError{StatusCode: statusCode}
}

var _ error = &HttpStatusError{}

func (c HttpStatusError) Error() string {
return fmt.Sprintf("Status code %d", c.StatusCode)
}
Loading

0 comments on commit 94c4eb4

Please sign in to comment.