From 6f99a7c939492b8e1809b7e01191a068a07fc737 Mon Sep 17 00:00:00 2001
From: Spike Curtis <spike@coder.com>
Date: Wed, 6 Nov 2024 15:01:46 +0400
Subject: [PATCH] feat: add IgnoreErrorFn to slogtest options

Signed-off-by: Spike Curtis <spike@coder.com>
---
 sloggers/slogtest/t.go      | 35 ++++++++++++++++++++------
 sloggers/slogtest/t_test.go | 49 +++++++++++++++++++++++++++++++++++++
 2 files changed, 76 insertions(+), 8 deletions(-)

diff --git a/sloggers/slogtest/t.go b/sloggers/slogtest/t.go
index df0225a..b1fbc86 100644
--- a/sloggers/slogtest/t.go
+++ b/sloggers/slogtest/t.go
@@ -46,6 +46,13 @@ type Options struct {
 	// as these are nearly always benign in testing. Override to []error{} (zero
 	// length error slice) to disable the whitelist entirely.
 	IgnoredErrorIs []error
+	// IgnoreErrorFn, if non-nil, defines a function that should return true if
+	// the given SinkEntry should not error the test on Error or Critical.  The
+	// result of this function is logically ORed with ignore directives defined
+	// by IgnoreErrors and IgnoredErrorIs. To depend exclusively on
+	// IgnoreErrorFn, set IgnoreErrors=false and IgnoredErrorIs=[]error{} (zero
+	// length error slice).
+	IgnoreErrorFn func(slog.SinkEntry) bool
 }
 
 var DefaultIgnoredErrorIs = []error{context.Canceled, context.DeadlineExceeded}
@@ -117,17 +124,16 @@ func (ts *testSink) shouldIgnoreError(ent slog.SinkEntry) bool {
 	if ts.opts.IgnoreErrors {
 		return true
 	}
-	for _, f := range ent.Fields {
-		if f.Name == "error" {
-			if err, ok := f.Value.(error); ok {
-				for _, ig := range ts.opts.IgnoredErrorIs {
-					if xerrors.Is(err, ig) {
-						return true
-					}
-				}
+	if err, ok := FindFirstError(ent); ok {
+		for _, ig := range ts.opts.IgnoredErrorIs {
+			if xerrors.Is(err, ig) {
+				return true
 			}
 		}
 	}
+	if ts.opts.IgnoreErrorFn != nil {
+		return ts.opts.IgnoreErrorFn(ent)
+	}
 	return false
 }
 
@@ -162,3 +168,16 @@ func Fatal(t testing.TB, msg string, fields ...any) {
 	slog.Helper()
 	l(t).Fatal(ctx, msg, fields...)
 }
+
+// FindFirstError finds the first slog.Field named "error" that contains an
+// error value.
+func FindFirstError(ent slog.SinkEntry) (err error, ok bool) {
+	for _, f := range ent.Fields {
+		if f.Name == "error" {
+			if err, ok = f.Value.(error); ok {
+				return err, true
+			}
+		}
+	}
+	return nil, false
+}
diff --git a/sloggers/slogtest/t_test.go b/sloggers/slogtest/t_test.go
index 0637bc5..2aa09d5 100644
--- a/sloggers/slogtest/t_test.go
+++ b/sloggers/slogtest/t_test.go
@@ -2,6 +2,7 @@ package slogtest_test
 
 import (
 	"context"
+	"fmt"
 	"testing"
 
 	"golang.org/x/xerrors"
@@ -108,6 +109,46 @@ func TestIgnoreErrorIs_Explicit(t *testing.T) {
 	l.Fatal(bg, "hello", slog.Error(xerrors.Errorf("test %w:", ignored)))
 }
 
+func TestIgnoreErrorFn(t *testing.T) {
+	t.Parallel()
+
+	tb := &fakeTB{}
+	ignored := testCodedError{code: 777}
+	notIgnored := testCodedError{code: 911}
+	l := slogtest.Make(tb, &slogtest.Options{IgnoreErrorFn: func(ent slog.SinkEntry) bool {
+		err, ok := slogtest.FindFirstError(ent)
+		if !ok {
+			t.Error("did not contain an error")
+			return false
+		}
+		ce := testCodedError{}
+		if !xerrors.As(err, &ce) {
+			return false
+		}
+		return ce.code != 911
+	}})
+
+	l.Error(bg, "ignored", slog.Error(xerrors.Errorf("test %w:", ignored)))
+	assert.Equal(t, "errors", 0, tb.errors)
+
+	l.Error(bg, "not ignored", slog.Error(xerrors.Errorf("test %w:", notIgnored)))
+	assert.Equal(t, "errors", 1, tb.errors)
+
+	// still ignored by default for IgnoredErrorIs
+	l.Error(bg, "canceled", slog.Error(xerrors.Errorf("test %w:", context.Canceled)))
+	assert.Equal(t, "errors", 1, tb.errors)
+
+	l.Error(bg, "new", slog.Error(xerrors.New("test")))
+	assert.Equal(t, "errors", 2, tb.errors)
+
+	defer func() {
+		recover()
+		assert.Equal(t, "fatals", 1, tb.fatals)
+	}()
+
+	l.Fatal(bg, "hello", slog.Error(xerrors.Errorf("test %w:", ignored)))
+}
+
 func TestCleanup(t *testing.T) {
 	t.Parallel()
 
@@ -163,3 +204,11 @@ func (tb *fakeTB) Fatal(v ...interface{}) {
 func (tb *fakeTB) Cleanup(fn func()) {
 	tb.cleanups = append(tb.cleanups, fn)
 }
+
+type testCodedError struct {
+	code int
+}
+
+func (e testCodedError) Error() string {
+	return fmt.Sprintf("code: %d", e.code)
+}