diff --git a/cli.go b/cli.go index 6bce342..9f17232 100644 --- a/cli.go +++ b/cli.go @@ -16,6 +16,9 @@ Cli represents the structure of a CLI app. It should be constructed using the Ap type Cli struct { *Cmd version *cliVersion + exiter func(code int) // REVIEW: might be desirable to have other options than just callback; the most common use I can imagine, other than `os.Exit`, is to simply capture the value again, for which a callback is both overkill and high friction. + stdOut io.Writer // REVIEW: I brought this along for the ride, because it was there at package scope before, but... it's never used, afaict? + stdErr io.Writer } type cliVersion struct { @@ -34,15 +37,20 @@ name and description will be used to construct the help message for the app: */ func App(name, desc string) *Cli { - return &Cli{ - Cmd: &Cmd{ - name: name, - desc: desc, - optionsIdx: map[string]*container.Container{}, - argsIdx: map[string]*container.Container{}, - ErrorHandling: flag.ExitOnError, - }, + cli := &Cli{ + exiter: func(code int) { os.Exit(code) }, + stdOut: os.Stdout, + stdErr: os.Stderr, } + cli.Cmd = &Cmd{ + cli: cli, + name: name, + desc: desc, + optionsIdx: map[string]*container.Container{}, + argsIdx: map[string]*container.Container{}, + ErrorHandling: flag.ExitOnError, + } + return cli } /* @@ -65,6 +73,44 @@ func (cli *Cli) Version(name, version string) { cli.version = &cliVersion{version, option} } +/* +SetStdout sets the CLI's concept of what is "standard out". + +If SetStdout is not called, the default behavior is to use os.Stdout. + +This is currently unused. +*/ +func (cli *Cli) SetStdout(wr io.Writer) { + cli.stdOut = wr +} + +/* +SetStderr sets the CLI's concept of what is "standard error". + +If SetStderr is not called, the default behavior is to use os.Stderr. + +Information about parse errors is written to this stream, +as well as usage info or version info, if those are requested. +*/ +func (cli *Cli) SetStderr(wr io.Writer) { + cli.stdErr = wr +} + +/* +SetExiter sets a callback to define what happens when an exit should happen with an exit code. + +If SetExiter is not called, the default behavior is to call os.Exit +(which immediately halts the program). + +Common uses of setting a custom exit function include gathering the code instead of halting the program +(which is often useful for writing tests of CLI behavior, for example). + +SetExiter should not be used for cleanup hooks; use a Cmd.After callback for that. +*/ +func (cli *Cli) SetExiter(exiter func(code int)) { + cli.exiter = exiter +} + func (cli *Cli) parse(args []string, entry, inFlow, outFlow *flow.Step) error { // We overload Cmd.parse() and handle cases that only apply to the CLI command, like versioning // After that, we just call Cmd.parse() for the default behavior @@ -86,7 +132,7 @@ In most cases the library users won't need to call this method, unless a more complex validation is needed. */ func (cli *Cli) PrintVersion() { - fmt.Fprintln(stdErr, cli.version.version) + fmt.Fprintln(cli.stdErr, cli.version.version) } /* @@ -96,12 +142,12 @@ and to execute the matching command. In case of an incorrect usage, and depending on the configured ErrorHandling policy, it may return an error, panic or exit */ -func (cli *Cli) Run(args []string) error { +func (cli *Cli) Run(args []string) error { // REVIEW: I would actually prefer this returned `(error, int)`... but, that's a breaking change. We could also: use special error types for code; or, introduce a new function; or, do nothing, and require users to write a capturing thunk for an `exiter`. if err := cli.doInit(); err != nil { panic(err) } - inFlow := &flow.Step{Desc: "RootIn", Exiter: exiter} - outFlow := &flow.Step{Desc: "RootOut", Exiter: exiter} + inFlow := &flow.Step{Desc: "RootIn", Exiter: cli.exiter} + outFlow := &flow.Step{Desc: "RootOut", Exiter: cli.exiter} return cli.parse(args[1:], inFlow, inFlow, outFlow) } @@ -119,16 +165,10 @@ func ActionCommand(action func()) CmdInitializer { /* Exit causes the app the exit with the specified exit code while giving the After interceptors a chance to run. This should be used instead of os.Exit. + +This function is implemented using a panic; nothing will occur after it is called +(except other deferred functions, and the After intercepters). */ func Exit(code int) { panic(flow.ExitCode(code)) } - -var exiter = func(code int) { - os.Exit(code) -} - -var ( - stdOut io.Writer = os.Stdout - stdErr io.Writer = os.Stderr -) diff --git a/cli_test.go b/cli_test.go index 8e6ec49..0dce404 100644 --- a/cli_test.go +++ b/cli_test.go @@ -770,16 +770,10 @@ func TestHelpCommandSkipsValidation(t *testing.T) { app.BoolOpt("opt3", false, "opt3 desc") app.Command("command", "command desc", func(cmd *Cmd) {}) - var out, stdErr string - defer captureAndRestoreOutput(&out, &stdErr)() - - exitCalled := false - defer exitShouldBeCalledWith(t, 0, &exitCalled)() - - require.NoError(t, - // calling help on a command should skip validating the parents required arguments - app.Run(args)) + exitCode, _, stdErr, err := bufferizedRun(app, args) + require.NoError(t, err) // calling help on a command should skip validating the parents required arguments + require.Equal(t, exitCode, 0, "exit should have been called") require.Equal(t, ` Usage: app command @@ -811,16 +805,10 @@ command desc cmd.Command("child", "child desc", func(cmd *Cmd) {}) }) - var out, stdErr string - defer captureAndRestoreOutput(&out, &stdErr)() - - exitCalled := false - defer exitShouldBeCalledWith(t, 0, &exitCalled)() - - require.NoError(t, - // calling help on a command should skip validating the parents required arguments - app.Run(args)) + exitCode, _, stdErr, err := bufferizedRun(app, args) + require.NoError(t, err) // calling help on a command should skip validating the parents required arguments + require.Equal(t, exitCode, 0, "exit should have been called") require.Equal(t, ` Usage: app command child @@ -835,10 +823,6 @@ func TestHelpAndVersionWithOptionsEnd(t *testing.T) { for _, opt := range []string{"-h", "--help", "-v", "--version"} { t.Run(opt, func(t *testing.T) { t.Logf("Testing help/version with --: opt=%q", opt) - defer suppressOutput()() - - exitCalled := false - defer exitShouldBeCalledWith(t, 0, &exitCalled)() app := App("x", "") app.Version("v version", "1.0") @@ -852,11 +836,12 @@ func TestHelpAndVersionWithOptionsEnd(t *testing.T) { require.Equal(t, opt, *cmd) } - require.NoError(t, - app.Run([]string{"x", "--", opt})) + exitCode, _, _, err := bufferizedRun(app, + []string{"x", "--", opt}) + require.NoError(t, err) require.True(t, actionCalled, "action should have been called") - require.False(t, exitCalled, "exit should not have been called") + require.Equal(t, exitCode, -1, "exit should not have been called") }) } } @@ -884,13 +869,8 @@ func TestHelpMessage(t *testing.T) { cas := cas t.Run(cas.name, func(t *testing.T) { t.Logf("case: %+v", cas) - var out, stdErr string - defer captureAndRestoreOutput(&out, &stdErr)() defer setAndRestoreEnv(cas.env)() - exitCalled := false - defer exitShouldBeCalledWith(t, cas.exitCode, &exitCalled)() - app := App("app", "App Desc") app.Spec = "[-bdsuikqs] [BOOL1 STR1 INT3...]" @@ -952,8 +932,8 @@ func TestHelpMessage(t *testing.T) { }) t.Logf("calling app with %+v", cas.params) - require.NoError(t, - app.Run(cas.params)) + exitCode, _, stdErr, err := bufferizedRun(app, cas.params) + require.NoError(t, err) filename := fmt.Sprintf("testdata/help-output-%s.txt", cas.name) @@ -965,18 +945,13 @@ func TestHelpMessage(t *testing.T) { expected, e := ioutil.ReadFile(filename) require.NoError(t, e, "Failed to read the expected help output from %s", filename) + require.Equal(t, exitCode, cas.exitCode, "exit should have been called with the expected code") require.Equal(t, string(expected), stdErr) }) } } func TestLongHelpMessage(t *testing.T) { - var out, err string - defer captureAndRestoreOutput(&out, &err)() - - exitCalled := false - defer exitShouldBeCalledWith(t, 0, &exitCalled)() - app := App("app", "App Desc") app.LongDesc = "Longer App Desc" app.Spec = "[-o] ARG" @@ -985,27 +960,23 @@ func TestLongHelpMessage(t *testing.T) { app.String(StringArg{Name: "ARG", Value: "", Desc: "Argument"}) app.Action = func() {} - require.NoError(t, - app.Run([]string{"app", "-h"})) + + exitCode, _, stdErr, err := bufferizedRun(app, + []string{"app", "-h"}) + require.NoError(t, err) if *genGolden { require.NoError(t, - ioutil.WriteFile("testdata/long-help-output.txt.golden", []byte(err), 0644)) + ioutil.WriteFile("testdata/long-help-output.txt.golden", []byte(stdErr), 0644)) } expected, e := ioutil.ReadFile("testdata/long-help-output.txt") require.NoError(t, e, "Failed to read the expected help output from testdata/long-help-output.txt") - - require.Equal(t, expected, []byte(err)) + require.Equal(t, exitCode, 0, "exit should have been called") + require.Equal(t, expected, []byte(stdErr)) } func TestMultiLineDescInHelpMessage(t *testing.T) { - var out, err string - defer captureAndRestoreOutput(&out, &err)() - - exitCalled := false - defer exitShouldBeCalledWith(t, 0, &exitCalled)() - app := App("app", "App Desc") app.LongDesc = "Longer App Desc" app.Spec = "[-o] ARG" @@ -1015,25 +986,23 @@ func TestMultiLineDescInHelpMessage(t *testing.T) { app.String(StringArg{Name: "ARG", Value: "", Desc: "Argument\nDescription\nMultiple\nLines"}) app.Action = func() {} - require.NoError(t, - app.Run([]string{"app", "-h"})) + + exitCode, _, stdErr, err := bufferizedRun(app, + []string{"app", "-h"}) + require.NoError(t, err) if *genGolden { require.NoError(t, - ioutil.WriteFile("testdata/multi-line-desc-help-output.txt.golden", []byte(err), 0644)) + ioutil.WriteFile("testdata/multi-line-desc-help-output.txt.golden", []byte(stdErr), 0644)) } expected, e := ioutil.ReadFile("testdata/multi-line-desc-help-output.txt") require.NoError(t, e, "Failed to read the expected help output from testdata/long-help-output.txt") - - require.Equal(t, expected, []byte(err)) + require.Equal(t, exitCode, 0, "exit should have been called") + require.Equal(t, expected, []byte(stdErr)) } func TestVersionShortcut(t *testing.T) { - defer suppressOutput()() - exitCalled := false - defer exitShouldBeCalledWith(t, 0, &exitCalled)() - app := App("cp", "") app.Version("v version", "cp 1.2.3") @@ -1042,11 +1011,12 @@ func TestVersionShortcut(t *testing.T) { actionCalled = true } - require.NoError(t, - app.Run([]string{"cp", "--version"})) + exitCode, _, _, err := bufferizedRun(app, + []string{"cp", "--version"}) + require.NoError(t, err) require.False(t, actionCalled, "action should not have been called") - require.True(t, exitCalled, "exit should have been called") + require.Equal(t, exitCode, 0, "exit should have been called") } func TestSubCommands(t *testing.T) { @@ -1073,9 +1043,6 @@ func TestSubCommands(t *testing.T) { } func TestContinueOnError(t *testing.T) { - defer exitShouldNotCalled(t)() - defer suppressOutput()() - app := App("say", "") app.String(StringOpt{Name: "f", Value: "", Desc: ""}) app.Spec = "-f" @@ -1085,15 +1052,13 @@ func TestContinueOnError(t *testing.T) { called = true } - err := app.Run([]string{"say"}) + exitCode, _, _, err := bufferizedRun(app, []string{"say"}) require.NotNil(t, err) require.False(t, called, "Exec should NOT have been called") + require.Equal(t, exitCode, -1, "Exit should not have been called") } func TestContinueOnErrorWithHelpAndVersion(t *testing.T) { - defer exitShouldNotCalled(t)() - defer suppressOutput()() - app := App("say", "") app.Version("v", "1.0") app.String(StringOpt{Name: "f", Value: "", Desc: ""}) @@ -1105,58 +1070,47 @@ func TestContinueOnErrorWithHelpAndVersion(t *testing.T) { } { - err := app.Run([]string{"say", "-h"}) + exitCode, _, _, err := bufferizedRun(app, []string{"say", "-h"}) require.Nil(t, err) require.False(t, called, "Exec should NOT have been called") + require.Equal(t, exitCode, -1, "Exit should not have been called") } { - err := app.Run([]string{"say", "-v"}) + exitCode, _, _, err := bufferizedRun(app, []string{"say", "-v"}) require.Nil(t, err) require.False(t, called, "Exec should NOT have been called") + require.Equal(t, exitCode, -1, "Exit should not have been called") } } func TestExitOnError(t *testing.T) { - defer suppressOutput()() - - exitCalled := false - defer exitShouldBeCalledWith(t, 2, &exitCalled)() - app := App("x", "") app.ErrorHandling = flag.ExitOnError app.Spec = "Y" app.String(StringArg{Name: "Y", Value: "", Desc: ""}) - require.Error(t, - app.Run([]string{"x", "y", "z"})) - require.True(t, exitCalled, "exit should have been called") + exitCode, _, _, err := bufferizedRun(app, + []string{"x", "y", "z"}) + require.Error(t, err) + require.Equal(t, exitCode, 2, "exit should have been called") } func TestExitOnErrorWithHelp(t *testing.T) { - defer suppressOutput()() - - exitCalled := false - defer exitShouldBeCalledWith(t, 0, &exitCalled)() - app := App("x", "") app.Spec = "Y" app.ErrorHandling = flag.ExitOnError app.String(StringArg{Name: "Y", Value: "", Desc: ""}) - require.NoError(t, - app.Run([]string{"x", "-h"})) - require.True(t, exitCalled, "exit should have been called") + exitCode, _, _, err := bufferizedRun(app, + []string{"x", "-h"}) + require.NoError(t, err) + require.Equal(t, exitCode, 0, "exit should have been called") } func TestExitOnErrorWithVersion(t *testing.T) { - defer suppressOutput()() - - exitCalled := false - defer exitShouldBeCalledWith(t, 0, &exitCalled)() - app := App("x", "") app.Version("v", "1.0") app.Spec = "Y" @@ -1164,14 +1118,13 @@ func TestExitOnErrorWithVersion(t *testing.T) { app.String(StringArg{Name: "Y", Value: "", Desc: ""}) - require.NoError(t, - app.Run([]string{"x", "-v"})) - require.True(t, exitCalled, "exit should have been called") + exitCode, _, _, err := bufferizedRun(app, + []string{"x", "-v"}) + require.NoError(t, err) + require.Equal(t, exitCode, 0, "exit should have been called") } func TestPanicOnError(t *testing.T) { - defer suppressOutput()() - app := App("say", "") app.String(StringOpt{Name: "f", Value: "", Desc: ""}) app.Spec = "-f" @@ -1186,8 +1139,9 @@ func TestPanicOnError(t *testing.T) { require.False(t, called, "Exec should NOT have been called") } }() - require.NoError(t, - app.Run([]string{"say"})) + _, _, _, err := bufferizedRun(app, + []string{"say"}) + require.NoError(t, err) t.Fatalf("wanted panic") } @@ -2080,8 +2034,6 @@ func TestCommandAction(t *testing.T) { } func TestCommandAliases(t *testing.T) { - defer suppressOutput()() - cases := []struct { args []string errorExpected bool @@ -2117,7 +2069,7 @@ func TestCommandAliases(t *testing.T) { } }) - err := app.Run(cas.args) + _, _, _, err := bufferizedRun(app, cas.args) if cas.errorExpected { require.Error(t, err, "Run() should have returned with an error") @@ -2292,27 +2244,6 @@ func TestBeforeAndAfterFlowOrderWhenMultipleAftersPanic(t *testing.T) { require.Equal(t, 7, counter) } -func exitShouldBeCalledWith(t *testing.T, wantedExitCode int, called *bool) func() { - oldExiter := exiter - exiter = func(code int) { - require.Equal(t, wantedExitCode, code, "unwanted exit code") - *called = true - } - return func() { exiter = oldExiter } -} - -func exitShouldNotCalled(t *testing.T) func() { - oldExiter := exiter - exiter = func(code int) { - t.Errorf("exit should not have been called") - } - return func() { exiter = oldExiter } -} - -func suppressOutput() func() { - return captureAndRestoreOutput(nil, nil) -} - func setAndRestoreEnv(env map[string]string) func() { backup := map[string]string{} for k, v := range env { @@ -2327,45 +2258,14 @@ func setAndRestoreEnv(env map[string]string) func() { } } -func captureAndRestoreOutput(out, err *string) func() { - oldStdOut := stdOut - oldStdErr := stdErr - - if out == nil { - stdOut = ioutil.Discard - } else { - stdOut = trapWriter(out) - } - if err == nil { - stdErr = ioutil.Discard - } else { - stdErr = trapWriter(err) - } - - return func() { - stdOut = oldStdOut - stdErr = oldStdErr - } -} - -func trapWriter(writeTo *string) *writerTrap { - return &writerTrap{ - buffer: bytes.NewBuffer(nil), - writeTo: writeTo, - } -} - -type writerTrap struct { - buffer *bytes.Buffer - writeTo *string -} - -func (w *writerTrap) Write(p []byte) (n int, err error) { - n, err = w.buffer.Write(p) - if err == nil { - *(w.writeTo) = w.buffer.String() - } - return +func bufferizedRun(app *Cli, args []string) (exitCode int, stdout string, stderr string, err error) { + exitCode = -1 + var stdOutBuf, stdErrBuf bytes.Buffer + app.exiter = func(code int) { exitCode = code } + app.stdOut = &stdOutBuf + app.stdErr = &stdErrBuf + err = app.Run(args) + return exitCode, stdOutBuf.String(), stdErrBuf.String(), err } func callChecker(t *testing.T, wanted int, counter *int) func() { diff --git a/commands.go b/commands.go index 99b43ba..059a8ea 100644 --- a/commands.go +++ b/commands.go @@ -35,6 +35,9 @@ type Cmd struct { // The command error handling strategy ErrorHandling flag.ErrorHandling + // The root of the CLI tree, used to look up some wiring (such as stdout, exit strategy, etc) + cli *Cli + init CmdInitializer name string aliases []string @@ -128,7 +131,8 @@ the last argument, init, is a function that will be called by mow.cli to further func (c *Cmd) Command(name, desc string, init CmdInitializer) { aliases := strings.Fields(name) c.commands = append(c.commands, &Cmd{ - ErrorHandling: c.ErrorHandling, + ErrorHandling: c.ErrorHandling, // REVIEW: consider removing this; we can get it as `c.cli.ErrorHandling` now. + cli: c.cli, name: aliases[0], aliases: aliases, desc: desc, @@ -482,14 +486,14 @@ func (c *Cmd) doInit() error { func (c *Cmd) onError(err error) { if err == errHelpRequested || err == errVersionRequested { if c.ErrorHandling == flag.ExitOnError { - exiter(0) + c.cli.exiter(0) } return } switch c.ErrorHandling { case flag.ExitOnError: - exiter(2) + c.cli.exiter(2) case flag.PanicOnError: panic(err) } @@ -517,27 +521,27 @@ func (c *Cmd) PrintLongHelp() { func (c *Cmd) printHelp(longDesc bool) { full := append(c.parents, c.name) path := strings.Join(full, " ") - fmt.Fprintf(stdErr, "\nUsage: %s", path) + fmt.Fprintf(c.cli.stdErr, "\nUsage: %s", path) spec := strings.TrimSpace(c.Spec) if len(spec) > 0 { - fmt.Fprintf(stdErr, " %s", spec) + fmt.Fprintf(c.cli.stdErr, " %s", spec) } if len(c.commands) > 0 { - fmt.Fprint(stdErr, " COMMAND [arg...]") + fmt.Fprint(c.cli.stdErr, " COMMAND [arg...]") } - fmt.Fprint(stdErr, "\n\n") + fmt.Fprint(c.cli.stdErr, "\n\n") desc := c.desc if longDesc && len(c.LongDesc) > 0 { desc = c.LongDesc } if len(desc) > 0 { - fmt.Fprintf(stdErr, "%s\n", desc) + fmt.Fprintf(c.cli.stdErr, "%s\n", desc) } - w := tabwriter.NewWriter(stdErr, 15, 1, 3, ' ', 0) + w := tabwriter.NewWriter(c.cli.stdErr, 15, 1, 3, ' ', 0) if len(c.args) > 0 { fmt.Fprint(w, "\t\nArguments:\t\n") @@ -676,7 +680,7 @@ func (c *Cmd) parse(args []string, entry, inFlow, outFlow *flow.Step) error { } if err := c.fsm.Parse(args[:nargsLen]); err != nil { - fmt.Fprintf(stdErr, "Error: %s\n", err.Error()) + fmt.Fprintf(c.cli.stdErr, "Error: %s\n", err.Error()) c.PrintHelp() c.onError(err) return err @@ -686,7 +690,7 @@ func (c *Cmd) parse(args []string, entry, inFlow, outFlow *flow.Step) error { Do: c.Before, Error: outFlow, Desc: fmt.Sprintf("%s.Before", c.name), - Exiter: exiter, + Exiter: inFlow.Exiter, } inFlow.Success = newInFlow @@ -695,7 +699,7 @@ func (c *Cmd) parse(args []string, entry, inFlow, outFlow *flow.Step) error { Success: outFlow, Error: outFlow, Desc: fmt.Sprintf("%s.After", c.name), - Exiter: exiter, + Exiter: outFlow.Exiter, } args = args[nargsLen:] @@ -706,7 +710,7 @@ func (c *Cmd) parse(args []string, entry, inFlow, outFlow *flow.Step) error { Success: newOutFlow, Error: newOutFlow, Desc: fmt.Sprintf("%s.Action", c.name), - Exiter: exiter, + Exiter: inFlow.Exiter, } entry.Run(nil) @@ -731,10 +735,10 @@ func (c *Cmd) parse(args []string, entry, inFlow, outFlow *flow.Step) error { switch { case strings.HasPrefix(arg, "-"): err = fmt.Errorf("Error: illegal option %s", arg) - fmt.Fprintln(stdErr, err.Error()) + fmt.Fprintln(c.cli.stdErr, err.Error()) default: err = fmt.Errorf("Error: illegal input %s", arg) - fmt.Fprintln(stdErr, err.Error()) + fmt.Fprintln(c.cli.stdErr, err.Error()) } c.PrintHelp() c.onError(err) diff --git a/commands_test.go b/commands_test.go index 0ea8aac..6b2adb6 100644 --- a/commands_test.go +++ b/commands_test.go @@ -2,6 +2,7 @@ package cli import ( "flag" + "io/ioutil" "os" "strings" @@ -17,9 +18,8 @@ import ( ) func okCmd(t *testing.T, spec string, init CmdInitializer, args []string) { - defer suppressOutput()() - cmd := &Cmd{ + cli: &Cli{stdOut: ioutil.Discard, stdErr: ioutil.Discard}, name: "test", optionsIdx: map[string]*container.Container{}, argsIdx: map[string]*container.Container{}, @@ -37,9 +37,8 @@ func okCmd(t *testing.T, spec string, init CmdInitializer, args []string) { } func failCmd(t *testing.T, spec string, init CmdInitializer, args []string) { - defer suppressOutput()() - cmd := &Cmd{ + cli: &Cli{stdOut: ioutil.Discard, stdErr: ioutil.Discard}, name: "test", optionsIdx: map[string]*container.Container{}, argsIdx: map[string]*container.Container{},