From 7cec59afd8fc086328cc65f36a86aac93ab59d3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20H=C3=B6rl?= Date: Wed, 15 Jan 2020 15:17:11 +0000 Subject: [PATCH] Allow adding headers to generated fakes When the `-header` flag is specified and is a valid file path, the content of that file will be prepended to the generated fake. This is useful for e.g. specify a licence header. The header flag can be specified in the `go:generate` and the `counterfeiter:generate` line or in both, where the later takes precedence over the former. --- arguments/parser.go | 8 + arguments/parser_test.go | 12 + arguments/usage.go | 26 ++ benchmark_test.go | 59 +++-- fixtures/headers/default.header.go.txt | 3 + .../defaultheaderfakes/fake_header_default.go | 40 +++ .../fake_header_specific.go | 40 +++ fixtures/headers/defaultheader/header.go | 9 + fixtures/headers/nodefaultheader/headers.go | 9 + .../fake_header_default.go | 37 +++ .../fake_header_specific.go | 40 +++ fixtures/headers/specific.header.go.txt | 3 + generator/fake.go | 8 +- generator/file_reader.go | 95 +++++++ generator/file_reader_other_test.go | 12 + generator/file_reader_test.go | 236 ++++++++++++++++++ generator/file_reader_windows_test.go | 12 + generator/function_template.go | 2 +- generator/generator_internals_test.go | 8 +- generator/interface_template.go | 2 +- generator/package_template.go | 2 +- integration/roundtrip_test.go | 46 ++-- .../expected_fake_writecloser.header.txt | 186 ++++++++++++++ ...=> expected_fake_writecloser.noheader.txt} | 0 main.go | 37 ++- 25 files changed, 873 insertions(+), 59 deletions(-) create mode 100644 fixtures/headers/default.header.go.txt create mode 100644 fixtures/headers/defaultheader/defaultheaderfakes/fake_header_default.go create mode 100644 fixtures/headers/defaultheader/defaultheaderfakes/fake_header_specific.go create mode 100644 fixtures/headers/defaultheader/header.go create mode 100644 fixtures/headers/nodefaultheader/headers.go create mode 100644 fixtures/headers/nodefaultheader/nodefaultheaderfakes/fake_header_default.go create mode 100644 fixtures/headers/nodefaultheader/nodefaultheaderfakes/fake_header_specific.go create mode 100644 fixtures/headers/specific.header.go.txt create mode 100644 generator/file_reader.go create mode 100644 generator/file_reader_other_test.go create mode 100644 generator/file_reader_test.go create mode 100644 generator/file_reader_windows_test.go create mode 100644 integration/testdata/expected_fake_writecloser.header.txt rename integration/testdata/{expected_fake_writecloser.txt => expected_fake_writecloser.noheader.txt} (100%) diff --git a/arguments/parser.go b/arguments/parser.go index 76cb7d7..db0d60f 100644 --- a/arguments/parser.go +++ b/arguments/parser.go @@ -40,6 +40,11 @@ func New(args []string, workingDir string, evaler Evaler, stater Stater) (*Parse false, "Identify all //counterfeiter:generate directives in the current working directory and generate fakes for them", ) + headerFlag := fs.String( + "header", + "", + "A path to a file that should be used as a header for the generated fake", + ) helpFlag := fs.Bool( "help", false, @@ -62,6 +67,7 @@ func New(args []string, workingDir string, evaler Evaler, stater Stater) (*Parse PrintToStdOut: any(args, "-"), GenerateInterfaceAndShimFromPackageDirectory: packageMode, GenerateMode: *generateFlag, + HeaderFile: *headerFlag, } if *generateFlag { return result, nil @@ -193,6 +199,8 @@ type ParsedArguments struct { PrintToStdOut bool GenerateMode bool + + HeaderFile string } func fixupUnexportedNames(interfaceName string) string { diff --git a/arguments/parser_test.go b/arguments/parser_test.go index 4852a95..6420372 100644 --- a/arguments/parser_test.go +++ b/arguments/parser_test.go @@ -316,6 +316,18 @@ func testParsingArguments(t *testing.T, when spec.G, it spec.S) { Expect(parsedArgs.DestinationPackageName).To(Equal("fake_command_runnerfakes")) }) }) + + when("when '-header' is used", func() { + it.Before(func() { + args = []string{"counterfeiter", "-header", "some/header/file", "some.interface"} + justBefore() + }) + + it("sets the HeaderFile attriburte on the parsedArgs struct", func() { + Expect(parsedArgs.HeaderFile).To(Equal("some/header/file")) + Expect(err).NotTo(HaveOccurred()) + }) + }) } func fakeFileInfo(filename string, isDir bool) os.FileInfo { diff --git a/arguments/usage.go b/arguments/usage.go index e2926b4..69cc027 100644 --- a/arguments/usage.go +++ b/arguments/usage.go @@ -4,6 +4,7 @@ const usage = ` USAGE counterfeiter [-generate>] [-o ] [-p] [--fake-name ] + [-header ] [] [-] ARGUMENTS @@ -75,6 +76,31 @@ OPTIONS # now generate fake in ${PWD}/osshim/os_fake (fake_os.go) go generate osshim/... + -header + Path to the file which should be used as a header for all generated fakes. + By default, no special header is used. + This is useful to e.g. add a licence header to every fake. + + If the generate mode is used and both the "go:generate" and the + "counterfeiter:generate" specify a header file, the header file from the + "counterfeiter:generate" line takes precedence. + + example: + # having the following code in a package ... + //go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -header ./generic.go.txt -generate + //counterfeiter:generate -header ./specific.go.txt . MyInterface + //counterfeiter:generate . MyOtherInterface + //counterfeiter:generate . MyThirdInterface + + # ... generating the fakes ... + go generate . + + # writes "FakeMyInterface" with ./specific.go.txt as a header + # writes "FakeMyOtherInterface" & "FakeMyThirdInterface" with ./generic.go.txt as a header + + + counterfeiter -header ./fixtures/boilerplate.go.txt ./mypackage MyInterface + --fake-name Name of the fake struct to generate. By default, 'Fake' will be prepended to the name of the original interface. (ignored in diff --git a/benchmark_test.go b/benchmark_test.go index 19aa792..c46aa6f 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -10,7 +10,7 @@ import ( "github.com/maxbrunsfeld/counterfeiter/v6/generator" ) -func BenchmarkWithoutCache(b *testing.B) { +func BenchmarkDoGenerate(b *testing.B) { b.StopTimer() workingDir, err := filepath.Abs(filepath.Join(".", "fixtures")) if err != nil { @@ -29,35 +29,40 @@ func BenchmarkWithoutCache(b *testing.B) { PrintToStdOut: false, } - cache := &generator.FakeCache{} - b.StartTimer() - for i := 0; i < b.N; i++ { - doGenerate(workingDir, args, cache) + caches := map[string]struct { + cache generator.Cacher + headerReader generator.FileReader + }{ + "without caches": { + cache: &generator.FakeCache{}, + headerReader: &generator.SimpleFileReader{}, + }, + "with caches": { + cache: &generator.Cache{}, + headerReader: &generator.CachedFileReader{}, + }, } -} -func BenchmarkWithCache(b *testing.B) { - b.StopTimer() - workingDir, err := filepath.Abs(filepath.Join(".", "fixtures")) - if err != nil { - b.Fatal(err) - } - log.SetOutput(ioutil.Discard) - - args := &arguments.ParsedArguments{ - GenerateInterfaceAndShimFromPackageDirectory: false, - SourcePackageDir: workingDir, - PackagePath: workingDir, - OutputPath: filepath.Join(workingDir, "fixturesfakes", "fake_something.go"), - DestinationPackageName: "fixturesfakes", - InterfaceName: "Something", - FakeImplName: "FakeSomething", - PrintToStdOut: false, + headers := map[string]string{ + "without headerfile": "", + "with headerfile": "headers/default.header.go.txt", } - cache := &generator.Cache{} - b.StartTimer() - for i := 0; i < b.N; i++ { - doGenerate(workingDir, args, cache) + for name, caches := range caches { + caches := caches + b.Run(name, func(b *testing.B) { + for name, headerFile := range headers { + headerFile := headerFile + b.Run(name, func(b *testing.B) { + args.HeaderFile = headerFile + b.StartTimer() + for i := 0; i < b.N; i++ { + if _, err := doGenerate(workingDir, args, caches.cache, caches.headerReader); err != nil { + b.Errorf("Expected doGenerate not to return an error, got %v", err) + } + } + }) // b.Run for headerFiles + } + }) // b.Run for caches } } diff --git a/fixtures/headers/default.header.go.txt b/fixtures/headers/default.header.go.txt new file mode 100644 index 0000000..6efbeb7 --- /dev/null +++ b/fixtures/headers/default.header.go.txt @@ -0,0 +1,3 @@ +// This is a default header for all the fakes in this package +// + diff --git a/fixtures/headers/defaultheader/defaultheaderfakes/fake_header_default.go b/fixtures/headers/defaultheader/defaultheaderfakes/fake_header_default.go new file mode 100644 index 0000000..c8f51d3 --- /dev/null +++ b/fixtures/headers/defaultheader/defaultheaderfakes/fake_header_default.go @@ -0,0 +1,40 @@ +// This is a default header for all the fakes in this package +// + +// Code generated by counterfeiter. DO NOT EDIT. +package defaultheaderfakes + +import ( + "sync" + + "github.com/maxbrunsfeld/counterfeiter/v6/fixtures/headers/defaultheader" +) + +type FakeHeaderDefault struct { + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeHeaderDefault) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeHeaderDefault) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ defaultheader.HeaderDefault = new(FakeHeaderDefault) diff --git a/fixtures/headers/defaultheader/defaultheaderfakes/fake_header_specific.go b/fixtures/headers/defaultheader/defaultheaderfakes/fake_header_specific.go new file mode 100644 index 0000000..246c333 --- /dev/null +++ b/fixtures/headers/defaultheader/defaultheaderfakes/fake_header_specific.go @@ -0,0 +1,40 @@ +// This is a specific header for only some of the fakes in this package +// + +// Code generated by counterfeiter. DO NOT EDIT. +package defaultheaderfakes + +import ( + "sync" + + "github.com/maxbrunsfeld/counterfeiter/v6/fixtures/headers/defaultheader" +) + +type FakeHeaderSpecific struct { + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeHeaderSpecific) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeHeaderSpecific) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ defaultheader.HeaderSpecific = new(FakeHeaderSpecific) diff --git a/fixtures/headers/defaultheader/header.go b/fixtures/headers/defaultheader/header.go new file mode 100644 index 0000000..57b20e6 --- /dev/null +++ b/fixtures/headers/defaultheader/header.go @@ -0,0 +1,9 @@ +package defaultheader // import "github.com/maxbrunsfeld/counterfeiter/v6/fixtures/headers/defaultheader" + +//go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -header ../default.header.go.txt -generate + +//counterfeiter:generate . HeaderDefault +type HeaderDefault interface{} + +//counterfeiter:generate -header ../specific.header.go.txt . HeaderSpecific +type HeaderSpecific interface{} diff --git a/fixtures/headers/nodefaultheader/headers.go b/fixtures/headers/nodefaultheader/headers.go new file mode 100644 index 0000000..4bee1ee --- /dev/null +++ b/fixtures/headers/nodefaultheader/headers.go @@ -0,0 +1,9 @@ +package nodefaultheader // import "github.com/maxbrunsfeld/counterfeiter/v6/fixtures/headers/nodefaultheader" + +//go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -generate + +//counterfeiter:generate . HeaderDefault +type HeaderDefault interface{} + +//counterfeiter:generate -header ../specific.header.go.txt . HeaderSpecific +type HeaderSpecific interface{} diff --git a/fixtures/headers/nodefaultheader/nodefaultheaderfakes/fake_header_default.go b/fixtures/headers/nodefaultheader/nodefaultheaderfakes/fake_header_default.go new file mode 100644 index 0000000..abf2d07 --- /dev/null +++ b/fixtures/headers/nodefaultheader/nodefaultheaderfakes/fake_header_default.go @@ -0,0 +1,37 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package nodefaultheaderfakes + +import ( + "sync" + + "github.com/maxbrunsfeld/counterfeiter/v6/fixtures/headers/nodefaultheader" +) + +type FakeHeaderDefault struct { + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeHeaderDefault) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeHeaderDefault) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ nodefaultheader.HeaderDefault = new(FakeHeaderDefault) diff --git a/fixtures/headers/nodefaultheader/nodefaultheaderfakes/fake_header_specific.go b/fixtures/headers/nodefaultheader/nodefaultheaderfakes/fake_header_specific.go new file mode 100644 index 0000000..ddffdc0 --- /dev/null +++ b/fixtures/headers/nodefaultheader/nodefaultheaderfakes/fake_header_specific.go @@ -0,0 +1,40 @@ +// This is a specific header for only some of the fakes in this package +// + +// Code generated by counterfeiter. DO NOT EDIT. +package nodefaultheaderfakes + +import ( + "sync" + + "github.com/maxbrunsfeld/counterfeiter/v6/fixtures/headers/nodefaultheader" +) + +type FakeHeaderSpecific struct { + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeHeaderSpecific) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeHeaderSpecific) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ nodefaultheader.HeaderSpecific = new(FakeHeaderSpecific) diff --git a/fixtures/headers/specific.header.go.txt b/fixtures/headers/specific.header.go.txt new file mode 100644 index 0000000..58aca01 --- /dev/null +++ b/fixtures/headers/specific.header.go.txt @@ -0,0 +1,3 @@ +// This is a specific header for only some of the fakes in this package +// + diff --git a/generator/fake.go b/generator/fake.go index eab38e7..329043b 100644 --- a/generator/fake.go +++ b/generator/fake.go @@ -37,6 +37,7 @@ type Fake struct { Imports Imports Methods []Method Function Method + Header string } // Method is a method of the interface. @@ -46,9 +47,13 @@ type Method struct { Returns Returns } +type fileReader interface { + Get(path string) (string, error) +} + // NewFake returns a Fake that loads the package and finds the interface or the // function. -func NewFake(fakeMode FakeMode, targetName string, packagePath string, fakeName string, destinationPackage string, workingDir string, cache Cacher) (*Fake, error) { +func NewFake(fakeMode FakeMode, targetName string, packagePath string, fakeName string, destinationPackage string, headerContent string, workingDir string, cache Cacher) (*Fake, error) { f := &Fake{ TargetName: targetName, TargetPackage: packagePath, @@ -56,6 +61,7 @@ func NewFake(fakeMode FakeMode, targetName string, packagePath string, fakeName Mode: fakeMode, DestinationPackage: destinationPackage, Imports: newImports(), + Header: headerContent, } f.Imports.Add("sync", "sync") diff --git a/generator/file_reader.go b/generator/file_reader.go new file mode 100644 index 0000000..56e2471 --- /dev/null +++ b/generator/file_reader.go @@ -0,0 +1,95 @@ +package generator + +import ( + "io" + "io/ioutil" + "os" + "path/filepath" +) + +const ( + MaxHeaderBytes = 50 * 1024 // 50kb for the header should be enough? +) + +type FileReader interface { + Get(cwd, path string) (content string, err error) +} + +type Opener func(string) (io.ReadCloser, error) + +var ( + defaultOpen Opener = func(p string) (io.ReadCloser, error) { return os.Open(p) } +) + +func (open Opener) readString(path string) (string, error) { + if open == nil { + open = defaultOpen + } + + f, err := open(path) + if err != nil { + return "", err + } + defer f.Close() + + lr := io.LimitReader(f, MaxHeaderBytes) + + b, err := ioutil.ReadAll(lr) + if err != nil { + return "", err + } + + return string(b), nil +} + +type SimpleFileReader struct { + Open Opener +} + +var _ FileReader = &SimpleFileReader{} + +func (r *SimpleFileReader) Get(cwd, path string) (string, error) { + if path == "" { + return "", nil + } + + p := normalisePath(cwd, path) + return r.Open.readString(p) +} + +type CachedFileReader struct { + Open Opener + cache map[string]string +} + +var _ FileReader = &CachedFileReader{} + +func (r *CachedFileReader) Get(cwd, path string) (string, error) { + if path == "" { + return "", nil + } + + p := normalisePath(cwd, path) + + if s, ok := r.cache[p]; ok { + return s, nil + } + + s, err := r.Open.readString(p) + if err != nil { + return "", err + } + + if r.cache == nil { + r.cache = map[string]string{} + } + r.cache[p] = s + return s, nil +} + +func normalisePath(cwd, path string) string { + if !filepath.IsAbs(path) { + path = filepath.Join(cwd, path) + } + return filepath.Clean(path) +} diff --git a/generator/file_reader_other_test.go b/generator/file_reader_other_test.go new file mode 100644 index 0000000..7ebbfdd --- /dev/null +++ b/generator/file_reader_other_test.go @@ -0,0 +1,12 @@ +// +build !windows + +package generator_test + +const ( + relFile = "file.ext" + absFile = "/file.ext" + workingDir = "/some/dir" + + relFileUp = "../file.ext" + expectedFileUp = "/some/file.ext" +) diff --git a/generator/file_reader_test.go b/generator/file_reader_test.go new file mode 100644 index 0000000..cee8f83 --- /dev/null +++ b/generator/file_reader_test.go @@ -0,0 +1,236 @@ +package generator_test + +import ( + "fmt" + "io" + "io/ioutil" + "reflect" + "strings" + "testing" + + "github.com/maxbrunsfeld/counterfeiter/v6/generator" +) + +func TestFileReader(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + readerCreator func(generator.Opener) generator.FileReader + open generator.Opener + + workingDir string + path string + + expectedErrMsg string + expectedContent string + expectedCalls []string + }{ + // SimpleFileReader + "[simple] when the filepath is empty, it's a noop": { + readerCreator: simpleReaderCreator, + }, + "[simple] when open returns an error, the error bubbles up": { + readerCreator: simpleReaderCreator, + open: openReturningErr("some error"), + path: relFile, + expectedErrMsg: "some error", + expectedCalls: []string{relFile, relFile}, + }, + "[simple] when open returns a reader, the readers content is read": { + readerCreator: simpleReaderCreator, + open: openReturningReader("some content 0"), + path: relFile, + expectedContent: "some content 0", + expectedCalls: []string{relFile, relFile}, + }, + "[simple] when the working directory is set but the filepath is absolut, the absolute path is used": { + readerCreator: simpleReaderCreator, + open: openReturningReader("some content 1"), + workingDir: workingDir, + path: absFile, + expectedContent: "some content 1", + expectedCalls: []string{absFile, absFile}, + }, + "[simple] when the working directory and a relative filepath is set, the paths are combined & cleaned": { + readerCreator: simpleReaderCreator, + open: openReturningReader("some content 2"), + workingDir: workingDir, + path: relFileUp, + expectedContent: "some content 2", + expectedCalls: []string{expectedFileUp, expectedFileUp}, + }, + "[simple] when the file content is longer then allowed it is truncated": { + readerCreator: simpleReaderCreator, + open: openReturningReader(longerContent), + path: relFile, + expectedContent: maxContent, + expectedCalls: []string{relFile, relFile}, + }, + "[simple] when the reader's Read returns an error, the error bubbles up": { + readerCreator: simpleReaderCreator, + open: openReturningFailingReader("some read error"), + path: relFile, + expectedErrMsg: "some read error", + expectedCalls: []string{relFile, relFile}, + }, + + // CachedFileReader + "[cached] when the filepath is empty, it's a noop": { + readerCreator: cachedReaderCreator, + }, + "[cached] when open returns an error, the error bubbles up": { + readerCreator: cachedReaderCreator, + open: openReturningErr("some error"), + path: relFile, + expectedErrMsg: "some error", + expectedCalls: []string{relFile, relFile}, // because on error, we don't cache + }, + "[cached] when open returns a reader, the readers content is read": { + readerCreator: cachedReaderCreator, + open: openReturningReader("some content 3"), + path: relFile, + expectedContent: "some content 3", + expectedCalls: []string{relFile}, + }, + "[cached] when the working directory is set but the filepath is absolut, the absolute path is used": { + readerCreator: cachedReaderCreator, + open: openReturningReader("some content 4"), + workingDir: workingDir, + path: absFile, + expectedContent: "some content 4", + expectedCalls: []string{absFile}, + }, + "[cached] when the working directory and a relative filepath is set, the paths are combined & cleaned": { + readerCreator: cachedReaderCreator, + open: openReturningReader("some content 5"), + workingDir: workingDir, + path: relFileUp, + expectedContent: "some content 5", + expectedCalls: []string{expectedFileUp}, + }, + "[cached] when the file content is longer then allowed it is truncated": { + readerCreator: cachedReaderCreator, + open: openReturningReader(longerContent), + path: relFile, + expectedContent: maxContent, + expectedCalls: []string{relFile}, + }, + "[cached] when the reader's Read returns an error, the error bubbles up": { + readerCreator: cachedReaderCreator, + open: openReturningFailingReader("some read error"), + path: relFile, + expectedErrMsg: "some read error", + expectedCalls: []string{relFile, relFile}, // because on error, we don't cache + }, + } + + for name, tc := range tests { + tc := tc + + t.Run(name, func(t *testing.T) { + t.Parallel() + + spy := &openSpy{Func: tc.open} + reader := tc.readerCreator(spy.Open) + + readAndCheckContent := func() { + content, err := reader.Get(tc.workingDir, tc.path) + checkErr(t, err, tc.expectedErrMsg) + + if a, e := content, tc.expectedContent; e != a { + t.Errorf("Expected content to be '%s', got '%s'", e, a) + } + } + + // let's run the tests twice, to check on caching + readAndCheckContent() + readAndCheckContent() + + if a, e := spy.Calls, tc.expectedCalls; !reflect.DeepEqual(e, a) { + t.Errorf("Expected open call args to be %#v, got %#v", e, a) + } + }) + } +} + +func simpleReaderCreator(o generator.Opener) generator.FileReader { + return &generator.SimpleFileReader{Open: o} +} +func cachedReaderCreator(o generator.Opener) generator.FileReader { + return &generator.CachedFileReader{Open: o} +} + +func openReturningErr(err string) generator.Opener { + return func(_ string) (io.ReadCloser, error) { + return nil, fmt.Errorf(err) + } +} +func openReturningReader(content string) generator.Opener { + return func(_ string) (io.ReadCloser, error) { + return ioutil.NopCloser(strings.NewReader(content)), nil + } +} +func openReturningFailingReader(err string) generator.Opener { + return func(_ string) (io.ReadCloser, error) { + r := &erroringReader{ + reader: ioutil.NopCloser(strings.NewReader(longerContent)), + err: fmt.Errorf(err), + } + return r, nil + } +} + +var maxContent = strings.Repeat("a", generator.MaxHeaderBytes-1) + "|" +var longerContent = maxContent + "bbbb" + +func checkErr(t *testing.T, err error, msg string) { + t.Helper() + + if msg == "" { + if err != nil { + t.Errorf("Expected no error to occur, got %v", err) + } + return + } + + if err == nil { + t.Errorf("Expected error '%s', got no error", msg) + return + } + + if a, e := err.Error(), msg; a != e { + t.Errorf("Expected error '%s', got: '%s'", e, a) + } +} + +type openSpy struct { + Func generator.Opener + Calls []string +} + +func (o *openSpy) Open(p string) (io.ReadCloser, error) { + o.Calls = append(o.Calls, p) + return o.Func(p) +} + +var _ generator.Opener = (&openSpy{}).Open + +type erroringReader struct { + reader io.ReadCloser + callCount int + err error +} + +func (r *erroringReader) Read(p []byte) (int, error) { + r.callCount++ + if r.callCount >= 2 { + return 0, r.err + } + return r.reader.Read(p) +} + +func (r *erroringReader) Close() error { + return r.reader.Close() +} + +var _ io.ReadCloser = &erroringReader{} diff --git a/generator/file_reader_windows_test.go b/generator/file_reader_windows_test.go new file mode 100644 index 0000000..9fd821f --- /dev/null +++ b/generator/file_reader_windows_test.go @@ -0,0 +1,12 @@ +// +build windows + +package generator_test + +const ( + relFile = "file.ext" + absFile = "c:\\file.ext" + workingDir = "c:\\some\\dir" + + relFileUp = "..\\file.ext" + expectedFileUp = "c:\\some\\file.ext" +) diff --git a/generator/function_template.go b/generator/function_template.go index 533ef72..0043029 100644 --- a/generator/function_template.go +++ b/generator/function_template.go @@ -12,7 +12,7 @@ var functionFuncs = template.FuncMap{ "IsExported": isExported, } -const functionTemplate string = `// Code generated by counterfeiter. DO NOT EDIT. +const functionTemplate string = `{{.Header}}// Code generated by counterfeiter. DO NOT EDIT. package {{.DestinationPackage}} import ( diff --git a/generator/generator_internals_test.go b/generator/generator_internals_test.go index ace056c..c35c5c0 100644 --- a/generator/generator_internals_test.go +++ b/generator/generator_internals_test.go @@ -31,7 +31,7 @@ func testGenerator(t *testing.T, when spec.G, it spec.S) { when("the target is a nonexistent package", func() { it("errors", func() { c := &Cache{} - f, err = NewFake(InterfaceOrFunction, "NonExistent", "nonexistentpackage", "FakeNonExistent", "nonexistentpackagefakes", "", c) + f, err = NewFake(InterfaceOrFunction, "NonExistent", "nonexistentpackage", "FakeNonExistent", "nonexistentpackagefakes", "", "", c) Expect(err).To(HaveOccurred()) Expect(f).To(BeNil()) }) @@ -40,7 +40,7 @@ func testGenerator(t *testing.T, when spec.G, it spec.S) { when("the target is a package with a nonexistent interface", func() { it("errors", func() { c := &Cache{} - f, err = NewFake(InterfaceOrFunction, "NonExistent", "os", "FakeNonExistent", "osfakes", "", c) + f, err = NewFake(InterfaceOrFunction, "NonExistent", "os", "FakeNonExistent", "osfakes", "", "", c) Expect(err).To(HaveOccurred()) Expect(f).To(BeNil()) }) @@ -49,7 +49,7 @@ func testGenerator(t *testing.T, when spec.G, it spec.S) { when("the target is an interface that exists", func() { it("succeeds", func() { c := &Cache{} - f, err = NewFake(InterfaceOrFunction, "FileInfo", "os", "FakeFileInfo", "osfakes", "", c) + f, err = NewFake(InterfaceOrFunction, "FileInfo", "os", "FakeFileInfo", "osfakes", "", "", c) Expect(err).NotTo(HaveOccurred()) Expect(f).NotTo(BeNil()) Expect(f.TargetAlias).To(Equal("os")) @@ -80,7 +80,7 @@ func testGenerator(t *testing.T, when spec.G, it spec.S) { when("the target is a function that exists", func() { it("succeeds", func() { c := &Cache{} - f, err = NewFake(InterfaceOrFunction, "HandlerFunc", "net/http", "FakeHandlerFunc", "httpfakes", "", c) + f, err = NewFake(InterfaceOrFunction, "HandlerFunc", "net/http", "FakeHandlerFunc", "httpfakes", "", "", c) Expect(err).NotTo(HaveOccurred()) Expect(f).NotTo(BeNil()) diff --git a/generator/interface_template.go b/generator/interface_template.go index 75f424d..07f6154 100644 --- a/generator/interface_template.go +++ b/generator/interface_template.go @@ -13,7 +13,7 @@ var interfaceFuncs = template.FuncMap{ "Title": strings.Title, } -const interfaceTemplate string = `// Code generated by counterfeiter. DO NOT EDIT. +const interfaceTemplate string = `{{.Header}}// Code generated by counterfeiter. DO NOT EDIT. package {{.DestinationPackage}} import ( diff --git a/generator/package_template.go b/generator/package_template.go index 641e5bb..1aa68c3 100644 --- a/generator/package_template.go +++ b/generator/package_template.go @@ -12,7 +12,7 @@ var packageFuncs = template.FuncMap{ "Generate": func() string { return "go:generate" }, // yes, this seems insane but ensures that we can use `go generate ./...` from the main package } -const packageTemplate string = `// Code generated by counterfeiter. DO NOT EDIT. +const packageTemplate string = `{{.Header}}// Code generated by counterfeiter. DO NOT EDIT. package {{.DestinationPackage}} import ( diff --git a/integration/roundtrip_test.go b/integration/roundtrip_test.go index b18afdd..8ed118c 100644 --- a/integration/roundtrip_test.go +++ b/integration/roundtrip_test.go @@ -117,29 +117,37 @@ func runTests(useGopath bool, t *testing.T, when spec.G, it spec.S) { }) when("generating a fake for stdlib interfaces", func() { - it("succeeds", func() { - initModuleFunc() - cache := &generator.FakeCache{} - f, err := generator.NewFake(generator.InterfaceOrFunction, "WriteCloser", "io", "FakeWriteCloser", "custom", baseDir, cache) - Expect(err).NotTo(HaveOccurred()) - b, err := f.Generate(true) // Flip to false to see output if goimports fails - Expect(err).NotTo(HaveOccurred()) - if writeToTestData { - WriteOutput(b, filepath.Join("testdata", "output", "write_closer", "actual.go")) - } - WriteOutput(b, filepath.Join(baseDir, "fixturesfakes", "fake_write_closer.go")) - RunBuild(baseDir) - b2, err := ioutil.ReadFile(filepath.Join("testdata", "expected_fake_writecloser.txt")) - Expect(err).NotTo(HaveOccurred()) - Expect(string(b2)).To(Equal(string(b))) - }) + const ( + noHeader = "noheader" + withHeader = "header" + ) + t := func(header, variant string) { + it("succeeds", func() { + initModuleFunc() + cache := &generator.FakeCache{} + f, err := generator.NewFake(generator.InterfaceOrFunction, "WriteCloser", "io", "FakeWriteCloser", "custom", header, baseDir, cache) + Expect(err).NotTo(HaveOccurred()) + b, err := f.Generate(true) // Flip to false to see output if goimports fails + Expect(err).NotTo(HaveOccurred()) + if writeToTestData { + WriteOutput(b, filepath.Join("testdata", "output", "write_closer", "actual."+variant+".go")) + } + WriteOutput(b, filepath.Join(baseDir, "fixturesfakes", "fake_write_closer."+variant+".go")) + RunBuild(baseDir) + b2, err := ioutil.ReadFile(filepath.Join("testdata", "expected_fake_writecloser."+variant+".txt")) + Expect(err).NotTo(HaveOccurred()) + Expect(string(b2)).To(Equal(string(b))) + }) + } + t("", noHeader) + t("// some header\n//\n\n", withHeader) }) when("generating an interface for a package", func() { it("succeeds", func() { initModuleFunc() cache := &generator.FakeCache{} - f, err := generator.NewFake(generator.Package, "", "os", "Os", "custom", baseDir, cache) + f, err := generator.NewFake(generator.Package, "", "os", "Os", "custom", "", baseDir, cache) Expect(err).NotTo(HaveOccurred()) b, err := f.Generate(true) // Flip to false to see output if goimports fails Expect(err).NotTo(HaveOccurred()) @@ -175,7 +183,7 @@ func runTests(useGopath bool, t *testing.T, when spec.G, it spec.S) { WriteOutput([]byte(fmt.Sprintf("module github.com/maxbrunsfeld/counterfeiter/v6/fixtures%s\n", suffix)), filepath.Join(baseDir, "go.mod")) } cache := &generator.FakeCache{} - f, err := generator.NewFake(generator.InterfaceOrFunction, interfaceName, fmt.Sprintf("github.com/maxbrunsfeld/counterfeiter/v6/fixtures%s", suffix), "Fake"+interfaceName, "fixturesfakes", baseDir, cache) + f, err := generator.NewFake(generator.InterfaceOrFunction, interfaceName, fmt.Sprintf("github.com/maxbrunsfeld/counterfeiter/v6/fixtures%s", suffix), "Fake"+interfaceName, "fixturesfakes", "", baseDir, cache) Expect(err).NotTo(HaveOccurred()) b, err := f.Generate(true) // Flip to false to see output if goimports fails Expect(err).NotTo(HaveOccurred()) @@ -223,7 +231,7 @@ func runTests(useGopath bool, t *testing.T, when spec.G, it spec.S) { pkgPath = pkgPath + "/" + offset } cache := &generator.FakeCache{} - f, err := generator.NewFake(generator.InterfaceOrFunction, interfaceName, pkgPath, "Fake"+interfaceName, fakePackageName, baseDir, cache) + f, err := generator.NewFake(generator.InterfaceOrFunction, interfaceName, pkgPath, "Fake"+interfaceName, fakePackageName, "", baseDir, cache) Expect(err).NotTo(HaveOccurred()) b, err := f.Generate(false) // Flip to false to see output if goimports fails Expect(err).NotTo(HaveOccurred()) diff --git a/integration/testdata/expected_fake_writecloser.header.txt b/integration/testdata/expected_fake_writecloser.header.txt new file mode 100644 index 0000000..13d7f56 --- /dev/null +++ b/integration/testdata/expected_fake_writecloser.header.txt @@ -0,0 +1,186 @@ +// some header +// + +// Code generated by counterfeiter. DO NOT EDIT. +package custom + +import ( + "io" + "sync" +) + +type FakeWriteCloser struct { + CloseStub func() error + closeMutex sync.RWMutex + closeArgsForCall []struct { + } + closeReturns struct { + result1 error + } + closeReturnsOnCall map[int]struct { + result1 error + } + WriteStub func([]byte) (int, error) + writeMutex sync.RWMutex + writeArgsForCall []struct { + arg1 []byte + } + writeReturns struct { + result1 int + result2 error + } + writeReturnsOnCall map[int]struct { + result1 int + result2 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeWriteCloser) Close() error { + fake.closeMutex.Lock() + ret, specificReturn := fake.closeReturnsOnCall[len(fake.closeArgsForCall)] + fake.closeArgsForCall = append(fake.closeArgsForCall, struct { + }{}) + fake.recordInvocation("Close", []interface{}{}) + fake.closeMutex.Unlock() + if fake.CloseStub != nil { + return fake.CloseStub() + } + if specificReturn { + return ret.result1 + } + fakeReturns := fake.closeReturns + return fakeReturns.result1 +} + +func (fake *FakeWriteCloser) CloseCallCount() int { + fake.closeMutex.RLock() + defer fake.closeMutex.RUnlock() + return len(fake.closeArgsForCall) +} + +func (fake *FakeWriteCloser) CloseCalls(stub func() error) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = stub +} + +func (fake *FakeWriteCloser) CloseReturns(result1 error) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = nil + fake.closeReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeWriteCloser) CloseReturnsOnCall(i int, result1 error) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = nil + if fake.closeReturnsOnCall == nil { + fake.closeReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.closeReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeWriteCloser) Write(arg1 []byte) (int, error) { + var arg1Copy []byte + if arg1 != nil { + arg1Copy = make([]byte, len(arg1)) + copy(arg1Copy, arg1) + } + fake.writeMutex.Lock() + ret, specificReturn := fake.writeReturnsOnCall[len(fake.writeArgsForCall)] + fake.writeArgsForCall = append(fake.writeArgsForCall, struct { + arg1 []byte + }{arg1Copy}) + fake.recordInvocation("Write", []interface{}{arg1Copy}) + fake.writeMutex.Unlock() + if fake.WriteStub != nil { + return fake.WriteStub(arg1) + } + if specificReturn { + return ret.result1, ret.result2 + } + fakeReturns := fake.writeReturns + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeWriteCloser) WriteCallCount() int { + fake.writeMutex.RLock() + defer fake.writeMutex.RUnlock() + return len(fake.writeArgsForCall) +} + +func (fake *FakeWriteCloser) WriteCalls(stub func([]byte) (int, error)) { + fake.writeMutex.Lock() + defer fake.writeMutex.Unlock() + fake.WriteStub = stub +} + +func (fake *FakeWriteCloser) WriteArgsForCall(i int) []byte { + fake.writeMutex.RLock() + defer fake.writeMutex.RUnlock() + argsForCall := fake.writeArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeWriteCloser) WriteReturns(result1 int, result2 error) { + fake.writeMutex.Lock() + defer fake.writeMutex.Unlock() + fake.WriteStub = nil + fake.writeReturns = struct { + result1 int + result2 error + }{result1, result2} +} + +func (fake *FakeWriteCloser) WriteReturnsOnCall(i int, result1 int, result2 error) { + fake.writeMutex.Lock() + defer fake.writeMutex.Unlock() + fake.WriteStub = nil + if fake.writeReturnsOnCall == nil { + fake.writeReturnsOnCall = make(map[int]struct { + result1 int + result2 error + }) + } + fake.writeReturnsOnCall[i] = struct { + result1 int + result2 error + }{result1, result2} +} + +func (fake *FakeWriteCloser) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.closeMutex.RLock() + defer fake.closeMutex.RUnlock() + fake.writeMutex.RLock() + defer fake.writeMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeWriteCloser) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ io.WriteCloser = new(FakeWriteCloser) diff --git a/integration/testdata/expected_fake_writecloser.txt b/integration/testdata/expected_fake_writecloser.noheader.txt similarity index 100% rename from integration/testdata/expected_fake_writecloser.txt rename to integration/testdata/expected_fake_writecloser.noheader.txt diff --git a/main.go b/main.go index 6dcd92e..e76ccf4 100644 --- a/main.go +++ b/main.go @@ -53,10 +53,13 @@ func run() error { } var cache generator.Cacher + var headerReader generator.FileReader if disableCache() { cache = &generator.FakeCache{} + headerReader = &generator.SimpleFileReader{} } else { cache = &generator.Cache{} + headerReader = &generator.CachedFileReader{} } var invocations []command.Invocation var args *arguments.ParsedArguments @@ -75,7 +78,16 @@ func run() error { if err != nil { return err } - err = generate(cwd, a, cache) + + // If the '//counterfeiter:generate ...' line does not have a '-header' + // flag, we use the one from the "global" + // '//go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -generate -header /some/header.txt' + // line (which defaults to none). By doing so, we can configure the header + // once per package, which is probably the most common case for adding + // licence headers (i.e. all the fakes will have the same licence headers). + a.HeaderFile = or(a.HeaderFile, args.HeaderFile) + + err = generate(cwd, a, cache, headerReader) if err != nil { return err } @@ -83,6 +95,15 @@ func run() error { return nil } +func or(opts ...string) string { + for _, s := range opts { + if s != "" { + return s + } + } + return "" +} + func isDebug() bool { return os.Getenv("COUNTERFEITER_DEBUG") != "" } @@ -91,12 +112,12 @@ func disableCache() bool { return os.Getenv("COUNTERFEITER_DISABLECACHE") != "" } -func generate(workingDir string, args *arguments.ParsedArguments, cache generator.Cacher) error { +func generate(workingDir string, args *arguments.ParsedArguments, cache generator.Cacher, headerReader generator.FileReader) error { if err := reportStarting(workingDir, args.OutputPath, args.FakeImplName); err != nil { return err } - b, err := doGenerate(workingDir, args, cache) + b, err := doGenerate(workingDir, args, cache, headerReader) if err != nil { return err } @@ -108,12 +129,18 @@ func generate(workingDir string, args *arguments.ParsedArguments, cache generato return nil } -func doGenerate(workingDir string, args *arguments.ParsedArguments, cache generator.Cacher) ([]byte, error) { +func doGenerate(workingDir string, args *arguments.ParsedArguments, cache generator.Cacher, headerReader generator.FileReader) ([]byte, error) { mode := generator.InterfaceOrFunction if args.GenerateInterfaceAndShimFromPackageDirectory { mode = generator.Package } - f, err := generator.NewFake(mode, args.InterfaceName, args.PackagePath, args.FakeImplName, args.DestinationPackageName, workingDir, cache) + + headerContent, err := headerReader.Get(workingDir, args.HeaderFile) + if err != nil { + return nil, err + } + + f, err := generator.NewFake(mode, args.InterfaceName, args.PackagePath, args.FakeImplName, args.DestinationPackageName, headerContent, workingDir, cache) if err != nil { return nil, err }