From fb627230caf01834f1554d8c243dcf9223a76f6b Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sat, 3 Jan 2026 12:18:19 +0000 Subject: [PATCH 1/4] feat: Add generic request binding and cool colored logger - Add BindJSON, BindQuery, BindForm to Context with comprehensive reflection support (floats, complex, slices, pointers). - Enhance Logger middleware with ANSI colors and structured output (Method, Path, Status, Latency). - Add tests for request binding and logger middleware. --- binding_test.go | 147 ++++++++++++++++++++++++++++++++++++++++++ context.go | 119 ++++++++++++++++++++++++++++++++++ middlewares/logger.go | 100 ++++++++++++++++++++++++++-- 3 files changed, 361 insertions(+), 5 deletions(-) create mode 100644 binding_test.go diff --git a/binding_test.go b/binding_test.go new file mode 100644 index 0000000..fbc55af --- /dev/null +++ b/binding_test.go @@ -0,0 +1,147 @@ +package amaro + +import ( + "bytes" + "encoding/json" + "net/http/httptest" + "net/url" + "strings" + "testing" +) + +type TestUser struct { + Name string `json:"name" query:"name" form:"name"` + Age int `json:"age" query:"age" form:"age"` + Admin bool `json:"admin" query:"admin" form:"admin"` + Score float64 `json:"score" query:"score" form:"score"` + Tags []string `json:"tags" query:"tags" form:"tags"` + Ratings []int `json:"ratings" query:"ratings" form:"ratings"` + PtrField *int `json:"ptr_field" query:"ptr_field" form:"ptr_field"` + ComplexVal complex128 `json:"complex" query:"complex" form:"complex"` +} + +func TestBindJSON(t *testing.T) { + ptrVal := 123 + user := TestUser{ + Name: "Alice", + Age: 30, + Admin: true, + Score: 99.5, + Tags: []string{"go", "rust"}, + Ratings: []int{5, 4}, + PtrField: &ptrVal, + ComplexVal: 1 + 2i, + } + // JSON marshaling of complex numbers is not supported by standard library + // So we omit it for JSON test or handle it specially if we wanted. + // We'll skip complex for JSON test as standard json doesn't support it without custom marshaller. + // But let's test the others. + user.ComplexVal = 0 + + body, _ := json.Marshal(user) + + req := httptest.NewRequest("POST", "/", bytes.NewReader(body)) + w := httptest.NewRecorder() + c := NewContext(w, req) + + var boundUser TestUser + if err := c.BindJSON(&boundUser); err != nil { + t.Fatalf("BindJSON failed: %v", err) + } + + if boundUser.Name != user.Name { + t.Errorf("Expected Name %v, got %v", user.Name, boundUser.Name) + } + if boundUser.Score != user.Score { + t.Errorf("Expected Score %v, got %v", user.Score, boundUser.Score) + } + if len(boundUser.Tags) != 2 { + t.Errorf("Expected 2 tags, got %d", len(boundUser.Tags)) + } +} + +func TestBindQuery(t *testing.T) { + // query parameters + // arrays in query usually: ?tags=go&tags=rust + q := url.Values{} + q.Set("name", "Bob") + q.Set("age", "25") + q.Set("admin", "false") + q.Set("score", "88.8") + q.Add("tags", "one") + q.Add("tags", "two") + q.Add("ratings", "10") + q.Add("ratings", "20") + q.Set("ptr_field", "456") + q.Set("complex", "1+2i") + + req := httptest.NewRequest("GET", "/?"+q.Encode(), nil) + w := httptest.NewRecorder() + c := NewContext(w, req) + + var boundUser TestUser + if err := c.BindQuery(&boundUser); err != nil { + t.Fatalf("BindQuery failed: %v", err) + } + + if boundUser.Name != "Bob" { + t.Errorf("Expected Name Bob, got %s", boundUser.Name) + } + if boundUser.Score != 88.8 { + t.Errorf("Expected Score 88.8, got %f", boundUser.Score) + } + if len(boundUser.Tags) != 2 || boundUser.Tags[0] != "one" { + t.Errorf("Expected Tags [one, two], got %v", boundUser.Tags) + } + if len(boundUser.Ratings) != 2 || boundUser.Ratings[0] != 10 { + t.Errorf("Expected Ratings [10, 20], got %v", boundUser.Ratings) + } + if boundUser.PtrField == nil || *boundUser.PtrField != 456 { + t.Errorf("Expected PtrField 456, got %v", boundUser.PtrField) + } + if boundUser.ComplexVal != 1+2i { + t.Errorf("Expected Complex 1+2i, got %v", boundUser.ComplexVal) + } +} + +func TestBindForm(t *testing.T) { + form := url.Values{} + form.Set("name", "Charlie") + form.Set("age", "40") + form.Set("admin", "true") + form.Set("score", "12.34") + form.Add("tags", "alpha") + form.Add("tags", "beta") + form.Add("ratings", "100") + form.Set("ptr_field", "789") + form.Set("complex", "3+4i") + + req := httptest.NewRequest("POST", "/", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + w := httptest.NewRecorder() + c := NewContext(w, req) + + var boundUser TestUser + if err := c.BindForm(&boundUser); err != nil { + t.Fatalf("BindForm failed: %v", err) + } + + if boundUser.Name != "Charlie" { + t.Errorf("Expected Name Charlie, got %s", boundUser.Name) + } + if boundUser.Score != 12.34 { + t.Errorf("Expected Score 12.34, got %f", boundUser.Score) + } + if len(boundUser.Tags) != 2 { + t.Errorf("Expected 2 tags, got %v", boundUser.Tags) + } + if len(boundUser.Ratings) != 1 || boundUser.Ratings[0] != 100 { + t.Errorf("Expected Ratings [100], got %v", boundUser.Ratings) + } + if boundUser.PtrField == nil || *boundUser.PtrField != 789 { + t.Errorf("Expected PtrField 789, got %v", boundUser.PtrField) + } + if boundUser.ComplexVal != 3+4i { + t.Errorf("Expected Complex 3+4i, got %v", boundUser.ComplexVal) + } +} diff --git a/context.go b/context.go index 1bbf400..1f94948 100644 --- a/context.go +++ b/context.go @@ -2,11 +2,14 @@ package amaro import ( "encoding/json" + "errors" "io" "mime/multipart" "net/http" "os" "path/filepath" + "reflect" + "strconv" ) // FormFile returns the first file for the provided form key. @@ -181,3 +184,119 @@ func (c *Context) Get(key string) (value interface{}, exists bool) { } return } + +// BindJSON binds the request body to the provided struct. +func (c *Context) BindJSON(v interface{}) error { + if c.Request.Body == nil { + return errors.New("request body is empty") + } + return json.NewDecoder(c.Request.Body).Decode(v) +} + +// BindQuery binds the query parameters to the provided struct. +func (c *Context) BindQuery(v interface{}) error { + return bindData(v, c.Request.URL.Query(), "query") +} + +// BindForm binds the form parameters to the provided struct. +func (c *Context) BindForm(v interface{}) error { + if err := c.Request.ParseForm(); err != nil { + return err + } + return bindData(v, c.Request.Form, "form") +} + +func bindData(ptr interface{}, data map[string][]string, tag string) error { + typ := reflect.TypeOf(ptr).Elem() + val := reflect.ValueOf(ptr).Elem() + + if typ.Kind() != reflect.Struct { + return errors.New("binding element must be a struct") + } + + for i := 0; i < typ.NumField(); i++ { + typeField := typ.Field(i) + structField := val.Field(i) + + if !structField.CanSet() { + continue + } + + inputFieldName := typeField.Tag.Get(tag) + if inputFieldName == "" { + continue + } + + inputValue, exists := data[inputFieldName] + if !exists || len(inputValue) == 0 { + continue + } + + if err := setField(structField, inputValue); err != nil { + return err + } + } + return nil +} + +func setField(val reflect.Value, inputs []string) error { + if len(inputs) == 0 { + return nil + } + input := inputs[0] + + switch val.Kind() { + case reflect.Ptr: + if val.IsNil() { + val.Set(reflect.New(val.Type().Elem())) + } + return setField(val.Elem(), inputs) + + case reflect.Slice: + slice := reflect.MakeSlice(val.Type(), len(inputs), len(inputs)) + for i, v := range inputs { + if err := setField(slice.Index(i), []string{v}); err != nil { + return err + } + } + val.Set(slice) + + case reflect.String: + val.SetString(input) + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if num, err := strconv.ParseInt(input, 10, 64); err == nil { + val.SetInt(num) + } else { + return err + } + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if num, err := strconv.ParseUint(input, 10, 64); err == nil { + val.SetUint(num) + } else { + return err + } + + case reflect.Float32, reflect.Float64: + if num, err := strconv.ParseFloat(input, 64); err == nil { + val.SetFloat(num) + } else { + return err + } + + case reflect.Bool: + if b, err := strconv.ParseBool(input); err == nil { + val.SetBool(b) + } else { + return err + } + case reflect.Complex64, reflect.Complex128: + if c, err := strconv.ParseComplex(input, 128); err == nil { + val.SetComplex(c) + } else { + return err + } + } + return nil +} diff --git a/middlewares/logger.go b/middlewares/logger.go index 894d758..314f892 100644 --- a/middlewares/logger.go +++ b/middlewares/logger.go @@ -1,14 +1,28 @@ package middlewares import ( + "fmt" "log" + "net/http" "time" "github.com/buildwithgo/amaro" ) +// ANSI color codes +const ( + green = "\033[97;42m" + white = "\033[90;47m" + yellow = "\033[90;43m" + red = "\033[97;41m" + blue = "\033[97;44m" + magenta = "\033[97;45m" + cyan = "\033[97;46m" + reset = "\033[0m" +) + type LoggerOption func(*loggerConfig) -type LoggerPrintFunc func(logger *log.Logger, duration time.Duration, c *amaro.Context) +type LoggerPrintFunc func(logger *log.Logger, duration time.Duration, c *amaro.Context, statusCode int) type loggerConfig struct { logger *log.Logger @@ -27,15 +41,59 @@ func WithLoggerLogFunc(logFunc LoggerPrintFunc) LoggerOption { } } +// statusColor returns the ANSI color code for a given HTTP status code. +func statusColor(code int) string { + switch { + case code >= http.StatusOK && code < http.StatusMultipleChoices: + return green + case code >= http.StatusMultipleChoices && code < http.StatusBadRequest: + return white + case code >= http.StatusBadRequest && code < http.StatusInternalServerError: + return yellow + default: + return red + } +} + +// methodColor returns the ANSI color code for a given HTTP method. +func methodColor(method string) string { + switch method { + case http.MethodGet: + return blue + case http.MethodPost: + return cyan + case http.MethodPut: + return yellow + case http.MethodDelete: + return red + case http.MethodPatch: + return green + case http.MethodHead: + return magenta + case http.MethodOptions: + return white + default: + return reset + } +} + func Logger(opts ...LoggerOption) amaro.Middleware { cfg := &loggerConfig{ logger: log.Default(), - printFunc: func(logger *log.Logger, duration time.Duration, c *amaro.Context) { - logger.Printf("%s %s - %s \n", - c.Request.Method, + printFunc: func(logger *log.Logger, duration time.Duration, c *amaro.Context, statusCode int) { + statusColor := statusColor(statusCode) + methodColor := methodColor(c.Request.Method) + resetColor := reset + + // Format: [STATUS] METHOD PATH - LATENCY + // Example: [200] GET /users - 12ms + logMsg := fmt.Sprintf("%s %3d %s %s %s %s %s %s", + statusColor, statusCode, resetColor, + methodColor, c.Request.Method, resetColor, c.Request.URL.Path, duration, ) + logger.Println(logMsg) }, } for _, opt := range opts { @@ -44,10 +102,42 @@ func Logger(opts ...LoggerOption) amaro.Middleware { return func(next amaro.Handler) amaro.Handler { return func(c *amaro.Context) error { start := time.Now() + // Default status code to 200, if not written it stays 200 usually. + // But we need to capture the status code. + // amaro.Context doesn't expose the status code written unless we wrap the writer. + // Since amaro is zero-allocation focused, wrapping might be tricky without allocating. + // However, typically a logger wraps the writer. + // Let's check if amaro.Context has a way to get the status. + // It doesn't seem so from my read of context.go. + // So I need to wrap the response writer. + + // Simple ResponseWriter wrapper to capture status code + lrw := &loggingResponseWriter{ResponseWriter: c.Writer, statusCode: http.StatusOK} + c.Writer = lrw + err := next(c) duration := time.Since(start) - cfg.printFunc(cfg.logger, duration, c) + + // If error returned and not handled, status might be different? + // Typically middleware is outer layer. + // If next(c) returns error, the error handler eventually handles it. + // But the logger middleware is executed *around* next(c). + // If an error bubbles up, we might miss the final status code set by the error handler + // UNLESS the error handler is inside the chain or we handle it here. + // But usually Logger is the first middleware. + + cfg.printFunc(cfg.logger, duration, c, lrw.statusCode) return err } } } + +type loggingResponseWriter struct { + http.ResponseWriter + statusCode int +} + +func (lrw *loggingResponseWriter) WriteHeader(code int) { + lrw.statusCode = code + lrw.ResponseWriter.WriteHeader(code) +} From de8665054a6340afabe7005d178a80d108deb094 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sat, 3 Jan 2026 12:29:50 +0000 Subject: [PATCH 2/4] feat: Add route introspection, binding validation and robustness - Add `Routes()` to `Router` interface and implement in `TrieRouter` for introspection. - Add validation tags (`required`, `min`, `max`) to `Bind*` methods. - Ensure `Bind*` methods safely handle non-pointer or nil inputs by returning errors. - Update `loggingResponseWriter` to implement `http.Flusher`. - Add comprehensive tests for validation, introspection, and binding safety. --- binding_test.go | 117 ++++++++++++++++++++++++++----- context.go | 125 +++++++++++++++++++++++++++++++++- middlewares/logger.go | 26 +++---- router.go | 1 + routers/introspection_test.go | 56 +++++++++++++++ routers/trie.go | 42 ++++++++++++ 6 files changed, 330 insertions(+), 37 deletions(-) create mode 100644 routers/introspection_test.go diff --git a/binding_test.go b/binding_test.go index fbc55af..6999698 100644 --- a/binding_test.go +++ b/binding_test.go @@ -10,14 +10,15 @@ import ( ) type TestUser struct { - Name string `json:"name" query:"name" form:"name"` - Age int `json:"age" query:"age" form:"age"` + Name string `json:"name" query:"name" form:"name" validate:"required,min=2"` + Age int `json:"age" query:"age" form:"age" validate:"min=18,max=120"` Admin bool `json:"admin" query:"admin" form:"admin"` Score float64 `json:"score" query:"score" form:"score"` Tags []string `json:"tags" query:"tags" form:"tags"` Ratings []int `json:"ratings" query:"ratings" form:"ratings"` PtrField *int `json:"ptr_field" query:"ptr_field" form:"ptr_field"` - ComplexVal complex128 `json:"complex" query:"complex" form:"complex"` + // Standard JSON does not support complex numbers, so we ignore it for JSON binding tests + ComplexVal complex128 `json:"-" query:"complex" form:"complex"` } func TestBindJSON(t *testing.T) { @@ -30,15 +31,12 @@ func TestBindJSON(t *testing.T) { Tags: []string{"go", "rust"}, Ratings: []int{5, 4}, PtrField: &ptrVal, - ComplexVal: 1 + 2i, } - // JSON marshaling of complex numbers is not supported by standard library - // So we omit it for JSON test or handle it specially if we wanted. - // We'll skip complex for JSON test as standard json doesn't support it without custom marshaller. - // But let's test the others. - user.ComplexVal = 0 - body, _ := json.Marshal(user) + body, err := json.Marshal(user) + if err != nil { + t.Fatalf("Failed to marshal user: %v", err) + } req := httptest.NewRequest("POST", "/", bytes.NewReader(body)) w := httptest.NewRecorder() @@ -52,12 +50,73 @@ func TestBindJSON(t *testing.T) { if boundUser.Name != user.Name { t.Errorf("Expected Name %v, got %v", user.Name, boundUser.Name) } - if boundUser.Score != user.Score { - t.Errorf("Expected Score %v, got %v", user.Score, boundUser.Score) - } - if len(boundUser.Tags) != 2 { - t.Errorf("Expected 2 tags, got %d", len(boundUser.Tags)) - } +} + +func TestBindValidation(t *testing.T) { + t.Run("Valid", func(t *testing.T) { + q := url.Values{} + q.Set("name", "Bob") + q.Set("age", "25") + + req := httptest.NewRequest("GET", "/?"+q.Encode(), nil) + w := httptest.NewRecorder() + c := NewContext(w, req) + + var u TestUser + if err := c.BindQuery(&u); err != nil { + t.Fatalf("Validation failed unexpectedly: %v", err) + } + }) + + t.Run("Missing Required", func(t *testing.T) { + q := url.Values{} + q.Set("age", "25") // Name missing + + req := httptest.NewRequest("GET", "/?"+q.Encode(), nil) + w := httptest.NewRecorder() + c := NewContext(w, req) + + var u TestUser + if err := c.BindQuery(&u); err == nil { + t.Fatal("Expected validation error for missing Name, got nil") + } else if !strings.Contains(err.Error(), "field 'Name' is required") { + t.Errorf("Expected 'required' error, got: %v", err) + } + }) + + t.Run("Min Violation", func(t *testing.T) { + q := url.Values{} + q.Set("name", "B") // Too short + q.Set("age", "25") + + req := httptest.NewRequest("GET", "/?"+q.Encode(), nil) + w := httptest.NewRecorder() + c := NewContext(w, req) + + var u TestUser + if err := c.BindQuery(&u); err == nil { + t.Fatal("Expected validation error for short Name, got nil") + } else if !strings.Contains(err.Error(), "field 'Name' must be at least 2") { + t.Errorf("Expected 'min' error, got: %v", err) + } + }) + + t.Run("Age Range Violation", func(t *testing.T) { + q := url.Values{} + q.Set("name", "Bob") + q.Set("age", "17") // Too young + + req := httptest.NewRequest("GET", "/?"+q.Encode(), nil) + w := httptest.NewRecorder() + c := NewContext(w, req) + + var u TestUser + if err := c.BindQuery(&u); err == nil { + t.Fatal("Expected validation error for young Age, got nil") + } else if !strings.Contains(err.Error(), "field 'Age' must be at least 18") { + t.Errorf("Expected 'min' error for Age, got: %v", err) + } + }) } func TestBindQuery(t *testing.T) { @@ -145,3 +204,29 @@ func TestBindForm(t *testing.T) { t.Errorf("Expected Complex 3+4i, got %v", boundUser.ComplexVal) } } + +func TestBindErrorOnNonPointer(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + c := NewContext(w, req) + + var u TestUser + // Pass by value instead of pointer + err := c.BindQuery(u) + if err == nil { + t.Fatal("Expected error when binding to non-pointer, got nil") + } + if err.Error() != "binding element must be a non-nil pointer" { + t.Errorf("Expected 'non-nil pointer' error, got: %v", err) + } + + // Pass nil pointer + var nilPtr *TestUser + err = c.BindQuery(nilPtr) + if err == nil { + t.Fatal("Expected error when binding to nil pointer, got nil") + } + if err.Error() != "binding element must be a non-nil pointer" { + t.Errorf("Expected 'non-nil pointer' error, got: %v", err) + } +} diff --git a/context.go b/context.go index 1f94948..b8d91c7 100644 --- a/context.go +++ b/context.go @@ -3,6 +3,7 @@ package amaro import ( "encoding/json" "errors" + "fmt" "io" "mime/multipart" "net/http" @@ -10,6 +11,7 @@ import ( "path/filepath" "reflect" "strconv" + "strings" ) // FormFile returns the first file for the provided form key. @@ -190,23 +192,50 @@ func (c *Context) BindJSON(v interface{}) error { if c.Request.Body == nil { return errors.New("request body is empty") } - return json.NewDecoder(c.Request.Body).Decode(v) + if err := checkPtr(v); err != nil { + return err + } + if err := json.NewDecoder(c.Request.Body).Decode(v); err != nil { + return err + } + return validateStruct(v) } // BindQuery binds the query parameters to the provided struct. func (c *Context) BindQuery(v interface{}) error { - return bindData(v, c.Request.URL.Query(), "query") + if err := checkPtr(v); err != nil { + return err + } + if err := bindData(v, c.Request.URL.Query(), "query"); err != nil { + return err + } + return validateStruct(v) } // BindForm binds the form parameters to the provided struct. func (c *Context) BindForm(v interface{}) error { + if err := checkPtr(v); err != nil { + return err + } if err := c.Request.ParseForm(); err != nil { return err } - return bindData(v, c.Request.Form, "form") + if err := bindData(v, c.Request.Form, "form"); err != nil { + return err + } + return validateStruct(v) +} + +func checkPtr(v interface{}) error { + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return errors.New("binding element must be a non-nil pointer") + } + return nil } func bindData(ptr interface{}, data map[string][]string, tag string) error { + // Ptr is guaranteed to be a non-nil pointer by checkPtr typ := reflect.TypeOf(ptr).Elem() val := reflect.ValueOf(ptr).Elem() @@ -300,3 +329,93 @@ func setField(val reflect.Value, inputs []string) error { } return nil } + +// validateStruct performs basic validation based on struct tags. +// Supported tags: validate:"required,min=X,max=Y" +func validateStruct(s interface{}) error { + // s is guaranteed to be a non-nil pointer by checkPtr + val := reflect.ValueOf(s).Elem() + typ := val.Type() + + if val.Kind() != reflect.Struct { + return nil // validation only works on structs + } + + var validationErrors []string + + for i := 0; i < val.NumField(); i++ { + field := val.Field(i) + tag := typ.Field(i).Tag.Get("validate") + if tag == "" { + continue + } + + rules := strings.Split(tag, ",") + for _, rule := range rules { + if rule == "required" { + if isZero(field) { + validationErrors = append(validationErrors, fmt.Sprintf("field '%s' is required", typ.Field(i).Name)) + } + } else if strings.HasPrefix(rule, "min=") { + minVal, _ := strconv.Atoi(strings.TrimPrefix(rule, "min=")) + if !checkMin(field, minVal) { + validationErrors = append(validationErrors, fmt.Sprintf("field '%s' must be at least %d", typ.Field(i).Name, minVal)) + } + } else if strings.HasPrefix(rule, "max=") { + maxVal, _ := strconv.Atoi(strings.TrimPrefix(rule, "max=")) + if !checkMax(field, maxVal) { + validationErrors = append(validationErrors, fmt.Sprintf("field '%s' must be at most %d", typ.Field(i).Name, maxVal)) + } + } + } + } + + if len(validationErrors) > 0 { + return errors.New(strings.Join(validationErrors, "; ")) + } + return nil +} + +func isZero(v reflect.Value) bool { + switch v.Kind() { + case reflect.String, reflect.Array, reflect.Slice, reflect.Map: + return v.Len() == 0 + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return v.Uint() == 0 + case reflect.Float32, reflect.Float64: + return v.Float() == 0 + case reflect.Ptr, reflect.Interface: + return v.IsNil() + } + return false +} + +func checkMin(v reflect.Value, min int) bool { + switch v.Kind() { + case reflect.String, reflect.Array, reflect.Slice, reflect.Map: + return v.Len() >= min + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() >= int64(min) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return v.Uint() >= uint64(min) + case reflect.Float32, reflect.Float64: + return v.Float() >= float64(min) + } + return true +} + +func checkMax(v reflect.Value, max int) bool { + switch v.Kind() { + case reflect.String, reflect.Array, reflect.Slice, reflect.Map: + return v.Len() <= max + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() <= int64(max) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return v.Uint() <= uint64(max) + case reflect.Float32, reflect.Float64: + return v.Float() <= float64(max) + } + return true +} diff --git a/middlewares/logger.go b/middlewares/logger.go index 314f892..347c410 100644 --- a/middlewares/logger.go +++ b/middlewares/logger.go @@ -102,30 +102,13 @@ func Logger(opts ...LoggerOption) amaro.Middleware { return func(next amaro.Handler) amaro.Handler { return func(c *amaro.Context) error { start := time.Now() - // Default status code to 200, if not written it stays 200 usually. - // But we need to capture the status code. - // amaro.Context doesn't expose the status code written unless we wrap the writer. - // Since amaro is zero-allocation focused, wrapping might be tricky without allocating. - // However, typically a logger wraps the writer. - // Let's check if amaro.Context has a way to get the status. - // It doesn't seem so from my read of context.go. - // So I need to wrap the response writer. - - // Simple ResponseWriter wrapper to capture status code + // Default status code to 200. lrw := &loggingResponseWriter{ResponseWriter: c.Writer, statusCode: http.StatusOK} c.Writer = lrw err := next(c) duration := time.Since(start) - // If error returned and not handled, status might be different? - // Typically middleware is outer layer. - // If next(c) returns error, the error handler eventually handles it. - // But the logger middleware is executed *around* next(c). - // If an error bubbles up, we might miss the final status code set by the error handler - // UNLESS the error handler is inside the chain or we handle it here. - // But usually Logger is the first middleware. - cfg.printFunc(cfg.logger, duration, c, lrw.statusCode) return err } @@ -141,3 +124,10 @@ func (lrw *loggingResponseWriter) WriteHeader(code int) { lrw.statusCode = code lrw.ResponseWriter.WriteHeader(code) } + +// Flush implements the http.Flusher interface to allow streaming. +func (lrw *loggingResponseWriter) Flush() { + if flusher, ok := lrw.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } +} diff --git a/router.go b/router.go index 7995b37..37423b9 100644 --- a/router.go +++ b/router.go @@ -66,6 +66,7 @@ type Router interface { Group(prefix string) *Group Find(method, path string, ctx *Context) (*Route, error) StaticFS(pathPrefix string, fs fs.FS) + Routes() []Route } // WithRouter returns an AppOption that configures the App to use the specified router. diff --git a/routers/introspection_test.go b/routers/introspection_test.go new file mode 100644 index 0000000..a235da0 --- /dev/null +++ b/routers/introspection_test.go @@ -0,0 +1,56 @@ +package routers_test + +import ( + "sort" + "testing" + + "github.com/buildwithgo/amaro" + "github.com/buildwithgo/amaro/routers" +) + +func TestRoutesIntrospection(t *testing.T) { + r := routers.NewTrieRouter() + + dummyHandler := func(c *amaro.Context) error { return nil } + + r.GET("/users", dummyHandler) + r.POST("/users", dummyHandler) + r.GET("/users/:id", dummyHandler) + r.GET("/assets/*filepath", dummyHandler) + + routes := r.Routes() + + // Expected routes + // GET /users + // POST /users + // GET /users/:id + // GET /assets/*filepath + + expected := []struct { + Method string + Path string + }{ + {"GET", "/assets/*filepath"}, + {"GET", "/users"}, + {"GET", "/users/:id"}, + {"POST", "/users"}, + } + + // Helper to sort actual routes for comparison (though implementation already sorts) + sort.Slice(routes, func(i, j int) bool { + if routes[i].Method != routes[j].Method { + return routes[i].Method < routes[j].Method + } + return routes[i].Path < routes[j].Path + }) + + if len(routes) != len(expected) { + t.Fatalf("Expected %d routes, got %d", len(expected), len(routes)) + } + + for i, exp := range expected { + if routes[i].Method != exp.Method || routes[i].Path != exp.Path { + t.Errorf("Index %d: Expected %s %s, got %s %s", i, exp.Method, exp.Path, routes[i].Method, routes[i].Path) + } + } +} diff --git a/routers/trie.go b/routers/trie.go index 9967edb..dfdfbf2 100644 --- a/routers/trie.go +++ b/routers/trie.go @@ -4,6 +4,7 @@ import ( "fmt" "io/fs" "net/http" + "sort" "strings" "github.com/buildwithgo/amaro" @@ -229,6 +230,47 @@ func (r *TrieRouter) Find(method, path string, ctx *amaro.Context) (*amaro.Route return nil, amaro.NewHTTPError(http.StatusNotFound, "route not found") } +func (r *TrieRouter) Routes() []amaro.Route { + var routes []amaro.Route + + // Sort methods for deterministic output + var methods []string + for m := range r.root { + methods = append(methods, m) + } + sort.Strings(methods) + + for _, method := range methods { + walkNode(r.root[method], &routes) + } + return routes +} + +func walkNode(n *node, routes *[]amaro.Route) { + if n == nil { + return + } + if n.Handler != nil { + *routes = append(*routes, n.Route) + } + + // Walk static children (sorted for determinism) + var staticKeys []string + for k := range n.children { + staticKeys = append(staticKeys, k) + } + sort.Strings(staticKeys) + for _, k := range staticKeys { + walkNode(n.children[k], routes) + } + + // Walk param + walkNode(n.paramNode, routes) + + // Walk wildcard + walkNode(n.catchAllNode, routes) +} + func (r *TrieRouter) StaticFS(pathPrefix string, fsys fs.FS) { handler := amaro.StaticHandler(amaro.StaticConfig{ Root: fsys, From b84eb4181db741a592b462af2737cfa10f972b28 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sat, 3 Jan 2026 13:23:43 +0000 Subject: [PATCH 3/4] feat: Add HTMX addon and Pretty Recovery - Add `addons/htmx` for HTMX helpers (Is, Trigger, Redirect, etc). - Enhance `Recovery` middleware with `WithHTMLDebug(bool)` option. - Add simple but pretty HTML template for panic debug page. - Verified with tests. --- addons/htmx/htmx.go | 52 ++++++++++++++++++++++++ addons/htmx/htmx_test.go | 86 ++++++++++++++++++++++++++++++++++++++++ recovery.go | 79 +++++++++++++++++++++++++++++++++++- recovery_test.go | 75 +++++++++++++++++++++++++++++++++++ 4 files changed, 290 insertions(+), 2 deletions(-) create mode 100644 addons/htmx/htmx.go create mode 100644 addons/htmx/htmx_test.go create mode 100644 recovery_test.go diff --git a/addons/htmx/htmx.go b/addons/htmx/htmx.go new file mode 100644 index 0000000..e20b06e --- /dev/null +++ b/addons/htmx/htmx.go @@ -0,0 +1,52 @@ +package htmx + +import ( + "encoding/json" + + "github.com/buildwithgo/amaro" +) + +// Is returns true if the request is an HTMX request. +func Is(c *amaro.Context) bool { + return c.GetHeader("HX-Request") == "true" +} + +// Trigger sets the HX-Trigger header to trigger a client-side event. +func Trigger(c *amaro.Context, event string) { + c.SetHeader("HX-Trigger", event) +} + +// TriggerJSON sets the HX-Trigger header with a JSON object for passing data to events. +func TriggerJSON(c *amaro.Context, events map[string]any) error { + b, err := json.Marshal(events) + if err != nil { + return err + } + c.SetHeader("HX-Trigger", string(b)) + return nil +} + +// PushURL sets the HX-Push-Url header to push a new URL into the history stack. +func PushURL(c *amaro.Context, url string) { + c.SetHeader("HX-Push-Url", url) +} + +// Redirect sets the HX-Redirect header to force a client-side redirect. +func Redirect(c *amaro.Context, url string) { + c.SetHeader("HX-Redirect", url) +} + +// Refresh sets the HX-Refresh header to force a full page refresh. +func Refresh(c *amaro.Context) { + c.SetHeader("HX-Refresh", "true") +} + +// Retarget sets the HX-Retarget header to update a different element than the one triggering the request. +func Retarget(c *amaro.Context, target string) { + c.SetHeader("HX-Retarget", target) +} + +// Reswap sets the HX-Reswap header to specify how the response should be swapped in. +func Reswap(c *amaro.Context, swap string) { + c.SetHeader("HX-Reswap", swap) +} diff --git a/addons/htmx/htmx_test.go b/addons/htmx/htmx_test.go new file mode 100644 index 0000000..05a5763 --- /dev/null +++ b/addons/htmx/htmx_test.go @@ -0,0 +1,86 @@ +package htmx_test + +import ( + "encoding/json" + "net/http/httptest" + "testing" + + "github.com/buildwithgo/amaro" + "github.com/buildwithgo/amaro/addons/htmx" +) + +func TestHTMX(t *testing.T) { + t.Run("Is", func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("HX-Request", "true") + w := httptest.NewRecorder() + c := amaro.NewContext(w, req) + + if !htmx.Is(c) { + t.Error("Expected IsHTMX to return true") + } + }) + + t.Run("Trigger", func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + c := amaro.NewContext(w, req) + + htmx.Trigger(c, "myEvent") + if w.Header().Get("HX-Trigger") != "myEvent" { + t.Errorf("Expected HX-Trigger header 'myEvent', got %s", w.Header().Get("HX-Trigger")) + } + }) + + t.Run("TriggerJSON", func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + c := amaro.NewContext(w, req) + + events := map[string]any{ + "event1": "data1", + "event2": 123, + } + if err := htmx.TriggerJSON(c, events); err != nil { + t.Fatalf("TriggerJSON failed: %v", err) + } + + header := w.Header().Get("HX-Trigger") + var decoded map[string]any + if err := json.Unmarshal([]byte(header), &decoded); err != nil { + t.Fatalf("Failed to unmarshal HX-Trigger header: %v", err) + } + + if decoded["event1"] != "data1" { + t.Errorf("Expected event1 data1, got %v", decoded["event1"]) + } + }) + + t.Run("Headers", func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + c := amaro.NewContext(w, req) + + htmx.PushURL(c, "/new-url") + htmx.Redirect(c, "/redirect") + htmx.Refresh(c) + htmx.Retarget(c, "#target") + htmx.Reswap(c, "outerHTML") + + if w.Header().Get("HX-Push-Url") != "/new-url" { + t.Error("PushURL failed") + } + if w.Header().Get("HX-Redirect") != "/redirect" { + t.Error("Redirect failed") + } + if w.Header().Get("HX-Refresh") != "true" { + t.Error("Refresh failed") + } + if w.Header().Get("HX-Retarget") != "#target" { + t.Error("Retarget failed") + } + if w.Header().Get("HX-Reswap") != "outerHTML" { + t.Error("Reswap failed") + } + }) +} diff --git a/recovery.go b/recovery.go index db99b10..77c7bff 100644 --- a/recovery.go +++ b/recovery.go @@ -2,12 +2,34 @@ package amaro import ( "fmt" + "html/template" "net/http" "runtime" + "strings" ) +// RecoveryOption configures the Recovery middleware. +type RecoveryOption func(*recoveryConfig) + +type recoveryConfig struct { + htmlDebug bool +} + +// WithHTMLDebug enables rendering a pretty HTML debug page for panics. +// WARNING: Do not use this in production as it exposes stack traces. +func WithHTMLDebug(enabled bool) RecoveryOption { + return func(c *recoveryConfig) { + c.htmlDebug = enabled + } +} + // Recovery recovers from panics, logs the stack trace, and returns an Internal Server Error. -func Recovery() Middleware { +func Recovery(opts ...RecoveryOption) Middleware { + cfg := &recoveryConfig{htmlDebug: false} + for _, opt := range opts { + opt(cfg) + } + return func(next Handler) Handler { return func(c *Context) error { defer func() { @@ -18,10 +40,63 @@ func Recovery() Middleware { fmt.Printf("panic: %v\nStack trace:\n%s\n", err, stackTrace) - c.String(http.StatusInternalServerError, "Internal Server Error") + if cfg.htmlDebug { + c.HTML(http.StatusInternalServerError, renderDebugPage(err, stackTrace)) + } else { + c.String(http.StatusInternalServerError, "Internal Server Error") + } } }() return next(c) } } } + +func renderDebugPage(err interface{}, stack string) string { + tmpl := ` + + + + + + Internal Server Error - Amaro + + + +
+

Internal Server Error

+
Panic: {{.Error}}
+

Stack Trace:

+
{{.Stack}}
+ +
+ + +` + data := struct { + Error interface{} + Stack string + }{ + Error: err, + Stack: stack, + } + + t, parseErr := template.New("debug").Parse(tmpl) + if parseErr != nil { + return "Internal Server Error (Failed to render debug page)" + } + + var buf strings.Builder + if execErr := t.Execute(&buf, data); execErr != nil { + return "Internal Server Error (Failed to execute debug template)" + } + + return buf.String() +} diff --git a/recovery_test.go b/recovery_test.go new file mode 100644 index 0000000..4b3162c --- /dev/null +++ b/recovery_test.go @@ -0,0 +1,75 @@ +package amaro_test + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/buildwithgo/amaro" + "github.com/buildwithgo/amaro/routers" +) + +func TestRecoveryMiddleware(t *testing.T) { + t.Run("Default", func(t *testing.T) { + // Let's test the middleware function directly to be sure about config. + router := routers.NewTrieRouter() + mw := amaro.Recovery() + + router.Use(mw) + router.GET("/panic", func(c *amaro.Context) error { + panic("oops") + }) + + // We need to simulate the app/router execution manually since we aren't using amaro.New + // But router.ServeHTTP isn't a thing, we need App or to wrap it. + // amaro.App wraps the router. + + // Let's use a dummy handler wrapped by the middleware + handler := mw(func(c *amaro.Context) error { + panic("oops") + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/", nil) + c := amaro.NewContext(w, req) + + err := handler(c) + if err != nil { + t.Errorf("Expected nil error (recovered), got %v", err) + } + + if w.Code != http.StatusInternalServerError { + t.Errorf("Expected 500, got %d", w.Code) + } + if w.Body.String() != "Internal Server Error" { + t.Errorf("Expected 'Internal Server Error', got %s", w.Body.String()) + } + }) + + t.Run("HTMLDebug", func(t *testing.T) { + mw := amaro.Recovery(amaro.WithHTMLDebug(true)) + + handler := mw(func(c *amaro.Context) error { + panic("debug me") + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/", nil) + c := amaro.NewContext(w, req) + + _ = handler(c) + + if w.Code != http.StatusInternalServerError { + t.Errorf("Expected 500, got %d", w.Code) + } + + body := w.Body.String() + if !strings.Contains(body, "") { + t.Error("Expected HTML response") + } + if !strings.Contains(body, "debug me") { + t.Error("Expected panic message in body") + } + }) +} From 7c66537ef30437865a1cc34d6e72c9af25b1683e Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sat, 3 Jan 2026 13:35:23 +0000 Subject: [PATCH 4/4] feat: Add Testing Helper and Remove Swagger UI - Add `app.Test(req)` helper method to simplify integration testing for developers. - Removed Swagger UI addon based on user feedback. - Ensured all tests pass for new features (HTMX, Binding, Recovery). --- amaro.go | 9 +++++++++ amaro_test_helpers_test.go | 28 ++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) create mode 100644 amaro_test_helpers_test.go diff --git a/amaro.go b/amaro.go index cb3f665..d5f8374 100644 --- a/amaro.go +++ b/amaro.go @@ -7,6 +7,7 @@ import ( "io/fs" "log" "net/http" + "net/http/httptest" "os" "os/signal" "strings" @@ -100,6 +101,14 @@ func (a *App) Find(method, path string) (*Route, error) { return a.router.Find(method, path, nil) } +// Test executes a request against the application and returns the response recorder. +// This is a helper for writing tests. +func (a *App) Test(req *http.Request) *httptest.ResponseRecorder { + w := httptest.NewRecorder() + a.ServeHTTP(w, req) + return w +} + // AppOption defines a function to configure the App during initialization. type AppOption func(*App) diff --git a/amaro_test_helpers_test.go b/amaro_test_helpers_test.go new file mode 100644 index 0000000..48b6c04 --- /dev/null +++ b/amaro_test_helpers_test.go @@ -0,0 +1,28 @@ +package amaro_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/buildwithgo/amaro" + "github.com/buildwithgo/amaro/routers" +) + +func TestAppTestHelper(t *testing.T) { + app := amaro.New(amaro.WithRouter(routers.NewTrieRouter())) + + app.GET("/hello", func(c *amaro.Context) error { + return c.String(http.StatusOK, "world") + }) + + req := httptest.NewRequest("GET", "/hello", nil) + w := app.Test(req) + + if w.Code != http.StatusOK { + t.Errorf("Expected 200 OK, got %d", w.Code) + } + if w.Body.String() != "world" { + t.Errorf("Expected 'world', got %s", w.Body.String()) + } +}