Skip to content

Commit 0d3b6ea

Browse files
committed
enhance capture functions and add comprehensive tests
1 parent ea84ecb commit 0d3b6ea

File tree

2 files changed

+194
-79
lines changed

2 files changed

+194
-79
lines changed

capture.go

+74-57
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,69 @@ import (
88
"testing"
99
)
1010

11-
// CaptureStdout captures the output of a function that writes to stdout.
12-
// All Capture functions are not thread-safe if used in parallel tests.
13-
// Usually it is better to pass a custom io.Writer to the function under test instead.
11+
// CaptureStdout captures os.Stdout output from the provided function.
1412
func CaptureStdout(t *testing.T, f func()) string {
1513
t.Helper()
16-
return capture(t, os.Stdout, f)
14+
old := os.Stdout
15+
r, w, err := os.Pipe()
16+
if err != nil {
17+
t.Fatal(err)
18+
}
19+
20+
os.Stdout = w
21+
defer func() { os.Stdout = old }()
22+
23+
var buf bytes.Buffer
24+
var wg sync.WaitGroup
25+
wg.Add(1)
26+
27+
go func() {
28+
defer wg.Done()
29+
_, err := io.Copy(&buf, r)
30+
if err != nil {
31+
t.Errorf("failed to read captured stdout: %v", err)
32+
}
33+
}()
34+
35+
f()
36+
_ = w.Close()
37+
wg.Wait()
38+
return buf.String()
1739
}
1840

19-
// CaptureStderr captures the output of a function that writes to stderr.
41+
// CaptureStderr captures os.Stderr output from the provided function.
2042
func CaptureStderr(t *testing.T, f func()) string {
2143
t.Helper()
22-
return capture(t, os.Stderr, f)
44+
old := os.Stderr
45+
r, w, err := os.Pipe()
46+
if err != nil {
47+
t.Fatal(err)
48+
}
49+
50+
os.Stderr = w
51+
defer func() { os.Stderr = old }()
52+
53+
var buf bytes.Buffer
54+
var wg sync.WaitGroup
55+
wg.Add(1)
56+
57+
go func() {
58+
defer wg.Done()
59+
_, err := io.Copy(&buf, r)
60+
if err != nil {
61+
t.Errorf("failed to read captured stderr: %v", err)
62+
}
63+
}()
64+
65+
f()
66+
_ = w.Close()
67+
wg.Wait()
68+
return buf.String()
2369
}
2470

