diff --git a/.golangci.yml b/.golangci.yml index f2e97ed0..14206026 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -29,6 +29,10 @@ linters: - ok - fs - ca + tagliatelle: + case: + rules: + yaml: kebab exclusions: generated: lax presets: diff --git a/go.mod b/go.mod index bc8f4704..ff023bcf 100644 --- a/go.mod +++ b/go.mod @@ -15,11 +15,9 @@ require ( github.com/gorilla/mux v1.8.1 github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-version v1.9.0 - github.com/mitchellh/mapstructure v1.5.0 github.com/samber/lo v1.53.0 github.com/spf13/afero v1.15.0 github.com/spf13/pflag v1.0.10 - github.com/spf13/viper v1.21.0 github.com/stretchr/testify v1.11.1 github.com/xeipuuv/gojsonschema v1.2.0 github.com/yuin/gopher-lua v1.1.2 @@ -38,9 +36,7 @@ require ( github.com/clipperhouse/displaywidth v0.11.0 // indirect github.com/clipperhouse/uax29/v2 v2.7.0 // indirect github.com/gkampitakis/ciinfo v0.3.4 // indirect - github.com/go-viper/mapstructure/v2 v2.5.0 // indirect github.com/goccy/go-yaml v1.19.2 // indirect - github.com/google/go-cmp v0.7.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/kr/text v0.2.0 // indirect @@ -49,7 +45,6 @@ require ( github.com/mattn/go-runewidth v0.0.23 // indirect github.com/muesli/cancelreader v0.2.2 // indirect github.com/rogpeppe/go-internal v1.14.1 // indirect - github.com/sagikazarmark/locafero v0.12.0 // indirect github.com/sergi/go-diff v1.4.0 // indirect github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/match v1.2.0 // indirect @@ -58,18 +53,14 @@ require ( github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect - go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/sync v0.20.0 // indirect ) require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/fsnotify/fsnotify v1.10.0 - github.com/pelletier/go-toml/v2 v2.3.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/rivo/uniseg v0.4.7 // indirect - github.com/spf13/cast v1.10.0 // indirect - github.com/subosito/gotenv v1.6.0 // indirect golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f // indirect golang.org/x/sys v0.43.0 // indirect golang.org/x/text v0.36.0 // indirect diff --git a/go.sum b/go.sum index 4dbd6c64..52c7cd63 100644 --- a/go.sum +++ b/go.sum @@ -39,8 +39,6 @@ github.com/dgryski/go-farm v0.0.0-20240924180020-3414d57e47da h1:aIftn67I1fkbMa5 github.com/dgryski/go-farm v0.0.0-20240924180020-3414d57e47da/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= -github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= -github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.10.0 h1:Xx/5Ydg9CeBDX/wi4VJqStNtohYjitZhhlHt4h3St1M= github.com/fsnotify/fsnotify v1.10.0/go.mod h1:TLheqan6HD6GBK6PrDWyDPBaEV8LspOxvPSjC+bVfgo= github.com/gkampitakis/ciinfo v0.3.4 h1:5eBSibVuSMbb/H6Elc0IIEFbkzCJi3lm94n0+U7Z0KY= @@ -49,14 +47,10 @@ github.com/gkampitakis/go-snaps v0.5.21 h1:SvhSFeZviQXwlT+dnGyAIATVehkhqRVW6qfQZ github.com/gkampitakis/go-snaps v0.5.21/go.mod h1:gC3YqxQTPyIXvQrw/Vpt3a8VqR1MO8sVpZFWN4DGwNs= github.com/go-http-utils/headers v0.0.0-20181008091004-fed159eddc2a h1:v6zMvHuY9yue4+QkG/HQ/W67wvtQmWJ4SDo9aK/GIno= github.com/go-http-utils/headers v0.0.0-20181008091004-fed159eddc2a/go.mod h1:I79BieaU4fxrw4LMXby6q5OS9XnoR9UIKLOzDFjUmuw= -github.com/go-viper/mapstructure/v2 v2.5.0 h1:vM5IJoUAy3d7zRSVtIwQgBj7BiWtMPfmPEgAXnvj1Ro= -github.com/go-viper/mapstructure/v2 v2.5.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM= github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= github.com/gojuno/minimock/v3 v3.4.7 h1:vhE5zpniyPDRT0DXd5s3DbtZJVlcbmC5k80izYtj9lY= github.com/gojuno/minimock/v3 v3.4.7/go.mod h1:QxJk4mdPrVyYUmEZGc2yD2NONpqM/j4dWhsy9twjFHg= -github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= -github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= @@ -79,12 +73,8 @@ github.com/maruel/natural v1.3.0 h1:VsmCsBmEyrR46RomtgHs5hbKADGRVtliHTyCOLFBpsg= github.com/maruel/natural v1.3.0/go.mod h1:v+Rfd79xlw1AgVBjbO0BEQmptqb5HvL/k9GRHB7ZKEg= github.com/mattn/go-runewidth v0.0.23 h1:7ykA0T0jkPpzSvMS5i9uoNn2Xy3R383f9HDx3RybWcw= github.com/mattn/go-runewidth v0.0.23/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= -github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= -github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= -github.com/pelletier/go-toml/v2 v2.3.1 h1:MYEvvGnQjeNkRF1qUuGolNtNExTDwct51yp7olPtrEc= -github.com/pelletier/go-toml/v2 v2.3.1/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= @@ -94,27 +84,19 @@ github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUc github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= -github.com/sagikazarmark/locafero v0.12.0 h1:/NQhBAkUb4+fH1jivKHWusDYFjMOOKU88eegjfxfHb4= -github.com/sagikazarmark/locafero v0.12.0/go.mod h1:sZh36u/YSZ918v0Io+U9ogLYQJ9tLLBmM4eneO6WwsI= github.com/samber/lo v1.53.0 h1:t975lj2py4kJPQ6haz1QMgtId2gtmfktACxIXArw3HM= github.com/samber/lo v1.53.0/go.mod h1:4+MXEGsJzbKGaUEQFKBq2xtfuznW9oz/WrgyzMzRoM0= github.com/sergi/go-diff v1.4.0 h1:n/SP9D5ad1fORl+llWyN+D6qoUETXNZARKjyY2/KVCw= github.com/sergi/go-diff v1.4.0/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4= github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I= github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg= -github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= -github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= -github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU= -github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= -github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= -github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= @@ -137,8 +119,6 @@ github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavM github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= github.com/yuin/gopher-lua v1.1.2 h1:yF/FjE3hD65tBbt0VXLE13HWS9h34fdzJmrWRXwobGA= github.com/yuin/gopher-lua v1.1.2/go.mod h1:7aRmXIWl37SqRf0koeyylBEzJ+aPt8A+mmkQ4f1ntR8= -go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= -go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f h1:W3F4c+6OLc6H2lb//N1q4WpJkhzJCK5J6kUi1NTVXfM= golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f/go.mod h1:J1xhfL/vlindoeF/aINzNzt2Bket5bjo9sdOYzOsU80= golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA= diff --git a/internal/config/cache_config.go b/internal/config/cache_config.go index aac07294..1efbfa72 100644 --- a/internal/config/cache_config.go +++ b/internal/config/cache_config.go @@ -1,9 +1,17 @@ package config import ( + "errors" + "fmt" + "strings" "time" + + "gopkg.in/yaml.v3" ) +// ErrInvalidCacheConfig is returned when the cache-config YAML value is not a mapping. +var ErrInvalidCacheConfig = errors.New("expected a mapping for cache-config") + type CacheGlobs []string func (g CacheGlobs) Clone() CacheGlobs { @@ -18,9 +26,9 @@ func (g CacheGlobs) Clone() CacheGlobs { } type CacheConfig struct { - ExpirationTime time.Duration `mapstructure:"expiration-time"` - MaxSize int64 `mapstructure:"max-size"` - Methods []string `mapstructure:"methods"` + ExpirationTime time.Duration `yaml:"-"` + MaxSize int64 `yaml:"max-size"` + Methods []string `yaml:"methods"` } func (c *CacheConfig) Clone() *CacheConfig { @@ -35,3 +43,41 @@ func (c *CacheConfig) Clone() *CacheConfig { Methods: methods, } } + +// UnmarshalYAML implements custom decoding so that the "expiration-time" field +// can be expressed as a human-readable duration string (e.g. "30m", "1h"). +// Other fields are decoded by the standard yaml.v3 machinery. +// Only fields present in the YAML node are updated; existing values (defaults) +// are preserved for absent keys. +func (c *CacheConfig) UnmarshalYAML(value *yaml.Node) error { + if value.Kind != yaml.MappingNode { + return ErrInvalidCacheConfig + } + + for i := 0; i+1 < len(value.Content); i += 2 { + keyNode := value.Content[i] + valNode := value.Content[i+1] + + switch keyNode.Value { + case "expiration-time": + dur, err := time.ParseDuration(strings.ReplaceAll(valNode.Value, " ", "")) + if err != nil { + return fmt.Errorf("invalid expiration-time %q: %w", valNode.Value, err) + } + + c.ExpirationTime = dur + case "max-size": + err := valNode.Decode(&c.MaxSize) + if err != nil { + return err + } + case "methods": + err := valNode.Decode(&c.Methods) + if err != nil { + return err + } + } + } + + return nil +} diff --git a/internal/config/cache_config_test.go b/internal/config/cache_config_test.go index 632902bc..30df8cda 100644 --- a/internal/config/cache_config_test.go +++ b/internal/config/cache_config_test.go @@ -7,8 +7,69 @@ import ( "github.com/evg4b/uncors/internal/config" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" ) +func TestCacheConfigUnmarshalYAML(t *testing.T) { + t.Run("decodes all fields", func(t *testing.T) { + const input = ` +expiration-time: 30m +max-size: 52428800 +methods: + - GET + - POST +` + + var actual config.CacheConfig + + require.NoError(t, yaml.Unmarshal([]byte(input), &actual)) + assert.Equal(t, config.CacheConfig{ + ExpirationTime: 30 * time.Minute, + MaxSize: 52428800, + Methods: []string{http.MethodGet, http.MethodPost}, + }, actual) + }) + + t.Run("parses expiration-time with embedded spaces", func(t *testing.T) { + const input = `expiration-time: "1h 30m"` + + var actual config.CacheConfig + + require.NoError(t, yaml.Unmarshal([]byte(input), &actual)) + assert.Equal(t, 90*time.Minute, actual.ExpirationTime) + }) + + t.Run("absent fields keep zero values", func(t *testing.T) { + const input = `max-size: 1024` + + var actual config.CacheConfig + + require.NoError(t, yaml.Unmarshal([]byte(input), &actual)) + assert.Equal(t, int64(1024), actual.MaxSize) + assert.Zero(t, actual.ExpirationTime) + assert.Nil(t, actual.Methods) + }) + + t.Run("returns ErrInvalidCacheConfig for non-mapping node", func(t *testing.T) { + const input = `- item1` + + var actual config.CacheConfig + + err := yaml.Unmarshal([]byte(input), &actual) + + assert.ErrorIs(t, err, config.ErrInvalidCacheConfig) + }) + + t.Run("returns error for invalid expiration-time", func(t *testing.T) { + const input = `expiration-time: not-a-duration` + + var actual config.CacheConfig + + assert.Error(t, yaml.Unmarshal([]byte(input), &actual)) + }) +} + func TestCacheGlobsClone(t *testing.T) { globs := config.CacheGlobs{ "/api/**", diff --git a/internal/config/config.go b/internal/config/config.go index 8a590ee0..1608fcde 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -3,79 +3,86 @@ package config import ( "fmt" - "github.com/evg4b/uncors/internal/helpers" - "github.com/mitchellh/mapstructure" + "github.com/spf13/afero" "github.com/spf13/pflag" - "github.com/spf13/viper" + "gopkg.in/yaml.v3" ) -var flags *pflag.FlagSet - +// UncorsConfig is the root configuration for the uncors proxy. type UncorsConfig struct { - Mappings Mappings `mapstructure:"mappings"` - Proxy string `mapstructure:"proxy"` - Debug bool `mapstructure:"debug"` - CacheConfig CacheConfig `mapstructure:"cache-config"` - Interactive bool `mapstructure:"interactive"` + Mappings Mappings `yaml:"mappings"` + Proxy string `yaml:"proxy"` + Debug bool `yaml:"debug"` + CacheConfig CacheConfig `yaml:"cache-config"` + Interactive bool `yaml:"-"` } -func LoadConfiguration(viperInstance *viper.Viper, args []string) *UncorsConfig { - defineFlags() - helpers.AssertIsDefined(flags) +// LoadConfiguration parses CLI arguments and optionally reads a YAML config file. +// CLI flags take precedence over config file values. +// Returns the loaded config, the active config file path (empty if none), and any error. +func LoadConfiguration(fs afero.Fs, args []string) (*UncorsConfig, string, error) { + flags := defineFlags() err := flags.Parse(args) if err != nil { - panic(fmt.Errorf("failed parsing flags: %w", err)) + return nil, "", fmt.Errorf("failed parsing flags: %w", err) } - err = viperInstance.BindPFlags(flags) - if err != nil { - panic(fmt.Errorf("failed binding flags: %w", err)) - } + cfg := defaultConfig() + configPath, _ := flags.GetString("config") - configuration := &UncorsConfig{ - Mappings: []Mapping{}, + if configPath != "" { + readErr := readYAMLFile(fs, cfg, configPath) + if readErr != nil { + return nil, "", readErr + } } - if configPath := viperInstance.GetString("config"); len(configPath) > 0 { - viperInstance.SetConfigFile(configPath) - - err := viperInstance.ReadInConfig() - if err != nil { - panic(fmt.Errorf("failed to read config file '%s': %w", configPath, err)) - } + err = applyFlagOverrides(cfg, flags) + if err != nil { + return nil, "", err } - configOption := viper.DecodeHook(mapstructure.ComposeDecodeHookFunc( - mapstructure.StringToSliceHookFunc(","), - StringToTimeDurationHookFunc(), - URLMappingHookFunc(), - )) + cfg.Mappings = NormaliseMappings(cfg.Mappings) - setDefaultValues(viperInstance) + return cfg, configPath, nil +} - err = viperInstance.Unmarshal(configuration, configOption) +// readYAMLFile opens a YAML config file and decodes it directly into cfg, +// preserving any existing default values for keys absent in the file. +func readYAMLFile(fs afero.Fs, cfg *UncorsConfig, path string) error { + file, err := fs.Open(path) if err != nil { - panic(fmt.Errorf("failed parsing config: %w", err)) + return fmt.Errorf("failed to read config file '%s': %w", path, err) } - err = readURLMapping(viperInstance, configuration) + defer file.Close() + + err = yaml.NewDecoder(file).Decode(cfg) if err != nil { - panic(err) + return fmt.Errorf("failed to read config file '%s': While parsing config: %w", path, err) } - configuration.Mappings = NormaliseMappings(configuration.Mappings) - - return configuration + return nil } -func defineFlags() { - flags = pflag.NewFlagSet("uncors", pflag.ContinueOnError) - flags.Usage = pflag.Usage - flags.StringSliceP("to", "t", []string{}, "Target host with protocol for the resource to be proxied") - flags.StringSliceP("from", "f", []string{}, "Local host with protocol for the resource from which proxying will take place") //nolint: lll - flags.String("proxy", "", "HTTP/HTTPS proxy for requests to the real server (uses system proxy by default)") - flags.Bool("debug", false, "Show debug output") - flags.StringP("config", "c", "", "Path to the configuration file") - flags.Bool("interactive", true, "") +// applyFlagOverrides applies CLI flag values to cfg, overriding any config file values. +// Only flags explicitly set on the command line are applied. +func applyFlagOverrides(cfg *UncorsConfig, flags *pflag.FlagSet) error { + if flags.Changed("proxy") { + cfg.Proxy, _ = flags.GetString("proxy") + } + + if flags.Changed("debug") { + cfg.Debug, _ = flags.GetBool("debug") + } + + if flags.Changed("interactive") { + cfg.Interactive, _ = flags.GetBool("interactive") + } + + from, _ := flags.GetStringSlice("from") + to, _ := flags.GetStringSlice("to") + + return mergeURLMappings(cfg, from, to) } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index e10b62ec..b8233886 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -9,8 +9,9 @@ import ( "github.com/evg4b/uncors/testing/hosts" "github.com/evg4b/uncors/testing/testutils" "github.com/evg4b/uncors/testing/testutils/params" - "github.com/spf13/viper" + "github.com/spf13/afero" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // clearMappingsCache clears the URL cache in all mappings for testing purposes. @@ -20,7 +21,7 @@ func clearMappingsCache(cfg *config.UncorsConfig) { } } -const acceptEncoding = "accept-encoding" +const acceptEncoding = "Accept-Encoding" const ( corruptedConfigPath = "/corrupted-config.yaml" @@ -78,13 +79,19 @@ mappings: ` ) -func TestLoadConfiguration(t *testing.T) { - fs := testutils.FsFromMap(t, map[string]string{ +func makeTestFs(t *testing.T) afero.Fs { + t.Helper() + + return testutils.FsFromMap(t, map[string]string{ corruptedConfigPath: corruptedConfig, fullConfigPath: fullConfig, incorrectConfigPath: incorrectConfig, minimalConfigPath: minimalConfig, }) +} + +func TestLoadConfiguration(t *testing.T) { + fs := makeTestFs(t) t.Run("correctly parse config", func(t *testing.T) { tests := []struct { @@ -203,110 +210,149 @@ func TestLoadConfiguration(t *testing.T) { Interactive: false, }, }, + { + name: "CLI proxy and debug flags override config file values", + args: []string{ + params.Config, fullConfigPath, + "--proxy", "newproxy:9999", + "--debug=false", + }, + expected: &config.UncorsConfig{ + Mappings: config.Mappings{ + {From: hosts.Localhost.HTTPPort(8080), To: hosts.Github.HTTPS()}, + { + From: hosts.Localhost2.HTTPPort(8080), + To: hosts.Stackoverflow.HTTPS(), + Mocks: config.Mocks{ + { + Matcher: config.RequestMatcher{ + Path: "/demo", + Method: "POST", + Queries: map[string]string{"foo": "bar"}, + Headers: map[string]string{acceptEncoding: "deflate"}, + }, + Response: config.Response{ + Code: 201, + Headers: map[string]string{acceptEncoding: "deflate"}, + Raw: "demo", + File: "/demo.txt", + }, + }, + }, + }, + }, + Proxy: "newproxy:9999", + Debug: false, + CacheConfig: config.CacheConfig{ + ExpirationTime: time.Hour, MaxSize: 52428800, + Methods: []string{http.MethodGet, http.MethodPost}, + }, + Interactive: true, + }, + }, + { + name: "CLI from/to updates existing mapping from config file", + args: []string{ + params.Config, minimalConfigPath, + params.From, hosts.Localhost.HTTPPort(8080), params.To, hosts.Stackoverflow.HTTPS(), + }, + expected: &config.UncorsConfig{ + Mappings: config.Mappings{ + {From: hosts.Localhost.HTTPPort(8080), To: hosts.Stackoverflow.HTTPS()}, + }, + CacheConfig: config.CacheConfig{ + ExpirationTime: config.DefaultExpirationTime, + MaxSize: config.DefaultMaxSize, + Methods: []string{http.MethodGet}, + }, + Interactive: true, + }, + }, } for _, testCase := range tests { t.Run(testCase.name, func(t *testing.T) { - viper.Reset() - - viperInstance := viper.New() - viperInstance.SetFs(fs) + actual, _, err := config.LoadConfiguration(fs, testCase.args) + require.NoError(t, err) - uncorsConfig := config.LoadConfiguration(viperInstance, testCase.args) - clearMappingsCache(uncorsConfig) + clearMappingsCache(actual) clearMappingsCache(testCase.expected) - assert.Equal(t, testCase.expected, uncorsConfig) + assert.Equal(t, testCase.expected, actual) }) } }) + t.Run("returns config file path", func(t *testing.T) { + t.Run("empty when no config file flag", func(t *testing.T) { + _, configPath, err := config.LoadConfiguration(afero.NewMemMapFs(), []string{}) + require.NoError(t, err) + assert.Empty(t, configPath) + }) + + t.Run("returns the given config path", func(t *testing.T) { + _, configPath, err := config.LoadConfiguration(fs, []string{params.Config, minimalConfigPath}) + require.NoError(t, err) + assert.Equal(t, minimalConfigPath, configPath) + }) + }) + t.Run("parse config with error", func(t *testing.T) { tests := []struct { - name string - args []string - expected []string + name string + args []string + expectedErr string }{ { - name: "incorrect flag provided", - args: []string{ - "--incorrect-flag", - }, - expected: []string{ - "failed parsing flags: unknown flag: --incorrect-flag", - }, + name: "incorrect flag provided", + args: []string{"--incorrect-flag"}, + expectedErr: "failed parsing flags: unknown flag: --incorrect-flag", }, { - name: "return default config", - args: []string{ - params.To, hosts.Github.Host(), - }, - expected: []string{ - "`from` values are not set for every `to`", - }, + name: "to without matching from", + args: []string{params.To, hosts.Github.Host()}, + expectedErr: "`from` values are not set for every `to`", }, { - name: "count of from values great then count of to", + name: "from count exceeds to count", args: []string{ params.From, hosts.Localhost1.Host(), params.To, hosts.Github.Host(), params.From, hosts.Localhost2.Host(), }, - expected: []string{ - "`to` values are not set for every `from`", - }, + expectedErr: "`to` values are not set for every `from`", }, { - name: "count of to values great then count of from", + name: "to count exceeds from count", args: []string{ params.From, hosts.Localhost1.Host(), params.To, hosts.Github.Host(), params.To, hosts.Stackoverflow.Host(), }, - expected: []string{ - "`from` values are not set for every `to`", - }, + expectedErr: "`from` values are not set for every `to`", }, { name: "config file doesn't exist", - args: []string{ - params.Config, "/not-exist-config.yaml", - }, - expected: []string{ - "failed to read config file '/not-exist-config.yaml': open /not-exist-config.yaml: file does not exist", - }, + args: []string{params.Config, "/not-exist-config.yaml"}, + expectedErr: "failed to read config file '/not-exist-config.yaml': " + + "open /not-exist-config.yaml: file does not exist", }, { name: "config file is corrupted", - args: []string{ - params.Config, corruptedConfigPath, - }, - expected: []string{ - "failed to read config file '/corrupted-config.yaml': " + - "While parsing config: yaml: line 2: mapping values are not allowed in this context", - }, + args: []string{params.Config, corruptedConfigPath}, + expectedErr: "failed to read config file '/corrupted-config.yaml': " + + "While parsing config: yaml: line 2: mapping values are not allowed in this context", }, { name: "incorrect type in config file", - args: []string{ - params.Config, incorrectConfigPath, - }, - expected: []string{ - "failed parsing config: decoding failed due to the following error(s):\n" + - "\n'mappings[0]' unsupported operation", - }, + args: []string{params.Config, incorrectConfigPath}, + expectedErr: "failed to read config file '/incorrect-config.yaml': " + + "While parsing config: mapping shorthand value must be a string URL", }, } + for _, testCase := range tests { t.Run(testCase.name, func(t *testing.T) { - for _, expected := range testCase.expected { - viper.Reset() - - viperInstance := viper.New() - viperInstance.SetFs(fs) - - assert.PanicsWithError(t, expected, func() { - config.LoadConfiguration(viperInstance, testCase.args) - }) - } + _, _, err := config.LoadConfiguration(fs, testCase.args) + assert.EqualError(t, err, testCase.expectedErr) }) } }) diff --git a/internal/config/default.go b/internal/config/default.go index 11753dde..9ae177b0 100644 --- a/internal/config/default.go +++ b/internal/config/default.go @@ -3,8 +3,6 @@ package config import ( "net/http" "time" - - "github.com/spf13/viper" ) const ( @@ -14,9 +12,15 @@ const ( DefaultMaxSize = 100 * 1024 * 1024 // 100 MB ) -func setDefaultValues(instance *viper.Viper) { - instance.SetDefault("cache-config.expiration-time", DefaultExpirationTime) - instance.SetDefault("cache-config.max-size", DefaultMaxSize) - instance.SetDefault("cache-config.methods", []string{http.MethodGet}) - instance.SetDefault("interactive", true) +// defaultConfig returns a new UncorsConfig with all default values applied. +func defaultConfig() *UncorsConfig { + return &UncorsConfig{ + Mappings: Mappings{}, + CacheConfig: CacheConfig{ + ExpirationTime: DefaultExpirationTime, + MaxSize: DefaultMaxSize, + Methods: []string{http.MethodGet}, + }, + Interactive: true, + } } diff --git a/internal/config/flags.go b/internal/config/flags.go new file mode 100644 index 00000000..6da8727a --- /dev/null +++ b/internal/config/flags.go @@ -0,0 +1,16 @@ +package config + +import "github.com/spf13/pflag" + +func defineFlags() *pflag.FlagSet { + flags := pflag.NewFlagSet("uncors", pflag.ContinueOnError) + flags.Usage = pflag.Usage + flags.StringSliceP("to", "t", []string{}, "Target host with protocol for the resource to be proxied") + flags.StringSliceP("from", "f", []string{}, "Local host with protocol for the resource from which proxying will take place") //nolint: lll + flags.String("proxy", "", "HTTP/HTTPS proxy for requests to the real server (uses system proxy by default)") + flags.Bool("debug", false, "Show debug output") + flags.StringP("config", "c", "", "Path to the configuration file") + flags.Bool("interactive", true, "") + + return flags +} diff --git a/internal/config/har.go b/internal/config/har.go index 6f89b225..1d3e1bd3 100644 --- a/internal/config/har.go +++ b/internal/config/har.go @@ -1,47 +1,39 @@ package config -import ( - "reflect" - - "github.com/mitchellh/mapstructure" -) +import "gopkg.in/yaml.v3" // HARConfig defines settings for the HAR (HTTP Archive) collector middleware. // When File is non-empty, all requests/responses passing through the proxy // for this mapping will be recorded to the specified HAR file. type HARConfig struct { - File string `mapstructure:"file"` - CaptureSecureHeaders bool `mapstructure:"capture-secure-headers"` + File string `yaml:"file"` + CaptureSecureHeaders bool `yaml:"capture-secure-headers"` } -func (h HARConfig) Enabled() bool { +func (h *HARConfig) Enabled() bool { return h.File != "" } -func (h HARConfig) Clone() HARConfig { +func (h *HARConfig) Clone() HARConfig { return HARConfig{ File: h.File, CaptureSecureHeaders: h.CaptureSecureHeaders, } } -var harConfigType = reflect.TypeFor[HARConfig]() - -// HARConfigHookFunc returns a mapstructure decode hook that allows HARConfig -// to be specified as a plain string in YAML/config files. +// UnmarshalYAML allows HARConfig to be specified as a plain string (file path) +// or as a full mapping. // // Short form: har: ./recordings/api.har // Full form: har: { file: ./recordings/api.har, capture-secure-headers: true }. -func HARConfigHookFunc() mapstructure.DecodeHookFunc { - return func(f reflect.Type, t reflect.Type, rawData any) (any, error) { - if t != harConfigType || f.Kind() != reflect.String { - return rawData, nil - } - - if file, ok := rawData.(string); ok { - return HARConfig{File: file}, nil - } +func (h *HARConfig) UnmarshalYAML(value *yaml.Node) error { + if value.Kind == yaml.ScalarNode { + h.File = value.Value - return rawData, nil + return nil } + + type harConfigAlias HARConfig + + return value.Decode((*harConfigAlias)(h)) } diff --git a/internal/config/har_test.go b/internal/config/har_test.go index aee8a88f..9ad961e9 100644 --- a/internal/config/har_test.go +++ b/internal/config/har_test.go @@ -4,39 +4,40 @@ import ( "testing" "github.com/evg4b/uncors/internal/config" - "github.com/evg4b/uncors/testing/testutils" - "github.com/mitchellh/mapstructure" - "github.com/spf13/viper" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" ) -func TestHARConfigHookFunc(t *testing.T) { - decode := func(t *testing.T, raw any) config.HARConfig { - t.Helper() - - var out config.HARConfig - - decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ - Result: &out, - DecodeHook: config.HARConfigHookFunc(), - }) - require.NoError(t, err) - require.NoError(t, decoder.Decode(raw)) +func TestHARConfigEnabled(t *testing.T) { + t.Run("returns true when File is set", func(t *testing.T) { + cfg := config.HARConfig{File: "./recordings/api.har"} + assert.True(t, cfg.Enabled()) + }) - return out - } + t.Run("returns false when File is empty", func(t *testing.T) { + cfg := config.HARConfig{} + assert.False(t, cfg.Enabled()) + }) +} +func TestHARConfigUnmarshalYAML(t *testing.T) { t.Run("string shorthand sets File", func(t *testing.T) { - cfg := decode(t, "./recordings/api.har") + var cfg config.HARConfig + + require.NoError(t, yaml.Unmarshal([]byte(`"./recordings/api.har"`), &cfg)) assert.Equal(t, config.HARConfig{File: "./recordings/api.har"}, cfg) }) t.Run("map form decoded normally", func(t *testing.T) { - cfg := decode(t, map[string]any{ - "file": "./out.har", - "capture-secure-headers": true, - }) + const input = ` +file: ./out.har +capture-secure-headers: true +` + + var cfg config.HARConfig + + require.NoError(t, yaml.Unmarshal([]byte(input), &cfg)) assert.Equal(t, config.HARConfig{ File: "./out.har", CaptureSecureHeaders: true, @@ -45,23 +46,15 @@ func TestHARConfigHookFunc(t *testing.T) { } func TestHARShorthandInMapping(t *testing.T) { - const configFile = "config.yaml" - - const yaml = ` + const input = ` from: http://localhost:3000 to: https://api.example.com har: ./recordings/api.har ` - viperCfg := viper.New() - viperCfg.SetFs(testutils.FsFromMap(t, map[string]string{configFile: yaml})) - viperCfg.SetConfigFile(configFile) - require.NoError(t, viperCfg.ReadInConfig()) + var actual config.Mapping - actual := config.Mapping{} - require.NoError(t, viperCfg.Unmarshal(&actual, viper.DecodeHook( - config.URLMappingHookFunc(), - ))) + require.NoError(t, yaml.Unmarshal([]byte(input), &actual)) assert.Equal(t, "./recordings/api.har", actual.HAR.File) assert.False(t, actual.HAR.CaptureSecureHeaders) diff --git a/internal/config/helpers.go b/internal/config/helpers.go index 808fd40d..33c459a1 100644 --- a/internal/config/helpers.go +++ b/internal/config/helpers.go @@ -7,9 +7,6 @@ import ( "net/url" "strconv" "strings" - - "github.com/mitchellh/mapstructure" - "github.com/spf13/viper" ) var ( @@ -17,9 +14,10 @@ var ( ErrNoFromPair = errors.New("`from` values are not set for every `to`") ) -func readURLMapping(config *viper.Viper, configuration *UncorsConfig) error { - from, to := config.GetStringSlice("from"), config.GetStringSlice("to") - +// mergeURLMappings merges from/to CLI pairs into cfg.Mappings. +// If a from URL already exists in the mappings, its to value is updated. +// Otherwise a new mapping entry is appended. +func mergeURLMappings(cfg *UncorsConfig, from, to []string) error { if len(from) > len(to) { return ErrNoToPair } @@ -31,9 +29,9 @@ func readURLMapping(config *viper.Viper, configuration *UncorsConfig) error { for index, key := range from { found := false - for i := range configuration.Mappings { - if strings.EqualFold(configuration.Mappings[i].From, key) { - configuration.Mappings[i].To = to[index] + for i := range cfg.Mappings { + if strings.EqualFold(cfg.Mappings[i].From, key) { + cfg.Mappings[i].To = to[index] found = true break @@ -41,7 +39,7 @@ func readURLMapping(config *viper.Viper, configuration *UncorsConfig) error { } if !found { - configuration.Mappings = append(configuration.Mappings, Mapping{ + cfg.Mappings = append(cfg.Mappings, Mapping{ From: key, To: to[index], }) @@ -51,33 +49,13 @@ func readURLMapping(config *viper.Viper, configuration *UncorsConfig) error { return nil } -func decodeConfig[T any](data any, mapping *T, decodeFuncs ...mapstructure.DecodeHookFunc) error { - hook := mapstructure.ComposeDecodeHookFunc( - StringToTimeDurationHookFunc(), - mapstructure.StringToSliceHookFunc(","), - mapstructure.ComposeDecodeHookFunc(decodeFuncs...), - ) - - decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ - Result: mapping, - DecodeHook: hook, - ErrorUnused: true, - IgnoreUntaggedFields: true, - }) - if err != nil { - return err - } - - err = decoder.Decode(data) - - return err -} - const ( httpScheme = "http" httpsScheme = "https" ) +// NormaliseMappings normalises the From URL in each mapping: adds the default +// scheme (http) if absent and removes the port when it equals the scheme default. func NormaliseMappings(mappings Mappings) Mappings { processedMappings := make(Mappings, 0, len(mappings)) @@ -92,7 +70,6 @@ func NormaliseMappings(mappings Mappings) Mappings { panic(fmt.Errorf("failed to parse source url: %w", err)) } - // Normalize the mapping with port from URL normalizedMapping := mapping.Clone() normalizedMapping.From = normalizeURL(*sourceURL, host, portStr) processedMappings = append(processedMappings, normalizedMapping) @@ -117,7 +94,6 @@ func normalizeURL(parsedURL url.URL, host, portStr string) string { panic(fmt.Errorf("invalid port number: %w", err)) } } else { - // Use default port based on scheme if scheme == httpsScheme { port = defaultHTTPSPort } else { @@ -127,7 +103,6 @@ func normalizeURL(parsedURL url.URL, host, portStr string) string { parsedURL.Scheme = scheme - // Only include port in host if it's not the default port for the scheme if !isDefaultPort(scheme, port) { parsedURL.Host = net.JoinHostPort(host, strconv.Itoa(port)) } else { diff --git a/internal/config/mapping.go b/internal/config/mapping.go index d1900e0f..2085e644 100644 --- a/internal/config/mapping.go +++ b/internal/config/mapping.go @@ -3,28 +3,66 @@ package config import ( "errors" "net/url" - "reflect" "github.com/evg4b/uncors/internal/urlparser" - "github.com/mitchellh/mapstructure" - "github.com/samber/lo" + "gopkg.in/yaml.v3" ) +// ErrMappingShorthandValue is returned when a URL shorthand mapping has a +// non-string value (e.g. "http://localhost: 123" instead of a URL string). +var ErrMappingShorthandValue = errors.New("mapping shorthand value must be a string URL") + type Mapping struct { - From string `mapstructure:"from"` - To string `mapstructure:"to"` - Statics StaticDirectories `mapstructure:"statics"` - Mocks Mocks `mapstructure:"mocks"` - Scripts Scripts `mapstructure:"scripts"` - Cache CacheGlobs `mapstructure:"cache"` - Rewrites RewriteOptions `mapstructure:"rewrites"` - OptionsHandling OptionsHandling `mapstructure:"options-handling"` - HAR HARConfig `mapstructure:"har"` + From string `yaml:"from"` + To string `yaml:"to"` + Statics StaticDirectories `yaml:"statics"` + Mocks Mocks `yaml:"mocks"` + Scripts Scripts `yaml:"scripts"` + Cache CacheGlobs `yaml:"cache"` + Rewrites RewriteOptions `yaml:"rewrites"` + OptionsHandling OptionsHandling `yaml:"options-handling"` + HAR HARConfig `yaml:"har"` // Cached parsed URL and its components (not serialized) - fromURL *url.URL `json:"-" mapstructure:"-" yaml:"-"` - fromHost string `json:"-" mapstructure:"-" yaml:"-"` - fromPort string `json:"-" mapstructure:"-" yaml:"-"` + fromURL *url.URL `yaml:"-"` + fromHost string `yaml:"-"` + fromPort string `yaml:"-"` +} + +// knownMappingFields is the set of yaml keys that belong to a full Mapping +// object. Any single-key YAML map whose key is NOT in this set is interpreted +// as the shorthand "from: to" form. +var knownMappingFields = map[string]bool{ + "from": true, "to": true, "statics": true, "mocks": true, + "scripts": true, "cache": true, "rewrites": true, + "options-handling": true, "har": true, +} + +// UnmarshalYAML decodes a Mapping from YAML. It recognises two forms: +// +// Shorthand — a single-key mapping whose key is not a known field name: +// +// http://localhost:8080: https://example.com +// +// Full form — a standard YAML mapping with "from", "to", and optional fields. +func (m *Mapping) UnmarshalYAML(value *yaml.Node) error { + if value.Kind == yaml.MappingNode && len(value.Content) == 2 { + key := value.Content[0].Value + if !knownMappingFields[key] { + if value.Content[1].Tag != "!!str" { + return ErrMappingShorthandValue + } + + m.From = key + m.To = value.Content[1].Value + + return nil + } + } + + type mappingAlias Mapping + + return value.Decode((*mappingAlias)(m)) } func (m *Mapping) Clone() Mapping { @@ -38,14 +76,13 @@ func (m *Mapping) Clone() Mapping { Rewrites: m.Rewrites.Clone(), OptionsHandling: m.OptionsHandling.Clone(), HAR: m.HAR.Clone(), - fromURL: m.fromURL, // Share cached URL + fromURL: m.fromURL, fromHost: m.fromHost, fromPort: m.fromPort, } } // GetFromURL returns the parsed URL, caching it on first access. -// This method performs lazy parsing to avoid redundant URL parsing operations. func (m *Mapping) GetFromURL() (*url.URL, error) { if m.fromURL == nil { parsedURL, err := urlparser.Parse(m.From) @@ -60,7 +97,6 @@ func (m *Mapping) GetFromURL() (*url.URL, error) { } // GetFromHostPort returns the host and port from the From URL, caching them on first access. -// This method combines URL parsing and host/port splitting to avoid redundant operations. func (m *Mapping) GetFromHostPort() (string, string, error) { if m.fromHost == "" && m.fromPort == "" { uri, err := m.GetFromURL() @@ -83,52 +119,3 @@ func (m *Mapping) ClearCache() { m.fromHost = "" m.fromPort = "" } - -var ( - mappingType = reflect.TypeFor[Mapping]() - mappingFields = getTagValues(mappingType, "mapstructure") -) - -func URLMappingHookFunc() mapstructure.DecodeHookFunc { - return func(f reflect.Type, t reflect.Type, rawData any) (any, error) { - if t != mappingType || f.Kind() != reflect.Map { - return rawData, nil - } - - if data, ok := rawData.(map[string]any); ok { - availableFields, _ := lo.Difference(lo.Keys(data), mappingFields) - - if len(data) == 1 && len(availableFields) == 1 { - from := lo.FirstOrEmpty(availableFields) - if to, ok := data[from].(string); ok { - return Mapping{ - From: from, - To: to, - }, nil - } - - return nil, errors.ErrUnsupported - } - - mapping := Mapping{} - err := decodeConfig( - data, - &mapping, - StaticDirMappingHookFunc(), - HARConfigHookFunc(), - ) - - return mapping, err - } - - return rawData, nil - } -} - -func getTagValues(typeValue reflect.Type, tag string) []string { - fields := reflect.VisibleFields(typeValue) - - return lo.FilterMap(fields, func(field reflect.StructField, _ int) (string, bool) { - return field.Tag.Lookup(tag) - }) -} diff --git a/internal/config/mapping_test.go b/internal/config/mapping_test.go index aa831596..55da6094 100644 --- a/internal/config/mapping_test.go +++ b/internal/config/mapping_test.go @@ -5,60 +5,68 @@ import ( "github.com/evg4b/uncors/internal/config" "github.com/evg4b/uncors/testing/hosts" - "github.com/evg4b/uncors/testing/testutils" - "github.com/spf13/viper" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" ) var localhostSecure = "https://localhost:9090" -func TestURLMappingHookFunc(t *testing.T) { - const configFile = "config.yaml" - +func TestMappingUnmarshalYAML(t *testing.T) { t.Run("positive cases", func(t *testing.T) { tests := []struct { name string - config string + input string expected config.Mapping }{ { - name: "simple key-value mapping", - config: "http://localhost:4200: https://github.com", + name: "simple key-value shorthand", + input: "http://localhost:4200: https://github.com", expected: config.Mapping{ From: hosts.Localhost.HTTPPort(4200), To: hosts.Github.HTTPS(), }, }, { - name: "full object mapping", - config: "{ from: http://localhost:3000, to: https://api.github.com }", + name: "full object mapping", + input: "{ from: http://localhost:3000, to: https://api.github.com }", expected: config.Mapping{ From: hosts.Localhost.HTTPPort(3000), To: hosts.APIGithub.HTTPS(), }, }, + { + name: "mapping with HAR shorthand", + input: ` +from: http://localhost:3000 +to: https://api.example.com +har: ./recordings/api.har +`, + expected: config.Mapping{ + From: hosts.Localhost.HTTPPort(3000), + To: "https://api.example.com", + HAR: config.HARConfig{File: "./recordings/api.har"}, + }, + }, } + for _, testCase := range tests { t.Run(testCase.name, func(t *testing.T) { - viperInstance := viper.GetViper() - viperInstance.SetFs(testutils.FsFromMap(t, map[string]string{ - configFile: testCase.config, - })) - viperInstance.SetConfigFile(configFile) - err := viperInstance.ReadInConfig() - testutils.CheckNoError(t, err) - - actual := config.Mapping{} - - err = viperInstance.Unmarshal(&actual, viper.DecodeHook( - config.URLMappingHookFunc(), - )) - testutils.CheckNoError(t, err) - + var actual config.Mapping + require.NoError(t, yaml.Unmarshal([]byte(testCase.input), &actual)) assert.Equal(t, testCase.expected, actual) }) } }) + + t.Run("error cases", func(t *testing.T) { + t.Run("shorthand with non-string value", func(t *testing.T) { + var actual config.Mapping + + err := yaml.Unmarshal([]byte("http://localhost: 123"), &actual) + assert.Error(t, err) + }) + }) } func TestURLMappingClone(t *testing.T) { @@ -94,6 +102,7 @@ func TestURLMappingClone(t *testing.T) { }, }, } + for _, testCase := range tests { t.Run(testCase.name, func(t *testing.T) { actual := testCase.expected.Clone() diff --git a/internal/config/mock.go b/internal/config/mock.go index 629813d5..4689c485 100644 --- a/internal/config/mock.go +++ b/internal/config/mock.go @@ -7,8 +7,8 @@ import ( ) type Mock struct { - Matcher RequestMatcher `mapstructure:",squash"` - Response Response `mapstructure:"response"` + Matcher RequestMatcher `yaml:",inline"` + Response Response `yaml:"response"` } func (m *Mock) Clone() Mock { diff --git a/internal/config/options_handling.go b/internal/config/options_handling.go index ee442bed..b2a3ae71 100644 --- a/internal/config/options_handling.go +++ b/internal/config/options_handling.go @@ -3,9 +3,9 @@ package config import "github.com/evg4b/uncors/internal/helpers" type OptionsHandling struct { - Disabled bool `mapstructure:"disabled"` - Headers map[string]string `mapstructure:"headers"` - Code int `mapstructure:"code"` + Disabled bool `yaml:"disabled"` + Headers map[string]string `yaml:"headers"` + Code int `yaml:"code"` } func (o *OptionsHandling) Clone() OptionsHandling { diff --git a/internal/config/request_matcher.go b/internal/config/request_matcher.go index f2c50380..cd349362 100644 --- a/internal/config/request_matcher.go +++ b/internal/config/request_matcher.go @@ -3,10 +3,10 @@ package config import "github.com/evg4b/uncors/internal/helpers" type RequestMatcher struct { - Path string `mapstructure:"path"` - Method string `mapstructure:"method"` - Queries map[string]string `mapstructure:"queries"` - Headers map[string]string `mapstructure:"headers"` + Path string `yaml:"path"` + Method string `yaml:"method"` + Queries map[string]string `yaml:"queries"` + Headers map[string]string `yaml:"headers"` } func (r *RequestMatcher) Clone() RequestMatcher { diff --git a/internal/config/response.go b/internal/config/response.go index 2836f81b..08eb84cb 100644 --- a/internal/config/response.go +++ b/internal/config/response.go @@ -1,17 +1,20 @@ package config import ( + "fmt" + "strings" "time" "github.com/evg4b/uncors/internal/helpers" + "gopkg.in/yaml.v3" ) type Response struct { - Code int `mapstructure:"code"` - Headers map[string]string `mapstructure:"headers"` - Delay time.Duration `mapstructure:"delay"` - Raw string `mapstructure:"raw"` - File string `mapstructure:"file"` + Code int `yaml:"code"` + Headers map[string]string `yaml:"headers"` + Delay time.Duration `yaml:"-"` + Raw string `yaml:"raw"` + File string `yaml:"file"` } func (r *Response) Clone() Response { @@ -31,3 +34,41 @@ func (r *Response) IsRaw() bool { func (r *Response) IsFile() bool { return len(r.File) > 0 } + +// UnmarshalYAML implements custom decoding so that the "delay" field can be +// expressed as a human-readable duration string (e.g. "200ms", "1s 500ms"). +// All other fields are decoded by the standard yaml.v3 machinery. +func (r *Response) UnmarshalYAML(value *yaml.Node) error { + type responseRaw struct { + Code int `yaml:"code"` + Headers map[string]string `yaml:"headers"` + Delay string `yaml:"delay"` + Raw string `yaml:"raw"` + File string `yaml:"file"` + } + + var raw responseRaw + + err := value.Decode(&raw) + if err != nil { + return err + } + + r.Code = raw.Code + r.Headers = raw.Headers + r.Raw = raw.Raw + r.File = raw.File + + if raw.Delay == "" { + return nil + } + + dur, err := time.ParseDuration(strings.ReplaceAll(raw.Delay, " ", "")) + if err != nil { + return fmt.Errorf("invalid delay %q: %w", raw.Delay, err) + } + + r.Delay = dur + + return nil +} diff --git a/internal/config/response_test.go b/internal/config/response_test.go index b877ca4b..6d895ee1 100644 --- a/internal/config/response_test.go +++ b/internal/config/response_test.go @@ -8,8 +8,64 @@ import ( "github.com/evg4b/uncors/internal/config" "github.com/go-http-utils/headers" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" ) +func TestResponseUnmarshalYAML(t *testing.T) { + t.Run("decodes all fields", func(t *testing.T) { + const input = ` +code: 200 +headers: + Content-Type: application/json + X-Custom: value +delay: 200ms +raw: '{"ok":true}' +file: ./body.json +` + + var actual config.Response + + require.NoError(t, yaml.Unmarshal([]byte(input), &actual)) + assert.Equal(t, config.Response{ + Code: 200, + Headers: map[string]string{ + "Content-Type": "application/json", + "X-Custom": "value", + }, + Delay: 200 * time.Millisecond, + Raw: `{"ok":true}`, + File: "./body.json", + }, actual) + }) + + t.Run("zero delay when field is absent", func(t *testing.T) { + const input = `code: 204` + + var actual config.Response + + require.NoError(t, yaml.Unmarshal([]byte(input), &actual)) + assert.Zero(t, actual.Delay) + }) + + t.Run("parses delay with embedded spaces", func(t *testing.T) { + const input = `delay: "1s 500ms"` + + var actual config.Response + + require.NoError(t, yaml.Unmarshal([]byte(input), &actual)) + assert.Equal(t, 1500*time.Millisecond, actual.Delay) + }) + + t.Run("returns error for invalid delay", func(t *testing.T) { + const input = `delay: not-a-duration` + + var actual config.Response + + assert.Error(t, yaml.Unmarshal([]byte(input), &actual)) + }) +} + func TestResponseClone(t *testing.T) { response := config.Response{ Code: http.StatusOK, diff --git a/internal/config/rewrite.go b/internal/config/rewrite.go index 6e0305f3..c86602f3 100644 --- a/internal/config/rewrite.go +++ b/internal/config/rewrite.go @@ -1,9 +1,9 @@ package config type RewritingOption struct { - From string `mapstructure:"from"` - To string `mapstructure:"to"` - Host string `mapstructure:"host"` + From string `yaml:"from"` + To string `yaml:"to"` + Host string `yaml:"host"` } func (r RewritingOption) Clone() RewritingOption { diff --git a/internal/config/script.go b/internal/config/script.go index 01654da3..1f8ba27d 100644 --- a/internal/config/script.go +++ b/internal/config/script.go @@ -7,9 +7,9 @@ import ( ) type Script struct { - Matcher RequestMatcher `mapstructure:",squash"` - Script string `mapstructure:"script"` - File string `mapstructure:"file"` + Matcher RequestMatcher `yaml:",inline"` + Script string `yaml:"script"` + File string `yaml:"file"` } func (s *Script) Clone() Script { diff --git a/internal/config/static.go b/internal/config/static.go index 1d707558..92b54b25 100644 --- a/internal/config/static.go +++ b/internal/config/static.go @@ -2,16 +2,15 @@ package config import ( "fmt" - "reflect" - "github.com/mitchellh/mapstructure" "github.com/samber/lo" + "gopkg.in/yaml.v3" ) type StaticDirectory struct { - Path string `mapstructure:"path"` - Dir string `mapstructure:"dir"` - Index string `mapstructure:"index"` + Path string `yaml:"path"` + Dir string `yaml:"dir"` + Index string `yaml:"index"` } func (s *StaticDirectory) Clone() StaticDirectory { @@ -28,52 +27,57 @@ func (s *StaticDirectory) String() string { type StaticDirectories []StaticDirectory -func (s StaticDirectories) Clone() StaticDirectories { - if s == nil { +func (s *StaticDirectories) Clone() StaticDirectories { + if s == nil || *s == nil { return nil } - return lo.Map(s, func(item StaticDirectory, _ int) StaticDirectory { + return lo.Map(*s, func(item StaticDirectory, _ int) StaticDirectory { return item.Clone() }) } -var staticDirMappingsType = reflect.TypeFor[StaticDirectories]() - -func StaticDirMappingHookFunc() mapstructure.DecodeHookFunc { - return func(f reflect.Type, t reflect.Type, rawData any) (any, error) { - if t != staticDirMappingsType || f.Kind() != reflect.Map { - return rawData, nil - } - - mappingsDefs, ok := rawData.(map[string]any) - if !ok { - return rawData, nil - } - - var mappings StaticDirectories - - for path, mappingDef := range mappingsDefs { - if def, ok := mappingDef.(string); ok { - mappings = append(mappings, StaticDirectory{ - Path: path, - Dir: def, - }) - - continue +// UnmarshalYAML allows StaticDirectories to be specified as a YAML mapping +// (shorthand: path → dir or path → {dir, index}) as well as a sequence of +// full StaticDirectory objects. +// +// Map form: +// +// statics: +// /path: /static-dir +// /other: { dir: /other-dir, index: index.html } +// +// Sequence form: +// +// statics: +// - path: /path +// dir: /static-dir +func (s *StaticDirectories) UnmarshalYAML(value *yaml.Node) error { + if value.Kind == yaml.MappingNode { + for i := 0; i+1 < len(value.Content); i += 2 { + path := value.Content[i].Value + valNode := value.Content[i+1] + + var staticDir StaticDirectory + + if valNode.Kind == yaml.ScalarNode { + staticDir = StaticDirectory{Path: path, Dir: valNode.Value} + } else { + err := valNode.Decode(&staticDir) + if err != nil { + return err + } + + staticDir.Path = path // map key always wins over any inline path field } - mapping := StaticDirectory{} - - err := decodeConfig(mappingDef, &mapping) - if err != nil { - return nil, err - } - - mapping.Path = path - mappings = append(mappings, mapping) + *s = append(*s, staticDir) } - return mappings, nil + return nil } + + type staticDirectoriesAlias StaticDirectories + + return value.Decode((*staticDirectoriesAlias)(s)) } diff --git a/internal/config/static_test.go b/internal/config/static_test.go index 0b9bf41c..8718c419 100644 --- a/internal/config/static_test.go +++ b/internal/config/static_test.go @@ -4,9 +4,9 @@ import ( "testing" "github.com/evg4b/uncors/internal/config" - "github.com/evg4b/uncors/testing/testutils" - "github.com/spf13/viper" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" ) const ( @@ -20,87 +20,106 @@ const ( indexHTML = "index.html" ) -func TestStaticDirMappingHookFunc(t *testing.T) { - const configFile = "config.yaml" - +func TestStaticDirectoriesUnmarshalYAML(t *testing.T) { type testType struct { - Statics config.StaticDirectories `mapstructure:"statics"` + Statics config.StaticDirectories `yaml:"statics"` } - tests := []struct { - name string - config string - expected config.StaticDirectories - }{ - { - name: "decode plan mapping", - config: ` + t.Run("map form", func(t *testing.T) { + tests := []struct { + name string + input string + expected config.StaticDirectories + }{ + { + name: "plain map shorthand", + input: ` statics: /path: /static-dir /another-path: /another-static-dir `, - expected: config.StaticDirectories{ - {Path: anotherPath, Dir: anotherStaticDir}, - {Path: path, Dir: staticDir}, + expected: config.StaticDirectories{ + {Path: path, Dir: staticDir}, + {Path: anotherPath, Dir: anotherStaticDir}, + }, }, - }, - { - name: "decode object mappings", - config: ` + { + name: "object map without index", + input: ` statics: /path: { dir: /static-dir } /another-path: { dir: /another-static-dir } `, - expected: config.StaticDirectories{ - {Path: path, Dir: staticDir}, - {Path: anotherPath, Dir: anotherStaticDir}, + expected: config.StaticDirectories{ + {Path: path, Dir: staticDir}, + {Path: anotherPath, Dir: anotherStaticDir}, + }, }, - }, - { - name: "decode object mappings with index", - config: ` + { + name: "object map with index", + input: ` statics: /path: { dir: /static-dir, index: index.html } /another-path: { dir: /another-static-dir, index: default.html } `, - expected: config.StaticDirectories{ - {Path: path, Dir: staticDir, Index: indexHTML}, - {Path: anotherPath, Dir: anotherStaticDir, Index: "default.html"}, + expected: config.StaticDirectories{ + {Path: path, Dir: staticDir, Index: indexHTML}, + {Path: anotherPath, Dir: anotherStaticDir, Index: "default.html"}, + }, }, - }, - { - name: "decode mixed mappings with index", - config: ` + { + name: "mixed map", + input: ` statics: /path: { dir: /static-dir, index: index.html } /another-path: /another-static-dir `, - expected: config.StaticDirectories{ - {Path: path, Dir: staticDir, Index: indexHTML}, - {Path: anotherPath, Dir: anotherStaticDir}, + expected: config.StaticDirectories{ + {Path: path, Dir: staticDir, Index: indexHTML}, + {Path: anotherPath, Dir: anotherStaticDir}, + }, }, - }, - } - for _, testCase := range tests { - t.Run(testCase.name, func(t *testing.T) { - viperInstance := viper.GetViper() - viperInstance.SetFs(testutils.FsFromMap(t, map[string]string{ - configFile: testCase.config, - })) - viperInstance.SetConfigFile(configFile) - err := viperInstance.ReadInConfig() - testutils.CheckNoError(t, err) - - actual := testType{} - - err = viperInstance.Unmarshal(&actual, viper.DecodeHook( - config.StaticDirMappingHookFunc(), - )) - testutils.CheckNoError(t, err) - - assert.ElementsMatch(t, actual.Statics, testCase.expected) - }) - } + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + var actual testType + + require.NoError(t, yaml.Unmarshal([]byte(testCase.input), &actual)) + assert.ElementsMatch(t, testCase.expected, actual.Statics) + }) + } + }) + + t.Run("object map with invalid field type returns error", func(t *testing.T) { + const input = ` +statics: + /path: [a, b, c] +` + + var actual testType + + assert.Error(t, yaml.Unmarshal([]byte(input), &actual)) + }) + + t.Run("sequence form", func(t *testing.T) { + const input = ` +statics: + - path: /path + dir: /static-dir + - path: /another-path + dir: /another-static-dir + index: index.html +` + + var actual testType + + require.NoError(t, yaml.Unmarshal([]byte(input), &actual)) + assert.Equal(t, config.StaticDirectories{ + {Path: path, Dir: staticDir}, + {Path: anotherPath, Dir: anotherStaticDir, Index: indexHTML}, + }, actual.Statics) + }) } func TestStaticDirMappingClone(t *testing.T) { @@ -134,6 +153,7 @@ func TestStaticDirMappingClone(t *testing.T) { }, }, } + for _, testCase := range tests { t.Run(testCase.name, func(t *testing.T) { actual := testCase.expected.Clone() diff --git a/internal/config/time_decode_hook.go b/internal/config/time_decode_hook.go deleted file mode 100644 index 4cdb1da7..00000000 --- a/internal/config/time_decode_hook.go +++ /dev/null @@ -1,24 +0,0 @@ -package config - -import ( - "errors" - "reflect" - "strings" - "time" - - "github.com/mitchellh/mapstructure" -) - -func StringToTimeDurationHookFunc() mapstructure.DecodeHookFunc { - return func(f reflect.Type, t reflect.Type, data any) (any, error) { - if f.Kind() != reflect.String || t != reflect.TypeFor[time.Duration]() { - return data, nil - } - - if value, ok := data.(string); ok { - return time.ParseDuration(strings.ReplaceAll(value, " ", "")) - } - - return nil, errors.ErrUnsupported - } -} diff --git a/internal/config/time_decode_hook_test.go b/internal/config/time_decode_hook_test.go index 21d75f21..e15233e3 100644 --- a/internal/config/time_decode_hook_test.go +++ b/internal/config/time_decode_hook_test.go @@ -5,101 +5,133 @@ import ( "time" "github.com/evg4b/uncors/internal/config" - "github.com/evg4b/uncors/testing/testutils" - "github.com/mitchellh/mapstructure" - "github.com/spf13/viper" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" ) -func TestStringToTimeDurationHookFunc(t *testing.T) { - const key = "duration" - - viperInstance := viper.New() - configOption := viper.DecodeHook(mapstructure.ComposeDecodeHookFunc( - config.StringToTimeDurationHookFunc(), - mapstructure.OrComposeDecodeHookFunc( - mapstructure.StringToSliceHookFunc(","), - mapstructure.StringToSliceHookFunc(", "), - ), - )) - - t.Run("correct parse different formats", func(t *testing.T) { +func TestCacheConfigDurationUnmarshal(t *testing.T) { + t.Run("parses valid duration strings", func(t *testing.T) { tests := []struct { name string - value string + input string expected time.Duration }{ - { - name: "duration with spaces", - value: "1m 4s", - expected: 1*time.Minute + 4*time.Second, - }, { name: "duration without spaces", - value: "3h6m13s", + input: "expiration-time: 3h6m13s", expected: 3*time.Hour + 6*time.Minute + 13*time.Second, }, + { + name: "duration with spaces", + input: "expiration-time: \"1m 4s\"", + expected: 1*time.Minute + 4*time.Second, + }, { name: "duration with mixed spaces", - value: "1h 3m59s 40ms", + input: "expiration-time: \"1h 3m59s 40ms\"", expected: 1*time.Hour + 3*time.Minute + 59*time.Second + 40*time.Millisecond, }, } for _, testCase := range tests { t.Run(testCase.name, func(t *testing.T) { - viperInstance.Set(key, testCase.value) - - durationValue := time.Duration(0) - err := viperInstance.UnmarshalKey(key, &durationValue, configOption) - testutils.CheckNoError(t, err) - - assert.Equal(t, testCase.expected, durationValue) + cfg := config.CacheConfig{ExpirationTime: config.DefaultExpirationTime} + require.NoError(t, yaml.Unmarshal([]byte(testCase.input), &cfg)) + assert.Equal(t, testCase.expected, cfg.ExpirationTime) }) } }) - t.Run("doesnt not affected other type parses", func(t *testing.T) { - t.Run("string to string", func(t *testing.T) { - viperInstance.Set(key, "value") + t.Run("preserves defaults for absent fields", func(t *testing.T) { + cfg := config.CacheConfig{ + ExpirationTime: config.DefaultExpirationTime, + MaxSize: config.DefaultMaxSize, + Methods: []string{"GET"}, + } + + require.NoError(t, yaml.Unmarshal([]byte("max-size: 1048576"), &cfg)) + assert.Equal(t, config.DefaultExpirationTime, cfg.ExpirationTime) + assert.Equal(t, int64(1048576), cfg.MaxSize) + }) + + t.Run("returns error for non-mapping input", func(t *testing.T) { + var cfg config.CacheConfig + + err := yaml.Unmarshal([]byte("just-a-string"), &cfg) + assert.ErrorIs(t, err, config.ErrInvalidCacheConfig) + }) + + t.Run("returns error for invalid duration string", func(t *testing.T) { + var cfg config.CacheConfig + + err := yaml.Unmarshal([]byte("expiration-time: notaduration"), &cfg) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid expiration-time") + }) - stringValue := "" - err := viperInstance.UnmarshalKey(key, &stringValue, configOption) - testutils.CheckNoError(t, err) + t.Run("returns error when max-size is not a number", func(t *testing.T) { + var cfg config.CacheConfig - assert.Equal(t, "value", stringValue) - }) + err := yaml.Unmarshal([]byte("max-size: [a, b, c]"), &cfg) + require.Error(t, err) + }) - t.Run("string to []string", func(t *testing.T) { - viperInstance.Set(key, "value,value2") + t.Run("returns error when methods is not a sequence", func(t *testing.T) { + var cfg config.CacheConfig - var stringValue []string + err := yaml.Unmarshal([]byte("methods: {key: value}"), &cfg) + require.Error(t, err) + }) +} - err := viperInstance.UnmarshalKey(key, &stringValue, configOption) - testutils.CheckNoError(t, err) +func TestResponseDelayUnmarshal(t *testing.T) { + t.Run("parses valid delay strings", func(t *testing.T) { + tests := []struct { + name string + input string + expected time.Duration + }{ + { + name: "millisecond delay", + input: "delay: 200ms", + expected: 200 * time.Millisecond, + }, + { + name: "delay with spaces", + input: "delay: \"1s 500ms\"", + expected: 1*time.Second + 500*time.Millisecond, + }, + } - assert.Equal(t, []string{"value", "value2"}, stringValue) - }) + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + var resp config.Response + require.NoError(t, yaml.Unmarshal([]byte(testCase.input), &resp)) + assert.Equal(t, testCase.expected, resp.Delay) + }) + } + }) - t.Run("number to string", func(t *testing.T) { - viperInstance.Set(key, 11) + t.Run("returns error for invalid delay string", func(t *testing.T) { + var resp config.Response - stringValue := "" - err := viperInstance.UnmarshalKey(key, &stringValue, configOption) - testutils.CheckNoError(t, err) + err := yaml.Unmarshal([]byte("delay: notaduration"), &resp) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid delay") + }) - assert.Equal(t, "11", stringValue) - }) + t.Run("zero delay when field absent", func(t *testing.T) { + var resp config.Response - t.Run("number to duration", func(t *testing.T) { - const expected = 14 * time.Minute - viperInstance.Set(key, int(expected)) + require.NoError(t, yaml.Unmarshal([]byte("code: 200"), &resp)) + assert.Zero(t, resp.Delay) + }) - durationValue := time.Nanosecond - err := viperInstance.UnmarshalKey(key, &durationValue, configOption) - testutils.CheckNoError(t, err) + t.Run("returns error when response is not a mapping", func(t *testing.T) { + var resp config.Response - assert.Equal(t, expected, durationValue) - }) + err := yaml.Unmarshal([]byte("[200, 404]"), &resp) + require.Error(t, err) }) } diff --git a/internal/config/watcher.go b/internal/config/watcher.go new file mode 100644 index 00000000..9b9aacc5 --- /dev/null +++ b/internal/config/watcher.go @@ -0,0 +1,93 @@ +package config + +import ( + "fmt" + "log" + "time" + + "github.com/fsnotify/fsnotify" +) + +// debounceDelay is the wait time after the last file event before calling onChange. +// This prevents multiple rapid callbacks when editors write files in stages. +const debounceDelay = 10 * time.Millisecond + +// Watcher monitors a configuration file for changes and invokes a callback +// whenever the file is written or recreated. It uses a short debounce window to +// coalesce bursts of filesystem events that editors typically produce on save. +type Watcher struct { + fsWatcher *fsnotify.Watcher + onChange func() + done chan struct{} +} + +// NewWatcher creates a Watcher that monitors the given file path. +// onChange is called (after debouncing) on every write or create event. +// The returned watcher is already running; call Close to stop it. +func NewWatcher(filePath string, onChange func()) (*Watcher, error) { + fsWatcher, err := fsnotify.NewWatcher() + if err != nil { + return nil, fmt.Errorf("failed to create file watcher: %w", err) + } + + err = fsWatcher.Add(filePath) + if err != nil { + _ = fsWatcher.Close() + + return nil, fmt.Errorf("failed to watch config file '%s': %w", filePath, err) + } + + watcher := &Watcher{ + fsWatcher: fsWatcher, + onChange: onChange, + done: make(chan struct{}), + } + + go watcher.run() + + return watcher, nil +} + +// Close stops the watcher and releases all associated resources. +func (cw *Watcher) Close() error { + close(cw.done) + + return cw.fsWatcher.Close() +} + +func (cw *Watcher) run() { + var debounce *time.Timer + + stopDebounce := func() { + if debounce != nil { + debounce.Stop() + } + } + + for { + select { + case <-cw.done: + stopDebounce() + + return + + case event, ok := <-cw.fsWatcher.Events: + if !ok { + return + } + + if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) { + stopDebounce() + + debounce = time.AfterFunc(debounceDelay, cw.onChange) + } + + case err, ok := <-cw.fsWatcher.Errors: + if !ok { + return + } + + log.Printf("config watcher error: %v", err) + } + } +} diff --git a/internal/config/watcher_internal_test.go b/internal/config/watcher_internal_test.go new file mode 100644 index 00000000..c8c49377 --- /dev/null +++ b/internal/config/watcher_internal_test.go @@ -0,0 +1,125 @@ +package config + +import ( + "fmt" + "testing" + "time" + + "github.com/fsnotify/fsnotify" + "github.com/stretchr/testify/require" +) + +var errSyntheticWatcher = fmt.Errorf("synthetic test error") + +// newIsolatedWatcher creates a Watcher whose run goroutine uses custom channels +// that are owned entirely by the test. Both fsnotify.Watcher.Events and +// fsnotify.Watcher.Errors are replaced with channels we control so that the +// fsnotify kqueue backend goroutine (which holds its own copies of the original +// channels) never writes to our replacements. This avoids any data race between +// the test and the kqueue backend. +// +// The underlying fsnotify watcher must be closed by the caller after run exits +// to release the backend goroutine cleanly. +func newIsolatedWatcher(t *testing.T) (*Watcher, chan fsnotify.Event, chan error) { + t.Helper() + + fsW, err := fsnotify.NewWatcher() + require.NoError(t, err) + + // Replace the channels the Watcher struct exposes. The kqueue backend holds + // its own references to the original channels and will never touch these. + eventsCh := make(chan fsnotify.Event) + errsCh := make(chan error) + fsW.Events = eventsCh + fsW.Errors = errsCh + + watcher := &Watcher{ + fsWatcher: fsW, + onChange: func() {}, + done: make(chan struct{}), + } + + return watcher, eventsCh, errsCh +} + +// runAndWait starts watcher.run in a goroutine and returns a channel that is +// closed when run returns. +func runAndWait(watcher *Watcher) <-chan struct{} { + exited := make(chan struct{}) + + go func() { + defer close(exited) + + watcher.run() + }() + + return exited +} + +// TestWatcherRunEventsNotOk covers the early return in run() when the Events +// channel is closed with ok=false (lines 75-77 in watcher.go). +func TestWatcherRunEventsNotOk(t *testing.T) { + watcher, events, _ := newIsolatedWatcher(t) + + exited := runAndWait(watcher) + + // Closing the channel makes the Events select case fire with ok=false, + // which triggers the return on lines 75-77. + close(events) + + select { + case <-exited: + case <-time.After(200 * time.Millisecond): + t.Fatal("run goroutine did not exit after Events channel was closed") + } + + // Close the underlying fsnotify watcher so its backend goroutine can exit. + // It closes its own (original) Events and Errors channels, not ours. + _ = watcher.fsWatcher.Close() +} + +// TestWatcherRunErrorsNotOk covers the return in run() when the Errors channel +// is closed with ok=false (lines 86-88 in watcher.go). +func TestWatcherRunErrorsNotOk(t *testing.T) { + watcher, _, errs := newIsolatedWatcher(t) + + exited := runAndWait(watcher) + + // Closing the channel makes the Errors select case fire with ok=false, + // which triggers the return on lines 86-88. + close(errs) + + select { + case <-exited: + case <-time.After(200 * time.Millisecond): + t.Fatal("run goroutine did not exit after Errors channel was closed") + } + + // Close the underlying fsnotify watcher so its backend goroutine can exit. + // It closes its own (original) Events and Errors channels, not ours. + _ = watcher.fsWatcher.Close() +} + +// TestWatcherRunErrorPath covers the log.Printf branch in run() when an error +// arrives from the backend with ok=true (line 90 in watcher.go). +func TestWatcherRunErrorPath(t *testing.T) { + watcher, _, errs := newIsolatedWatcher(t) + + exited := runAndWait(watcher) + + // Sending to the unbuffered errs channel blocks until run's select receives + // it. Because the channel is open, ok=true and the error is logged (line 90). + errs <- errSyntheticWatcher + + // Signal run to stop and wait for it to exit cleanly. + close(watcher.done) + + select { + case <-exited: + case <-time.After(200 * time.Millisecond): + t.Fatal("run goroutine did not exit after done was closed") + } + + // Close the underlying fsnotify watcher so its backend goroutine can exit. + _ = watcher.fsWatcher.Close() +} diff --git a/internal/config/watcher_test.go b/internal/config/watcher_test.go new file mode 100644 index 00000000..94ae5bb3 --- /dev/null +++ b/internal/config/watcher_test.go @@ -0,0 +1,117 @@ +package config_test + +import ( + "os" + "path/filepath" + "testing" + "time" + + "github.com/evg4b/uncors/internal/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const watcherTimeout = 500 * time.Millisecond + +// waitForCall blocks until fn fires or the timeout elapses, returning true on success. +func waitForCall(ch <-chan struct{}, timeout time.Duration) bool { + select { + case <-ch: + return true + case <-time.After(timeout): + return false + } +} + +func TestNewConfigWatcher(t *testing.T) { + t.Run("returns error for non-existent file", func(t *testing.T) { + _, err := config.NewWatcher("/no/such/file.yaml", func() {}) + assert.Error(t, err) + }) + + t.Run("invokes onChange on file write", func(t *testing.T) { + tmpDir := t.TempDir() + configFile := filepath.Join(tmpDir, "config.yaml") + require.NoError(t, os.WriteFile(configFile, []byte("proxy: \"\""), 0o600)) + + called := make(chan struct{}, 1) + + watcher, err := config.NewWatcher(configFile, func() { + select { + case called <- struct{}{}: + default: + } + }) + require.NoError(t, err) + + defer watcher.Close() + + require.NoError(t, os.WriteFile(configFile, []byte("proxy: localhost:8080"), 0o600)) + assert.True(t, waitForCall(called, watcherTimeout), "onChange was not called after file write") + }) + + t.Run("does not invoke onChange after Close", func(t *testing.T) { + tmpDir := t.TempDir() + configFile := filepath.Join(tmpDir, "config.yaml") + require.NoError(t, os.WriteFile(configFile, []byte("proxy: \"\""), 0o600)) + + called := make(chan struct{}, 1) + + watcher, err := config.NewWatcher(configFile, func() { + select { + case called <- struct{}{}: + default: + } + }) + require.NoError(t, err) + + require.NoError(t, watcher.Close()) + + require.NoError(t, os.WriteFile(configFile, []byte("proxy: changed"), 0o600)) + assert.False(t, waitForCall(called, 100*time.Millisecond), "onChange was called after Close") + }) + + t.Run("debounces rapid successive writes", func(t *testing.T) { + tmpDir := t.TempDir() + configFile := filepath.Join(tmpDir, "config.yaml") + require.NoError(t, os.WriteFile(configFile, []byte("proxy: \"\""), 0o600)) + + callCount := 0 + called := make(chan struct{}, 10) + + watcher, err := config.NewWatcher(configFile, func() { + callCount++ + + called <- struct{}{} + }) + require.NoError(t, err) + + defer watcher.Close() + + // Write multiple times in quick succession. + for i := range 5 { + require.NoError(t, os.WriteFile(configFile, []byte("proxy: change"), 0o600)) + + _ = i + } + + // Wait for the first (and hopefully only) callback. + assert.True(t, waitForCall(called, watcherTimeout), "onChange was never called") + + // Give any extra calls a chance to arrive. + time.Sleep(50 * time.Millisecond) + + assert.LessOrEqual(t, callCount, 3, "too many onChange calls for rapid writes (expected debouncing)") + }) + + t.Run("Close returns nil on first call", func(t *testing.T) { + tmpDir := t.TempDir() + configFile := filepath.Join(tmpDir, "config.yaml") + require.NoError(t, os.WriteFile(configFile, []byte(""), 0o600)) + + watcher, err := config.NewWatcher(configFile, func() {}) + require.NoError(t, err) + + assert.NoError(t, watcher.Close()) + }) +} diff --git a/internal/uncors_app/app.go b/internal/uncors_app/app.go index 6c6bf5bc..92e7f8a1 100644 --- a/internal/uncors_app/app.go +++ b/internal/uncors_app/app.go @@ -15,9 +15,7 @@ import ( "github.com/evg4b/uncors/internal/server" "github.com/evg4b/uncors/internal/uncors" "github.com/evg4b/uncors/internal/version" - "github.com/fsnotify/fsnotify" "github.com/spf13/afero" - "github.com/spf13/viper" ) const ( @@ -43,7 +41,9 @@ type uncorsApp struct { cfg *config.UncorsConfig loadConfig func() *config.UncorsConfig - viper *viper.Viper + configPath string + + watcher *config.Watcher termHeight int termWidth int @@ -65,10 +65,13 @@ type appUpdateMsg interface { update(app *uncorsApp) tea.Cmd } +// NewUncorsApp creates the interactive TUI model. configPath is the active +// config file path (empty string if no config file is used); when non-empty +// the app watches it for changes and auto-restarts the proxy on every save. func NewUncorsApp( ver string, fs afero.Fs, - viperInstance *viper.Viper, + configPath string, cfg *config.UncorsConfig, loadConfig func() *config.UncorsConfig, ) tea.Model { @@ -93,7 +96,7 @@ func NewUncorsApp( cancel: cancel, cfg: cfg, loadConfig: loadConfig, - viper: viperInstance, + configPath: configPath, historyWidget: historyWidget, trackerWidget: NewTrackerWidget(), helpWidget: NewHelpWidget(keys), @@ -263,19 +266,25 @@ func (msg shutdownMsg) update(app *uncorsApp) tea.Cmd { } func (m *uncorsApp) handleServerStarted() tea.Cmd { - m.viper.OnConfigChange(func(_ fsnotify.Event) { - defer helpers.PanicInterceptor(func(value any) { - m.output.Errorf("Config reloading error: %v", value) - }) + if m.configPath != "" { + watcher, err := config.NewWatcher(m.configPath, func() { + defer helpers.PanicInterceptor(func(value any) { + m.output.Errorf("Config reloading error: %v", value) + }) - newCfg := m.loadConfig() + newCfg := m.loadConfig() - err := m.app.Restart(m.appContext(), newCfg) + err := m.app.Restart(m.appContext(), newCfg) + if err != nil { + m.output.Errorf("Failed to restart server: %v", err) + } + }) if err != nil { - m.output.Errorf("Failed to restart server: %v", err) + m.output.Errorf("Failed to watch config file: %v", err) + } else { + m.watcher = watcher } - }) - m.viper.WatchConfig() + } return m.versionCheckCmd() } @@ -306,6 +315,10 @@ func (m *uncorsApp) handleRestart() { func (m *uncorsApp) handleShutdown() tea.Cmd { log.Println("Handling shutdown") + if m.watcher != nil { + _ = m.watcher.Close() + } + _ = m.historyWidget.Close() return tea.Quit diff --git a/internal/uncors_app/app_internal_test.go b/internal/uncors_app/app_internal_test.go index e31c8e6a..3d631003 100644 --- a/internal/uncors_app/app_internal_test.go +++ b/internal/uncors_app/app_internal_test.go @@ -3,16 +3,16 @@ package uncorsapp import ( "errors" "net/url" + "os" "testing" "time" "charm.land/bubbles/v2/spinner" tea "charm.land/bubbletea/v2" "github.com/evg4b/uncors/internal/config" + "github.com/evg4b/uncors/internal/contracts" "github.com/evg4b/uncors/internal/server" - "github.com/fsnotify/fsnotify" "github.com/spf13/afero" - "github.com/spf13/viper" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -23,7 +23,6 @@ func newTestApp(t *testing.T) (*uncorsApp, *int) { t.Helper() fs := afero.NewMemMapFs() - viperInstance := viper.New() uncorsConfig := &config.UncorsConfig{ Mappings: config.Mappings{}, } @@ -32,7 +31,7 @@ func newTestApp(t *testing.T) (*uncorsApp, *int) { model := NewUncorsApp( "test-version", fs, - viperInstance, + "", // no config file — watcher is not created uncorsConfig, func() *config.UncorsConfig { loadCalls++ @@ -163,7 +162,6 @@ func TestUncorsAppCommandFactoriesAndChannels(t *testing.T) { cmd := app.handleServerStarted() require.NotNil(t, cmd) - app.viper.OnConfigChange(func(_ fsnotify.Event) {}) msg = app.restartCmd()() assert.Equal(t, restartMsg{}, msg) @@ -294,3 +292,166 @@ func TestUncorsAppServerErrorRestartShutdownAndFormatting(t *testing.T) { _ = app.app.Close() }) } + +func TestServerStartedMsgUpdate(t *testing.T) { + app, _ := newTestApp(t) + defer cleanupTestApp(t, app) + + model, cmd := app.Update(serverStartedMsg{}) + + require.Same(t, app, model) + require.NotNil(t, cmd) +} + +func TestHandleServerStartedWithConfigPath(t *testing.T) { + t.Run("creates watcher when config file exists", func(t *testing.T) { + tmpFile, err := os.CreateTemp(t.TempDir(), "uncors-*.yaml") + require.NoError(t, err) + + _ = tmpFile.Close() + + fs := afero.NewMemMapFs() + cfg := &config.UncorsConfig{Mappings: config.Mappings{}} + + model := NewUncorsApp("v1", fs, tmpFile.Name(), cfg, func() *config.UncorsConfig { return cfg }) + app, ok := model.(*uncorsApp) + require.True(t, ok) + + defer func() { + app.cancel() + _ = app.app.Close() + + if app.historyWidget != nil && app.historyWidget.hist != nil { + _ = app.historyWidget.hist.Close() + } + }() + + cmd := app.handleServerStarted() + + require.NotNil(t, cmd) + require.NotNil(t, app.watcher) + + _ = app.watcher.Close() + }) + + t.Run("logs error when config file does not exist", func(t *testing.T) { + fs := afero.NewMemMapFs() + cfg := &config.UncorsConfig{Mappings: config.Mappings{}} + + model := NewUncorsApp("v1", fs, "/nonexistent/path/config.yaml", cfg, func() *config.UncorsConfig { return cfg }) + app, ok := model.(*uncorsApp) + require.True(t, ok) + + defer func() { + app.cancel() + _ = app.app.Close() + + if app.historyWidget != nil && app.historyWidget.hist != nil { + _ = app.historyWidget.hist.Close() + } + }() + + cmd := app.handleServerStarted() + + require.NotNil(t, cmd) + assert.Nil(t, app.watcher) + }) +} + +func TestHandleRequestEventWithData(t *testing.T) { + requestURL, err := url.Parse("https://example.com/api") + require.NoError(t, err) + + data := &contracts.ReqestData{Method: "GET", URL: requestURL, Code: 200} + + t.Run("outputs request without prefix", func(t *testing.T) { + app, _ := newTestApp(t) + defer cleanupTestApp(t, app) + + app.handleRequestEvent(requestEventMsg{Done: true, Data: data}) + }) + + t.Run("outputs request with prefix", func(t *testing.T) { + app, _ := newTestApp(t) + defer cleanupTestApp(t, app) + + app.handleRequestEvent(requestEventMsg{Done: true, Data: data, Prefix: "api"}) + }) +} + +func TestHandleServerStartedCallbackOnFileChange(t *testing.T) { + tmpFile, err := os.CreateTemp(t.TempDir(), "uncors-*.yaml") + require.NoError(t, err) + + _ = tmpFile.Close() + + fs := afero.NewMemMapFs() + cfg := &config.UncorsConfig{Mappings: config.Mappings{}} + + called := make(chan struct{}, 1) + + model := NewUncorsApp("v1", fs, tmpFile.Name(), cfg, func() *config.UncorsConfig { + select { + case called <- struct{}{}: + default: + } + + return cfg + }) + + app, ok := model.(*uncorsApp) + require.True(t, ok) + + defer func() { + // Cancel context first so any in-flight Restart fails fast. + // We deliberately skip app.app.Close() here: closeAll() writes + // app.closers concurrently with the Restart goroutine's read of + // app.closers, which would be a data race. + app.cancel() + + if app.watcher != nil { + _ = app.watcher.Close() + } + + if app.historyWidget != nil && app.historyWidget.hist != nil { + _ = app.historyWidget.hist.Close() + } + }() + + cmd := app.handleServerStarted() + + require.NotNil(t, cmd) + require.NotNil(t, app.watcher) + + require.NoError(t, os.WriteFile(tmpFile.Name(), []byte("proxy: \"\""), 0o600)) + + select { + case <-called: + case <-time.After(500 * time.Millisecond): + t.Fatal("onChange callback was not invoked within timeout") + } +} + +func TestHandleShutdownWithWatcher(t *testing.T) { + tmpFile, err := os.CreateTemp(t.TempDir(), "uncors-*.yaml") + require.NoError(t, err) + + _ = tmpFile.Close() + + watcher, err := config.NewWatcher(tmpFile.Name(), func() {}) + require.NoError(t, err) + + app, _ := newTestApp(t) + app.watcher = watcher + + cmd := app.handleShutdown() + require.NotNil(t, cmd) + assert.Equal(t, tea.Quit(), cmd()) + + app.cancel() + _ = app.app.Close() + + if app.historyWidget != nil && app.historyWidget.hist != nil { + _ = app.historyWidget.hist.Close() + } +} diff --git a/main.go b/main.go index a12b0d3c..d0b57952 100644 --- a/main.go +++ b/main.go @@ -19,14 +19,14 @@ import ( "github.com/evg4b/uncors/internal/uncors" uncorsapp "github.com/evg4b/uncors/internal/uncors_app" "github.com/evg4b/uncors/internal/version" - "github.com/fsnotify/fsnotify" "github.com/spf13/afero" "github.com/spf13/pflag" - "github.com/spf13/viper" ) var Version = "X.X.X" +const generateCertsCmd = "generate-certs" + func main() { exitCode := run() os.Exit(exitCode) @@ -42,31 +42,8 @@ func run() int { fs := afero.NewOsFs() - if len(os.Args) > 1 && os.Args[1] == "generate-certs" { - cmd := commands.NewGenerateCertsCommand( - commands.WithFs(fs), - commands.WithOutput(output), - ) - flags := pflag.NewFlagSet("generate-certs", pflag.ExitOnError) - cmd.DefineFlags(flags) - - err := flags.Parse(os.Args[2:]) - if err != nil { - output.Error(err) - log.Printf("Error: %v", err) - - return 1 - } - - err = cmd.Execute() - if err != nil { - output.Error(err) - log.Printf("Error: %v", err) - - return 1 - } - - return 0 + if len(os.Args) > 1 && os.Args[1] == generateCertsCmd { + return runGenerateCerts(fs, output) } pflag.Usage = func() { @@ -75,82 +52,164 @@ func run() int { pflag.PrintDefaults() } - viperInstance := viper.GetViper() + uncorsConfig, configPath := loadConfiguration(fs) - uncorsConfig := loadConfiguration(viperInstance, fs) + if uncorsConfig.Interactive { + return runInteractive(fs, configPath, uncorsConfig) + } - ctx := context.Background() + return runNonInteractive(context.Background(), fs, output, configPath, uncorsConfig) +} - if !uncorsConfig.Interactive { - tracker := server.NewRequestTracker() - app := uncors.CreateUncors(fs, output, Version).WithTracker(tracker) +// runGenerateCerts executes the generate-certs sub-command and returns an exit code. +func runGenerateCerts(fs afero.Fs, output *tui.CliOutput) int { + cmd := commands.NewGenerateCertsCommand( + commands.WithFs(fs), + commands.WithOutput(output), + ) - go server.RequestPrinter(tracker, output) + flags := pflag.NewFlagSet(generateCertsCmd, pflag.ContinueOnError) + cmd.DefineFlags(flags) - viperInstance.OnConfigChange(func(_ fsnotify.Event) { - defer helpers.PanicInterceptor(func(value any) { - log.Printf("Config reloading error: %v", value) - output.Errorf("Config reloading error: %v", value) - }) + err := flags.Parse(os.Args[2:]) + if err != nil { + output.Error(err) + log.Printf("Error: %v", err) - err := app.Restart(ctx, loadConfiguration(viperInstance, fs)) - if err != nil { - log.Printf("Failed to restart server: %v", err) - output.Errorf("Failed to restart server: %v", err) - } - }) - viperInstance.WatchConfig() + return 1 + } - err := app.Start(ctx, uncorsConfig) - if err != nil { - panic(err) - } + err = cmd.Execute() + if err != nil { + output.Error(err) + log.Printf("Error: %v", err) - go func() { - const checkDelay = 50 * time.Second + return 1 + } - versionChecker := version.NewVersionChecker( - version.WithOutput(output), - version.WithHTTPClient(infra.MakeHTTPClient(uncorsConfig.Proxy)), - version.WithCurrentVersion(Version), - ) + return 0 +} - time.Sleep(checkDelay) - versionChecker.CheckNewVersion(ctx) - }() +// runNonInteractive starts the proxy in non-interactive (headless) mode and +// blocks until the server shuts down. The config file is watched for changes +// when configPath is non-empty. +func runNonInteractive( + ctx context.Context, + fs afero.Fs, + output *tui.CliOutput, + configPath string, + cfg *config.UncorsConfig, +) int { + tracker := server.NewRequestTracker() + app := uncors.CreateUncors(fs, output, Version).WithTracker(tracker) + + go server.RequestPrinter(tracker, output) + + if configPath != "" { + startConfigWatcher(ctx, fs, output, configPath, app) + } - go helpers.GracefulShutdown(ctx, func(shutdownCtx context.Context) error { - log.Println("shutdown signal received") + err := app.Start(ctx, cfg) + if err != nil { + panic(err) + } - return app.Shutdown(shutdownCtx) - }) + go startVersionChecker(ctx, output, cfg.Proxy) - app.Wait() - output.Info("Server was stopped") - } else { - app := uncorsapp.NewUncorsApp( - Version, - fs, - viperInstance, - uncorsConfig, - func() *config.UncorsConfig { return loadConfiguration(viperInstance, fs) }, - ) + go helpers.GracefulShutdown(ctx, func(shutdownCtx context.Context) error { + log.Println("shutdown signal received") + + return app.Shutdown(shutdownCtx) + }) - p := tea.NewProgram(app) + app.Wait() + output.Info("Server was stopped") - _, err := p.Run() - if err != nil { - log.Fatal(err) + return 0 +} + +// startConfigWatcher begins watching the config file and restarts the proxy on +// every change. The watcher lives for the process lifetime (not closed explicitly). +func startConfigWatcher( + ctx context.Context, + fs afero.Fs, + output *tui.CliOutput, + configPath string, + app *uncors.Uncors, +) { + watcher, err := config.NewWatcher(configPath, func() { + defer helpers.PanicInterceptor(func(value any) { + log.Printf("Config reloading error: %v", value) + output.Errorf("Config reloading error: %v", value) + }) + + reloaded, _ := loadConfiguration(fs) + + restartErr := app.Restart(ctx, reloaded) + if restartErr != nil { + log.Printf("Failed to restart server: %v", restartErr) + output.Errorf("Failed to restart server: %v", restartErr) } + }) + if err != nil { + log.Printf("Failed to start config watcher: %v", err) + output.Errorf("Failed to start config watcher: %v", err) + + return + } + + // The watcher goroutine owns its lifetime; it is intentionally not closed + // here because the proxy server (app.Wait) blocks the caller for the same + // duration. The OS reclaims resources when the process exits. + _ = watcher +} + +// startVersionChecker waits for a short delay then checks for a newer release. +func startVersionChecker(ctx context.Context, output *tui.CliOutput, proxy string) { + const checkDelay = 50 * time.Second + + versionChecker := version.NewVersionChecker( + version.WithOutput(output), + version.WithHTTPClient(infra.MakeHTTPClient(proxy)), + version.WithCurrentVersion(Version), + ) + + time.Sleep(checkDelay) + versionChecker.CheckNewVersion(ctx) +} + +// runInteractive starts the proxy in interactive TUI mode. +func runInteractive(fs afero.Fs, configPath string, cfg *config.UncorsConfig) int { + app := uncorsapp.NewUncorsApp( + Version, + fs, + configPath, + cfg, + func() *config.UncorsConfig { + reloaded, _ := loadConfiguration(fs) + + return reloaded + }, + ) + + _, err := tea.NewProgram(app).Run() + if err != nil { + log.Fatal(err) } return 0 } -func loadConfiguration(viperInstance *viper.Viper, fs afero.Fs) *config.UncorsConfig { - uncorsConfig := config.LoadConfiguration(viperInstance, os.Args) +// loadConfiguration loads and validates the configuration from CLI args and the +// config file. It panics on any error so that the PanicInterceptor in run() can +// display a human-readable message and exit cleanly. +func loadConfiguration(fs afero.Fs) (*config.UncorsConfig, string) { + uncorsConfig, configPath, err := config.LoadConfiguration(fs, os.Args) + if err != nil { + panic(err) + } - err := validators.ValidateConfig(uncorsConfig, fs) + err = validators.ValidateConfig(uncorsConfig, fs) if err != nil { panic(err) } @@ -167,5 +226,5 @@ func loadConfiguration(viperInstance *viper.Viper, fs afero.Fs) *config.UncorsCo log.SetOutput(io.Discard) } - return uncorsConfig + return uncorsConfig, configPath } diff --git a/main_test.go b/main_test.go new file mode 100644 index 00000000..20ccebf0 --- /dev/null +++ b/main_test.go @@ -0,0 +1,140 @@ +package main + +import ( + "context" + "io" + "os" + "path/filepath" + "testing" + + "github.com/evg4b/uncors/internal/tui" + "github.com/spf13/afero" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// setArgs temporarily overrides os.Args and returns a restore function. +func setArgs(args []string) func() { + old := os.Args + os.Args = args + + return func() { os.Args = old } +} + +func newTestOutput() *tui.CliOutput { + return tui.NewCliOutput(io.Discard) +} + +func TestLoadConfiguration(t *testing.T) { + t.Run("returns config for valid flags", func(t *testing.T) { + defer setArgs([]string{"uncors", "-f", "http://localhost:3000", "-t", "https://api.example.com"})() + + cfg, path := loadConfiguration(afero.NewMemMapFs()) + + require.NotNil(t, cfg) + assert.Empty(t, path) + assert.Len(t, cfg.Mappings, 1) + }) + + t.Run("panics when mappings are empty", func(t *testing.T) { + defer setArgs([]string{"uncors"})() + + assert.Panics(t, func() { + loadConfiguration(afero.NewMemMapFs()) + }) + }) + + t.Run("panics on invalid flags", func(t *testing.T) { + defer setArgs([]string{"uncors", "--no-such-flag"})() + + assert.Panics(t, func() { + loadConfiguration(afero.NewMemMapFs()) + }) + }) +} + +func TestRunGenerateCerts(t *testing.T) { + t.Run("generates certs and returns 0", func(t *testing.T) { + defer setArgs([]string{"uncors", generateCertsCmd})() + + fs := afero.NewMemMapFs() + output := newTestOutput() + + result := runGenerateCerts(fs, output) + + assert.Equal(t, 0, result) + }) + + t.Run("returns 1 when execute fails", func(t *testing.T) { + defer setArgs([]string{"uncors", generateCertsCmd})() + + // Second call on the same fs finds certs already exist → ErrCAAlreadyExists. + fs := afero.NewMemMapFs() + output := newTestOutput() + + _ = runGenerateCerts(fs, output) + result := runGenerateCerts(fs, output) + + assert.Equal(t, 1, result) + }) + + t.Run("returns 1 when flags parse fails", func(t *testing.T) { + defer setArgs([]string{"uncors", generateCertsCmd, "--no-such-flag"})() + + result := runGenerateCerts(afero.NewMemMapFs(), newTestOutput()) + + assert.Equal(t, 1, result) + }) +} + +func TestLoadConfigurationWithDebug(t *testing.T) { + t.Chdir(t.TempDir()) + + defer setArgs([]string{"uncors", "-f", "http://localhost:3000", "-t", "https://api.example.com", "--debug"})() + + cfg, _ := loadConfiguration(afero.NewMemMapFs()) + + require.NotNil(t, cfg) + assert.True(t, cfg.Debug) +} + +func TestLoadConfigurationWithConfigFile(t *testing.T) { + const cfgContent = ` +mappings: + - from: http://localhost:3000 + to: https://api.example.com +` + + defer setArgs([]string{"uncors", "--config", "/config.yaml"})() + + fs := afero.NewMemMapFs() + require.NoError(t, afero.WriteFile(fs, "/config.yaml", []byte(cfgContent), 0o600)) + + cfg, path := loadConfiguration(fs) + + require.NotNil(t, cfg) + assert.Equal(t, "/config.yaml", path) + assert.Len(t, cfg.Mappings, 1) +} + +func TestStartConfigWatcher(t *testing.T) { + t.Run("logs error for non-existent config path", func(t *testing.T) { + output := newTestOutput() + + assert.NotPanics(t, func() { + startConfigWatcher(context.Background(), afero.NewMemMapFs(), output, "/no/such/config.yaml", nil) + }) + }) + + t.Run("creates watcher for existing config file", func(t *testing.T) { + tmpDir := t.TempDir() + configFile := filepath.Join(tmpDir, "config.yaml") + require.NoError(t, os.WriteFile(configFile, []byte("proxy: \"\""), 0o600)) + + output := newTestOutput() + + assert.NotPanics(t, func() { + startConfigWatcher(context.Background(), afero.NewMemMapFs(), output, configFile, nil) + }) + }) +}