From 4b6327f28e3cd7e9aa1840df7b13edbab5b5a999 Mon Sep 17 00:00:00 2001 From: Larry Clapp Date: Fri, 8 Mar 2024 13:20:19 -0500 Subject: [PATCH] Cancellable reads; EOF on "read" sets var to "" - Make reading cancellable. Not all input from stdin, just that done directly by the shell (i.e. the "read" builtin). Exec'ed programs still read directly from stdin's os.File and are not cancellable. - If you press ^D (EOF) when reading into a shell variable, set the variable to "". This is consistent with bash & zsh. --- go.mod | 1 + go.sum | 2 ++ interp/builtin.go | 52 ++++++++++++++++++++++++++++-------- interp/interp_test.go | 61 +++++++++++++++++++++++++++++++++++++++++++ interp/runner.go | 2 +- 5 files changed, 106 insertions(+), 12 deletions(-) diff --git a/go.mod b/go.mod index 03824d50..8dd2d312 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/go-quicktest/qt v1.101.0 github.com/google/go-cmp v0.6.0 github.com/google/renameio/v2 v2.0.0 + github.com/muesli/cancelreader v0.2.2 github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e github.com/rogpeppe/go-internal v1.12.0 golang.org/x/sync v0.6.0 diff --git a/go.sum b/go.sum index ff4788f1..4c15ee1a 100644 --- a/go.sum +++ b/go.sum @@ -11,6 +11,8 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= +github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e h1:aoZm08cpOy4WuID//EZDgcC4zIxODThtZNPirFr42+A= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= diff --git a/interp/builtin.go b/interp/builtin.go index b24a3d6a..0342a7f0 100644 --- a/interp/builtin.go +++ b/interp/builtin.go @@ -9,12 +9,13 @@ import ( "context" "errors" "fmt" - "io" "os" "path/filepath" "strconv" "strings" + "sync" + "github.com/muesli/cancelreader" "mvdan.cc/sh/v3/expand" "mvdan.cc/sh/v3/syntax" ) @@ -589,10 +590,7 @@ func (r *Runner) builtinCode(ctx context.Context, pos syntax.Pos, name string, a r.out(prompt) } - line, err := r.readLine(raw) - if err != nil { - return 1 - } + line, err := r.readLine(ctx, raw) if len(args) == 0 { args = append(args, shellReplyVar) } @@ -606,6 +604,12 @@ func (r *Runner) builtinCode(ctx context.Context, pos syntax.Pos, name string, a r.setVarString(name, val) } + // We can get data back from readLine and an error at the same time, so + // check err after we process the data. + if err != nil { + return 1 + } + return 0 case "getopts": @@ -917,7 +921,7 @@ func (r *Runner) printOptLine(name string, enabled, supported bool) { r.outf("%s\t%s\t(%q not supported)\n", name, state, r.optStatusText(!enabled)) } -func (r *Runner) readLine(raw bool) ([]byte, error) { +func (r *Runner) readLine(ctx context.Context, raw bool) ([]byte, error) { if r.stdin == nil { return nil, errors.New("interp: can't read, there's no stdin") } @@ -925,9 +929,38 @@ func (r *Runner) readLine(raw bool) ([]byte, error) { var line []byte esc := false + stdin := r.stdin + if osFile, ok := stdin.(*os.File); ok { + cr, err := cancelreader.NewReader(osFile) + if err != nil { + return nil, err + } + stdin = cr + done := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(1) + go func() { + select { + case <-ctx.Done(): + cr.Cancel() + case <-done: + } + wg.Done() + }() + defer func() { + close(done) + wg.Wait() + // Could put the Close in the above goroutine, but if "read" is + // immediately called again, the Close might overlap with creating a + // new cancelreader. Want this cancelreader to be completely closed + // by the time readLine returns. + cr.Close() + }() + } + for { var buf [1]byte - n, err := r.stdin.Read(buf[:]) + n, err := stdin.Read(buf[:]) if n > 0 { b := buf[0] switch { @@ -945,11 +978,8 @@ func (r *Runner) readLine(raw bool) ([]byte, error) { esc = false } } - if err == io.EOF && len(line) > 0 { - return line, nil - } if err != nil { - return nil, err + return line, err } } } diff --git a/interp/interp_test.go b/interp/interp_test.go index 8a6a528a..f65b721c 100644 --- a/interp/interp_test.go +++ b/interp/interp_test.go @@ -2903,6 +2903,22 @@ done <<< 2`, "read -r -p 'Prompt and raw flag together: ' a <<< '\\a\\b\\c'; echo $a", "Prompt and raw flag together: \\a\\b\\c\n #IGNORE bash requires a terminal", }, + { + `a=a; echo | (read a; echo -n "$a")`, + "", + }, + { + `a=b; read a < /dev/null; echo -n "$a"`, + "", + }, + { + "a=c; echo x | (read a; echo -n $a)", + "x", + }, + { + "a=d; echo -n y | (read a; echo -n $a)", + "y", + }, // getopts { @@ -3943,6 +3959,51 @@ func TestRunnerContext(t *testing.T) { } } +func TestCancelreader(t *testing.T) { + t.Parallel() + + p := syntax.NewParser() + file := parse(t, p, "read x") + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + // Make the linter happy, even though we deliberately wait for the + // timeout. + defer cancel() + + var stdinRead *os.File + if runtime.GOOS == "windows" { + // On Windows, the cancelreader only works on stdin + stdinRead = os.Stdin + } else { + var stdinWrite *os.File + var err error + stdinRead, stdinWrite, err = os.Pipe() + if err != nil { + t.Fatalf("Error calling os.Pipe: %v", err) + } + defer func() { + stdinWrite.Close() + stdinRead.Close() + }() + } + r, _ := interp.New(interp.StdIO(stdinRead, nil, nil)) + now := time.Now() + errChan := make(chan error) + go func() { + errChan <- r.Run(ctx, file) + }() + + timeout := 500 * time.Millisecond + select { + case err := <-errChan: + if err == nil || err.Error() != "exit status 1" || ctx.Err() != context.DeadlineExceeded { + t.Fatalf("'read x' did not timeout correctly; err: %v, ctx.Err(): %v; dur: %v", + err, ctx.Err(), time.Since(now)) + } + case <-time.After(timeout): + t.Fatalf("program was not killed in %s", timeout) + } +} + func TestRunnerAltNodes(t *testing.T) { t.Parallel() diff --git a/interp/runner.go b/interp/runner.go index 79bf0fd4..7df486fd 100644 --- a/interp/runner.go +++ b/interp/runner.go @@ -515,7 +515,7 @@ func (r *Runner) cmd(ctx context.Context, cm syntax.Command) { } r.errf("%s", ps3) - line, err := r.readLine(true) + line, err := r.readLine(ctx, true) if err != nil { r.exit = 1 return nil