diff --git a/internal/config/cache_config.go b/internal/config/cache_config.go index 1efbfa72..c555329c 100644 --- a/internal/config/cache_config.go +++ b/internal/config/cache_config.go @@ -1,83 +1,50 @@ package config import ( - "errors" "fmt" - "strings" + "slices" "time" - "gopkg.in/yaml.v3" + multierror "github.com/hashicorp/go-multierror" ) -// ErrInvalidCacheConfig is returned when the cache-config YAML value is not a mapping. -var ErrInvalidCacheConfig = errors.New("expected a mapping for cache-config") - type CacheGlobs []string func (g CacheGlobs) Clone() CacheGlobs { - if g == nil { - return nil - } - - cacheGlobs := make(CacheGlobs, 0, len(g)) - cacheGlobs = append(cacheGlobs, g...) - - return cacheGlobs + return slices.Clone(g) } type CacheConfig struct { - ExpirationTime time.Duration `yaml:"-"` + ExpirationTime time.Duration `yaml:"expiration-time"` MaxSize int64 `yaml:"max-size"` Methods []string `yaml:"methods"` } func (c *CacheConfig) Clone() *CacheConfig { - var methods []string - if c.Methods != nil { - methods = append(methods, c.Methods...) - } - return &CacheConfig{ ExpirationTime: c.ExpirationTime, MaxSize: c.MaxSize, - Methods: methods, + Methods: slices.Clone(c.Methods), } } -// UnmarshalYAML implements custom decoding so that the "expiration-time" field -// can be expressed as a human-readable duration string (e.g. "30m", "1h"). -// Other fields are decoded by the standard yaml.v3 machinery. -// Only fields present in the YAML node are updated; existing values (defaults) -// are preserved for absent keys. -func (c *CacheConfig) UnmarshalYAML(value *yaml.Node) error { - if value.Kind != yaml.MappingNode { - return ErrInvalidCacheConfig - } +func (c *CacheConfig) Validate(field string) error { + var errs *multierror.Error - for i := 0; i+1 < len(value.Content); i += 2 { - keyNode := value.Content[i] - valNode := value.Content[i+1] + errs = multierror.Append(errs, ValidateDuration(joinPath(field, "expiration-time"), c.ExpirationTime, false)) - switch keyNode.Value { - case "expiration-time": - dur, err := time.ParseDuration(strings.ReplaceAll(valNode.Value, " ", "")) - if err != nil { - return fmt.Errorf("invalid expiration-time %q: %w", valNode.Value, err) - } + if c.MaxSize <= 0 { + msg := fmt.Sprintf("%s must be greater than 0", joinPath(field, "max-size")) + errs = multierror.Append(errs, &ValidationError{msg}) + } + + if len(c.Methods) == 0 { + errs = multierror.Append(errs, &ValidationError{"methods must not be empty"}) + } - c.ExpirationTime = dur - case "max-size": - err := valNode.Decode(&c.MaxSize) - if err != nil { - return err - } - case "methods": - err := valNode.Decode(&c.Methods) - if err != nil { - return err - } - } + for i, method := range c.Methods { + errs = multierror.Append(errs, ValidateMethod(joinPath(field, "methods", index(i)), method, false)) } - return nil + return joinErrors(errs) } diff --git a/internal/config/cache_config_test.go b/internal/config/cache_config_test.go index 30df8cda..37559918 100644 --- a/internal/config/cache_config_test.go +++ b/internal/config/cache_config_test.go @@ -8,68 +8,8 @@ import ( "github.com/evg4b/uncors/internal/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/yaml.v3" ) -func TestCacheConfigUnmarshalYAML(t *testing.T) { - t.Run("decodes all fields", func(t *testing.T) { - const input = ` -expiration-time: 30m -max-size: 52428800 -methods: - - GET - - POST -` - - var actual config.CacheConfig - - require.NoError(t, yaml.Unmarshal([]byte(input), &actual)) - assert.Equal(t, config.CacheConfig{ - ExpirationTime: 30 * time.Minute, - MaxSize: 52428800, - Methods: []string{http.MethodGet, http.MethodPost}, - }, actual) - }) - - t.Run("parses expiration-time with embedded spaces", func(t *testing.T) { - const input = `expiration-time: "1h 30m"` - - var actual config.CacheConfig - - require.NoError(t, yaml.Unmarshal([]byte(input), &actual)) - assert.Equal(t, 90*time.Minute, actual.ExpirationTime) - }) - - t.Run("absent fields keep zero values", func(t *testing.T) { - const input = `max-size: 1024` - - var actual config.CacheConfig - - require.NoError(t, yaml.Unmarshal([]byte(input), &actual)) - assert.Equal(t, int64(1024), actual.MaxSize) - assert.Zero(t, actual.ExpirationTime) - assert.Nil(t, actual.Methods) - }) - - t.Run("returns ErrInvalidCacheConfig for non-mapping node", func(t *testing.T) { - const input = `- item1` - - var actual config.CacheConfig - - err := yaml.Unmarshal([]byte(input), &actual) - - assert.ErrorIs(t, err, config.ErrInvalidCacheConfig) - }) - - t.Run("returns error for invalid expiration-time", func(t *testing.T) { - const input = `expiration-time: not-a-duration` - - var actual config.CacheConfig - - assert.Error(t, yaml.Unmarshal([]byte(input), &actual)) - }) -} - func TestCacheGlobsClone(t *testing.T) { globs := config.CacheGlobs{ "/api/**", @@ -110,3 +50,68 @@ func TestCacheConfigClone(t *testing.T) { assert.NotSame(t, &cacheConfig.Methods, &clonedCacheConfig.Methods) }) } + +func TestCacheConfigValidator(t *testing.T) { + const field = "test" + + t.Run("should not register errors for", func(t *testing.T) { + err := (&config.CacheConfig{ + ExpirationTime: 5 * time.Minute, + MaxSize: 100 * 1024 * 1024, + Methods: []string{http.MethodGet, http.MethodPost}, + }).Validate(field) + assert.NoError(t, err) + }) + + t.Run("should register errors for", func(t *testing.T) { + tests := []struct { + name string + value config.CacheConfig + error string + }{ + { + name: "empty expiration time", + value: config.CacheConfig{MaxSize: 100 * 1024 * 1024, Methods: []string{http.MethodGet}}, + error: "test.expiration-time must be greater than 0", + }, + { + name: "zero max size", + value: config.CacheConfig{ExpirationTime: 5 * time.Minute, MaxSize: 0, Methods: []string{http.MethodGet}}, + error: "test.max-size must be greater than 0", + }, + { + name: "negative max size", + value: config.CacheConfig{ExpirationTime: 5 * time.Minute, MaxSize: -1, Methods: []string{http.MethodGet}}, + error: "test.max-size must be greater than 0", + }, + { + name: "empty methods", + value: config.CacheConfig{ExpirationTime: 5 * time.Minute, MaxSize: 100 * 1024 * 1024}, + error: "methods must not be empty", + }, + { + name: "invalid method", + value: config.CacheConfig{ + ExpirationTime: 5 * time.Minute, + MaxSize: 100 * 1024 * 1024, + Methods: []string{"invalid"}, + }, + error: "test.methods[0] must be one of GET, HEAD, POST, PUT, PATCH, DELETE, CONNECT, OPTIONS, TRACE", + }, + { + name: "invalid second method", + value: config.CacheConfig{ + ExpirationTime: 5 * time.Minute, + MaxSize: 100 * 1024 * 1024, + Methods: []string{http.MethodGet, "invalid", http.MethodPost}, + }, + error: "test.methods[1] must be one of GET, HEAD, POST, PUT, PATCH, DELETE, CONNECT, OPTIONS, TRACE", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require.EqualError(t, test.value.Validate(field), test.error) + }) + } + }) +} diff --git a/internal/config/config.go b/internal/config/config.go index 1608fcde..7b37110a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -3,12 +3,12 @@ package config import ( "fmt" + multierror "github.com/hashicorp/go-multierror" "github.com/spf13/afero" "github.com/spf13/pflag" "gopkg.in/yaml.v3" ) -// UncorsConfig is the root configuration for the uncors proxy. type UncorsConfig struct { Mappings Mappings `yaml:"mappings"` Proxy string `yaml:"proxy"` @@ -17,9 +17,6 @@ type UncorsConfig struct { Interactive bool `yaml:"-"` } -// LoadConfiguration parses CLI arguments and optionally reads a YAML config file. -// CLI flags take precedence over config file values. -// Returns the loaded config, the active config file path (empty if none), and any error. func LoadConfiguration(fs afero.Fs, args []string) (*UncorsConfig, string, error) { flags := defineFlags() @@ -32,9 +29,9 @@ func LoadConfiguration(fs afero.Fs, args []string) (*UncorsConfig, string, error configPath, _ := flags.GetString("config") if configPath != "" { - readErr := readYAMLFile(fs, cfg, configPath) - if readErr != nil { - return nil, "", readErr + err := readYAMLFile(fs, cfg, configPath) + if err != nil { + return nil, "", err } } @@ -45,11 +42,14 @@ func LoadConfiguration(fs afero.Fs, args []string) (*UncorsConfig, string, error cfg.Mappings = NormaliseMappings(cfg.Mappings) + err = cfg.Validate(fs) + if err != nil { + return nil, "", err + } + return cfg, configPath, nil } -// readYAMLFile opens a YAML config file and decodes it directly into cfg, -// preserving any existing default values for keys absent in the file. func readYAMLFile(fs afero.Fs, cfg *UncorsConfig, path string) error { file, err := fs.Open(path) if err != nil { @@ -66,8 +66,6 @@ func readYAMLFile(fs afero.Fs, cfg *UncorsConfig, path string) error { return nil } -// applyFlagOverrides applies CLI flag values to cfg, overriding any config file values. -// Only flags explicitly set on the command line are applied. func applyFlagOverrides(cfg *UncorsConfig, flags *pflag.FlagSet) error { if flags.Changed("proxy") { cfg.Proxy, _ = flags.GetString("proxy") @@ -86,3 +84,20 @@ func applyFlagOverrides(cfg *UncorsConfig, flags *pflag.FlagSet) error { return mergeURLMappings(cfg, from, to) } + +func (cfg *UncorsConfig) Validate(fs afero.Fs) error { + if len(cfg.Mappings) == 0 { + return &ValidationError{"mappings must not be empty"} + } + + var errs *multierror.Error + + for i, mapping := range cfg.Mappings { + errs = multierror.Append(errs, mapping.Validate(joinPath("mappings", index(i)), fs)) + } + + errs = multierror.Append(errs, ValidateProxy("proxy", cfg.Proxy)) + errs = multierror.Append(errs, cfg.CacheConfig.Validate("cache-config")) + + return joinErrors(errs) +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go index b8233886..e30eae0e 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -49,7 +49,6 @@ mappings: headers: Accept-Encoding: deflate raw: demo - file: /demo.txt proxy: localhost:8080 debug: true https-port: 8081 @@ -99,19 +98,6 @@ func TestLoadConfiguration(t *testing.T) { args []string expected *config.UncorsConfig }{ - { - name: "return default config", - args: []string{}, - expected: &config.UncorsConfig{ - Mappings: config.Mappings{}, - CacheConfig: config.CacheConfig{ - ExpirationTime: config.DefaultExpirationTime, - MaxSize: config.DefaultMaxSize, - Methods: []string{http.MethodGet}, - }, - Interactive: true, - }, - }, { name: "minimal config is set", args: []string{params.Config, minimalConfigPath}, @@ -153,8 +139,7 @@ func TestLoadConfiguration(t *testing.T) { Headers: map[string]string{ acceptEncoding: "deflate", }, - Raw: "demo", - File: "/demo.txt", + Raw: "demo", }, }, }, @@ -235,7 +220,6 @@ func TestLoadConfiguration(t *testing.T) { Code: 201, Headers: map[string]string{acceptEncoding: "deflate"}, Raw: "demo", - File: "/demo.txt", }, }, }, @@ -285,7 +269,8 @@ func TestLoadConfiguration(t *testing.T) { t.Run("returns config file path", func(t *testing.T) { t.Run("empty when no config file flag", func(t *testing.T) { - _, configPath, err := config.LoadConfiguration(afero.NewMemMapFs(), []string{}) + args := []string{params.From, hosts.Localhost1.HTTP(), params.To, hosts.Github.Host()} + _, configPath, err := config.LoadConfiguration(afero.NewMemMapFs(), args) require.NoError(t, err) assert.Empty(t, configPath) }) @@ -303,6 +288,11 @@ func TestLoadConfiguration(t *testing.T) { args []string expectedErr string }{ + { + name: "no args produces validation error", + args: []string{}, + expectedErr: "mappings must not be empty", + }, { name: "incorrect flag provided", args: []string{"--incorrect-flag"}, @@ -357,3 +347,77 @@ func TestLoadConfiguration(t *testing.T) { } }) } + +func TestUncorsConfigValidator(t *testing.T) { + mapFs := testutils.FsFromMap(t, map[string]string{}) + + t.Run("should not register errors for", func(t *testing.T) { + tests := []struct { + name string + value *config.UncorsConfig + }{ + { + name: "minimal config", + value: &config.UncorsConfig{ + Mappings: []config.Mapping{ + {From: hosts.Localhost.Port(8080), To: hosts.Localhost.HTTPSPort(8443)}, + }, + CacheConfig: config.CacheConfig{ + MaxSize: 100 * 1024 * 1024, + ExpirationTime: 10 * time.Minute, + Methods: []string{http.MethodGet}, + }, + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + errors := test.value.Validate(mapFs) + + require.NoError(t, errors) + }) + } + }) + + t.Run("should register errors for invalid config", func(t *testing.T) { + tests := []struct { + name string + value *config.UncorsConfig + error string + }{ + { + name: "invalid mapping", + value: &config.UncorsConfig{ + Mappings: []config.Mapping{}, + CacheConfig: config.CacheConfig{ + MaxSize: 100 * 1024 * 1024, + ExpirationTime: 10 * time.Minute, + Methods: []string{http.MethodGet}, + }, + }, + error: "mappings must not be empty", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + errors := test.value.Validate(mapFs) + + require.EqualError(t, errors, test.error) + }) + } + }) +} + +func TestProxyValidatorIsValid(t *testing.T) { + t.Run("valid url", func(t *testing.T) { + assert.NoError(t, config.ValidateProxy("testField", "http://valid-url.com")) + }) + + t.Run("invalid url", func(t *testing.T) { + require.EqualError(t, config.ValidateProxy("testField", "invalid:::url"), "testField is not a valid URL") + }) + + t.Run("empty url", func(t *testing.T) { + assert.NoError(t, config.ValidateProxy("testField", "")) + }) +} diff --git a/internal/config/default.go b/internal/config/default.go index 9ae177b0..3970fd00 100644 --- a/internal/config/default.go +++ b/internal/config/default.go @@ -12,7 +12,6 @@ const ( DefaultMaxSize = 100 * 1024 * 1024 // 100 MB ) -// defaultConfig returns a new UncorsConfig with all default values applied. func defaultConfig() *UncorsConfig { return &UncorsConfig{ Mappings: Mappings{}, diff --git a/internal/config/har.go b/internal/config/har.go index 1d3e1bd3..9117ff07 100644 --- a/internal/config/har.go +++ b/internal/config/har.go @@ -1,10 +1,12 @@ package config -import "gopkg.in/yaml.v3" +import ( + "fmt" + "path/filepath" + + "gopkg.in/yaml.v3" +) -// HARConfig defines settings for the HAR (HTTP Archive) collector middleware. -// When File is non-empty, all requests/responses passing through the proxy -// for this mapping will be recorded to the specified HAR file. type HARConfig struct { File string `yaml:"file"` CaptureSecureHeaders bool `yaml:"capture-secure-headers"` @@ -21,11 +23,6 @@ func (h *HARConfig) Clone() HARConfig { } } -// UnmarshalYAML allows HARConfig to be specified as a plain string (file path) -// or as a full mapping. -// -// Short form: har: ./recordings/api.har -// Full form: har: { file: ./recordings/api.har, capture-secure-headers: true }. func (h *HARConfig) UnmarshalYAML(value *yaml.Node) error { if value.Kind == yaml.ScalarNode { h.File = value.Value @@ -37,3 +34,15 @@ func (h *HARConfig) UnmarshalYAML(value *yaml.Node) error { return value.Decode((*harConfigAlias)(h)) } + +func (h *HARConfig) Validate(field string) error { + if !h.Enabled() { + return nil + } + + if filepath.Ext(h.File) == "" { + return &ValidationError{fmt.Sprintf("%s: HAR file path %q must have a file extension (e.g. .har)", field, h.File)} + } + + return nil +} diff --git a/internal/config/har_test.go b/internal/config/har_test.go index 9ad961e9..29763d13 100644 --- a/internal/config/har_test.go +++ b/internal/config/har_test.go @@ -59,3 +59,37 @@ har: ./recordings/api.har assert.Equal(t, "./recordings/api.har", actual.HAR.File) assert.False(t, actual.HAR.CaptureSecureHeaders) } + +func TestHARValidator(t *testing.T) { + t.Run("valid cases", func(t *testing.T) { + cases := []struct { + name string + value config.HARConfig + }{ + { + name: "disabled (empty file)", + value: config.HARConfig{}, + }, + { + name: "valid file path with extension", + value: config.HARConfig{File: "output.har"}, + }, + { + name: "path with directory and extension", + value: config.HARConfig{File: "/tmp/trace.har"}, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + assert.NoError(t, tc.value.Validate("mappings[0].har")) + }) + } + }) + + t.Run("invalid cases", func(t *testing.T) { + t.Run("file path without extension", func(t *testing.T) { + assert.Error(t, (&config.HARConfig{File: "outputfile"}).Validate("mappings[0].har")) + }) + }) +} diff --git a/internal/config/helpers.go b/internal/config/helpers.go index 33c459a1..370470ee 100644 --- a/internal/config/helpers.go +++ b/internal/config/helpers.go @@ -14,9 +14,6 @@ var ( ErrNoFromPair = errors.New("`from` values are not set for every `to`") ) -// mergeURLMappings merges from/to CLI pairs into cfg.Mappings. -// If a from URL already exists in the mappings, its to value is updated. -// Otherwise a new mapping entry is appended. func mergeURLMappings(cfg *UncorsConfig, from, to []string) error { if len(from) > len(to) { return ErrNoToPair @@ -54,8 +51,6 @@ const ( httpsScheme = "https" ) -// NormaliseMappings normalises the From URL in each mapping: adds the default -// scheme (http) if absent and removes the port when it equals the scheme default. func NormaliseMappings(mappings Mappings) Mappings { processedMappings := make(Mappings, 0, len(mappings)) diff --git a/internal/config/mapping.go b/internal/config/mapping.go index 2085e644..220edbe2 100644 --- a/internal/config/mapping.go +++ b/internal/config/mapping.go @@ -2,14 +2,16 @@ package config import ( "errors" + "fmt" "net/url" + infratls "github.com/evg4b/uncors/internal/infra/tls" "github.com/evg4b/uncors/internal/urlparser" + multierror "github.com/hashicorp/go-multierror" + "github.com/spf13/afero" "gopkg.in/yaml.v3" ) -// ErrMappingShorthandValue is returned when a URL shorthand mapping has a -// non-string value (e.g. "http://localhost: 123" instead of a URL string). var ErrMappingShorthandValue = errors.New("mapping shorthand value must be a string URL") type Mapping struct { @@ -23,28 +25,17 @@ type Mapping struct { OptionsHandling OptionsHandling `yaml:"options-handling"` HAR HARConfig `yaml:"har"` - // Cached parsed URL and its components (not serialized) fromURL *url.URL `yaml:"-"` fromHost string `yaml:"-"` fromPort string `yaml:"-"` } -// knownMappingFields is the set of yaml keys that belong to a full Mapping -// object. Any single-key YAML map whose key is NOT in this set is interpreted -// as the shorthand "from: to" form. var knownMappingFields = map[string]bool{ "from": true, "to": true, "statics": true, "mocks": true, "scripts": true, "cache": true, "rewrites": true, "options-handling": true, "har": true, } -// UnmarshalYAML decodes a Mapping from YAML. It recognises two forms: -// -// Shorthand — a single-key mapping whose key is not a known field name: -// -// http://localhost:8080: https://example.com -// -// Full form — a standard YAML mapping with "from", "to", and optional fields. func (m *Mapping) UnmarshalYAML(value *yaml.Node) error { if value.Kind == yaml.MappingNode && len(value.Content) == 2 { key := value.Content[0].Value @@ -82,7 +73,6 @@ func (m *Mapping) Clone() Mapping { } } -// GetFromURL returns the parsed URL, caching it on first access. func (m *Mapping) GetFromURL() (*url.URL, error) { if m.fromURL == nil { parsedURL, err := urlparser.Parse(m.From) @@ -96,7 +86,6 @@ func (m *Mapping) GetFromURL() (*url.URL, error) { return m.fromURL, nil } -// GetFromHostPort returns the host and port from the From URL, caching them on first access. func (m *Mapping) GetFromHostPort() (string, string, error) { if m.fromHost == "" && m.fromPort == "" { uri, err := m.GetFromURL() @@ -113,9 +102,70 @@ func (m *Mapping) GetFromHostPort() (string, string, error) { return m.fromHost, m.fromPort, nil } -// ClearCache clears the cached URL and its components. This is primarily used for testing. func (m *Mapping) ClearCache() { m.fromURL = nil m.fromHost = "" m.fromPort = "" } + +func ValidateProxy(field, value string) error { + if value == "" { + return nil + } + + _, err := urlparser.Parse(value) + if err != nil { + return &ValidationError{fmt.Sprintf("%s is not a valid URL", field)} + } + + return nil +} + +func ValidateTLS(_ string, mapping Mapping, fs afero.Fs) error { + fromURL, err := mapping.GetFromURL() + if err != nil { + return nil //nolint:nilerr + } + + if fromURL.Scheme != httpsScheme { + return nil + } + + if !infratls.CAExists(fs) { + return &TLSError{fromURL.Host} + } + + return nil +} + +func (m *Mapping) Validate(field string, fs afero.Fs) error { + var errs *multierror.Error + + errs = multierror.Append(errs, ValidateHost(joinPath(field, "from"), m.From)) + errs = multierror.Append(errs, ValidateHost(joinPath(field, "to"), m.To)) + errs = multierror.Append(errs, m.OptionsHandling.Validate(joinPath(field, "options-handling"))) + errs = multierror.Append(errs, m.HAR.Validate(joinPath(field, "har"))) + errs = multierror.Append(errs, ValidateTLS(field, *m, fs)) + + for i, static := range m.Statics { + errs = multierror.Append(errs, static.Validate(joinPath(field, "statics", index(i)), fs)) + } + + for i, mock := range m.Mocks { + errs = multierror.Append(errs, mock.Validate(joinPath(field, "mocks", index(i)), fs)) + } + + for i, glob := range m.Cache { + errs = multierror.Append(errs, ValidateGlobPattern(joinPath(field, "cache", index(i)), glob)) + } + + for i, rewrite := range m.Rewrites { + errs = multierror.Append(errs, rewrite.Validate(joinPath(field, "rewrite", index(i)))) + } + + for i, script := range m.Scripts { + errs = multierror.Append(errs, script.Validate(joinPath(field, "scripts", index(i)), fs)) + } + + return joinErrors(errs) +} diff --git a/internal/config/mapping_test.go b/internal/config/mapping_test.go index 55da6094..8225e043 100644 --- a/internal/config/mapping_test.go +++ b/internal/config/mapping_test.go @@ -1,10 +1,17 @@ package config_test import ( + "fmt" + "net/http" + "os" + "path/filepath" "testing" "github.com/evg4b/uncors/internal/config" + infratls "github.com/evg4b/uncors/internal/infra/tls" "github.com/evg4b/uncors/testing/hosts" + "github.com/evg4b/uncors/testing/testutils" + "github.com/spf13/afero" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gopkg.in/yaml.v3" @@ -113,3 +120,242 @@ func TestURLMappingClone(t *testing.T) { }) } } + +func TestMappingValidator(t *testing.T) { + const ( + field = "mapping" + demoJSONPath = "/tmp/demo.json" + ) + + t.Run("should not register errors for", func(t *testing.T) { + fs := testutils.FsFromMap(t, map[string]string{ + demoJSONPath: "{}", + }) + + tests := []struct { + name string + value config.Mapping + }{ + { + name: "full filled mapping", + value: config.Mapping{ + From: "localhost", + To: hosts.Github.Host(), + Statics: []config.StaticDirectory{ + {Path: "/", Dir: "/tmp"}, + {Path: "/", Dir: "/tmp"}, + }, + Mocks: []config.Mock{ + { + Matcher: config.RequestMatcher{ + Path: "/api/info", + Method: http.MethodGet, + }, + Response: config.Response{ + Code: 200, + Raw: "test", + }, + }, + { + Matcher: config.RequestMatcher{ + Path: "/api/info/demo", + Method: http.MethodGet, + }, + Response: config.Response{ + Code: 300, + File: demoJSONPath, + }, + }, + }, + Cache: config.CacheGlobs{ + "/api/constants", + "/**", + }, + }, + }, + { + name: "mapping without mocks and statics and caches", + value: config.Mapping{ + From: "localhost", + To: hosts.Github.Host(), + Statics: []config.StaticDirectory{}, + Mocks: []config.Mock{}, + Cache: config.CacheGlobs{}, + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + assert.NoError(t, test.value.Validate(field, fs)) + }) + } + }) + + t.Run("should register errors for", func(t *testing.T) { + fs := testutils.FsFromMap(t, map[string]string{ + demoJSONPath: "{}", + }) + + tests := []struct { + name string + value config.Mapping + error string + }{ + { + name: "mapping without from", + value: config.Mapping{ + From: "", + To: hosts.Github.Host(), + Statics: []config.StaticDirectory{}, + Mocks: []config.Mock{}, + Cache: config.CacheGlobs{}, + }, + error: "mapping.from must not be empty", + }, + { + name: "mapping without to", + value: config.Mapping{ + From: "localhost", + To: "", + Statics: []config.StaticDirectory{}, + Mocks: []config.Mock{}, + Cache: config.CacheGlobs{}, + }, + error: "mapping.to must not be empty", + }, + { + name: "mapping with invalid statics", + value: config.Mapping{ + From: "localhost", + To: hosts.Github.Host(), + Statics: []config.StaticDirectory{ + {Path: "/", Dir: "/tmp"}, + {Path: "/", Dir: ""}, + }, + Mocks: []config.Mock{}, + Cache: config.CacheGlobs{}, + }, + error: "mapping.statics[1].directory must not be empty", + }, + { + name: "mapping with invalid mocks", + value: config.Mapping{ + From: "localhost", + To: hosts.Github.Host(), + Statics: []config.StaticDirectory{}, + Mocks: []config.Mock{ + { + Matcher: config.RequestMatcher{ + Path: "/api/user", + Method: "invalid", + }, + Response: config.Response{ + Code: 200, + Raw: "test", + }, + }, + }, + Cache: config.CacheGlobs{}, + }, + error: "mapping.mocks[0].method must be one of GET, HEAD, POST, PUT, PATCH, DELETE, CONNECT, OPTIONS, TRACE", + }, + { + name: "mapping with invalid cache glob", + value: config.Mapping{ + From: "localhost", + To: hosts.Github.Host(), + Statics: []config.StaticDirectory{}, + Mocks: []config.Mock{}, + Cache: config.CacheGlobs{ + "/api/info[", + }, + }, + error: "mapping.cache[0] is not a valid glob pattern", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require.EqualError(t, test.value.Validate(field, fs), test.error) + }) + } + }) +} + +func TestValidateTLS(t *testing.T) { + t.Run("skip validation for invalid URL", func(t *testing.T) { + err := config.ValidateTLS( + "test", + config.Mapping{From: "://invalid-url", To: hosts.Example.HTTP()}, + afero.NewMemMapFs(), + ) + assert.NoError(t, err) + }) + + t.Run("skip validation for non-HTTPS", func(t *testing.T) { + err := config.ValidateTLS( + "test", + config.Mapping{From: "http://localhost:8080", To: hosts.Example.HTTP()}, + afero.NewMemMapFs(), + ) + assert.NoError(t, err) + }) + + t.Run("error when CA does not exist", func(t *testing.T) { + tmpDir := t.TempDir() + fakeHome := filepath.Join(tmpDir, "home") + require.NoError(t, os.MkdirAll(fakeHome, 0o755)) + t.Setenv("HOME", fakeHome) + + err := config.ValidateTLS("test", + config.Mapping{From: "https://localhost:8443", To: hosts.Example.HTTP()}, + afero.NewOsFs()) + + require.Error(t, err) + assert.Contains(t, err.Error(), "HTTPS mapping 'localhost:8443' requires a local CA certificate") + assert.Contains(t, err.Error(), "uncors generate-certs") + }) + + t.Run("pass when CA exists", func(t *testing.T) { + tmpDir := t.TempDir() + fakeHome := filepath.Join(tmpDir, "home") + require.NoError(t, os.MkdirAll(fakeHome, 0o755)) + t.Setenv("HOME", fakeHome) + + fs := afero.NewOsFs() + caDir := filepath.Join(fakeHome, ".config", "uncors") + _, _, err := infratls.GenerateCA(infratls.CAConfig{ValidityDays: 365, OutputDir: caDir, Fs: fs}) + require.NoError(t, err) + + err = config.ValidateTLS("test", + config.Mapping{From: "https://localhost:8443", To: hosts.Example.HTTP()}, + fs) + + assert.NoError(t, err) + }) +} + +func TestValidateGlobPatternForCache(t *testing.T) { + const field = "cache" + + t.Run("should not register errors for", func(t *testing.T) { + patterns := []string{"/api/**", "/constants", "/translations", "/**/*.js", "/**", "/[12]/demo", "**", "*"} + for _, pattern := range patterns { + p := pattern + t.Run(fmt.Sprintf("%s pattern", p), func(t *testing.T) { + assert.NoError(t, config.ValidateGlobPattern(field, p)) + }) + } + }) + + t.Run("should register errors for", func(t *testing.T) { + tests := []struct{ pattern, error string }{ + {pattern: "/[12/demo", error: "cache is not a valid glob pattern"}, + {pattern: "/{{12}/demo", error: "cache is not a valid glob pattern"}, + } + for _, test := range tests { + t.Run(fmt.Sprintf("%s test", test.pattern), func(t *testing.T) { + require.EqualError(t, config.ValidateGlobPattern(field, test.pattern), test.error) + }) + } + }) +} diff --git a/internal/config/mappings.go b/internal/config/mappings.go index a42e39ca..2b617b1b 100644 --- a/internal/config/mappings.go +++ b/internal/config/mappings.go @@ -72,14 +72,14 @@ func (m Mappings) GroupByPort() PortGroups { panic(fmt.Errorf("failed to parse mapping from URL: %w", err)) } - port := 80 // default HTTP port + port := 80 if portStr != "" { port, err = strconv.Atoi(portStr) if err != nil { panic(fmt.Errorf("invalid port number: %w", err)) } } else if uri.Scheme == "https" { - port = 443 // default HTTPS port + port = 443 } key := portKey{port: port, scheme: uri.Scheme} @@ -95,7 +95,6 @@ func (m Mappings) GroupByPort() PortGroups { }) } - // Sort by port for consistent ordering sort.Slice(result, func(i, j int) bool { if result[i].Port != result[j].Port { return result[i].Port < result[j].Port diff --git a/internal/config/mock.go b/internal/config/mock.go index 4689c485..aa6abdb4 100644 --- a/internal/config/mock.go +++ b/internal/config/mock.go @@ -3,7 +3,9 @@ package config import ( "fmt" + multierror "github.com/hashicorp/go-multierror" "github.com/samber/lo" + "github.com/spf13/afero" ) type Mock struct { @@ -38,3 +40,12 @@ func (m Mocks) Clone() Mocks { return item.Clone() }) } + +func (m *Mock) Validate(field string, fs afero.Fs) error { + var errs *multierror.Error + + errs = multierror.Append(errs, m.Matcher.Validate(field)) + errs = multierror.Append(errs, m.Response.Validate(joinPath(field, "response"), fs)) + + return joinErrors(errs) +} diff --git a/internal/config/mock_test.go b/internal/config/mock_test.go index 3b49aa5d..799e5438 100644 --- a/internal/config/mock_test.go +++ b/internal/config/mock_test.go @@ -3,9 +3,11 @@ package config_test import ( "net/http" "testing" + "time" "github.com/evg4b/uncors/internal/config" "github.com/go-http-utils/headers" + "github.com/spf13/afero" "github.com/stretchr/testify/assert" ) @@ -63,3 +65,21 @@ func TestMockClone(t *testing.T) { assert.Equal(t, mock.Response, clonedMock.Response) }) } + +func TestMockValidator(t *testing.T) { + t.Run("should return true", func(t *testing.T) { + err := (&config.Mock{ + Matcher: config.RequestMatcher{ + Path: "/api/info", + Method: "", + }, + Response: config.Response{ + Code: 200, + Raw: "test", + Delay: 1 * time.Second, + }, + }).Validate("mock", afero.NewMemMapFs()) + + assert.NoError(t, err) + }) +} diff --git a/internal/config/options_handling.go b/internal/config/options_handling.go index b2a3ae71..46caccd9 100644 --- a/internal/config/options_handling.go +++ b/internal/config/options_handling.go @@ -15,3 +15,11 @@ func (o *OptionsHandling) Clone() OptionsHandling { Code: o.Code, } } + +func (o *OptionsHandling) Validate(field string) error { + if o.Code != 0 { + return ValidateStatus(joinPath(field, "code"), o.Code) + } + + return nil +} diff --git a/internal/config/options_handling_test.go b/internal/config/options_handling_test.go index 1e778d0a..3bd913e7 100644 --- a/internal/config/options_handling_test.go +++ b/internal/config/options_handling_test.go @@ -6,6 +6,7 @@ import ( "github.com/evg4b/uncors/internal/config" "github.com/go-http-utils/headers" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestOptionsClone(t *testing.T) { @@ -73,3 +74,25 @@ func TestOptionsClone(t *testing.T) { }) } } + +func TestOptionsValidator(t *testing.T) { + t.Run("should return true", func(t *testing.T) { + t.Run("for default options", func(t *testing.T) { + assert.NoError(t, (&config.OptionsHandling{}).Validate("options")) + }) + + t.Run("for correct status code", func(t *testing.T) { + assert.NoError(t, (&config.OptionsHandling{ + Headers: map[string]string{headers.ContentType: "application/json"}, + Code: 200, + }).Validate("options")) + }) + }) + + t.Run("should return false for invalid status code", func(t *testing.T) { + require.EqualError(t, (&config.OptionsHandling{ + Headers: map[string]string{headers.ContentType: "application/json"}, + Code: -10, + }).Validate("options"), "options.code code must be in range 100-599") + }) +} diff --git a/internal/config/request_matcher.go b/internal/config/request_matcher.go index cd349362..ee665600 100644 --- a/internal/config/request_matcher.go +++ b/internal/config/request_matcher.go @@ -1,6 +1,10 @@ package config -import "github.com/evg4b/uncors/internal/helpers" +import ( + multierror "github.com/hashicorp/go-multierror" + + "github.com/evg4b/uncors/internal/helpers" +) type RequestMatcher struct { Path string `yaml:"path"` @@ -21,3 +25,12 @@ func (r *RequestMatcher) Clone() RequestMatcher { func (r *RequestMatcher) IsPathOnly() bool { return r.Method == "" && len(r.Queries) == 0 && len(r.Headers) == 0 } + +func (r *RequestMatcher) Validate(field string) error { + var errs *multierror.Error + + errs = multierror.Append(errs, ValidatePath(joinPath(field, "path"), r.Path, false)) + errs = multierror.Append(errs, ValidateMethod(joinPath(field, "method"), r.Method, true)) + + return joinErrors(errs) +} diff --git a/internal/config/request_matcher_test.go b/internal/config/request_matcher_test.go index e07a22e0..84446867 100644 --- a/internal/config/request_matcher_test.go +++ b/internal/config/request_matcher_test.go @@ -4,7 +4,9 @@ import ( "testing" "github.com/evg4b/uncors/internal/config" + "github.com/go-http-utils/headers" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestRequestMatcherIsPathOnly(t *testing.T) { @@ -39,3 +41,46 @@ func TestRequestMatcherIsPathOnly(t *testing.T) { assert.True(t, m.IsPathOnly()) }) } + +const requestMatcherTestPath = "/api/test" + +func TestRequestMatcherValidator(t *testing.T) { + t.Run("should not register errors for valid filter with all fields", func(t *testing.T) { + err := (&config.RequestMatcher{ + Path: requestMatcherTestPath, + Method: "GET", + Queries: map[string]string{ + "param1": "value1", + "param2": "value2", + }, + Headers: map[string]string{ + headers.ContentType: "application/json", + headers.Accept: "application/json", + }, + }).Validate("test") + assert.NoError(t, err) + }) + + t.Run("should not register errors for valid filter with minimal fields", func(t *testing.T) { + assert.NoError(t, (&config.RequestMatcher{Path: requestMatcherTestPath}).Validate("test")) + }) + + t.Run("should register error for invalid path", func(t *testing.T) { + err := (&config.RequestMatcher{Path: "", Method: "GET"}).Validate("test") + require.Error(t, err) + assert.Contains(t, err.Error(), "path must not be empty") + }) + + t.Run("should register error for invalid method", func(t *testing.T) { + err := (&config.RequestMatcher{Path: requestMatcherTestPath, Method: "INVALID"}).Validate("test") + require.Error(t, err) + assert.Contains(t, err.Error(), "method must be one of") + }) + + t.Run("should register multiple validation errors", func(t *testing.T) { + err := (&config.RequestMatcher{Path: "", Method: "INVALID"}).Validate("test") + require.Error(t, err) + assert.Contains(t, err.Error(), "path must not be empty") + assert.Contains(t, err.Error(), "method must be one of") + }) +} diff --git a/internal/config/response.go b/internal/config/response.go index 08eb84cb..ded3dc88 100644 --- a/internal/config/response.go +++ b/internal/config/response.go @@ -2,17 +2,17 @@ package config import ( "fmt" - "strings" "time" "github.com/evg4b/uncors/internal/helpers" - "gopkg.in/yaml.v3" + multierror "github.com/hashicorp/go-multierror" + "github.com/spf13/afero" ) type Response struct { Code int `yaml:"code"` Headers map[string]string `yaml:"headers"` - Delay time.Duration `yaml:"-"` + Delay time.Duration `yaml:"delay"` Raw string `yaml:"raw"` File string `yaml:"file"` } @@ -35,40 +35,28 @@ func (r *Response) IsFile() bool { return len(r.File) > 0 } -// UnmarshalYAML implements custom decoding so that the "delay" field can be -// expressed as a human-readable duration string (e.g. "200ms", "1s 500ms"). -// All other fields are decoded by the standard yaml.v3 machinery. -func (r *Response) UnmarshalYAML(value *yaml.Node) error { - type responseRaw struct { - Code int `yaml:"code"` - Headers map[string]string `yaml:"headers"` - Delay string `yaml:"delay"` - Raw string `yaml:"raw"` - File string `yaml:"file"` +func (r *Response) Validate(field string, fs afero.Fs) error { + var errs *multierror.Error + + errs = multierror.Append(errs, ValidateStatus(joinPath(field, "code"), r.Code)) + errs = multierror.Append(errs, ValidateDuration(joinPath(field, "delay"), r.Delay, true)) + + switch { + case r.Raw == "" && r.File == "": + errs = multierror.Append(errs, &ValidationError{fmt.Sprintf( + "%s or %s must be set", + joinPath(field, "raw"), + joinPath(field, "file"), + )}) + case r.Raw != "" && r.File != "": + errs = multierror.Append(errs, &ValidationError{fmt.Sprintf( + "only one of %s or %s must be set", + joinPath(field, "raw"), + joinPath(field, "file"), + )}) + case r.File != "": + errs = multierror.Append(errs, ValidateFile(joinPath(field, "file"), r.File, fs)) } - var raw responseRaw - - err := value.Decode(&raw) - if err != nil { - return err - } - - r.Code = raw.Code - r.Headers = raw.Headers - r.Raw = raw.Raw - r.File = raw.File - - if raw.Delay == "" { - return nil - } - - dur, err := time.ParseDuration(strings.ReplaceAll(raw.Delay, " ", "")) - if err != nil { - return fmt.Errorf("invalid delay %q: %w", raw.Delay, err) - } - - r.Delay = dur - - return nil + return joinErrors(errs) } diff --git a/internal/config/response_test.go b/internal/config/response_test.go index 6d895ee1..7f63e0c8 100644 --- a/internal/config/response_test.go +++ b/internal/config/response_test.go @@ -6,6 +6,7 @@ import ( "time" "github.com/evg4b/uncors/internal/config" + "github.com/evg4b/uncors/testing/testutils" "github.com/go-http-utils/headers" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -49,7 +50,7 @@ file: ./body.json }) t.Run("parses delay with embedded spaces", func(t *testing.T) { - const input = `delay: "1s 500ms"` + const input = `delay: "1s500ms"` var actual config.Response @@ -124,3 +125,64 @@ func TestResponseClone(t *testing.T) { }) }) } + +func TestResponseValidator(t *testing.T) { + const file = "testdata/file.txt" + + fs := testutils.FsFromMap(t, map[string]string{file: "test"}) + + t.Run("should not register errors if response is valid", func(t *testing.T) { + tests := []struct { + name string + value config.Response + }{ + {name: "with file", value: config.Response{Code: 200, File: file, Delay: 3 * time.Second}}, + {name: "with raw", value: config.Response{Code: 200, Raw: `{ "test": "test" }`, Delay: 3 * time.Second}}, + {name: "without delay", value: config.Response{Code: 200, Raw: `{ "test": "test" }`}}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + assert.NoError(t, test.value.Validate("test", fs)) + }) + } + }) + + t.Run("should register errors for", func(t *testing.T) { + tests := []struct { + name string + value config.Response + error string + }{ + { + name: "code", + value: config.Response{Code: 0, File: file, Delay: 3 * time.Second}, + error: "test.code code must be in range 100-599", + }, + { + name: "file", + value: config.Response{Code: 200, File: "testdata/unknown.txt", Delay: 3 * time.Second}, + error: "test.file testdata/unknown.txt does not exist", + }, + { + name: "delay", + value: config.Response{Code: 200, File: file, Delay: -1 * time.Second}, + error: "test.delay must be greater than or equal to 0", + }, + { + name: "both empty", + value: config.Response{Code: 200, Delay: 3 * time.Second}, + error: "test.raw or test.file must be set", + }, + { + name: "both set", + value: config.Response{Code: 200, File: file, Raw: "test", Delay: 3 * time.Second}, + error: "only one of test.raw or test.file must be set", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require.EqualError(t, test.value.Validate("test", fs), test.error) + }) + } + }) +} diff --git a/internal/config/rewrite.go b/internal/config/rewrite.go index c86602f3..fc3d8f0f 100644 --- a/internal/config/rewrite.go +++ b/internal/config/rewrite.go @@ -1,5 +1,11 @@ package config +import ( + "slices" + + multierror "github.com/hashicorp/go-multierror" +) + type RewritingOption struct { From string `yaml:"from"` To string `yaml:"to"` @@ -7,24 +13,24 @@ type RewritingOption struct { } func (r RewritingOption) Clone() RewritingOption { - return RewritingOption{ - From: r.From, - To: r.To, - Host: r.Host, - } + return r } type RewriteOptions []RewritingOption func (r RewriteOptions) Clone() RewriteOptions { - if r == nil { - return nil - } + return slices.Clone(r) +} + +func (r RewritingOption) Validate(field string) error { + var errs *multierror.Error + + errs = multierror.Append(errs, ValidatePath(joinPath(field, "from"), r.From, true)) + errs = multierror.Append(errs, ValidatePath(joinPath(field, "to"), r.To, true)) - clone := make(RewriteOptions, len(r)) - for i, rewrite := range r { - clone[i] = rewrite.Clone() + if r.Host != "" { + errs = multierror.Append(errs, ValidateHost(joinPath(field, "host"), r.Host)) } - return clone + return joinErrors(errs) } diff --git a/internal/config/rewrite_test.go b/internal/config/rewrite_test.go index f60d74e7..c1acd9cc 100644 --- a/internal/config/rewrite_test.go +++ b/internal/config/rewrite_test.go @@ -4,7 +4,9 @@ import ( "testing" "github.com/evg4b/uncors/internal/config" + "github.com/evg4b/uncors/testing/hosts" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestRewritingOptionClone(t *testing.T) { @@ -80,3 +82,62 @@ func TestRewriteOptionsClone(t *testing.T) { }) } } + +const ( + fromPath = "/from/path" + toPath = "/to/path" +) + +func TestRewritingOptionValidatorIsValidNoError(t *testing.T) { + tests := []struct { + name string + value config.RewritingOption + }{ + {name: "valid paths and host", value: config.RewritingOption{From: fromPath, To: toPath, Host: hosts.Github.Host()}}, + {name: "no host", value: config.RewritingOption{From: fromPath, To: toPath}}, + { + name: "relative from path", + value: config.RewritingOption{From: "../relative/from/path", To: toPath, Host: hosts.Github.Host()}, + }, + { + name: "relative to path", + value: config.RewritingOption{From: fromPath, To: "../relative/to/path", Host: hosts.Github.Host()}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.NoError(t, tt.value.Validate("testField")) + }) + } +} + +func TestRewritingOptionValidatorIsValidWithError(t *testing.T) { + tests := []struct { + name string + value config.RewritingOption + error string + }{ + { + name: "invalid from path", + value: config.RewritingOption{From: "", To: toPath, Host: hosts.Github.Host()}, + error: "testField.from must not be empty", + }, + { + name: "invalid to path", + value: config.RewritingOption{From: fromPath, To: "", Host: hosts.Github.Host()}, + error: "testField.to must not be empty", + }, + { + name: "invalid host format", + value: config.RewritingOption{From: fromPath, To: toPath, Host: "&&&"}, + error: "testField.host is not a valid host", + }, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + require.EqualError(t, testCase.value.Validate("testField"), testCase.error) + }) + } +} diff --git a/internal/config/script.go b/internal/config/script.go index 1f8ba27d..0313be33 100644 --- a/internal/config/script.go +++ b/internal/config/script.go @@ -3,7 +3,9 @@ package config import ( "fmt" + multierror "github.com/hashicorp/go-multierror" "github.com/samber/lo" + "github.com/spf13/afero" ) type Script struct { @@ -45,3 +47,32 @@ func (s Scripts) Clone() Scripts { return item.Clone() }) } + +func (s *Script) Validate(field string, fs afero.Fs) error { + var errs *multierror.Error + + errs = multierror.Append(errs, s.Matcher.Validate(field)) + + switch { + case s.Script == "" && s.File == "": + scriptField := joinPath(field, "script") + fileField := joinPath(field, "file") + + const neitherMsg = ": either 'script' or 'file' must be provided" + + errs = multierror.Append(errs, &ValidationError{scriptField + neitherMsg}) + errs = multierror.Append(errs, &ValidationError{fileField + neitherMsg}) + case s.Script != "" && s.File != "": + scriptField := joinPath(field, "script") + fileField := joinPath(field, "file") + + const bothMsg = ": only one of 'script' or 'file' can be provided" + + errs = multierror.Append(errs, &ValidationError{scriptField + bothMsg}) + errs = multierror.Append(errs, &ValidationError{fileField + bothMsg}) + case s.File != "": + errs = multierror.Append(errs, ValidateFile(joinPath(field, "file"), s.File, fs)) + } + + return joinErrors(errs) +} diff --git a/internal/config/script_test.go b/internal/config/script_test.go index 17e2b1cc..4bdb5a7a 100644 --- a/internal/config/script_test.go +++ b/internal/config/script_test.go @@ -4,8 +4,10 @@ import ( "testing" "github.com/evg4b/uncors/internal/config" + "github.com/evg4b/uncors/testing/testutils" "github.com/go-http-utils/headers" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestRequestMatcher_Clone(t *testing.T) { @@ -169,3 +171,125 @@ func TestScripts_Clone(t *testing.T) { assert.NotNil(t, cloned) }) } + +const ( + testAPIPath = "/api/test" + testScriptContent = "response.status = 200" + testScriptFilePath = "/scripts/test.lua" + scriptPathField = "script.path" + scriptScriptField = "script.script" + scriptFileField = "script.file" +) + +func TestScriptValidator(t *testing.T) { + noFS := testutils.FsFromMap(t, map[string]string{}) + + t.Run("valid inline script", func(t *testing.T) { + err := (&config.Script{ + Matcher: config.RequestMatcher{Path: testAPIPath, Method: "GET"}, + Script: testScriptContent, + }).Validate("script", noFS) + assert.NoError(t, err) + }) + + t.Run("valid file script", func(t *testing.T) { + fs := testutils.FsFromMap(t, map[string]string{testScriptFilePath: testScriptContent}) + + err := (&config.Script{ + Matcher: config.RequestMatcher{Path: testAPIPath, Method: "POST"}, + File: testScriptFilePath, + }).Validate("script", fs) + assert.NoError(t, err) + }) + + t.Run("empty method is allowed", func(t *testing.T) { + err := (&config.Script{ + Matcher: config.RequestMatcher{Path: testAPIPath}, + Script: testScriptContent, + }).Validate("script", noFS) + assert.NoError(t, err) + }) + + t.Run("valid queries and headers", func(t *testing.T) { + err := (&config.Script{ + Matcher: config.RequestMatcher{ + Path: "/api/test", + Queries: map[string]string{"filter": "active"}, + Headers: map[string]string{headers.Authorization: "Bearer token"}, + }, + Script: testScriptContent, + }).Validate("script", noFS) + assert.NoError(t, err) + }) + + t.Run("empty path", func(t *testing.T) { + err := (&config.Script{ + Matcher: config.RequestMatcher{Path: ""}, + Script: testScriptContent, + }).Validate("script", noFS) + require.Error(t, err) + assert.Contains(t, err.Error(), scriptPathField) + }) + + t.Run("invalid path", func(t *testing.T) { + err := (&config.Script{ + Matcher: config.RequestMatcher{Path: "invalid-path"}, + Script: testScriptContent, + }).Validate("script", noFS) + require.Error(t, err) + assert.Contains(t, err.Error(), scriptPathField) + }) + + t.Run("invalid method", func(t *testing.T) { + err := (&config.Script{ + Matcher: config.RequestMatcher{Path: testAPIPath, Method: "INVALID"}, + Script: testScriptContent, + }).Validate("script", noFS) + require.Error(t, err) + assert.Contains(t, err.Error(), "script.method") + }) + + t.Run("neither script nor file provided", func(t *testing.T) { + err := (&config.Script{ + Matcher: config.RequestMatcher{Path: testAPIPath, Method: "GET"}, + }).Validate("script", noFS) + require.Error(t, err) + assert.Contains(t, err.Error(), scriptScriptField) + assert.Contains(t, err.Error(), scriptFileField) + assert.Contains(t, err.Error(), "either 'script' or 'file' must be provided") + }) + + t.Run("both script and file provided", func(t *testing.T) { + fs := testutils.FsFromMap(t, map[string]string{testScriptFilePath: testScriptContent}) + + err := (&config.Script{ + Matcher: config.RequestMatcher{Path: "/api/test"}, + Script: testScriptContent, + File: "/scripts/test.lua", + }).Validate("script", fs) + require.Error(t, err) + assert.Contains(t, err.Error(), scriptScriptField) + assert.Contains(t, err.Error(), scriptFileField) + assert.Contains(t, err.Error(), "only one of 'script' or 'file' can be provided") + }) + + t.Run("file does not exist", func(t *testing.T) { + err := (&config.Script{ + Matcher: config.RequestMatcher{Path: "/api/test"}, + File: "/scripts/nonexistent.lua", + }).Validate("script", noFS) + require.Error(t, err) + assert.Contains(t, err.Error(), scriptFileField) + }) + + t.Run("multiple errors", func(t *testing.T) { + err := (&config.Script{ + Matcher: config.RequestMatcher{Path: "", Method: "INVALID"}, + }).Validate("script", noFS) + require.Error(t, err) + assert.Contains(t, err.Error(), "script.path") + assert.Contains(t, err.Error(), "script.method") + assert.Contains(t, err.Error(), "script.script") + assert.Contains(t, err.Error(), "script.file") + }) +} diff --git a/internal/config/static.go b/internal/config/static.go index 92b54b25..3cc5d01a 100644 --- a/internal/config/static.go +++ b/internal/config/static.go @@ -2,8 +2,11 @@ package config import ( "fmt" + "path" + multierror "github.com/hashicorp/go-multierror" "github.com/samber/lo" + "github.com/spf13/afero" "gopkg.in/yaml.v3" ) @@ -37,21 +40,6 @@ func (s *StaticDirectories) Clone() StaticDirectories { }) } -// UnmarshalYAML allows StaticDirectories to be specified as a YAML mapping -// (shorthand: path → dir or path → {dir, index}) as well as a sequence of -// full StaticDirectory objects. -// -// Map form: -// -// statics: -// /path: /static-dir -// /other: { dir: /other-dir, index: index.html } -// -// Sequence form: -// -// statics: -// - path: /path -// dir: /static-dir func (s *StaticDirectories) UnmarshalYAML(value *yaml.Node) error { if value.Kind == yaml.MappingNode { for i := 0; i+1 < len(value.Content); i += 2 { @@ -81,3 +69,16 @@ func (s *StaticDirectories) UnmarshalYAML(value *yaml.Node) error { return value.Decode((*staticDirectoriesAlias)(s)) } + +func (s *StaticDirectory) Validate(field string, fs afero.Fs) error { + var errs *multierror.Error + + errs = multierror.Append(errs, ValidatePath(joinPath(field, "path"), s.Path, false)) + errs = multierror.Append(errs, ValidateDirectory(joinPath(field, "directory"), s.Dir, fs)) + + if s.Index != "" { + errs = multierror.Append(errs, ValidateFile(joinPath(field, "index"), path.Join(s.Dir, s.Index), fs)) + } + + return joinErrors(errs) +} diff --git a/internal/config/static_test.go b/internal/config/static_test.go index 8718c419..e78a5110 100644 --- a/internal/config/static_test.go +++ b/internal/config/static_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/evg4b/uncors/internal/config" + "github.com/evg4b/uncors/testing/testutils" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gopkg.in/yaml.v3" @@ -163,3 +164,63 @@ func TestStaticDirMappingClone(t *testing.T) { }) } } + +func TestStaticValidator(t *testing.T) { + const ( + assetsPath = "/assets" + staticPath = "/static" + indexFilePath = "/static/index.html" + ) + + fs := testutils.FsFromMap(t, map[string]string{indexFilePath: indexFilePath}) + + t.Run("should not register errors if response is valid", func(t *testing.T) { + tests := []struct { + name string + value config.StaticDirectory + }{ + { + name: "valid static directory with index", + value: config.StaticDirectory{Path: assetsPath, Dir: staticPath, Index: "index.html"}, + }, + { + name: "valid static directory without index", + value: config.StaticDirectory{Path: assetsPath, Dir: staticPath}, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + assert.NoError(t, test.value.Validate("test", fs)) + }) + } + }) + + t.Run("should register errors if response is invalid", func(t *testing.T) { + tests := []struct { + name string + value config.StaticDirectory + error string + }{ + { + name: "empty path", + value: config.StaticDirectory{Path: "", Dir: staticPath}, + error: "test.path must not be empty", + }, + { + name: "empty directory", + value: config.StaticDirectory{Path: assetsPath, Dir: ""}, + error: "test.directory must not be empty", + }, + { + name: "missing index file", + value: config.StaticDirectory{Path: assetsPath, Dir: staticPath, Index: "index.php"}, + error: "test.index /static/index.php does not exist", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require.EqualError(t, test.value.Validate("test", fs), test.error) + }) + } + }) +} diff --git a/internal/config/time_decode_hook_test.go b/internal/config/time_decode_hook_test.go index e15233e3..b5b6b462 100644 --- a/internal/config/time_decode_hook_test.go +++ b/internal/config/time_decode_hook_test.go @@ -22,16 +22,6 @@ func TestCacheConfigDurationUnmarshal(t *testing.T) { input: "expiration-time: 3h6m13s", expected: 3*time.Hour + 6*time.Minute + 13*time.Second, }, - { - name: "duration with spaces", - input: "expiration-time: \"1m 4s\"", - expected: 1*time.Minute + 4*time.Second, - }, - { - name: "duration with mixed spaces", - input: "expiration-time: \"1h 3m59s 40ms\"", - expected: 1*time.Hour + 3*time.Minute + 59*time.Second + 40*time.Millisecond, - }, } for _, testCase := range tests { @@ -55,19 +45,12 @@ func TestCacheConfigDurationUnmarshal(t *testing.T) { assert.Equal(t, int64(1048576), cfg.MaxSize) }) - t.Run("returns error for non-mapping input", func(t *testing.T) { - var cfg config.CacheConfig - - err := yaml.Unmarshal([]byte("just-a-string"), &cfg) - assert.ErrorIs(t, err, config.ErrInvalidCacheConfig) - }) - t.Run("returns error for invalid duration string", func(t *testing.T) { var cfg config.CacheConfig err := yaml.Unmarshal([]byte("expiration-time: notaduration"), &cfg) require.Error(t, err) - assert.Contains(t, err.Error(), "invalid expiration-time") + assert.Contains(t, err.Error(), "cannot unmarshal !!str `notadur...` into time.Duration") }) t.Run("returns error when max-size is not a number", func(t *testing.T) { @@ -98,9 +81,9 @@ func TestResponseDelayUnmarshal(t *testing.T) { expected: 200 * time.Millisecond, }, { - name: "delay with spaces", - input: "delay: \"1s 500ms\"", - expected: 1*time.Second + 500*time.Millisecond, + name: "a houd with 500 milliseconds", + input: "delay: \"1h500ms\"", + expected: 1*time.Hour + 500*time.Millisecond, }, } @@ -118,7 +101,7 @@ func TestResponseDelayUnmarshal(t *testing.T) { err := yaml.Unmarshal([]byte("delay: notaduration"), &resp) require.Error(t, err) - assert.Contains(t, err.Error(), "invalid delay") + assert.Contains(t, err.Error(), "cannot unmarshal !!str `notadur...` into time.Duration") }) t.Run("zero delay when field absent", func(t *testing.T) { diff --git a/internal/config/validate_errors.go b/internal/config/validate_errors.go new file mode 100644 index 00000000..bfd86ed7 --- /dev/null +++ b/internal/config/validate_errors.go @@ -0,0 +1,47 @@ +package config + +import ( + "fmt" + "strings" + + multierror "github.com/hashicorp/go-multierror" +) + +type ValidationError struct { + Message string +} + +func (e *ValidationError) Error() string { + return e.Message +} + +type TLSError struct { + Host string +} + +func (e *TLSError) Error() string { + var builder strings.Builder + fmt.Fprintf(&builder, "HTTPS mapping '%s' requires a local CA certificate for automatic TLS.\n\n", e.Host) + builder.WriteString("Generate a local CA certificate:\n") + builder.WriteString(" uncors generate-certs\n\n") + builder.WriteString("After generating CA, you can add it to your system's trusted certificates.") + + return builder.String() +} + +func joinErrors(errs *multierror.Error) error { + if errs == nil { + return nil + } + + errs.ErrorFormat = func(errs []error) string { + msgs := make([]string, len(errs)) + for i, e := range errs { + msgs[i] = e.Error() + } + + return strings.Join(msgs, "\n") + } + + return errs.ErrorOrNil() +} diff --git a/internal/config/validators/helpers.go b/internal/config/validate_helpers.go similarity index 95% rename from internal/config/validators/helpers.go rename to internal/config/validate_helpers.go index 34a2ad75..12050769 100644 --- a/internal/config/validators/helpers.go +++ b/internal/config/validate_helpers.go @@ -1,4 +1,4 @@ -package validators +package config import ( "fmt" diff --git a/internal/config/validators/helpers_internal_test.go b/internal/config/validate_helpers_internal_test.go similarity index 98% rename from internal/config/validators/helpers_internal_test.go rename to internal/config/validate_helpers_internal_test.go index 28a34657..5d1df147 100644 --- a/internal/config/validators/helpers_internal_test.go +++ b/internal/config/validate_helpers_internal_test.go @@ -1,4 +1,4 @@ -package validators +package config import ( "testing" diff --git a/internal/config/validate_primitives.go b/internal/config/validate_primitives.go new file mode 100644 index 00000000..9be0ec76 --- /dev/null +++ b/internal/config/validate_primitives.go @@ -0,0 +1,184 @@ +package config + +import ( + "fmt" + "net/http" + "os" + "slices" + "strings" + "time" + + multierror "github.com/hashicorp/go-multierror" + + "github.com/bmatcuk/doublestar/v4" + "github.com/evg4b/uncors/internal/urlparser" + "github.com/spf13/afero" +) + +const maxHostLength = 255 + +func ValidateHost(field, value string) error { + if value == "" { + return &ValidationError{fmt.Sprintf("%s must not be empty", field)} + } + + if len(value) > maxHostLength { + return &ValidationError{fmt.Sprintf("%s must not be longer than 255 characters, but got %d", field, len(value))} + } + + uri, err := urlparser.Parse(value) + if err != nil { + return &ValidationError{fmt.Sprintf("%s is not a valid host", field)} + } + + var errs *multierror.Error + + if uri.Path != "" { + errs = multierror.Append(errs, &ValidationError{fmt.Sprintf("%s must not contain a path", field)}) + } + + if uri.RawQuery != "" { + errs = multierror.Append(errs, &ValidationError{fmt.Sprintf("%s must not contain a query", field)}) + } + + if uri.Scheme != "http" && uri.Scheme != httpsScheme && uri.Scheme != "" { + errs = multierror.Append(errs, &ValidationError{fmt.Sprintf("%s scheme must be http or https", field)}) + } + + return joinErrors(errs) +} + +func ValidatePath(field, value string, relative bool) error { + if value == "" { + return &ValidationError{fmt.Sprintf("%s must not be empty", field)} + } + + if !relative && !strings.HasPrefix(value, "/") { + return &ValidationError{fmt.Sprintf("%s must be absolute and start with /", field)} + } + + uri, err := urlparser.Parse("//localhost/" + strings.TrimPrefix(value, "/")) + if err != nil { + return &ValidationError{fmt.Sprintf("%s is not a valid path", field)} + } + + if uri.RawQuery != "" { + return &ValidationError{fmt.Sprintf("%s must not contain a query", field)} + } + + return nil +} + +func ValidateFile(field, value string, fs afero.Fs) error { + stat, err := fs.Stat(value) + if err != nil { + switch { + case os.IsNotExist(err): + return &ValidationError{fmt.Sprintf("%s %s does not exist", field, value)} + case os.IsPermission(err): + return &ValidationError{fmt.Sprintf("%s %s is not accessible", field, value)} + default: + return &ValidationError{fmt.Sprintf("%s %s is not a file", field, value)} + } + } + + if stat.IsDir() { + return &ValidationError{fmt.Sprintf("%s %s is a directory", field, value)} + } + + return nil +} + +func ValidateDirectory(field, value string, fs afero.Fs) error { + if value == "" { + return &ValidationError{fmt.Sprintf("%s must not be empty", field)} + } + + stat, err := fs.Stat(value) + if err != nil { + switch { + case os.IsNotExist(err): + return &ValidationError{fmt.Sprintf("%s directory does not exist", field)} + case os.IsPermission(err): + return &ValidationError{fmt.Sprintf("%s directory is not accessible", field)} + default: + return &ValidationError{fmt.Sprintf("%s is not a directory", field)} + } + } + + if !stat.IsDir() { + return &ValidationError{fmt.Sprintf("%s is not a directory", field)} + } + + return nil +} + +func ValidateStatus(field string, value int) error { + if value < 100 || value > 599 { + return &ValidationError{fmt.Sprintf("%s code must be in range 100-599", field)} + } + + return nil +} + +func ValidateDuration(field string, value time.Duration, allowZero bool) error { + if allowZero { + if value < 0 { + return &ValidationError{fmt.Sprintf("%s must be greater than or equal to 0", field)} + } + } else { + if value <= 0 { + return &ValidationError{fmt.Sprintf("%s must be greater than 0", field)} + } + } + + return nil +} + +var allowedMethods = []string{ + http.MethodGet, + http.MethodHead, + http.MethodPost, + http.MethodPut, + http.MethodPatch, + http.MethodDelete, + http.MethodConnect, + http.MethodOptions, + http.MethodTrace, +} + +func ValidateMethod(field, value string, allowEmpty bool) error { + if allowEmpty && value == "" { + return nil + } + + if !slices.Contains(allowedMethods, value) { + return &ValidationError{fmt.Sprintf("%s must be one of %s", field, strings.Join(allowedMethods, ", "))} + } + + return nil +} + +func ValidatePort(field string, value int) error { + if value < 1 || value > 65535 { + return &ValidationError{fmt.Sprintf("%s must be between 1 and 65535", field)} + } + + return nil +} + +func ValidateGlobPattern(field, value string) error { + if !doublestar.ValidatePathPattern(value) { + return &ValidationError{fmt.Sprintf("%s is not a valid glob pattern", field)} + } + + return nil +} + +func ValidateStringEnum(_ string, value string, options []string) error { + if !slices.Contains(options, value) { + return &ValidationError{fmt.Sprintf("'%s' is not a valid option", value)} + } + + return nil +} diff --git a/internal/config/validators/primitives_test.go b/internal/config/validate_primitives_test.go similarity index 51% rename from internal/config/validators/primitives_test.go rename to internal/config/validate_primitives_test.go index 7e1f28e5..2edbee5e 100644 --- a/internal/config/validators/primitives_test.go +++ b/internal/config/validate_primitives_test.go @@ -1,4 +1,4 @@ -package validators_test +package config_test import ( "fmt" @@ -8,28 +8,24 @@ import ( "testing" "time" - "github.com/evg4b/uncors/internal/config/validators" + "github.com/evg4b/uncors/internal/config" "github.com/evg4b/uncors/testing/hosts" "github.com/evg4b/uncors/testing/testutils" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func runOK(t *testing.T, name string, fn func(*validators.Errors)) { +func runOK(t *testing.T, name string, fn func() error) { t.Helper() t.Run(name, func(t *testing.T) { - var errs validators.Errors - fn(&errs) - assert.False(t, errs.HasAny(), "expected no errors, got: %v", errs) + assert.NoError(t, fn()) }) } -func runErr(t *testing.T, name, expected string, fn func(*validators.Errors)) { +func runErr(t *testing.T, name, expected string, fn func() error) { t.Helper() t.Run(name, func(t *testing.T) { - var errs validators.Errors - fn(&errs) - require.EqualError(t, errs, expected) + require.EqualError(t, fn(), expected) }) } @@ -39,25 +35,25 @@ func TestValidateHost(t *testing.T) { const field = "field" t.Run("valid", func(t *testing.T) { - runOK(t, "bare host", func(e *validators.Errors) { validators.ValidateHost(field, hosts.Localhost.Host(), e) }) - runOK(t, "http scheme", func(e *validators.Errors) { validators.ValidateHost(field, hosts.Github.HTTP(), e) }) - runOK(t, "https scheme", func(e *validators.Errors) { validators.ValidateHost(field, hosts.Github.HTTPS(), e) }) - runOK(t, "ip address", func(e *validators.Errors) { validators.ValidateHost(field, hosts.Loopback.Host(), e) }) + runOK(t, "bare host", func() error { return config.ValidateHost(field, hosts.Localhost.Host()) }) + runOK(t, "http scheme", func() error { return config.ValidateHost(field, hosts.Github.HTTP()) }) + runOK(t, "https scheme", func() error { return config.ValidateHost(field, hosts.Github.HTTPS()) }) + runOK(t, "ip address", func() error { return config.ValidateHost(field, hosts.Loopback.Host()) }) }) t.Run("invalid", func(t *testing.T) { runErr(t, "empty", "field must not be empty", - func(e *validators.Errors) { validators.ValidateHost(field, "", e) }) + func() error { return config.ValidateHost(field, "") }) runErr(t, "too long", "field must not be longer than 255 characters, but got 256", - func(e *validators.Errors) { validators.ValidateHost(field, strings.Repeat("a", 256), e) }) + func() error { return config.ValidateHost(field, strings.Repeat("a", 256)) }) runErr(t, "with path", "field must not contain a path", - func(e *validators.Errors) { validators.ValidateHost(field, "example.com/path", e) }) + func() error { return config.ValidateHost(field, "example.com/path") }) runErr(t, "with query", "field must not contain a query", - func(e *validators.Errors) { validators.ValidateHost(field, "example.com?query=1", e) }) + func() error { return config.ValidateHost(field, "example.com?query=1") }) runErr(t, "unsupported scheme", "field scheme must be http or https", - func(e *validators.Errors) { validators.ValidateHost(field, hosts.Localhost.Scheme("ftp"), e) }) + func() error { return config.ValidateHost(field, hosts.Localhost.Scheme("ftp")) }) runErr(t, "invalid host", "field is not a valid host", - func(e *validators.Errors) { validators.ValidateHost(field, "loca:::lhost", e) }) + func() error { return config.ValidateHost(field, "loca:::lhost") }) }) } @@ -67,8 +63,8 @@ func TestValidatePath(t *testing.T) { const field = "field" t.Run("valid absolute", func(t *testing.T) { - runOK(t, "root", func(e *validators.Errors) { validators.ValidatePath(field, "/", false, e) }) - runOK(t, "api path", func(e *validators.Errors) { validators.ValidatePath(field, "/api/info", false, e) }) + runOK(t, "root", func() error { return config.ValidatePath(field, "/", false) }) + runOK(t, "api path", func() error { return config.ValidatePath(field, "/api/info", false) }) }) } @@ -80,16 +76,16 @@ func TestValidateFile(t *testing.T) { t.Run("valid file", func(t *testing.T) { path := "/demo/file.go" fs := testutils.FsFromMap(t, map[string]string{path: "package validators"}) - runOK(t, "existing file", func(e *validators.Errors) { validators.ValidateFile(field, path, fs, e) }) + runOK(t, "existing file", func() error { return config.ValidateFile(field, path, fs) }) }) fs := testutils.FsFromMap(t, map[string]string{"file.go": "package validators"}) testutils.CheckNoError(t, fs.Mkdir("/demo", 0o755)) runErr(t, "does not exist", "test file_does_not_exist.go does not exist", - func(e *validators.Errors) { validators.ValidateFile(field, "file_does_not_exist.go", fs, e) }) + func() error { return config.ValidateFile(field, "file_does_not_exist.go", fs) }) runErr(t, "is a directory", "test /demo is a directory", - func(e *validators.Errors) { validators.ValidateFile(field, "/demo", fs, e) }) + func() error { return config.ValidateFile(field, "/demo", fs) }) } // ---- ValidateDirectory --------------------------------------------------- @@ -103,14 +99,14 @@ func TestValidateDirectory(t *testing.T) { fs := testutils.FsFromMap(t, map[string]string{"file.go": "package validators"}) testutils.CheckNoError(t, fs.Mkdir(dir, 0o755)) - runOK(t, "existing directory", func(e *validators.Errors) { validators.ValidateDirectory(field, dir, fs, e) }) + runOK(t, "existing directory", func() error { return config.ValidateDirectory(field, dir, fs) }) runErr(t, "empty path", "test must not be empty", - func(e *validators.Errors) { validators.ValidateDirectory(field, "", fs, e) }) + func() error { return config.ValidateDirectory(field, "", fs) }) runErr(t, "does not exist", "test directory does not exist", - func(e *validators.Errors) { validators.ValidateDirectory(field, "does_not_exist", fs, e) }) + func() error { return config.ValidateDirectory(field, "does_not_exist", fs) }) runErr(t, "is a file", "test is not a directory", - func(e *validators.Errors) { validators.ValidateDirectory(field, "file.go", fs, e) }) + func() error { return config.ValidateDirectory(field, "file.go", fs) }) } // ---- ValidateStatus ------------------------------------------------------ @@ -119,12 +115,12 @@ func TestValidateStatus(t *testing.T) { const field = "status" for _, code := range []int{100, 200, 300, 400, 404, 500, 503, 599} { - runOK(t, strconv.Itoa(code), func(e *validators.Errors) { validators.ValidateStatus(field, code, e) }) + runOK(t, strconv.Itoa(code), func() error { return config.ValidateStatus(field, code) }) } for _, code := range []int{-200, 0, 99, 600} { runErr(t, strconv.Itoa(code), "status code must be in range 100-599", - func(e *validators.Errors) { validators.ValidateStatus(field, code, e) }) + func() error { return config.ValidateStatus(field, code) }) } } @@ -133,19 +129,19 @@ func TestValidateStatus(t *testing.T) { func TestValidateDuration(t *testing.T) { const field = "test-field" - runOK(t, "positive without allowZero", func(e *validators.Errors) { - validators.ValidateDuration(field, time.Second, false, e) + runOK(t, "positive without allowZero", func() error { + return config.ValidateDuration(field, time.Second, false) }) - runOK(t, "zero with allowZero", func(e *validators.Errors) { - validators.ValidateDuration(field, 0, true, e) + runOK(t, "zero with allowZero", func() error { + return config.ValidateDuration(field, 0, true) }) runErr(t, "negative without allowZero", "test-field must be greater than 0", - func(e *validators.Errors) { validators.ValidateDuration(field, -time.Second, false, e) }) + func() error { return config.ValidateDuration(field, -time.Second, false) }) runErr(t, "zero without allowZero", "test-field must be greater than 0", - func(e *validators.Errors) { validators.ValidateDuration(field, 0, false, e) }) + func() error { return config.ValidateDuration(field, 0, false) }) runErr(t, "negative with allowZero", "test-field must be greater than or equal to 0", - func(e *validators.Errors) { validators.ValidateDuration(field, -time.Second, true, e) }) + func() error { return config.ValidateDuration(field, -time.Second, true) }) } // ---- ValidateMethod ------------------------------------------------------ @@ -158,20 +154,20 @@ func TestValidateMethod(t *testing.T) { http.MethodPatch, http.MethodDelete, http.MethodConnect, http.MethodOptions, http.MethodTrace, } { m := method - runOK(t, fmt.Sprintf("http method %s", m), func(e *validators.Errors) { - validators.ValidateMethod(field, m, false, e) + runOK(t, fmt.Sprintf("http method %s", m), func() error { + return config.ValidateMethod(field, m, false) }) } - runOK(t, "empty when allowEmpty", func(e *validators.Errors) { - validators.ValidateMethod(field, "", true, e) + runOK(t, "empty when allowEmpty", func() error { + return config.ValidateMethod(field, "", true) }) expected := "test-field must be one of GET, HEAD, POST, PUT, PATCH, DELETE, CONNECT, OPTIONS, TRACE" runErr(t, "empty when not allowEmpty", expected, - func(e *validators.Errors) { validators.ValidateMethod(field, "", false, e) }) + func() error { return config.ValidateMethod(field, "", false) }) runErr(t, "invalid method", expected, - func(e *validators.Errors) { validators.ValidateMethod(field, "invalid", false, e) }) + func() error { return config.ValidateMethod(field, "invalid", false) }) } // ---- ValidatePort -------------------------------------------------------- @@ -181,24 +177,24 @@ func TestValidatePort(t *testing.T) { for _, port := range []int{1, 443, 65535} { p := port - runOK(t, fmt.Sprintf("port %d", p), func(e *validators.Errors) { validators.ValidatePort(field, p, e) }) + runOK(t, fmt.Sprintf("port %d", p), func() error { return config.ValidatePort(field, p) }) } for _, port := range []int{-5, 0, 70000} { p := port runErr(t, fmt.Sprintf("port %d", p), "port-field must be between 1 and 65535", - func(e *validators.Errors) { validators.ValidatePort(field, p, e) }) + func() error { return config.ValidatePort(field, p) }) } } // ---- ValidateGlobPattern ------------------------------------------------- func TestValidateGlobPattern(t *testing.T) { - runOK(t, "valid glob", func(e *validators.Errors) { - validators.ValidateGlobPattern("field", "/api/**", e) + runOK(t, "valid glob", func() error { + return config.ValidateGlobPattern("field", "/api/**") }) runErr(t, "invalid glob", "field is not a valid glob pattern", - func(e *validators.Errors) { validators.ValidateGlobPattern("field", "[invalid", e) }) + func() error { return config.ValidateGlobPattern("field", "[invalid") }) } // ---- ValidateStringEnum -------------------------------------------------- @@ -206,9 +202,9 @@ func TestValidateGlobPattern(t *testing.T) { func TestValidateStringEnum(t *testing.T) { options := []string{"option-1", "option-2"} - runOK(t, "valid option", func(e *validators.Errors) { - validators.ValidateStringEnum("field", "option-1", options, e) + runOK(t, "valid option", func() error { + return config.ValidateStringEnum("field", "option-1", options) }) runErr(t, "invalid option", "'option-x' is not a valid option", - func(e *validators.Errors) { validators.ValidateStringEnum("field", "option-x", options, e) }) + func() error { return config.ValidateStringEnum("field", "option-x", options) }) } diff --git a/internal/config/validators/cache_config_test.go b/internal/config/validators/cache_config_test.go deleted file mode 100644 index 35aa0d3d..00000000 --- a/internal/config/validators/cache_config_test.go +++ /dev/null @@ -1,80 +0,0 @@ -package validators_test - -import ( - "net/http" - "testing" - "time" - - "github.com/evg4b/uncors/internal/config" - "github.com/evg4b/uncors/internal/config/validators" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestCacheConfigValidator(t *testing.T) { - const field = "test" - - t.Run("should not register errors for", func(t *testing.T) { - var errs validators.Errors - validators.ValidateCacheConfig(field, config.CacheConfig{ - ExpirationTime: 5 * time.Minute, - MaxSize: 100 * 1024 * 1024, - Methods: []string{http.MethodGet, http.MethodPost}, - }, &errs) - assert.False(t, errs.HasAny()) - }) - - t.Run("should register errors for", func(t *testing.T) { - tests := []struct { - name string - value config.CacheConfig - error string - }{ - { - name: "empty expiration time", - value: config.CacheConfig{MaxSize: 100 * 1024 * 1024, Methods: []string{http.MethodGet}}, - error: "test.expiration-time must be greater than 0", - }, - { - name: "zero max size", - value: config.CacheConfig{ExpirationTime: 5 * time.Minute, MaxSize: 0, Methods: []string{http.MethodGet}}, - error: "test.max-size must be greater than 0", - }, - { - name: "negative max size", - value: config.CacheConfig{ExpirationTime: 5 * time.Minute, MaxSize: -1, Methods: []string{http.MethodGet}}, - error: "test.max-size must be greater than 0", - }, - { - name: "empty methods", - value: config.CacheConfig{ExpirationTime: 5 * time.Minute, MaxSize: 100 * 1024 * 1024}, - error: "methods must not be empty", - }, - { - name: "invalid method", - value: config.CacheConfig{ - ExpirationTime: 5 * time.Minute, - MaxSize: 100 * 1024 * 1024, - Methods: []string{"invalid"}, - }, - error: "test.methods[0] must be one of GET, HEAD, POST, PUT, PATCH, DELETE, CONNECT, OPTIONS, TRACE", - }, - { - name: "invalid second method", - value: config.CacheConfig{ - ExpirationTime: 5 * time.Minute, - MaxSize: 100 * 1024 * 1024, - Methods: []string{http.MethodGet, "invalid", http.MethodPost}, - }, - error: "test.methods[1] must be one of GET, HEAD, POST, PUT, PATCH, DELETE, CONNECT, OPTIONS, TRACE", - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - var errs validators.Errors - validators.ValidateCacheConfig(field, test.value, &errs) - require.EqualError(t, errs, test.error) - }) - } - }) -} diff --git a/internal/config/validators/cache_test.go b/internal/config/validators/cache_test.go deleted file mode 100644 index 6d92fe3d..00000000 --- a/internal/config/validators/cache_test.go +++ /dev/null @@ -1,40 +0,0 @@ -package validators_test - -import ( - "fmt" - "testing" - - "github.com/evg4b/uncors/internal/config/validators" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestCacheValidator(t *testing.T) { - const field = "cache" - - t.Run("should not register errors for", func(t *testing.T) { - patterns := []string{"/api/**", "/constants", "/translations", "/**/*.js", "/**", "/[12]/demo", "**", "*"} - for _, pattern := range patterns { - p := pattern - t.Run(fmt.Sprintf("%s pattern", p), func(t *testing.T) { - var errs validators.Errors - validators.ValidateCacheGlob(field, p, &errs) - assert.False(t, errs.HasAny()) - }) - } - }) - - t.Run("should register errors for", func(t *testing.T) { - tests := []struct{ pattern, error string }{ - {pattern: "/[12/demo", error: "cache is not a valid glob pattern"}, - {pattern: "/{{12}/demo", error: "cache is not a valid glob pattern"}, - } - for _, test := range tests { - t.Run(fmt.Sprintf("%s test", test.pattern), func(t *testing.T) { - var errs validators.Errors - validators.ValidateCacheGlob(field, test.pattern, &errs) - require.EqualError(t, errs, test.error) - }) - } - }) -} diff --git a/internal/config/validators/errors.go b/internal/config/validators/errors.go deleted file mode 100644 index c0cf1a29..00000000 --- a/internal/config/validators/errors.go +++ /dev/null @@ -1,21 +0,0 @@ -package validators - -import "strings" - -// Errors collects validation error messages. -// -//nolint:recvcheck // Error/HasAny need value receivers so Errors satisfies the error interface as a value type -type Errors []string - -func (e Errors) Error() string { - return strings.Join(e, "\n") -} - -func (e Errors) HasAny() bool { - return len(e) > 0 -} - -// add appends a validation message. Uses a pointer receiver because it mutates the slice. -func (e *Errors) add(msg string) { - *e = append(*e, msg) -} diff --git a/internal/config/validators/har.go b/internal/config/validators/har.go deleted file mode 100644 index b8a3f430..00000000 --- a/internal/config/validators/har.go +++ /dev/null @@ -1,20 +0,0 @@ -package validators - -import ( - "fmt" - "path/filepath" - - "github.com/evg4b/uncors/internal/config" -) - -func ValidateHAR(field string, value config.HARConfig, errs *Errors) { - if !value.Enabled() { - return - } - - file := value.File - - if filepath.Ext(file) == "" { - errs.add(fmt.Sprintf("%s: HAR file path %q must have a file extension (e.g. .har)", field, file)) - } -} diff --git a/internal/config/validators/har_test.go b/internal/config/validators/har_test.go deleted file mode 100644 index ffe6ffd7..00000000 --- a/internal/config/validators/har_test.go +++ /dev/null @@ -1,49 +0,0 @@ -package validators_test - -import ( - "testing" - - "github.com/evg4b/uncors/internal/config" - "github.com/evg4b/uncors/internal/config/validators" - "github.com/stretchr/testify/assert" -) - -func TestHARValidator(t *testing.T) { - t.Run("valid cases", func(t *testing.T) { - cases := []struct { - name string - value config.HARConfig - }{ - { - name: "disabled (empty file)", - value: config.HARConfig{}, - }, - { - name: "valid file path with extension", - value: config.HARConfig{File: "output.har"}, - }, - { - name: "path with directory and extension", - value: config.HARConfig{File: "/tmp/trace.har"}, - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - errs := &validators.Errors{} - validators.ValidateHAR("mappings[0].har", tc.value, errs) - - assert.False(t, errs.HasAny()) - }) - } - }) - - t.Run("invalid cases", func(t *testing.T) { - t.Run("file path without extension", func(t *testing.T) { - errs := &validators.Errors{} - validators.ValidateHAR("mappings[0].har", config.HARConfig{File: "outputfile"}, errs) - - assert.True(t, errs.HasAny()) - }) - }) -} diff --git a/internal/config/validators/mapping_test.go b/internal/config/validators/mapping_test.go deleted file mode 100644 index 4f56c094..00000000 --- a/internal/config/validators/mapping_test.go +++ /dev/null @@ -1,177 +0,0 @@ -package validators_test - -import ( - "net/http" - "testing" - - "github.com/evg4b/uncors/internal/config" - "github.com/evg4b/uncors/internal/config/validators" - "github.com/evg4b/uncors/testing/hosts" - "github.com/evg4b/uncors/testing/testutils" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestMappingValidator(t *testing.T) { - const ( - field = "mapping" - demoJSONPath = "/tmp/demo.json" - ) - - t.Run("should not register errors for", func(t *testing.T) { - fs := testutils.FsFromMap(t, map[string]string{ - demoJSONPath: "{}", - }) - - tests := []struct { - name string - value config.Mapping - }{ - { - name: "full filled mapping", - value: config.Mapping{ - From: "localhost", - To: hosts.Github.Host(), - Statics: []config.StaticDirectory{ - {Path: "/", Dir: "/tmp"}, - {Path: "/", Dir: "/tmp"}, - }, - Mocks: []config.Mock{ - { - Matcher: config.RequestMatcher{ - Path: "/api/info", - Method: http.MethodGet, - }, - Response: config.Response{ - Code: 200, - Raw: "test", - }, - }, - { - Matcher: config.RequestMatcher{ - Path: "/api/info/demo", - Method: http.MethodGet, - }, - Response: config.Response{ - Code: 300, - File: demoJSONPath, - }, - }, - }, - Cache: config.CacheGlobs{ - "/api/constants", - "/**", - }, - }, - }, - { - name: "mapping without mocks and statics and caches", - value: config.Mapping{ - From: "localhost", - To: hosts.Github.Host(), - Statics: []config.StaticDirectory{}, - Mocks: []config.Mock{}, - Cache: config.CacheGlobs{}, - }, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - var errs validators.Errors - validators.ValidateMapping(field, test.value, fs, &errs) - assert.False(t, errs.HasAny()) - }) - } - }) - - t.Run("should register errors for", func(t *testing.T) { - fs := testutils.FsFromMap(t, map[string]string{ - demoJSONPath: "{}", - }) - - tests := []struct { - name string - value config.Mapping - error string - }{ - { - name: "mapping without from", - value: config.Mapping{ - From: "", - To: hosts.Github.Host(), - Statics: []config.StaticDirectory{}, - Mocks: []config.Mock{}, - Cache: config.CacheGlobs{}, - }, - error: "mapping.from must not be empty", - }, - { - name: "mapping without to", - value: config.Mapping{ - From: "localhost", - To: "", - Statics: []config.StaticDirectory{}, - Mocks: []config.Mock{}, - Cache: config.CacheGlobs{}, - }, - error: "mapping.to must not be empty", - }, - { - name: "mapping with invalid statics", - value: config.Mapping{ - From: "localhost", - To: hosts.Github.Host(), - Statics: []config.StaticDirectory{ - {Path: "/", Dir: "/tmp"}, - {Path: "/", Dir: ""}, - }, - Mocks: []config.Mock{}, - Cache: config.CacheGlobs{}, - }, - error: "mapping.statics[1].directory must not be empty", - }, - { - name: "mapping with invalid mocks", - value: config.Mapping{ - From: "localhost", - To: hosts.Github.Host(), - Statics: []config.StaticDirectory{}, - Mocks: []config.Mock{ - { - Matcher: config.RequestMatcher{ - Path: "/api/user", - Method: "invalid", - }, - Response: config.Response{ - Code: 200, - Raw: "test", - }, - }, - }, - Cache: config.CacheGlobs{}, - }, - error: "mapping.mocks[0].method must be one of GET, HEAD, POST, PUT, PATCH, DELETE, CONNECT, OPTIONS, TRACE", - }, - { - name: "mapping with invalid cache glob", - value: config.Mapping{ - From: "localhost", - To: hosts.Github.Host(), - Statics: []config.StaticDirectory{}, - Mocks: []config.Mock{}, - Cache: config.CacheGlobs{ - "/api/info[", - }, - }, - error: "mapping.cache[0] is not a valid glob pattern", - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - var errs validators.Errors - validators.ValidateMapping(field, test.value, fs, &errs) - require.EqualError(t, errs, test.error) - }) - } - }) -} diff --git a/internal/config/validators/mock_test.go b/internal/config/validators/mock_test.go deleted file mode 100644 index 0904468f..00000000 --- a/internal/config/validators/mock_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package validators_test - -import ( - "testing" - "time" - - "github.com/evg4b/uncors/internal/config" - "github.com/evg4b/uncors/internal/config/validators" - "github.com/spf13/afero" - "github.com/stretchr/testify/assert" -) - -func TestMockValidator(t *testing.T) { - t.Run("should return true", func(t *testing.T) { - var errs validators.Errors - validators.ValidateMock("mock", config.Mock{ - Matcher: config.RequestMatcher{ - Path: "/api/info", - Method: "", - }, - Response: config.Response{ - Code: 200, - Raw: "test", - Delay: 1 * time.Second, - }, - }, afero.NewMemMapFs(), &errs) - - assert.False(t, errs.HasAny()) - }) -} diff --git a/internal/config/validators/options_handling_test.go b/internal/config/validators/options_handling_test.go deleted file mode 100644 index 87873b83..00000000 --- a/internal/config/validators/options_handling_test.go +++ /dev/null @@ -1,39 +0,0 @@ -package validators_test - -import ( - "testing" - - "github.com/evg4b/uncors/internal/config" - "github.com/evg4b/uncors/internal/config/validators" - "github.com/go-http-utils/headers" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestOptionsValidator(t *testing.T) { - t.Run("should return true", func(t *testing.T) { - t.Run("for default options", func(t *testing.T) { - var errs validators.Errors - validators.ValidateOptionsHandling("options", config.OptionsHandling{}, &errs) - assert.False(t, errs.HasAny()) - }) - - t.Run("for correct status code", func(t *testing.T) { - var errs validators.Errors - validators.ValidateOptionsHandling("options", config.OptionsHandling{ - Headers: map[string]string{headers.ContentType: "application/json"}, - Code: 200, - }, &errs) - assert.False(t, errs.HasAny()) - }) - }) - - t.Run("should return false for invalid status code", func(t *testing.T) { - var errs validators.Errors - validators.ValidateOptionsHandling("options", config.OptionsHandling{ - Headers: map[string]string{headers.ContentType: "application/json"}, - Code: -10, - }, &errs) - require.EqualError(t, errs, "options.code code must be in range 100-599") - }) -} diff --git a/internal/config/validators/primitives.go b/internal/config/validators/primitives.go deleted file mode 100644 index 77c6bdc3..00000000 --- a/internal/config/validators/primitives.go +++ /dev/null @@ -1,178 +0,0 @@ -package validators - -import ( - "fmt" - "net/http" - "os" - "slices" - "strings" - "time" - - "github.com/bmatcuk/doublestar/v4" - "github.com/evg4b/uncors/internal/urlparser" - "github.com/spf13/afero" -) - -const maxHostLength = 255 - -func ValidateHost(field, value string, errs *Errors) { - if value == "" { - errs.add(fmt.Sprintf("%s must not be empty", field)) - - return - } - - if len(value) > maxHostLength { - errs.add(fmt.Sprintf("%s must not be longer than 255 characters, but got %d", field, len(value))) - - return - } - - uri, err := urlparser.Parse(value) - if err != nil { - errs.add(fmt.Sprintf("%s is not a valid host", field)) - - return - } - - if uri.Path != "" { - errs.add(fmt.Sprintf("%s must not contain a path", field)) - } - - if uri.RawQuery != "" { - errs.add(fmt.Sprintf("%s must not contain a query", field)) - } - - if uri.Scheme != "http" && uri.Scheme != "https" && uri.Scheme != "" { - errs.add(fmt.Sprintf("%s scheme must be http or https", field)) - } -} - -func ValidatePath(field, value string, relative bool, errs *Errors) { - if value == "" { - errs.add(fmt.Sprintf("%s must not be empty", field)) - - return - } - - if !relative && !strings.HasPrefix(value, "/") { - errs.add(fmt.Sprintf("%s must be absolute and start with /", field)) - - return - } - - uri, err := urlparser.Parse("//localhost/" + strings.TrimPrefix(value, "/")) - if err != nil { - errs.add(fmt.Sprintf("%s is not a valid path", field)) - - return - } - - if uri.RawQuery != "" { - errs.add(fmt.Sprintf("%s must not contain a query", field)) - } -} - -func ValidateFile(field, value string, fs afero.Fs, errs *Errors) { - stat, err := fs.Stat(value) - if err != nil { - switch { - case os.IsNotExist(err): - errs.add(fmt.Sprintf("%s %s does not exist", field, value)) - case os.IsPermission(err): - errs.add(fmt.Sprintf("%s %s is not accessible", field, value)) - default: - errs.add(fmt.Sprintf("%s %s is not a file", field, value)) - } - - return - } - - if stat.IsDir() { - errs.add(fmt.Sprintf("%s %s is a directory", field, value)) - } -} - -func ValidateDirectory(field, value string, fs afero.Fs, errs *Errors) { - if value == "" { - errs.add(fmt.Sprintf("%s must not be empty", field)) - - return - } - - stat, err := fs.Stat(value) - if err != nil { - switch { - case os.IsNotExist(err): - errs.add(fmt.Sprintf("%s directory does not exist", field)) - case os.IsPermission(err): - errs.add(fmt.Sprintf("%s directory is not accessible", field)) - default: - errs.add(fmt.Sprintf("%s is not a directory", field)) - } - - return - } - - if !stat.IsDir() { - errs.add(fmt.Sprintf("%s is not a directory", field)) - } -} - -func ValidateStatus(field string, value int, errs *Errors) { - if value < 100 || value > 599 { - errs.add(fmt.Sprintf("%s code must be in range 100-599", field)) - } -} - -func ValidateDuration(field string, value time.Duration, allowZero bool, errs *Errors) { - if allowZero { - if value < 0 { - errs.add(fmt.Sprintf("%s must be greater than or equal to 0", field)) - } - } else { - if value <= 0 { - errs.add(fmt.Sprintf("%s must be greater than 0", field)) - } - } -} - -var allowedMethods = []string{ - http.MethodGet, - http.MethodHead, - http.MethodPost, - http.MethodPut, - http.MethodPatch, - http.MethodDelete, - http.MethodConnect, - http.MethodOptions, - http.MethodTrace, -} - -func ValidateMethod(field, value string, allowEmpty bool, errs *Errors) { - if allowEmpty && value == "" { - return - } - - if !slices.Contains(allowedMethods, value) { - errs.add(fmt.Sprintf("%s must be one of %s", field, strings.Join(allowedMethods, ", "))) - } -} - -func ValidatePort(field string, value int, errs *Errors) { - if value < 1 || value > 65535 { - errs.add(fmt.Sprintf("%s must be between 1 and 65535", field)) - } -} - -func ValidateGlobPattern(field, value string, errs *Errors) { - if !doublestar.ValidatePathPattern(value) { - errs.add(fmt.Sprintf("%s is not a valid glob pattern", field)) - } -} - -func ValidateStringEnum(_ string, value string, options []string, errs *Errors) { - if !slices.Contains(options, value) { - errs.add(fmt.Sprintf("'%s' is not a valid option", value)) - } -} diff --git a/internal/config/validators/proxy_test.go b/internal/config/validators/proxy_test.go deleted file mode 100644 index 809a5812..00000000 --- a/internal/config/validators/proxy_test.go +++ /dev/null @@ -1,29 +0,0 @@ -package validators_test - -import ( - "testing" - - "github.com/evg4b/uncors/internal/config/validators" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestProxyValidatorIsValid(t *testing.T) { - t.Run("valid url", func(t *testing.T) { - var errs validators.Errors - validators.ValidateProxy("testField", "http://valid-url.com", &errs) - assert.False(t, errs.HasAny()) - }) - - t.Run("invalid url", func(t *testing.T) { - var errs validators.Errors - validators.ValidateProxy("testField", "invalid:::url", &errs) - require.EqualError(t, errs, "testField is not a valid URL") - }) - - t.Run("empty url", func(t *testing.T) { - var errs validators.Errors - validators.ValidateProxy("testField", "", &errs) - assert.False(t, errs.HasAny()) - }) -} diff --git a/internal/config/validators/request_matcher_test.go b/internal/config/validators/request_matcher_test.go deleted file mode 100644 index ad9126ef..00000000 --- a/internal/config/validators/request_matcher_test.go +++ /dev/null @@ -1,63 +0,0 @@ -package validators_test - -import ( - "testing" - - "github.com/evg4b/uncors/internal/config" - "github.com/evg4b/uncors/internal/config/validators" - "github.com/go-http-utils/headers" - "github.com/stretchr/testify/assert" -) - -const requestMatcherTestPath = "/api/test" - -func TestRequestMatcherValidator(t *testing.T) { - t.Run("should not register errors for valid filter with all fields", func(t *testing.T) { - var errs validators.Errors - validators.ValidateRequestMatcher("test", config.RequestMatcher{ - Path: requestMatcherTestPath, - Method: "GET", - Queries: map[string]string{ - "param1": "value1", - "param2": "value2", - }, - Headers: map[string]string{ - headers.ContentType: "application/json", - headers.Accept: "application/json", - }, - }, &errs) - assert.False(t, errs.HasAny()) - }) - - t.Run("should not register errors for valid filter with minimal fields", func(t *testing.T) { - var errs validators.Errors - validators.ValidateRequestMatcher("test", config.RequestMatcher{Path: requestMatcherTestPath}, &errs) - assert.False(t, errs.HasAny()) - }) - - t.Run("should register error for invalid path", func(t *testing.T) { - var errs validators.Errors - validators.ValidateRequestMatcher("test", config.RequestMatcher{Path: "", Method: "GET"}, &errs) - assert.True(t, errs.HasAny()) - assert.Contains(t, errs.Error(), "path must not be empty") - }) - - t.Run("should register error for invalid method", func(t *testing.T) { - var errs validators.Errors - validators.ValidateRequestMatcher( - "test", - config.RequestMatcher{Path: requestMatcherTestPath, Method: "INVALID"}, - &errs, - ) - assert.True(t, errs.HasAny()) - assert.Contains(t, errs.Error(), "method must be one of") - }) - - t.Run("should register multiple validation errors", func(t *testing.T) { - var errs validators.Errors - validators.ValidateRequestMatcher("test", config.RequestMatcher{Path: "", Method: "INVALID"}, &errs) - assert.True(t, errs.HasAny()) - assert.Contains(t, errs.Error(), "path must not be empty") - assert.Contains(t, errs.Error(), "method must be one of") - }) -} diff --git a/internal/config/validators/response_test.go b/internal/config/validators/response_test.go deleted file mode 100644 index 656774e3..00000000 --- a/internal/config/validators/response_test.go +++ /dev/null @@ -1,77 +0,0 @@ -package validators_test - -import ( - "testing" - "time" - - "github.com/evg4b/uncors/internal/config" - "github.com/evg4b/uncors/internal/config/validators" - "github.com/evg4b/uncors/testing/testutils" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestResponseValidator(t *testing.T) { - const file = "testdata/file.txt" - - fs := testutils.FsFromMap(t, map[string]string{file: "test"}) - - t.Run("should not register errors if response is valid", func(t *testing.T) { - tests := []struct { - name string - value config.Response - }{ - {name: "with file", value: config.Response{Code: 200, File: file, Delay: 3 * time.Second}}, - {name: "with raw", value: config.Response{Code: 200, Raw: `{ "test": "test" }`, Delay: 3 * time.Second}}, - {name: "without delay", value: config.Response{Code: 200, Raw: `{ "test": "test" }`}}, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - var errs validators.Errors - validators.ValidateResponse("test", test.value, fs, &errs) - assert.False(t, errs.HasAny()) - }) - } - }) - - t.Run("should register errors for", func(t *testing.T) { - tests := []struct { - name string - value config.Response - error string - }{ - { - name: "code", - value: config.Response{Code: 0, File: file, Delay: 3 * time.Second}, - error: "test.code code must be in range 100-599", - }, - { - name: "file", - value: config.Response{Code: 200, File: "testdata/unknown.txt", Delay: 3 * time.Second}, - error: "test.file testdata/unknown.txt does not exist", - }, - { - name: "delay", - value: config.Response{Code: 200, File: file, Delay: -1 * time.Second}, - error: "test.delay must be greater than or equal to 0", - }, - { - name: "both empty", - value: config.Response{Code: 200, Delay: 3 * time.Second}, - error: "test.raw or test.file must be set", - }, - { - name: "both set", - value: config.Response{Code: 200, File: file, Raw: "test", Delay: 3 * time.Second}, - error: "only one of test.raw or test.file must be set", - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - var errs validators.Errors - validators.ValidateResponse("test", test.value, fs, &errs) - require.EqualError(t, errs, test.error) - }) - } - }) -} diff --git a/internal/config/validators/rewrite_test.go b/internal/config/validators/rewrite_test.go deleted file mode 100644 index a66effd4..00000000 --- a/internal/config/validators/rewrite_test.go +++ /dev/null @@ -1,74 +0,0 @@ -package validators_test - -import ( - "testing" - - "github.com/evg4b/uncors/internal/config" - "github.com/evg4b/uncors/internal/config/validators" - "github.com/evg4b/uncors/testing/hosts" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -const ( - fromPath = "/from/path" - toPath = "/to/path" -) - -func TestRewritingOptionValidatorIsValidNoError(t *testing.T) { - tests := []struct { - name string - value config.RewritingOption - }{ - {name: "valid paths and host", value: config.RewritingOption{From: fromPath, To: toPath, Host: hosts.Github.Host()}}, - {name: "no host", value: config.RewritingOption{From: fromPath, To: toPath}}, - { - name: "relative from path", - value: config.RewritingOption{From: "../relative/from/path", To: toPath, Host: hosts.Github.Host()}, - }, - { - name: "relative to path", - value: config.RewritingOption{From: fromPath, To: "../relative/to/path", Host: hosts.Github.Host()}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var errs validators.Errors - validators.ValidateRewritingOption("testField", tt.value, &errs) - assert.False(t, errs.HasAny()) - }) - } -} - -func TestRewritingOptionValidatorIsValidWithError(t *testing.T) { - tests := []struct { - name string - value config.RewritingOption - error string - }{ - { - name: "invalid from path", - value: config.RewritingOption{From: "", To: toPath, Host: hosts.Github.Host()}, - error: "testField.from must not be empty", - }, - { - name: "invalid to path", - value: config.RewritingOption{From: fromPath, To: "", Host: hosts.Github.Host()}, - error: "testField.to must not be empty", - }, - { - name: "invalid host format", - value: config.RewritingOption{From: fromPath, To: toPath, Host: "&&&"}, - error: "testField.host is not a valid host", - }, - } - - for _, testCase := range tests { - t.Run(testCase.name, func(t *testing.T) { - var errs validators.Errors - validators.ValidateRewritingOption("testField", testCase.value, &errs) - require.EqualError(t, errs, testCase.error) - }) - } -} diff --git a/internal/config/validators/script_test.go b/internal/config/validators/script_test.go deleted file mode 100644 index c2815403..00000000 --- a/internal/config/validators/script_test.go +++ /dev/null @@ -1,145 +0,0 @@ -package validators_test - -import ( - "testing" - - "github.com/evg4b/uncors/internal/config" - "github.com/evg4b/uncors/internal/config/validators" - "github.com/evg4b/uncors/testing/testutils" - "github.com/go-http-utils/headers" - "github.com/stretchr/testify/assert" -) - -const ( - testAPIPath = "/api/test" - testScriptContent = "response.status = 200" - testScriptFilePath = "/scripts/test.lua" - scriptPathField = "script.path" - scriptScriptField = "script.script" - scriptFileField = "script.file" -) - -func TestScriptValidator(t *testing.T) { - noFS := testutils.FsFromMap(t, map[string]string{}) - - t.Run("valid inline script", func(t *testing.T) { - var errs validators.Errors - validators.ValidateScript("script", config.Script{ - Matcher: config.RequestMatcher{Path: testAPIPath, Method: "GET"}, - Script: testScriptContent, - }, noFS, &errs) - assert.False(t, errs.HasAny()) - }) - - t.Run("valid file script", func(t *testing.T) { - fs := testutils.FsFromMap(t, map[string]string{testScriptFilePath: testScriptContent}) - - var errs validators.Errors - validators.ValidateScript("script", config.Script{ - Matcher: config.RequestMatcher{Path: testAPIPath, Method: "POST"}, - File: testScriptFilePath, - }, fs, &errs) - assert.False(t, errs.HasAny()) - }) - - t.Run("empty method is allowed", func(t *testing.T) { - var errs validators.Errors - validators.ValidateScript("script", config.Script{ - Matcher: config.RequestMatcher{Path: testAPIPath}, - Script: testScriptContent, - }, noFS, &errs) - assert.False(t, errs.HasAny()) - }) - - t.Run("valid queries and headers", func(t *testing.T) { - var errs validators.Errors - validators.ValidateScript("script", config.Script{ - Matcher: config.RequestMatcher{ - Path: "/api/test", - Queries: map[string]string{"filter": "active"}, - Headers: map[string]string{headers.Authorization: "Bearer token"}, - }, - Script: testScriptContent, - }, noFS, &errs) - assert.False(t, errs.HasAny()) - }) - - t.Run("empty path", func(t *testing.T) { - var errs validators.Errors - validators.ValidateScript("script", config.Script{ - Matcher: config.RequestMatcher{Path: ""}, - Script: testScriptContent, - }, noFS, &errs) - assert.True(t, errs.HasAny()) - assert.Contains(t, errs.Error(), scriptPathField) - }) - - t.Run("invalid path", func(t *testing.T) { - var errs validators.Errors - validators.ValidateScript("script", config.Script{ - Matcher: config.RequestMatcher{Path: "invalid-path"}, - Script: testScriptContent, - }, noFS, &errs) - assert.True(t, errs.HasAny()) - assert.Contains(t, errs.Error(), scriptPathField) - }) - - t.Run("invalid method", func(t *testing.T) { - var errs validators.Errors - validators.ValidateScript("script", config.Script{ - Matcher: config.RequestMatcher{Path: testAPIPath, Method: "INVALID"}, - Script: testScriptContent, - }, noFS, &errs) - assert.True(t, errs.HasAny()) - assert.Contains(t, errs.Error(), "script.method") - }) - - t.Run("neither script nor file provided", func(t *testing.T) { - var errs validators.Errors - validators.ValidateScript("script", config.Script{ - Matcher: config.RequestMatcher{Path: testAPIPath, Method: "GET"}, - }, noFS, &errs) - assert.True(t, errs.HasAny()) - assert.Contains(t, errs.Error(), scriptScriptField) - assert.Contains(t, errs.Error(), scriptFileField) - assert.Contains(t, errs.Error(), "either 'script' or 'file' must be provided") - }) - - t.Run("both script and file provided", func(t *testing.T) { - fs := testutils.FsFromMap(t, map[string]string{testScriptFilePath: testScriptContent}) - - var errs validators.Errors - validators.ValidateScript("script", config.Script{ - Matcher: config.RequestMatcher{Path: "/api/test"}, - Script: testScriptContent, - File: "/scripts/test.lua", - }, fs, &errs) - assert.True(t, errs.HasAny()) - assert.Contains(t, errs.Error(), scriptScriptField) - assert.Contains(t, errs.Error(), scriptFileField) - assert.Contains(t, errs.Error(), "only one of 'script' or 'file' can be provided") - }) - - t.Run("file does not exist", func(t *testing.T) { - var errs validators.Errors - validators.ValidateScript("script", config.Script{ - Matcher: config.RequestMatcher{Path: "/api/test"}, - File: "/scripts/nonexistent.lua", - }, noFS, &errs) - assert.True(t, errs.HasAny()) - assert.Contains(t, errs.Error(), scriptFileField) - }) - - t.Run("multiple errors", func(t *testing.T) { - var errs validators.Errors - validators.ValidateScript("script", config.Script{ - Matcher: config.RequestMatcher{Path: "", Method: "INVALID"}, - }, noFS, &errs) - assert.True(t, errs.HasAny()) - errStr := errs.Error() - assert.Contains(t, errStr, "script.path") - assert.Contains(t, errStr, "script.method") - assert.Contains(t, errStr, "script.script") - assert.Contains(t, errStr, "script.file") - }) -} diff --git a/internal/config/validators/static_test.go b/internal/config/validators/static_test.go deleted file mode 100644 index d53e23de..00000000 --- a/internal/config/validators/static_test.go +++ /dev/null @@ -1,75 +0,0 @@ -package validators_test - -import ( - "testing" - - "github.com/evg4b/uncors/internal/config" - "github.com/evg4b/uncors/internal/config/validators" - "github.com/evg4b/uncors/testing/testutils" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestStaticValidator(t *testing.T) { - const ( - assetsPath = "/assets" - staticPath = "/static" - indexFilePath = "/static/index.html" - ) - - fs := testutils.FsFromMap(t, map[string]string{indexFilePath: indexFilePath}) - - t.Run("should not register errors if response is valid", func(t *testing.T) { - tests := []struct { - name string - value config.StaticDirectory - }{ - { - name: "valid static directory with index", - value: config.StaticDirectory{Path: assetsPath, Dir: staticPath, Index: "index.html"}, - }, - { - name: "valid static directory without index", - value: config.StaticDirectory{Path: assetsPath, Dir: staticPath}, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - var errs validators.Errors - validators.ValidateStatic("test", test.value, fs, &errs) - assert.False(t, errs.HasAny()) - }) - } - }) - - t.Run("should register errors if response is invalid", func(t *testing.T) { - tests := []struct { - name string - value config.StaticDirectory - error string - }{ - { - name: "empty path", - value: config.StaticDirectory{Path: "", Dir: staticPath}, - error: "test.path must not be empty", - }, - { - name: "empty directory", - value: config.StaticDirectory{Path: assetsPath, Dir: ""}, - error: "test.directory must not be empty", - }, - { - name: "missing index file", - value: config.StaticDirectory{Path: assetsPath, Dir: staticPath, Index: "index.php"}, - error: "test.index /static/index.php does not exist", - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - var errs validators.Errors - validators.ValidateStatic("test", test.value, fs, &errs) - require.EqualError(t, errs, test.error) - }) - } - }) -} diff --git a/internal/config/validators/tls_test.go b/internal/config/validators/tls_test.go deleted file mode 100644 index 2bc89c59..00000000 --- a/internal/config/validators/tls_test.go +++ /dev/null @@ -1,74 +0,0 @@ -package validators_test - -import ( - "os" - "path/filepath" - "testing" - - "github.com/evg4b/uncors/internal/config" - "github.com/evg4b/uncors/internal/config/validators" - infratls "github.com/evg4b/uncors/internal/infra/tls" - "github.com/evg4b/uncors/testing/hosts" - "github.com/spf13/afero" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestValidateTLS(t *testing.T) { - t.Run("skip validation for invalid URL", func(t *testing.T) { - var errs validators.Errors - validators.ValidateTLS( - "test", - config.Mapping{From: "://invalid-url", To: hosts.Example.HTTP()}, - afero.NewMemMapFs(), - &errs, - ) - assert.False(t, errs.HasAny()) - }) - - t.Run("skip validation for non-HTTPS", func(t *testing.T) { - var errs validators.Errors - validators.ValidateTLS( - "test", - config.Mapping{From: "http://localhost:8080", To: hosts.Example.HTTP()}, - afero.NewMemMapFs(), - &errs, - ) - assert.False(t, errs.HasAny()) - }) - - t.Run("error when CA does not exist", func(t *testing.T) { - tmpDir := t.TempDir() - fakeHome := filepath.Join(tmpDir, "home") - require.NoError(t, os.MkdirAll(fakeHome, 0o755)) - t.Setenv("HOME", fakeHome) - - var errs validators.Errors - validators.ValidateTLS("test", - config.Mapping{From: "https://localhost:8443", To: hosts.Example.HTTP()}, - afero.NewOsFs(), &errs) - - assert.True(t, errs.HasAny()) - assert.Contains(t, errs.Error(), "HTTPS mapping 'localhost:8443' requires a local CA certificate") - assert.Contains(t, errs.Error(), "uncors generate-certs") - }) - - t.Run("pass when CA exists", func(t *testing.T) { - tmpDir := t.TempDir() - fakeHome := filepath.Join(tmpDir, "home") - require.NoError(t, os.MkdirAll(fakeHome, 0o755)) - t.Setenv("HOME", fakeHome) - - fs := afero.NewOsFs() - caDir := filepath.Join(fakeHome, ".config", "uncors") - _, _, err := infratls.GenerateCA(infratls.CAConfig{ValidityDays: 365, OutputDir: caDir, Fs: fs}) - require.NoError(t, err) - - var errs validators.Errors - validators.ValidateTLS("test", - config.Mapping{From: "https://localhost:8443", To: hosts.Example.HTTP()}, - fs, &errs) - - assert.False(t, errs.HasAny()) - }) -} diff --git a/internal/config/validators/uncors_config.go b/internal/config/validators/uncors_config.go deleted file mode 100644 index f58f01f8..00000000 --- a/internal/config/validators/uncors_config.go +++ /dev/null @@ -1,31 +0,0 @@ -package validators - -import ( - "github.com/evg4b/uncors/internal/config" - "github.com/spf13/afero" -) - -// ValidateConfig validates the full uncors configuration and returns a combined -// error listing all validation failures. Returns nil if the config is valid. -func ValidateConfig(cfg *config.UncorsConfig, fs afero.Fs) error { - var errs Errors - - if len(cfg.Mappings) == 0 { - errs.add("mappings must not be empty") - - return errs - } - - for i, mapping := range cfg.Mappings { - ValidateMapping(joinPath("mappings", index(i)), mapping, fs, &errs) - } - - ValidateProxy("proxy", cfg.Proxy, &errs) - ValidateCacheConfig("cache-config", cfg.CacheConfig, &errs) - - if errs.HasAny() { - return errs - } - - return nil -} diff --git a/internal/config/validators/uncors_config_test.go b/internal/config/validators/uncors_config_test.go deleted file mode 100644 index d4f4d9e5..00000000 --- a/internal/config/validators/uncors_config_test.go +++ /dev/null @@ -1,73 +0,0 @@ -package validators_test - -import ( - "net/http" - "testing" - "time" - - "github.com/evg4b/uncors/internal/config" - "github.com/evg4b/uncors/internal/config/validators" - "github.com/evg4b/uncors/testing/hosts" - "github.com/evg4b/uncors/testing/testutils" - "github.com/stretchr/testify/require" -) - -func TestUncorsConfigValidator(t *testing.T) { - mapFs := testutils.FsFromMap(t, map[string]string{}) - - t.Run("should not register errors for", func(t *testing.T) { - tests := []struct { - name string - value *config.UncorsConfig - }{ - { - name: "minimal config", - value: &config.UncorsConfig{ - Mappings: []config.Mapping{ - {From: hosts.Localhost.Port(8080), To: hosts.Localhost.HTTPSPort(8443)}, - }, - CacheConfig: config.CacheConfig{ - MaxSize: 100 * 1024 * 1024, - ExpirationTime: 10 * time.Minute, - Methods: []string{http.MethodGet}, - }, - }, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - errors := validators.ValidateConfig(test.value, mapFs) - - require.NoError(t, errors) - }) - } - }) - - t.Run("should register errors for invalid config", func(t *testing.T) { - tests := []struct { - name string - value *config.UncorsConfig - error string - }{ - { - name: "invalid mapping", - value: &config.UncorsConfig{ - Mappings: []config.Mapping{}, - CacheConfig: config.CacheConfig{ - MaxSize: 100 * 1024 * 1024, - ExpirationTime: 10 * time.Minute, - Methods: []string{http.MethodGet}, - }, - }, - error: "mappings must not be empty", - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - errors := validators.ValidateConfig(test.value, mapFs) - - require.EqualError(t, errors, test.error) - }) - } - }) -} diff --git a/internal/config/validators/validators.go b/internal/config/validators/validators.go deleted file mode 100644 index fb5334c2..00000000 --- a/internal/config/validators/validators.go +++ /dev/null @@ -1,171 +0,0 @@ -package validators - -import ( - "fmt" - "path" - "strings" - - infratls "github.com/evg4b/uncors/internal/infra/tls" - "github.com/evg4b/uncors/internal/urlparser" - - "github.com/evg4b/uncors/internal/config" - "github.com/spf13/afero" -) - -func ValidateProxy(field, value string, errs *Errors) { - if value == "" { - return - } - - _, err := urlparser.Parse(value) - if err != nil { - errs.add(fmt.Sprintf("%s is not a valid URL", field)) - } -} - -func ValidateOptionsHandling(field string, value config.OptionsHandling, errs *Errors) { - if value.Code != 0 { - ValidateStatus(joinPath(field, "code"), value.Code, errs) - } -} - -func ValidateRequestMatcher(field string, value config.RequestMatcher, errs *Errors) { - ValidatePath(joinPath(field, "path"), value.Path, false, errs) - ValidateMethod(joinPath(field, "method"), value.Method, true, errs) -} - -func ValidateRewritingOption(field string, value config.RewritingOption, errs *Errors) { - ValidatePath(joinPath(field, "from"), value.From, true, errs) - ValidatePath(joinPath(field, "to"), value.To, true, errs) - - if value.Host != "" { - ValidateHost(joinPath(field, "host"), value.Host, errs) - } -} - -func ValidateCacheGlob(field, value string, errs *Errors) { - ValidateGlobPattern(field, value, errs) -} - -func ValidateResponse(field string, value config.Response, fs afero.Fs, errs *Errors) { - ValidateStatus(joinPath(field, "code"), value.Code, errs) - ValidateDuration(joinPath(field, "delay"), value.Delay, true, errs) - - switch { - case value.Raw == "" && value.File == "": - errs.add(fmt.Sprintf( - "%s or %s must be set", - joinPath(field, "raw"), - joinPath(field, "file"), - )) - case value.Raw != "" && value.File != "": - errs.add(fmt.Sprintf( - "only one of %s or %s must be set", - joinPath(field, "raw"), - joinPath(field, "file"), - )) - case value.File != "": - ValidateFile(joinPath(field, "file"), value.File, fs, errs) - } -} - -func ValidateMock(field string, value config.Mock, fs afero.Fs, errs *Errors) { - ValidateRequestMatcher(field, value.Matcher, errs) - ValidateResponse(joinPath(field, "response"), value.Response, fs, errs) -} - -func ValidateStatic(field string, value config.StaticDirectory, fs afero.Fs, errs *Errors) { - ValidatePath(joinPath(field, "path"), value.Path, false, errs) - ValidateDirectory(joinPath(field, "directory"), value.Dir, fs, errs) - - if value.Index != "" { - ValidateFile(joinPath(field, "index"), path.Join(value.Dir, value.Index), fs, errs) - } -} - -func ValidateScript(field string, value config.Script, fs afero.Fs, errs *Errors) { - ValidateRequestMatcher(field, value.Matcher, errs) - - switch { - case value.Script == "" && value.File == "": - scriptField := joinPath(field, "script") - fileField := joinPath(field, "file") - - errs.add(fmt.Sprintf("%s: either 'script' or 'file' must be provided", scriptField)) - errs.add(fmt.Sprintf("%s: either 'script' or 'file' must be provided", fileField)) - case value.Script != "" && value.File != "": - scriptField := joinPath(field, "script") - fileField := joinPath(field, "file") - - errs.add(fmt.Sprintf("%s: only one of 'script' or 'file' can be provided", scriptField)) - errs.add(fmt.Sprintf("%s: only one of 'script' or 'file' can be provided", fileField)) - case value.File != "": - ValidateFile(joinPath(field, "file"), value.File, fs, errs) - } -} - -func ValidateCacheConfig(field string, value config.CacheConfig, errs *Errors) { - ValidateDuration(joinPath(field, "expiration-time"), value.ExpirationTime, false, errs) - - if value.MaxSize <= 0 { - maxSizeField := joinPath(field, "max-size") - errs.add(fmt.Sprintf("%s must be greater than 0", maxSizeField)) - } - - if len(value.Methods) == 0 { - errs.add("methods must not be empty") - } - - for i, method := range value.Methods { - ValidateMethod(joinPath(field, "methods", index(i)), method, false, errs) - } -} - -func ValidateTLS(_ string, mapping config.Mapping, fs afero.Fs, errs *Errors) { - fromURL, err := mapping.GetFromURL() - if err != nil || fromURL.Scheme != "https" { - return - } - - if !infratls.CAExists(fs) { - errs.add(formatTLSError(fromURL.Host)) - } -} - -func formatTLSError(host string) string { - var builder strings.Builder - fmt.Fprintf(&builder, "HTTPS mapping '%s' requires a local CA certificate for automatic TLS.\n\n", host) - builder.WriteString("Generate a local CA certificate:\n") - builder.WriteString(" uncors generate-certs\n\n") - builder.WriteString("After generating CA, you can add it to your system's trusted certificates.") - - return builder.String() -} - -func ValidateMapping(field string, value config.Mapping, fs afero.Fs, errs *Errors) { - ValidateHost(joinPath(field, "from"), value.From, errs) - ValidateHost(joinPath(field, "to"), value.To, errs) - ValidateOptionsHandling(joinPath(field, "options-handling"), value.OptionsHandling, errs) - ValidateHAR(joinPath(field, "har"), value.HAR, errs) - ValidateTLS(field, value, fs, errs) - - for i, static := range value.Statics { - ValidateStatic(joinPath(field, "statics", index(i)), static, fs, errs) - } - - for i, mock := range value.Mocks { - ValidateMock(joinPath(field, "mocks", index(i)), mock, fs, errs) - } - - for i, glob := range value.Cache { - ValidateCacheGlob(joinPath(field, "cache", index(i)), glob, errs) - } - - for i, rewrite := range value.Rewrites { - ValidateRewritingOption(joinPath(field, "rewrite", index(i)), rewrite, errs) - } - - for i, script := range value.Scripts { - ValidateScript(joinPath(field, "scripts", index(i)), script, fs, errs) - } -} diff --git a/internal/config/watcher.go b/internal/config/watcher.go index 9b9aacc5..b199a6c4 100644 --- a/internal/config/watcher.go +++ b/internal/config/watcher.go @@ -8,22 +8,14 @@ import ( "github.com/fsnotify/fsnotify" ) -// debounceDelay is the wait time after the last file event before calling onChange. -// This prevents multiple rapid callbacks when editors write files in stages. const debounceDelay = 10 * time.Millisecond -// Watcher monitors a configuration file for changes and invokes a callback -// whenever the file is written or recreated. It uses a short debounce window to -// coalesce bursts of filesystem events that editors typically produce on save. type Watcher struct { fsWatcher *fsnotify.Watcher onChange func() done chan struct{} } -// NewWatcher creates a Watcher that monitors the given file path. -// onChange is called (after debouncing) on every write or create event. -// The returned watcher is already running; call Close to stop it. func NewWatcher(filePath string, onChange func()) (*Watcher, error) { fsWatcher, err := fsnotify.NewWatcher() if err != nil { @@ -48,7 +40,6 @@ func NewWatcher(filePath string, onChange func()) (*Watcher, error) { return watcher, nil } -// Close stops the watcher and releases all associated resources. func (cw *Watcher) Close() error { close(cw.done) diff --git a/internal/contracts/http.go b/internal/contracts/http.go index c07ceda0..d40b9844 100644 --- a/internal/contracts/http.go +++ b/internal/contracts/http.go @@ -13,7 +13,7 @@ const ( PrefixUpdaterKey contextKey = "uncors-prefix-updater" ) -type ReqestData struct { +type RequestData struct { Method string URL *url.URL Header http.Header diff --git a/internal/contracts/output.go b/internal/contracts/output.go index 5338d120..728c686e 100644 --- a/internal/contracts/output.go +++ b/internal/contracts/output.go @@ -29,7 +29,7 @@ type Output interface { Print(msg any) Printf(msg string, args ...any) - Request(data *ReqestData) + Request(data *RequestData) NewPrefixOutput(prefix string) Output } diff --git a/internal/handler/mock/handler.go b/internal/handler/mock/handler.go index f3443001..35349ab2 100644 --- a/internal/handler/mock/handler.go +++ b/internal/handler/mock/handler.go @@ -22,6 +22,8 @@ type Handler struct { after func(duration time.Duration) <-chan time.Time } +const contentTypeSniffLen = 512 + var ErrResponseIsNotDefined = errors.New("response is not defined") func NewMockHandler(options ...HandlerOption) *Handler { @@ -29,7 +31,7 @@ func NewMockHandler(options ...HandlerOption) *Handler { } func (h *Handler) ServeHTTP(writer contracts.ResponseWriter, request *contracts.Request) { - if h.waiteDelay(writer, request) { + if h.waitDelay(writer, request) { return } @@ -37,8 +39,6 @@ func (h *Handler) ServeHTTP(writer contracts.ResponseWriter, request *contracts. if err != nil { log.Printf("ERROR: Mock handler error: %s (URL: %s)", err.Error(), request.URL.String()) infra.HTTPError(writer, err) - - return } } @@ -55,20 +55,12 @@ func (h *Handler) writeResponse(writer contracts.ResponseWriter, request *contra switch { case response.IsFile(): - err := h.serveFileContent(writer, request) - if err != nil { - return err - } + return h.serveFileContent(writer, request) case response.IsRaw(): - err := h.serveRawContent(writer) - if err != nil { - return err - } + return h.serveRawContent(writer) default: return ErrResponseIsNotDefined } - - return nil } func (h *Handler) serveRawContent(writer http.ResponseWriter) error { @@ -76,7 +68,12 @@ func (h *Handler) serveRawContent(writer http.ResponseWriter) error { header := writer.Header() if len(header.Get(headers.ContentType)) == 0 { - contentType := http.DetectContentType([]byte(response.Raw)) + sniff := response.Raw + if len(sniff) > contentTypeSniffLen { + sniff = sniff[:contentTypeSniffLen] + } + + contentType := http.DetectContentType([]byte(sniff)) header.Set(headers.ContentType, contentType) } @@ -104,29 +101,17 @@ func (h *Handler) serveFileContent(writer http.ResponseWriter, request *http.Req return nil } -func (h *Handler) waiteDelay(writer contracts.ResponseWriter, request *contracts.Request) bool { - response := h.response - - if response.Delay > 0 { - log.Printf("Delay %s for %s", response.Delay, request.URL.RequestURI()) - ctx := request.Context() - url := request.URL.RequestURI() - - waitingLoop: - for { - select { - case <-ctx.Done(): - writer.WriteHeader(http.StatusServiceUnavailable) - log.Printf("Delay is canceled (url: %s)", url) +func (h *Handler) waitDelay(writer contracts.ResponseWriter, request *contracts.Request) bool { + if h.response.Delay <= 0 { + return false + } - return true - case <-h.after(response.Delay): - log.Printf("Delay is complete (url: %s)", url) + select { + case <-request.Context().Done(): + writer.WriteHeader(http.StatusServiceUnavailable) - break waitingLoop - } - } + return true + case <-h.after(h.response.Delay): + return false } - - return false } diff --git a/internal/handler/rewrite/helpers.go b/internal/handler/rewrite/helpers.go index 82651ffb..7bd17735 100644 --- a/internal/handler/rewrite/helpers.go +++ b/internal/handler/rewrite/helpers.go @@ -10,10 +10,8 @@ var ErrInvalidHost = errors.New("rewrite host has invalid type") type rewriteKeyType string -var RewriteHostKey rewriteKeyType = "__uncors_rewrite_host" +const RewriteHostKey rewriteKeyType = "__uncors_rewrite_host" -// GetRewriteHost extracts the rewrite host from the request context. -// Returns ErrInvalidHost if the value exists but is not a string. func GetRewriteHost(request *contracts.Request) (string, error) { value := request.Context().Value(RewriteHostKey) diff --git a/internal/helpers/request_data.go b/internal/helpers/request_data.go index f7b6bed2..bb537226 100644 --- a/internal/helpers/request_data.go +++ b/internal/helpers/request_data.go @@ -4,8 +4,8 @@ import ( "github.com/evg4b/uncors/internal/contracts" ) -func ToRequestData(req *contracts.Request, code int) *contracts.ReqestData { - return &contracts.ReqestData{ +func ToRequestData(req *contracts.Request, code int) *contracts.RequestData { + return &contracts.RequestData{ Method: req.Method, URL: req.URL, Header: req.Header, diff --git a/internal/server/printer_test.go b/internal/server/printer_test.go index bcd93706..8025062b 100644 --- a/internal/server/printer_test.go +++ b/internal/server/printer_test.go @@ -16,12 +16,12 @@ func TestRequestPrinter(t *testing.T) { tracker := server.NewRequestTracker() output := mocks.NewOutputMock(t) - data := &contracts.ReqestData{ + data := &contracts.RequestData{ Method: "GET", Code: 200, } - output.RequestMock.Set(func(_ *contracts.ReqestData) {}) + output.RequestMock.Set(func(_ *contracts.RequestData) {}) go server.RequestPrinter(tracker, output) @@ -45,7 +45,7 @@ func TestRequestPrinter(t *testing.T) { tracker := server.NewRequestTracker() output := mocks.NewOutputMock(t) - data := &contracts.ReqestData{ + data := &contracts.RequestData{ Method: "GET", Code: 200, } @@ -97,7 +97,7 @@ func TestRequestPrinter(t *testing.T) { const prefix = "PROXY" - data := &contracts.ReqestData{ + data := &contracts.RequestData{ Method: "GET", Code: 200, } @@ -105,7 +105,7 @@ func TestRequestPrinter(t *testing.T) { output.NewPrefixOutputMock.Set(func(_ string) contracts.Output { return prefixedOutput }) - prefixedOutput.RequestMock.Set(func(_ *contracts.ReqestData) {}) + prefixedOutput.RequestMock.Set(func(_ *contracts.RequestData) {}) go server.RequestPrinter(tracker, output) @@ -131,12 +131,12 @@ func TestRequestPrinter(t *testing.T) { tracker := server.NewRequestTracker() output := mocks.NewOutputMock(t) - data := &contracts.ReqestData{ + data := &contracts.RequestData{ Method: "GET", Code: 200, } - output.RequestMock.Set(func(_ *contracts.ReqestData) {}) + output.RequestMock.Set(func(_ *contracts.RequestData) {}) go server.RequestPrinter(tracker, output) @@ -163,21 +163,21 @@ func TestRequestPrinter(t *testing.T) { output := mocks.NewOutputMock(t) prefixedOutput := mocks.NewOutputMock(t) - data1 := &contracts.ReqestData{Method: "GET", Code: 200} - data2 := &contracts.ReqestData{Method: "POST", Code: 201} - data3 := &contracts.ReqestData{Method: "DELETE", Code: 204} + data1 := &contracts.RequestData{Method: "GET", Code: 200} + data2 := &contracts.RequestData{Method: "POST", Code: 201} + data3 := &contracts.RequestData{Method: "DELETE", Code: 204} - output.RequestMock.Set(func(_ *contracts.ReqestData) {}) + output.RequestMock.Set(func(_ *contracts.RequestData) {}) output.NewPrefixOutputMock.Set(func(_ string) contracts.Output { return prefixedOutput }) - prefixedOutput.RequestMock.Set(func(_ *contracts.ReqestData) {}) + prefixedOutput.RequestMock.Set(func(_ *contracts.RequestData) {}) go server.RequestPrinter(tracker, output) tracker.Emit(server.RequestEvent{ID: 1, Done: true, Data: data1}) tracker.Emit(server.RequestEvent{ID: 2, Prefix: "MOD1", Done: true, Data: data2}) - tracker.Emit(server.RequestEvent{ID: 3, Done: false, Data: &contracts.ReqestData{Method: "PATCH", Code: 200}}) + tracker.Emit(server.RequestEvent{ID: 3, Done: false, Data: &contracts.RequestData{Method: "PATCH", Code: 200}}) tracker.Emit(server.RequestEvent{ID: 4, Done: true, Data: data3}) tracker.Close() diff --git a/internal/server/tracker.go b/internal/server/tracker.go index 87f2138e..89adc8c5 100644 --- a/internal/server/tracker.go +++ b/internal/server/tracker.go @@ -20,7 +20,7 @@ type RequestEvent struct { StartedAt time.Time Prefix string Done bool - Data *contracts.ReqestData + Data *contracts.RequestData } type RequestTracker struct { diff --git a/internal/tui/box_message.go b/internal/tui/box_message.go index 0e0e331f..31dd200f 100644 --- a/internal/tui/box_message.go +++ b/internal/tui/box_message.go @@ -2,13 +2,12 @@ package tui import ( "fmt" - "io" "strings" "charm.land/lipgloss/v2" ) -func printMessageBox(out io.Writer, message, prefix string, blockStyles lipgloss.Style) { +func (output *CliOutput) printMessageBox(message, prefix string, blockStyles lipgloss.Style) { height := lipgloss.Height(message) space := strings.Repeat("\n", height-1) @@ -19,7 +18,7 @@ func printMessageBox(out io.Writer, message, prefix string, blockStyles lipgloss message, ) - _, err := fmt.Fprintln(out, block) + _, err := fmt.Fprintln(output.output, block) if err != nil { panic(err) } diff --git a/internal/tui/output.go b/internal/tui/output.go index 0f415517..3e0c34e4 100644 --- a/internal/tui/output.go +++ b/internal/tui/output.go @@ -13,10 +13,10 @@ import ( "github.com/evg4b/uncors/internal/tui/styles" ) -type ouputType int8 +type outputType int8 const ( - defaultOutput ouputType = iota + defaultOutput outputType = iota infoOutput warnOutput errorOutput @@ -24,14 +24,14 @@ const ( var boxLength = 8 -var levelStyles = map[ouputType]lipgloss.Style{ +var levelStyles = map[outputType]lipgloss.Style{ infoOutput: styles.InfoBlockStyle.Width(boxLength).Bold(true), warnOutput: styles.WarningBlockStyle.Width(boxLength).Bold(true), errorOutput: styles.ErrorBlockStyle.Width(boxLength).Bold(true), defaultOutput: lipgloss.NewStyle(), } -var messageMap = map[ouputType]string{ +var messageMap = map[outputType]string{ infoOutput: InfoLabel, warnOutput: WarningLabel, errorOutput: ErrorLabel, @@ -62,117 +62,113 @@ func NewCliOutput(output io.Writer, options ...Option) *CliOutput { return helpers.ApplyOptions(&CliOutput{ mutex: &sync.RWMutex{}, output: output, - buffer: bytes.Buffer{}, }, options) } -func (o *CliOutput) Write(p []byte) (int, error) { - o.mutex.RLock() - defer o.mutex.RUnlock() +func (output *CliOutput) Write(p []byte) (int, error) { + output.mutex.RLock() + defer output.mutex.RUnlock() - return o.output.Write(p) + return output.output.Write(p) } -func (o *CliOutput) Info(msg any) { - o.print(fmt.Sprint(msg), infoOutput) +func (output *CliOutput) Info(msg any) { + output.print(fmt.Sprint(msg), infoOutput) } -func (o *CliOutput) Infof(msg string, args ...any) { - o.print(fmt.Sprintf(msg, args...), infoOutput) +func (output *CliOutput) Infof(msg string, args ...any) { + output.print(fmt.Sprintf(msg, args...), infoOutput) } -func (o *CliOutput) InfoBox(messages ...string) { - printMessageBox( - o.output, - strings.Join(messages, "\n"), - InfoLabel, - styles.InfoBlockStyle, - ) +func (output *CliOutput) InfoBox(messages ...string) { + output.mutex.Lock() + defer output.mutex.Unlock() + + output.printMessageBox(strings.Join(messages, "\n"), InfoLabel, styles.InfoBlockStyle) } -func (o *CliOutput) Error(msg any) { - o.print(fmt.Sprint(msg), errorOutput) +func (output *CliOutput) Error(msg any) { + output.print(fmt.Sprint(msg), errorOutput) } -func (o *CliOutput) Errorf(msg string, args ...any) { - o.print(fmt.Sprintf(msg, args...), errorOutput) +func (output *CliOutput) Errorf(msg string, args ...any) { + output.print(fmt.Sprintf(msg, args...), errorOutput) } -func (o *CliOutput) ErrorBox(messages ...string) { - printMessageBox( - o.output, - strings.Join(messages, "\n"), - ErrorLabel, - styles.ErrorBlockStyle, - ) +func (output *CliOutput) ErrorBox(messages ...string) { + output.mutex.Lock() + defer output.mutex.Unlock() + + output.printMessageBox(strings.Join(messages, "\n"), ErrorLabel, styles.ErrorBlockStyle) } -func (o *CliOutput) Warn(msg any) { - o.print(fmt.Sprint(msg), warnOutput) +func (output *CliOutput) Warn(msg any) { + output.print(fmt.Sprint(msg), warnOutput) } -func (o *CliOutput) Warnf(msg string, args ...any) { - o.print(fmt.Sprintf(msg, args...), warnOutput) +func (output *CliOutput) Warnf(msg string, args ...any) { + output.print(fmt.Sprintf(msg, args...), warnOutput) } -func (o *CliOutput) WarnBox(messages ...string) { - printMessageBox( - o.output, - strings.Join(messages, "\n"), - WarningLabel, - styles.WarningBlockStyle, - ) +func (output *CliOutput) WarnBox(messages ...string) { + output.mutex.Lock() + defer output.mutex.Unlock() + + output.printMessageBox(strings.Join(messages, "\n"), WarningLabel, styles.WarningBlockStyle) } -func (o *CliOutput) Print(msg any) { - o.print(fmt.Sprint(msg), defaultOutput) +func (output *CliOutput) Print(msg any) { + output.print(fmt.Sprint(msg), defaultOutput) } -func (o *CliOutput) Printf(msg string, args ...any) { - o.print(fmt.Sprintf(msg, args...), defaultOutput) +func (output *CliOutput) Printf(msg string, args ...any) { + output.print(fmt.Sprintf(msg, args...), defaultOutput) } -func (o *CliOutput) Request(data *contracts.ReqestData) { - o.print(printResponse(data), defaultOutput) +func (output *CliOutput) Request(data *contracts.RequestData) { + output.print(printResponse(data), defaultOutput) } -func (o *CliOutput) NewPrefixOutput(prefix string) contracts.Output { - return NewCliOutput(o.output, WithPrefix(prefix), withMutex(o.mutex)) +func (output *CliOutput) NewPrefixOutput(prefix string) contracts.Output { + return NewCliOutput(output.output, WithPrefix(prefix), withMutex(output.mutex)) } -func (o *CliOutput) print(msg string, outputType ouputType) { - o.mutex.Lock() - defer o.mutex.Unlock() +// print holds the exclusive write lock for the full render+write cycle so that +// the shared buffer and the underlying writer are never accessed concurrently. +// Write() uses RLock only, which is blocked while print() owns the write lock. +func (output *CliOutput) print(msg string, level outputType) { + output.mutex.Lock() + defer output.mutex.Unlock() - o.renderPrefix() - o.renderLevel(outputType) - o.renderMessage(msg) - o.buffer.WriteByte('\n') + output.renderPrefix() + output.renderLevel(level) + output.renderMessage(msg) + output.buffer.WriteByte('\n') - defer o.buffer.Reset() + _, err := output.output.Write(output.buffer.Bytes()) + output.buffer.Reset() - _, err := o.output.Write(o.buffer.Bytes()) if err != nil { panic(err) } } -func (o *CliOutput) renderLevel(level ouputType) { +func (output *CliOutput) renderLevel(level outputType) { renderer := levelStyles[level] if levelMessage, ok := messageMap[level]; ok { - o.buffer.WriteString(renderer.Render(levelMessage)) + output.buffer.WriteString(renderer.Render(levelMessage)) } - o.buffer.WriteByte(' ') + output.buffer.WriteByte(' ') } -func (o *CliOutput) renderMessage(msg string) { +func (output *CliOutput) renderMessage(msg string) { msg = strings.TrimSuffix(msg, "\n") - fmt.Fprint(&o.buffer, msg) + fmt.Fprint(&output.buffer, msg) } -func (o *CliOutput) renderPrefix() { - if len(o.prefix) > 0 { - o.buffer.WriteString(o.prefix) +func (output *CliOutput) renderPrefix() { + if len(output.prefix) > 0 { + output.buffer.WriteString(output.prefix) } } diff --git a/internal/tui/output_internal_test.go b/internal/tui/output_internal_test.go deleted file mode 100644 index b8e13add..00000000 --- a/internal/tui/output_internal_test.go +++ /dev/null @@ -1,22 +0,0 @@ -package tui - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/assert" -) - -var errInternalWrite = fmt.Errorf("internal write error") - -type errorWriterInternal struct{} - -func (e *errorWriterInternal) Write(_ []byte) (int, error) { - return 0, errInternalWrite -} - -func TestPrintMessageBox_PanicsOnWriteError(t *testing.T) { - assert.Panics(t, func() { - printMessageBox(&errorWriterInternal{}, "test message", "INFO", levelStyles[infoOutput]) - }) -} diff --git a/internal/tui/output_test.go b/internal/tui/output_test.go index 1fe62378..48bda7ec 100644 --- a/internal/tui/output_test.go +++ b/internal/tui/output_test.go @@ -129,7 +129,7 @@ func TestCliOutput_Request(t *testing.T) { var buf strings.Builder out := tui.NewCliOutput(&buf) - out.NewPrefixOutput("PROXY").Request(&contracts.ReqestData{ + out.NewPrefixOutput("PROXY").Request(&contracts.RequestData{ Method: "GET", URL: mustParseURL("http://example.com/path"), Code: 200, diff --git a/internal/tui/printresponse.go b/internal/tui/printresponse.go index 1772d795..252d0f37 100644 --- a/internal/tui/printresponse.go +++ b/internal/tui/printresponse.go @@ -11,7 +11,7 @@ import ( const prefixWidth = 13 -func printResponse(data *contracts.ReqestData) string { +func printResponse(data *contracts.RequestData) string { var ( prefix string prefixStyle, textStyle lipgloss.Style diff --git a/internal/tui/printresponse_test.go b/internal/tui/printresponse_test.go index 273248fb..cc9c3a44 100644 --- a/internal/tui/printresponse_test.go +++ b/internal/tui/printresponse_test.go @@ -14,10 +14,10 @@ import ( "github.com/stretchr/testify/assert" ) -func makeRequestData(method, rawURL string, code int) *contracts.ReqestData { +func makeRequestData(method, rawURL string, code int) *contracts.RequestData { u, _ := url.Parse(rawURL) - return &contracts.ReqestData{ + return &contracts.RequestData{ Method: method, URL: u, Code: code, @@ -27,7 +27,7 @@ func makeRequestData(method, rawURL string, code int) *contracts.ReqestData { func TestPrintResponse(t *testing.T) { tests := []struct { name string - data *contracts.ReqestData + data *contracts.RequestData }{ { name: "1xx informational", diff --git a/internal/uncors_app/app_internal_test.go b/internal/uncors_app/app_internal_test.go index 3d631003..cb72e5bc 100644 --- a/internal/uncors_app/app_internal_test.go +++ b/internal/uncors_app/app_internal_test.go @@ -362,7 +362,7 @@ func TestHandleRequestEventWithData(t *testing.T) { requestURL, err := url.Parse("https://example.com/api") require.NoError(t, err) - data := &contracts.ReqestData{Method: "GET", URL: requestURL, Code: 200} + data := &contracts.RequestData{Method: "GET", URL: requestURL, Code: 200} t.Run("outputs request without prefix", func(t *testing.T) { app, _ := newTestApp(t) diff --git a/internal/uncors_app/output.go b/internal/uncors_app/output.go index 97b15a67..d61ca4dc 100644 --- a/internal/uncors_app/output.go +++ b/internal/uncors_app/output.go @@ -67,7 +67,7 @@ func (o *tuiOutput) Printf(msg string, args ...any) { o.capture(func(out *tui.CliOutput) { out.Printf(msg, args...) }) } -func (o *tuiOutput) Request(data *contracts.ReqestData) { +func (o *tuiOutput) Request(data *contracts.RequestData) { o.capture(func(out *tui.CliOutput) { out.Request(data) }) } diff --git a/internal/uncors_app/output_internal_test.go b/internal/uncors_app/output_internal_test.go index e44ea858..bfe04236 100644 --- a/internal/uncors_app/output_internal_test.go +++ b/internal/uncors_app/output_internal_test.go @@ -146,7 +146,7 @@ func TestTuiOutput_Request(t *testing.T) { t.Run("Request sends formatted request data", func(t *testing.T) { out, outputCh := newTestOutput() u, _ := url.Parse("http://example.com/api/resource") - out.Request(&contracts.ReqestData{ + out.Request(&contracts.RequestData{ Method: "GET", URL: u, Code: 200, diff --git a/internal/urlreplacer/replacer.go b/internal/urlreplacer/replacer.go index 5cb2f3cb..e6493a5b 100644 --- a/internal/urlreplacer/replacer.go +++ b/internal/urlreplacer/replacer.go @@ -24,10 +24,11 @@ var ( type hook = func(string) string type Replacer struct { - regexp *regexp.Regexp - pattern string - hooks map[string]hook - scheme string // target scheme (http or https), or empty + regexp *regexp.Regexp + pattern string + hooks map[string]hook + scheme string + subexpIndex map[string]int // precomputed name→index, built once at construction } func NewReplacer(source, target string) (*Replacer, error) { @@ -44,7 +45,6 @@ func NewReplacer(source, target string) (*Replacer, error) { return nil, fmt.Errorf("%w: {%s}", ErrDuplicateSourceKey, dup) } - // Validate raw URLs before any processing err := validateRawURL(source) if err != nil { return nil, fmt.Errorf("%w: %w", ErrInvalidSourceURL, err) @@ -65,13 +65,20 @@ func NewReplacer(source, target string) (*Replacer, error) { } replacer.pattern = wildCardToReplacePattern(target) - - // Extract and store target scheme replacer.scheme = extractScheme(target) + if len(replacer.scheme) > 0 { replacer.hooks["scheme"] = schemeHookFactory(replacer.scheme) } + // Build name→index map once so Replace doesn't call SubexpIndex (O(n)) per group per call. + replacer.subexpIndex = make(map[string]int, len(replacer.regexp.SubexpNames())) + for i, name := range replacer.regexp.SubexpNames() { + if name != "" { + replacer.subexpIndex[name] = i + } + } + return replacer, nil } @@ -83,18 +90,13 @@ func (r *Replacer) Replace(source string) (string, error) { replaced := strings.Clone(r.pattern) - for _, subExpName := range r.regexp.SubexpNames() { - if len(subExpName) > 0 { - partPattern := fmt.Sprintf("${%s}", subExpName) - partIndex := r.regexp.SubexpIndex(subExpName) - - partValue := matches[partIndex] - if hook, ok := r.hooks[subExpName]; ok { - partValue = hook(partValue) - } - - replaced = strings.ReplaceAll(replaced, partPattern, partValue) + for name, idx := range r.subexpIndex { + partValue := matches[idx] + if hook, ok := r.hooks[name]; ok { + partValue = hook(partValue) } + + replaced = strings.ReplaceAll(replaced, "${"+name+"}", partValue) } return replaced, nil diff --git a/main.go b/main.go index d0b57952..8ee1d934 100644 --- a/main.go +++ b/main.go @@ -11,7 +11,6 @@ import ( tea "charm.land/bubbletea/v2" "github.com/evg4b/uncors/internal/commands" "github.com/evg4b/uncors/internal/config" - "github.com/evg4b/uncors/internal/config/validators" "github.com/evg4b/uncors/internal/helpers" "github.com/evg4b/uncors/internal/infra" "github.com/evg4b/uncors/internal/server" @@ -209,11 +208,6 @@ func loadConfiguration(fs afero.Fs) (*config.UncorsConfig, string) { panic(err) } - err = validators.ValidateConfig(uncorsConfig, fs) - if err != nil { - panic(err) - } - if uncorsConfig.Debug { logFile, err := os.OpenFile("uncors.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, os.ModeAppend) if err != nil { diff --git a/testing/mocks/output_mock.go b/testing/mocks/output_mock.go index 9c8e0092..03f367d1 100644 --- a/testing/mocks/output_mock.go +++ b/testing/mocks/output_mock.go @@ -81,9 +81,9 @@ type OutputMock struct { beforePrintfCounter uint64 PrintfMock mOutputMockPrintf - funcRequest func(data *mm_contracts.ReqestData) + funcRequest func(data *mm_contracts.RequestData) funcRequestOrigin string - inspectFuncRequest func(data *mm_contracts.ReqestData) + inspectFuncRequest func(data *mm_contracts.RequestData) afterRequestCounter uint64 beforeRequestCounter uint64 RequestMock mOutputMockRequest @@ -3034,12 +3034,12 @@ type OutputMockRequestExpectation struct { // OutputMockRequestParams contains parameters of the Output.Request type OutputMockRequestParams struct { - data *mm_contracts.ReqestData + data *mm_contracts.RequestData } // OutputMockRequestParamPtrs contains pointers to parameters of the Output.Request type OutputMockRequestParamPtrs struct { - data **mm_contracts.ReqestData + data **mm_contracts.RequestData } // OutputMockRequestOrigins contains origins of expectations of the Output.Request @@ -3059,7 +3059,7 @@ func (mmRequest *mOutputMockRequest) Optional() *mOutputMockRequest { } // Expect sets up expected params for Output.Request -func (mmRequest *mOutputMockRequest) Expect(data *mm_contracts.ReqestData) *mOutputMockRequest { +func (mmRequest *mOutputMockRequest) Expect(data *mm_contracts.RequestData) *mOutputMockRequest { if mmRequest.mock.funcRequest != nil { mmRequest.mock.t.Fatalf("OutputMock.Request mock is already set by Set") } @@ -3084,7 +3084,7 @@ func (mmRequest *mOutputMockRequest) Expect(data *mm_contracts.ReqestData) *mOut } // ExpectDataParam1 sets up expected param data for Output.Request -func (mmRequest *mOutputMockRequest) ExpectDataParam1(data *mm_contracts.ReqestData) *mOutputMockRequest { +func (mmRequest *mOutputMockRequest) ExpectDataParam1(data *mm_contracts.RequestData) *mOutputMockRequest { if mmRequest.mock.funcRequest != nil { mmRequest.mock.t.Fatalf("OutputMock.Request mock is already set by Set") } @@ -3107,7 +3107,7 @@ func (mmRequest *mOutputMockRequest) ExpectDataParam1(data *mm_contracts.ReqestD } // Inspect accepts an inspector function that has same arguments as the Output.Request -func (mmRequest *mOutputMockRequest) Inspect(f func(data *mm_contracts.ReqestData)) *mOutputMockRequest { +func (mmRequest *mOutputMockRequest) Inspect(f func(data *mm_contracts.RequestData)) *mOutputMockRequest { if mmRequest.mock.inspectFuncRequest != nil { mmRequest.mock.t.Fatalf("Inspect function is already set for OutputMock.Request") } @@ -3132,7 +3132,7 @@ func (mmRequest *mOutputMockRequest) Return() *OutputMock { } // Set uses given function f to mock the Output.Request method -func (mmRequest *mOutputMockRequest) Set(f func(data *mm_contracts.ReqestData)) *OutputMock { +func (mmRequest *mOutputMockRequest) Set(f func(data *mm_contracts.RequestData)) *OutputMock { if mmRequest.defaultExpectation != nil { mmRequest.mock.t.Fatalf("Default expectation is already set for the Output.Request method") } @@ -3148,7 +3148,7 @@ func (mmRequest *mOutputMockRequest) Set(f func(data *mm_contracts.ReqestData)) // When sets expectation for the Output.Request which will trigger the result defined by the following // Then helper -func (mmRequest *mOutputMockRequest) When(data *mm_contracts.ReqestData) *OutputMockRequestExpectation { +func (mmRequest *mOutputMockRequest) When(data *mm_contracts.RequestData) *OutputMockRequestExpectation { if mmRequest.mock.funcRequest != nil { mmRequest.mock.t.Fatalf("OutputMock.Request mock is already set by Set") } @@ -3190,7 +3190,7 @@ func (mmRequest *mOutputMockRequest) invocationsDone() bool { } // Request implements mm_contracts.Output -func (mmRequest *OutputMock) Request(data *mm_contracts.ReqestData) { +func (mmRequest *OutputMock) Request(data *mm_contracts.RequestData) { mm_atomic.AddUint64(&mmRequest.beforeRequestCounter, 1) defer mm_atomic.AddUint64(&mmRequest.afterRequestCounter, 1)