From 6f99a7c939492b8e1809b7e01191a068a07fc737 Mon Sep 17 00:00:00 2001 From: Spike Curtis <spike@coder.com> Date: Wed, 6 Nov 2024 15:01:46 +0400 Subject: [PATCH] feat: add IgnoreErrorFn to slogtest options Signed-off-by: Spike Curtis <spike@coder.com> --- sloggers/slogtest/t.go | 35 ++++++++++++++++++++------ sloggers/slogtest/t_test.go | 49 +++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 8 deletions(-) diff --git a/sloggers/slogtest/t.go b/sloggers/slogtest/t.go index df0225a..b1fbc86 100644 --- a/sloggers/slogtest/t.go +++ b/sloggers/slogtest/t.go @@ -46,6 +46,13 @@ type Options struct { // as these are nearly always benign in testing. Override to []error{} (zero // length error slice) to disable the whitelist entirely. IgnoredErrorIs []error + // IgnoreErrorFn, if non-nil, defines a function that should return true if + // the given SinkEntry should not error the test on Error or Critical. The + // result of this function is logically ORed with ignore directives defined + // by IgnoreErrors and IgnoredErrorIs. To depend exclusively on + // IgnoreErrorFn, set IgnoreErrors=false and IgnoredErrorIs=[]error{} (zero + // length error slice). + IgnoreErrorFn func(slog.SinkEntry) bool } var DefaultIgnoredErrorIs = []error{context.Canceled, context.DeadlineExceeded} @@ -117,17 +124,16 @@ func (ts *testSink) shouldIgnoreError(ent slog.SinkEntry) bool { if ts.opts.IgnoreErrors { return true } - for _, f := range ent.Fields { - if f.Name == "error" { - if err, ok := f.Value.(error); ok { - for _, ig := range ts.opts.IgnoredErrorIs { - if xerrors.Is(err, ig) { - return true - } - } + if err, ok := FindFirstError(ent); ok { + for _, ig := range ts.opts.IgnoredErrorIs { + if xerrors.Is(err, ig) { + return true } } } + if ts.opts.IgnoreErrorFn != nil { + return ts.opts.IgnoreErrorFn(ent) + } return false } @@ -162,3 +168,16 @@ func Fatal(t testing.TB, msg string, fields ...any) { slog.Helper() l(t).Fatal(ctx, msg, fields...) } + +// FindFirstError finds the first slog.Field named "error" that contains an +// error value. +func FindFirstError(ent slog.SinkEntry) (err error, ok bool) { + for _, f := range ent.Fields { + if f.Name == "error" { + if err, ok = f.Value.(error); ok { + return err, true + } + } + } + return nil, false +} diff --git a/sloggers/slogtest/t_test.go b/sloggers/slogtest/t_test.go index 0637bc5..2aa09d5 100644 --- a/sloggers/slogtest/t_test.go +++ b/sloggers/slogtest/t_test.go @@ -2,6 +2,7 @@ package slogtest_test import ( "context" + "fmt" "testing" "golang.org/x/xerrors" @@ -108,6 +109,46 @@ func TestIgnoreErrorIs_Explicit(t *testing.T) { l.Fatal(bg, "hello", slog.Error(xerrors.Errorf("test %w:", ignored))) } +func TestIgnoreErrorFn(t *testing.T) { + t.Parallel() + + tb := &fakeTB{} + ignored := testCodedError{code: 777} + notIgnored := testCodedError{code: 911} + l := slogtest.Make(tb, &slogtest.Options{IgnoreErrorFn: func(ent slog.SinkEntry) bool { + err, ok := slogtest.FindFirstError(ent) + if !ok { + t.Error("did not contain an error") + return false + } + ce := testCodedError{} + if !xerrors.As(err, &ce) { + return false + } + return ce.code != 911 + }}) + + l.Error(bg, "ignored", slog.Error(xerrors.Errorf("test %w:", ignored))) + assert.Equal(t, "errors", 0, tb.errors) + + l.Error(bg, "not ignored", slog.Error(xerrors.Errorf("test %w:", notIgnored))) + assert.Equal(t, "errors", 1, tb.errors) + + // still ignored by default for IgnoredErrorIs + l.Error(bg, "canceled", slog.Error(xerrors.Errorf("test %w:", context.Canceled))) + assert.Equal(t, "errors", 1, tb.errors) + + l.Error(bg, "new", slog.Error(xerrors.New("test"))) + assert.Equal(t, "errors", 2, tb.errors) + + defer func() { + recover() + assert.Equal(t, "fatals", 1, tb.fatals) + }() + + l.Fatal(bg, "hello", slog.Error(xerrors.Errorf("test %w:", ignored))) +} + func TestCleanup(t *testing.T) { t.Parallel() @@ -163,3 +204,11 @@ func (tb *fakeTB) Fatal(v ...interface{}) { func (tb *fakeTB) Cleanup(fn func()) { tb.cleanups = append(tb.cleanups, fn) } + +type testCodedError struct { + code int +} + +func (e testCodedError) Error() string { + return fmt.Sprintf("code: %d", e.code) +}