Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified .gitignore
Binary file not shown.
1 change: 1 addition & 0 deletions .local/hello.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Hello Local
66 changes: 55 additions & 11 deletions addons/cache/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,22 @@ package cache

import (
"bytes"
"encoding/gob"
"fmt"
"net/http"
"time"

"github.com/buildwithgo/amaro"
)

// responseRecorder captures the response status and body for caching.
// CachedResponse stores the response data.
type CachedResponse struct {
StatusCode int
Headers http.Header
Body []byte
}

// responseRecorder captures the response status, headers, and body for caching.
type responseRecorder struct {
http.ResponseWriter
statusCode int
Expand All @@ -25,25 +34,48 @@ func (r *responseRecorder) Write(b []byte) (int, error) {
return r.ResponseWriter.Write(b)
}

// CachePage returns a middleware that caches the response body for a given duration.
// KeyGenerator allows customizing the cache key.
type KeyGenerator func(c *amaro.Context) string

func DefaultKeyGenerator(c *amaro.Context) string {
return "route_cache:" + c.Request.URL.String()
}

// CachePage returns a middleware that caches the response for a given duration.
// It uses the Cache interface.
func CachePage(store Cache, ttl time.Duration) amaro.Middleware {
func CachePage(store Cache, ttl time.Duration, keyGen ...KeyGenerator) amaro.Middleware {
getKey := DefaultKeyGenerator
if len(keyGen) > 0 {
getKey = keyGen[0]
}

return func(next amaro.Handler) amaro.Handler {
return func(c *amaro.Context) error {
// Only cache GET requests
if c.Request.Method != http.MethodGet {
return next(c)
}

key := "route_cache:" + c.Request.URL.String()
key := getKey(c)

// Check cache
if val, ok := store.Get(key); ok {
// Hit - We must assert to []byte
if bodyBytes, ok := val.([]byte); ok {
c.Writer.Header().Set("X-Cache", "HIT")
c.Writer.Write(bodyBytes)
return nil
if cachedBytes, ok := val.([]byte); ok {
var cached CachedResponse
// Use Gob for simple serialization of struct with headers
buf := bytes.NewBuffer(cachedBytes)
if err := gob.NewDecoder(buf).Decode(&cached); err == nil {
// Replay headers
for k, v := range cached.Headers {
for _, h := range v {
c.Writer.Header().Add(k, h)
}
}
c.Writer.Header().Set("X-Cache", "HIT")
c.Writer.WriteHeader(cached.StatusCode)
c.Writer.Write(cached.Body)
return nil
}
}
}

Expand All @@ -59,8 +91,20 @@ func CachePage(store Cache, ttl time.Duration) amaro.Middleware {
err := next(c)

// If successful, cache the result
if err == nil && recorder.statusCode == http.StatusOK {
store.Set(key, recorder.body.Bytes(), ttl)
if err == nil && recorder.statusCode < 400 {
// Create cached response
resp := CachedResponse{
StatusCode: recorder.statusCode,
Headers: recorder.Header().Clone(), // Copy headers
Body: recorder.body.Bytes(),
}

var buf bytes.Buffer
if err := gob.NewEncoder(&buf).Encode(resp); err == nil {
store.Set(key, buf.Bytes(), ttl)
} else {
fmt.Println("Cache encode error:", err)
}
}

return err
Expand Down
68 changes: 68 additions & 0 deletions addons/oauth2/oauth2.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package oauth2

import (
"context"
"fmt"
"net/http"

"github.com/buildwithgo/amaro"
"golang.org/x/oauth2"
)

// Config holds OAuth2 configuration.
type Config struct {
oauth2.Config

// SuccessHandler is called after successful token exchange.
// It should handle session creation or token response.
SuccessHandler func(c *amaro.Context, token *oauth2.Token) error

// ErrorHandler handles errors during the flow.
ErrorHandler func(c *amaro.Context, err error) error

// StateGenerator generates the state string.
StateGenerator func(c *amaro.Context) string

// StateValidator validates the state string.
StateValidator func(c *amaro.Context, state string) bool
}

// LoginHandler returns a handler that redirects to the OAuth2 provider.
func LoginHandler(config *Config) amaro.Handler {
return func(c *amaro.Context) error {
state := ""
if config.StateGenerator != nil {
state = config.StateGenerator(c)
}
url := config.AuthCodeURL(state)
return c.Redirect(http.StatusTemporaryRedirect, url)
}
}

// CallbackHandler returns a handler that processes the OAuth2 callback.
func CallbackHandler(config *Config) amaro.Handler {
return func(c *amaro.Context) error {
code := c.QueryParam("code")
state := c.QueryParam("state")

if config.StateValidator != nil {
if !config.StateValidator(c, state) {
return config.ErrorHandler(c, fmt.Errorf("invalid state"))
}
}

token, err := config.Exchange(context.Background(), code)
if err != nil {
if config.ErrorHandler != nil {
return config.ErrorHandler(c, err)
}
return err
}

if config.SuccessHandler != nil {
return config.SuccessHandler(c, token)
}

return c.JSON(http.StatusOK, token)
}
}
14 changes: 14 additions & 0 deletions amaro.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ func (a *App) StaticFS(pathPrefix string, fs fs.FS) {
a.router.StaticFS(pathPrefix, fs)
}

// Static serves files from the local filesystem.
func (a *App) Static(pathPrefix, root string) {
a.StaticFS(pathPrefix, os.DirFS(root))
}

func (a *App) Find(method, path string) (*Route, error) {
return a.router.Find(method, path, nil)
}
Expand All @@ -110,6 +115,15 @@ func New(options ...AppOption) *App {
},
},
errorHandler: func(c *Context, err error, code int) {
if he, ok := err.(*HTTPError); ok {
code = he.Code
if msg, ok := he.Message.(string); ok {
http.Error(c.Writer, msg, code)
} else {
http.Error(c.Writer, http.StatusText(code), code)
}
return
}
http.Error(c.Writer, err.Error(), code)
},
}
Expand Down
37 changes: 37 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package amaro

import (
"fmt"
"net/http"
)

// HTTPError represents an error with an associated HTTP status code.
type HTTPError struct {
Code int
Message interface{}
Internal error
}

func (e *HTTPError) Error() string {
return fmt.Sprintf("code=%d, message=%v", e.Code, e.Message)
}

// NewHTTPError creates a new HTTPError.
func NewHTTPError(code int, message ...interface{}) *HTTPError {
he := &HTTPError{Code: code, Message: http.StatusText(code)}
if len(message) > 0 {
he.Message = message[0]
}
return he
}

// SetInternal sets the internal error.
func (e *HTTPError) SetInternal(err error) *HTTPError {
e.Internal = err
return e
}

// Unwrap returns the internal error.
func (e *HTTPError) Unwrap() error {
return e.Internal
}
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ go 1.25
require github.com/golang-jwt/jwt/v5 v5.3.0

require golang.org/x/net v0.48.0

require golang.org/x/oauth2 v0.34.0 // indirect
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@ github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9v
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw=
golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
110 changes: 110 additions & 0 deletions middlewares/auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package middlewares

import (
"net/http"
"testing"

"github.com/buildwithgo/amaro"
"github.com/buildwithgo/amaro/routers"
)

func TestBasicAuth(t *testing.T) {
app := amaro.New(amaro.WithRouter(routers.NewTrieRouter()))

// Middleware
mw := BasicAuth(func(username, password string, c *amaro.Context) (bool, error) {
if username == "admin" && password == "secret" {
return true, nil
}
return false, nil
})

app.GET("/protected", func(c *amaro.Context) error {
return c.String(http.StatusOK, "Allowed")
}, mw)

// Case 1: No Auth
req, _ := http.NewRequest("GET", "/protected", nil)
w := &mockWriter{}
app.ServeHTTP(w, req)
if w.code != http.StatusUnauthorized {
t.Errorf("Expected 401, got %d", w.code)
}

// Case 2: Invalid Auth
req, _ = http.NewRequest("GET", "/protected", nil)
req.SetBasicAuth("admin", "wrong")
w = &mockWriter{}
app.ServeHTTP(w, req)
if w.code != http.StatusUnauthorized {
t.Errorf("Expected 401, got %d", w.code)
}

// Case 3: Valid Auth
req, _ = http.NewRequest("GET", "/protected", nil)
req.SetBasicAuth("admin", "secret")
w = &mockWriter{}
app.ServeHTTP(w, req)
if w.code != http.StatusOK {
t.Errorf("Expected 200, got %d", w.code)
}
if w.body != "Allowed" {
t.Errorf("Expected 'Allowed', got '%s'", w.body)
}
}

func TestKeyAuth(t *testing.T) {
app := amaro.New(amaro.WithRouter(routers.NewTrieRouter()))

mw := KeyAuth(func(key string, c *amaro.Context) (bool, error) {
return key == "valid-api-key", nil
})

app.GET("/api", func(c *amaro.Context) error {
return c.String(http.StatusOK, "Success")
}, mw)

// Case 1: Missing Key
req, _ := http.NewRequest("GET", "/api", nil)
w := &mockWriter{}
app.ServeHTTP(w, req)
if w.code != http.StatusUnauthorized {
t.Errorf("Expected 401, got %d", w.code)
}

// Case 2: Invalid Key
req, _ = http.NewRequest("GET", "/api", nil)
req.Header.Set("X-API-Key", "bad-key")
w = &mockWriter{}
app.ServeHTTP(w, req)
if w.code != http.StatusUnauthorized {
t.Errorf("Expected 401, got %d", w.code)
}

// Case 3: Valid Key
req, _ = http.NewRequest("GET", "/api", nil)
req.Header.Set("X-API-Key", "valid-api-key")
w = &mockWriter{}
app.ServeHTTP(w, req)
if w.code != http.StatusOK {
t.Errorf("Expected 200, got %d", w.code)
}
}

// Mock Writer
type mockWriter struct {
code int
body string
header http.Header
}
func (m *mockWriter) Header() http.Header {
if m.header == nil { m.header = make(http.Header) }
return m.header
}
func (m *mockWriter) Write(b []byte) (int, error) {
m.body = string(b)
return len(b), nil
}
func (m *mockWriter) WriteHeader(statusCode int) {
m.code = statusCode
}
Loading