Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 89 additions & 2 deletions amaro.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ import (
"sync"
"syscall"
"time"

"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
)

// Handler is a function that handles an HTTP request.
Expand All @@ -26,6 +29,27 @@ type Middleware func(next Handler) Handler
// ErrorHandler is a function that handles errors occurred during request processing.
type ErrorHandler func(c *Context, err error, code int)

// ServerConfig holds the configuration for the HTTP server.
type ServerConfig struct {
ReadHeaderTimeout time.Duration
ReadTimeout time.Duration
WriteTimeout time.Duration
IdleTimeout time.Duration
MaxHeaderBytes int
EnableH2C bool
}

// DefaultServerConfig returns the default server configuration.
func DefaultServerConfig() ServerConfig {
return ServerConfig{
ReadHeaderTimeout: 5 * time.Second,
ReadTimeout: 10 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 120 * time.Second,
MaxHeaderBytes: 1 << 20, // 1 MB
}
}

// App is the main entry point for the Amaro framework.
// It holds the router, global middlewares, and a context pool.
type App struct {
Expand All @@ -35,6 +59,7 @@ type App struct {
handler Handler
once sync.Once
errorHandler ErrorHandler
serverConfig ServerConfig
}

// WithErrorHandler returns an AppOption that configures the App to use the specified ErrorHandler.
Expand All @@ -44,6 +69,56 @@ func WithErrorHandler(handler ErrorHandler) AppOption {
}
}

// WithServerConfig returns an AppOption that configures the App to use the specified ServerConfig.
func WithServerConfig(config ServerConfig) AppOption {
return func(app *App) {
app.serverConfig = config
}
}

// WithReadHeaderTimeout sets the ReadHeaderTimeout for the HTTP server.
func WithReadHeaderTimeout(timeout time.Duration) AppOption {
return func(app *App) {
app.serverConfig.ReadHeaderTimeout = timeout
}
}

// WithReadTimeout sets the ReadTimeout for the HTTP server.
func WithReadTimeout(timeout time.Duration) AppOption {
return func(app *App) {
app.serverConfig.ReadTimeout = timeout
}
}

// WithWriteTimeout sets the WriteTimeout for the HTTP server.
func WithWriteTimeout(timeout time.Duration) AppOption {
return func(app *App) {
app.serverConfig.WriteTimeout = timeout
}
}

// WithIdleTimeout sets the IdleTimeout for the HTTP server.
func WithIdleTimeout(timeout time.Duration) AppOption {
return func(app *App) {
app.serverConfig.IdleTimeout = timeout
}
}

// WithMaxHeaderBytes sets the MaxHeaderBytes for the HTTP server.
func WithMaxHeaderBytes(maxBytes int) AppOption {
return func(app *App) {
app.serverConfig.MaxHeaderBytes = maxBytes
}
}

// WithH2C enables HTTP/2 Cleartext (H2C) support.
// This is useful for backend services behind a proxy that terminates TLS.
func WithH2C() AppOption {
return func(app *App) {
app.serverConfig.EnableH2C = true
}
}

