diff --git a/cmd/limactl/shell.go b/cmd/limactl/shell.go index 66c56c4af13..69622b0cc72 100644 --- a/cmd/limactl/shell.go +++ b/cmd/limactl/shell.go @@ -64,6 +64,8 @@ func newShellCommand() *cobra.Command { shellCmd.Flags().Bool("reconnect", false, "Reconnect to the SSH session") shellCmd.Flags().Bool("preserve-env", false, "Propagate environment variables to the shell") shellCmd.Flags().Bool("start", false, "Start the instance if it is not already running") + shellCmd.Flags().StringSlice("allow-env", []string{}, "Comma-separated list of environment variable patterns to allow when --preserve-env is set (overrides LIMA_SHELLENV_ALLOW)") + shellCmd.Flags().StringSlice("block-env", []string{}, "Comma-separated list of environment variable patterns to allow when --preserve-env is set (overrides LIMA_SHELLENV_BLOCK)") return shellCmd } @@ -216,8 +218,16 @@ func shellAction(cmd *cobra.Command, args []string) error { if err != nil { return err } + allowListRaw, err := cmd.Flags().GetStringSlice("allow-env") + if err != nil { + return err + } + blockListRaw, err := cmd.Flags().GetStringSlice("block-env") + if err != nil { + return err + } if preserveEnv { - filteredEnv := envutil.FilterEnvironment() + filteredEnv := envutil.FilterEnvironment(allowListRaw, blockListRaw) if len(filteredEnv) > 0 { envPrefix = "env " for _, envVar := range filteredEnv { diff --git a/pkg/envutil/envutil.go b/pkg/envutil/envutil.go index d5147d5a8c7..f9284cab17c 100644 --- a/pkg/envutil/envutil.go +++ b/pkg/envutil/envutil.go @@ -56,15 +56,20 @@ func validatePattern(pattern string) error { } // getBlockList returns the list of environment variable patterns to be blocked. -func getBlockList() []string { - blockEnv := os.Getenv("LIMA_SHELLENV_BLOCK") - if blockEnv == "" { - return defaultBlockList +func getBlockList(blockListRaw []string) []string { + var shouldAppend bool + patterns := blockListRaw + if len(patterns) == 0 { + blockEnv := os.Getenv("LIMA_SHELLENV_BLOCK") + if blockEnv == "" { + return defaultBlockList + } + shouldAppend = strings.HasPrefix(blockEnv, "+") + patterns = parseEnvList(strings.TrimPrefix(blockEnv, "+")) + } else { + shouldAppend = strings.HasPrefix(patterns[0], "+") } - shouldAppend := strings.HasPrefix(blockEnv, "+") - patterns := parseEnvList(strings.TrimPrefix(blockEnv, "+")) - for _, pattern := range patterns { if err := validatePattern(pattern); err != nil { logrus.Fatalf("Invalid LIMA_SHELLENV_BLOCK pattern: %v", err) @@ -78,14 +83,16 @@ func getBlockList() []string { } // getAllowList returns the list of environment variable patterns to be allowed. -func getAllowList() []string { - allowEnv := os.Getenv("LIMA_SHELLENV_ALLOW") - if allowEnv == "" { - return nil +func getAllowList(allowListRaw []string) []string { + patterns := allowListRaw + if len(patterns) == 0 { + allowEnv := os.Getenv("LIMA_SHELLENV_ALLOW") + if allowEnv == "" { + return nil + } + patterns = parseEnvList(allowEnv) } - patterns := parseEnvList(allowEnv) - for _, pattern := range patterns { if err := validatePattern(pattern); err != nil { logrus.Fatalf("Invalid LIMA_SHELLENV_ALLOW pattern: %v", err) @@ -131,11 +138,11 @@ func matchesAnyPattern(name string, patterns []string) bool { // FilterEnvironment filters environment variables based on configuration from environment variables. // It returns a slice of environment variables that are not blocked by the current configuration. // The filtering is controlled by LIMA_SHELLENV_BLOCK and LIMA_SHELLENV_ALLOW environment variables. -func FilterEnvironment() []string { +func FilterEnvironment(allowListRaw, blockListRaw []string) []string { return filterEnvironmentWithLists( os.Environ(), - getAllowList(), - getBlockList(), + getAllowList(allowListRaw), + getBlockList(blockListRaw), ) } diff --git a/pkg/envutil/envutil_test.go b/pkg/envutil/envutil_test.go index 6904580fa7d..9111d6a3552 100644 --- a/pkg/envutil/envutil_test.go +++ b/pkg/envutil/envutil_test.go @@ -88,8 +88,8 @@ func TestGetBlockAndAllowLists(t *testing.T) { t.Setenv("LIMA_SHELLENV_BLOCK", "") t.Setenv("LIMA_SHELLENV_ALLOW", "") - blockList := getBlockList() - allowList := getAllowList() + blockList := getBlockList([]string{}) + allowList := getAllowList([]string{}) assert.Assert(t, isUsingDefaultBlockList()) assert.DeepEqual(t, blockList, defaultBlockList) @@ -99,7 +99,7 @@ func TestGetBlockAndAllowLists(t *testing.T) { t.Run("custom blocklist", func(t *testing.T) { t.Setenv("LIMA_SHELLENV_BLOCK", "PATH,HOME") - blockList := getBlockList() + blockList := getBlockList([]string{}) assert.Assert(t, !isUsingDefaultBlockList()) expected := []string{"PATH", "HOME"} assert.DeepEqual(t, blockList, expected) @@ -108,7 +108,7 @@ func TestGetBlockAndAllowLists(t *testing.T) { t.Run("additive blocklist", func(t *testing.T) { t.Setenv("LIMA_SHELLENV_BLOCK", "+CUSTOM_VAR") - blockList := getBlockList() + blockList := getBlockList([]string{}) assert.Assert(t, isUsingDefaultBlockList()) expected := slices.Concat(GetDefaultBlockList(), []string{"CUSTOM_VAR"}) assert.DeepEqual(t, blockList, expected) @@ -117,7 +117,7 @@ func TestGetBlockAndAllowLists(t *testing.T) { t.Run("allowlist", func(t *testing.T) { t.Setenv("LIMA_SHELLENV_ALLOW", "FOO,BAR") - allowList := getAllowList() + allowList := getAllowList([]string{}) expected := []string{"FOO", "BAR"} assert.DeepEqual(t, allowList, expected) })