25-
// CaptureStdoutAndStderr captures the output of a function that writes to
26-
// stdout and stderr.
27-
func CaptureStdoutAndStderr(t *testing.T, f func()) (o, e string) {
71+
// CaptureStdoutAndStderr captures os.Stdout and os.Stderr output from the provided function.
72+
func CaptureStdoutAndStderr(t *testing.T, f func()) (stdout, stderr string) {
2873
t.Helper()
29-
3074
oldout, olderr := os.Stdout, os.Stderr
3175
rOut, wOut, err := os.Pipe()
3276
if err != nil {
@@ -36,65 +80,38 @@ func CaptureStdoutAndStderr(t *testing.T, f func()) (o, e string) {
3680
if err != nil {
3781
t.Fatal(err)
3882
}
39-
os.Stdout, os.Stderr = wOut, wErr
83+
84+
os.Stdout = wOut
85+
os.Stderr = wErr
4086
defer func() {
41-
os.Stdout, os.Stderr = oldout, olderr
87+
os.Stdout = oldout
88+
os.Stderr = olderr
4289
}()
43-
outCh, errCh := make(chan string), make(chan string)
4490

91+
var outBuf, errBuf bytes.Buffer
4592
var wg sync.WaitGroup
4693
wg.Add(2)
4794

48-
go func() { //nolint
49-
var buf bytes.Buffer
50-
wg.Done()
51-
if _, err := io.Copy(&buf, rOut); err != nil {
52-
t.Fatal(err) //nolint
95+
go func() {
96+
defer wg.Done()
97+
_, err := io.Copy(&outBuf, rOut)
98+
if err != nil {
99+
t.Errorf("failed to read captured stdout: %v", err)
53100
}
54-
outCh <- buf.String()
55101
}()
56102

57-
go func() { //nolint
58-
var buf bytes.Buffer
59-
wg.Done()
60-
if _, err := io.Copy(&buf, rErr); err != nil {
61-
t.Fatal(err) //nolint
103+
go func() {
104+
defer wg.Done()
105+
_, err := io.Copy(&errBuf, rErr)
106+
if err != nil {
107+
t.Errorf("failed to read captured stderr: %v", err)
62108
}
63-
errCh <- buf.String()
64109
}()
65110

66-
wg.Wait()
67111
f()
112+
_ = wOut.Close()
113+
_ = wErr.Close()
114+
wg.Wait()
68115

69-
if err := wOut.Close(); err != nil {
70-
t.Fatal(err)
71-
}
72-
if err := wErr.Close(); err != nil {
73-
t.Fatal(err)
74-
}
75-
76-
stdout, stderr := <-outCh, <-errCh
77-
return stdout, stderr
78-
}
79-
80-
func capture(t *testing.T, out *os.File, f func()) string {
81-
old := out
82-
r, w, err := os.Pipe()
83-
if err != nil {
84-
t.Fatal(err)
85-
}
86-
*out = *w
87-
defer func() { *out = *old }()
88-
89-
f()
90-
91-
_ = w.Close()
92-
93-
var buf bytes.Buffer
94-
_, err = io.Copy(&buf, r)
95-
if err != nil {
96-
t.Fatal(err)
97-
}
98-
99-
return buf.String()
116+
return outBuf.String(), errBuf.String()
100117
}

capture_test.go

+120-22
Original file line numberDiff line numberDiff line change
@@ -3,40 +3,138 @@ package testutils
33
import (
44
"fmt"
55
"os"
6+
"strings"
67
"testing"
78
)
89

910
func TestCaptureStdout(t *testing.T) {
10-
want := "hello world\n"
11-
got := CaptureStdout(t, func() {
12-
fmt.Fprintf(os.Stdout, want)
13-
})
14-
if want != got {
15-
t.Errorf("want %q, got %q", want, got)
11+
tests := []struct {
12+
name string
13+
want string
14+
f func()
15+
}{
16+
{
17+
name: "simple output",
18+
want: "hello world\n",
19+
f: func() {
20+
fmt.Fprintf(os.Stdout, "hello world\n")
21+
},
22+
},
23+
{
24+
name: "multiple lines",
25+
want: "line1\nline2\n",
26+
f: func() {
27+
fmt.Fprintln(os.Stdout, "line1")
28+
fmt.Fprintln(os.Stdout, "line2")
29+
},
30+
},
31+
{
32+
name: "empty output",
33+
want: "",
34+
f: func() {},
35+
},
36+
}
37+
38+
for _, tt := range tests {
39+
t.Run(tt.name, func(t *testing.T) {
40+
got := CaptureStdout(t, tt.f)
41+
if got != tt.want {
42+
t.Errorf("CaptureStdout() = %q, want %q", got, tt.want)
43+
}
44+
})
1645
}
1746
}
1847

1948
func TestCaptureStderr(t *testing.T) {
20-
want := "hello world\n"
21-
got := CaptureStderr(t, func() {
22-
fmt.Fprintf(os.Stderr, want)
23-
})
24-
if want != got {
25-
t.Errorf("want %q, got %q", want, got)
49+
tests := []struct {
50+
name string
51+
want string
52+
f func()
53+
}{
54+
{
55+
name: "simple output",
56+
want: "hello world\n",
57+
f: func() {
58+
fmt.Fprintf(os.Stderr, "hello world\n")
59+
},
60+
},
61+
{
62+
name: "multiple lines",
63+
want: "line1\nline2\n",
64+
f: func() {
65+
fmt.Fprintln(os.Stderr, "line1")
66+
fmt.Fprintln(os.Stderr, "line2")
67+
},
68+
},
69+
{
70+
name: "empty output",
71+
want: "",
72+
f: func() {},
73+
},
74+
}
75+
76+
for _, tt := range tests {
77+
t.Run(tt.name, func(t *testing.T) {
78+
got := CaptureStderr(t, tt.f)
79+
if got != tt.want {
80+
t.Errorf("CaptureStderr() = %q, want %q", got, tt.want)
81+
}
82+
})
2683
}
2784
}
2885

2986
func TestCaptureStdoutAndStderr(t *testing.T) {
30-
wantOut := "hello world\n"
31-
wantErr := "hello world\n"
32-
gotOut, gotErr := CaptureStdoutAndStderr(t, func() {
33-
fmt.Fprintf(os.Stdout, wantOut)
34-
fmt.Fprintf(os.Stderr, wantErr)
35-
})
36-
if wantOut != gotOut {
37-
t.Errorf("want %q, got %q", wantOut, gotOut)
87+
tests := []struct {
88+
name string
89+
wantOut string
90+
wantErr string
91+
f func()
92+
}{
93+
{
94+
name: "both outputs",
95+
wantOut: "stdout\n",
96+
wantErr: "stderr\n",
97+
f: func() {
98+
fmt.Fprintln(os.Stdout, "stdout")
99+
fmt.Fprintln(os.Stderr, "stderr")
100+
},
101+
},
102+
{
103+
name: "only stdout",
104+
wantOut: "stdout\n",
105+
wantErr: "",
106+
f: func() {
107+
fmt.Fprintln(os.Stdout, "stdout")
108+
},
109+
},
110+
{
111+
name: "only stderr",
112+
wantOut: "",
113+
wantErr: "stderr\n",
114+
f: func() {
115+
fmt.Fprintln(os.Stderr, "stderr")
116+
},
117+
},
118+
{
119+
name: "large output",
120+
wantOut: strings.Repeat("a", 100000) + "\n",
121+
wantErr: strings.Repeat("b", 100000) + "\n",
122+
f: func() {
123+
fmt.Fprintln(os.Stdout, strings.Repeat("a", 100000))
124+
fmt.Fprintln(os.Stderr, strings.Repeat("b", 100000))
125+
},
126+
},
38127
}
39-
if wantErr != gotErr {
40-
t.Errorf("want %q, got %q", wantErr, gotErr)
128+
129+
for _, tt := range tests {
130+
t.Run(tt.name, func(t *testing.T) {
131+
gotOut, gotErr := CaptureStdoutAndStderr(t, tt.f)
132+
if gotOut != tt.wantOut {
133+
t.Errorf("CaptureStdoutAndStderr() stdout = %q, want %q", gotOut, tt.wantOut)
134+
}
135+
if gotErr != tt.wantErr {
136+
t.Errorf("CaptureStdoutAndStderr() stderr = %q, want %q", gotErr, tt.wantErr)
137+
}
138+
})
41139
}
42140
}

0 commit comments

Comments
 (0)