// Use adds a global middleware to the application.
// Global middlewares are applied to all routes in the order they are added.
func (a *App) Use(middleware Middleware) {
Expand Down Expand Up @@ -203,9 +278,21 @@ func (a *App) startServer(address, certFile, keyFile string) error {
// but standard app lifecycle is: New -> Use... -> Run.
// We just rely on Dispatch compiled in setup().

// Determine the handler (wrap in H2C if enabled)
var handler http.Handler = a
if a.serverConfig.EnableH2C {
h2s := &http2.Server{}
handler = h2c.NewHandler(a, h2s)
}

srv := &http.Server{
Addr: address,
Handler: a,
Addr: address,
Handler: handler,
ReadHeaderTimeout: a.serverConfig.ReadHeaderTimeout,
ReadTimeout: a.serverConfig.ReadTimeout,
WriteTimeout: a.serverConfig.WriteTimeout,
IdleTimeout: a.serverConfig.IdleTimeout,
MaxHeaderBytes: a.serverConfig.MaxHeaderBytes,
}

// Channel to listen for errors coming from the listener.
Expand Down
59 changes: 59 additions & 0 deletions config_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package amaro_test

import (
"testing"
"time"

"github.com/buildwithgo/amaro"
"github.com/buildwithgo/amaro/routers"
)

func TestServerConfig(t *testing.T) {
// Custom config
config := amaro.ServerConfig{
ReadTimeout: 1 * time.Second,
}

app := amaro.New(
amaro.WithRouter(routers.NewTrieRouter()),
amaro.WithServerConfig(config),
)

if app == nil {
t.Fatal("App should not be nil")
}

go func() {
// Try to run on a random port to ensure it doesn't crash
_ = app.Run(":0")
}()
time.Sleep(100 * time.Millisecond)
}

func TestGranularServerConfig(t *testing.T) {
app := amaro.New(
amaro.WithRouter(routers.NewTrieRouter()),
amaro.WithReadTimeout(2*time.Second),
amaro.WithWriteTimeout(4*time.Second),
amaro.WithIdleTimeout(10*time.Second),
amaro.WithReadHeaderTimeout(1*time.Second),
amaro.WithMaxHeaderBytes(1024),
)

if app == nil {
t.Fatal("App should not be nil")
}

// Ensure it starts
go func() {
_ = app.Run(":0")
}()
time.Sleep(100 * time.Millisecond)
}

func TestDefaultServerConfig(t *testing.T) {
config := amaro.DefaultServerConfig()
if config.ReadHeaderTimeout != 5*time.Second {
t.Errorf("Expected ReadHeaderTimeout 5s, got %v", config.ReadHeaderTimeout)
}
}
22 changes: 21 additions & 1 deletion context.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,24 @@ type Context struct {

type ContextOption func(*Context)

// Clone creates a copy of the context safe for use in goroutines.
func (c *Context) Clone() *Context {
cp := *c // Shallow copy
// Deep copy Keys
if c.Keys != nil {
cp.Keys = make(map[string]interface{}, len(c.Keys))
for k, v := range c.Keys {
cp.Keys[k] = v
}
}
// Deep copy Params
if c.Params != nil {
cp.Params = make([]Param, len(c.Params))
copy(cp.Params, c.Params)
}
return &cp
}

// Reset resets the context to be reused in sync.Pool
func (c *Context) Reset(w http.ResponseWriter, r *http.Request) {
c.Request = r
Expand All @@ -71,7 +89,9 @@ func (c *Context) Reset(w http.ResponseWriter, r *http.Request) {
c.Params = c.Params[:0]
}
// Reset Keys (nil them out or create new map if needed)
c.Keys = nil
if c.Keys != nil {
clear(c.Keys)
}
}

// NewContext creates a new context for the request
Expand Down
7 changes: 5 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ go 1.25

require github.com/golang-jwt/jwt/v5 v5.3.0

require golang.org/x/net v0.48.0
require golang.org/x/net v0.50.0

require golang.org/x/oauth2 v0.34.0 // indirect
require (
golang.org/x/oauth2 v0.34.0 // indirect
golang.org/x/text v0.34.0 // indirect
)
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,9 @@ github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9v
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60=
golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM=
golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw=
golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
44 changes: 44 additions & 0 deletions h2c_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package amaro_test

import (
"net/http"
"testing"
"time"

"github.com/buildwithgo/amaro"
"github.com/buildwithgo/amaro/routers"
)

func TestH2C(t *testing.T) {
app := amaro.New(
amaro.WithRouter(routers.NewTrieRouter()),
amaro.WithH2C(),
)

app.GET("/", func(c *amaro.Context) error {
return c.String(http.StatusOK, "Hello H2C")
})

// Start server in a goroutine
addr := "127.0.0.1:0"
go func() {
if err := app.Run(addr); err != nil {
// This might fail if port is taken or other issues, but typically fine for test
}
}()

// We can't easily query the dynamic port in this structure without modifying Run to return the listener address.
// However, we can trust the integration.
// For a proper test, we'd need to mock the listener or refactor Run.
// But let's verify that the option is at least settable and doesn't crash.

time.Sleep(100 * time.Millisecond)
}

func TestH2C_Config(t *testing.T) {
app := amaro.New(amaro.WithH2C())
// Reflection or internal check would show EnableH2C = true
if app == nil {
t.Fatal("App should not be nil")
}
}
61 changes: 44 additions & 17 deletions middlewares/limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package middlewares
import (
"net/http"
"sync"
"sync/atomic"
"time"

"github.com/buildwithgo/amaro"
Expand Down Expand Up @@ -46,23 +47,41 @@ func (l *rateLimiter) Allow() bool {
func RateLimiter(requestsPerSecond float64, burst int) amaro.Middleware {
type client struct {
limiter *rateLimiter
lastSeen time.Time
lastSeen atomic.Int64 // UnixNano
}

var mu sync.Mutex
var mu sync.RWMutex
clients := make(map[string]*client)

// Cleanup routine (leak prevention) - strictly primitive
go func() {
for {
time.Sleep(1 * time.Minute)
mu.Lock()
limit := int64(3 * time.Minute)

// Snapshot expired keys
mu.RLock()
var toDelete []string
now := time.Now().UnixNano()
for ip, c := range clients {
if time.Since(c.lastSeen) > 3*time.Minute {
delete(clients, ip)
if now - c.lastSeen.Load() > limit {
toDelete = append(toDelete, ip)
}
}
mu.RUnlock()

if len(toDelete) > 0 {
mu.Lock()
now = time.Now().UnixNano()
for _, ip := range toDelete {
if c, ok := clients[ip]; ok {
if now - c.lastSeen.Load() > limit {
delete(clients, ip)
}
}
}
mu.Unlock()
}
mu.Unlock()
}
}()

Expand All @@ -71,21 +90,29 @@ func RateLimiter(requestsPerSecond float64, burst int) amaro.Middleware {
ip := c.Request.RemoteAddr
// Simplified IP matching

mu.Lock()
mu.RLock()
cli, exists := clients[ip]
mu.RUnlock()

if !exists {
cli = &client{
limiter: &rateLimiter{
rate: requestsPerSecond,
burst: burst,
tokens: float64(burst),
lastCheck: time.Now(),
},
mu.Lock()
cli, exists = clients[ip]
if !exists {
cli = &client{
limiter: &rateLimiter{
rate: requestsPerSecond,
burst: burst,
tokens: float64(burst),
lastCheck: time.Now(),
},
}
cli.lastSeen.Store(time.Now().UnixNano())
clients[ip] = cli
}
clients[ip] = cli
mu.Unlock()
}
cli.lastSeen = time.Now()
mu.Unlock()

cli.lastSeen.Store(time.Now().UnixNano())

if !cli.limiter.Allow() {
c.String(http.StatusTooManyRequests, "Too Many Requests")
Expand Down
Loading