diff --git a/mssmt/tree_test.go b/mssmt/tree_test.go index c07a10bfc..82bacade8 100644 --- a/mssmt/tree_test.go +++ b/mssmt/tree_test.go @@ -39,25 +39,20 @@ func genTestStores(t *testing.T) map[string]makeTestTreeStoreFunc { constructors := make(map[string]makeTestTreeStoreFunc) for _, driver := range mssmt.RegisteredTreeStores() { - var makeFunc makeTestTreeStoreFunc - if driver.Name == "sqlite3" { - makeFunc = func() (mssmt.TreeStore, error) { - dbFileName := filepath.Join( - t.TempDir(), "tmp.db", - ) - - treeStore, err := driver.New(dbFileName, "test") - if err != nil { - return nil, fmt.Errorf("unable to "+ - "create new sqlite tree "+ - "store: %v", err) - } + constructors[driver.Name] = func() (mssmt.TreeStore, error) { + dbFileName := filepath.Join( + t.TempDir(), "tmp.db", + ) - return treeStore, nil + treeStore, err := driver.New(dbFileName, "test") + if err != nil { + return nil, fmt.Errorf("unable to "+ + "create new sqlite tree "+ + "store: %v", err) } - } - constructors[driver.Name] = makeFunc + return treeStore, nil + } } constructors["default"] = func() (mssmt.TreeStore, error) { diff --git a/tapdb/mssmt.go b/tapdb/mssmt.go index 8a839e4fb..bbddbb3e5 100644 --- a/tapdb/mssmt.go +++ b/tapdb/mssmt.go @@ -402,3 +402,38 @@ func (t *taprootAssetTreeStoreTx) UpdateRoot(rootNode *mssmt.BranchNode) error { Namespace: t.namespace, }) } + +func init() { + driver := mssmt.TreeStoreDriver{ + Name: activeTestDB, + New: func(args ...interface{}) (mssmt.TreeStore, error) { + dbPath, ok := args[0].(string) + if !ok { + return nil, fmt.Errorf("invalid db path: "+ + "want string, got %T", args[0]) + } + namespace, ok := args[1].(string) + if !ok { + return nil, fmt.Errorf("invalid db path: "+ + "want string, got %T", args[0]) + } + + sqlDB, err := NewDbHandleFromPath(dbPath) + if err != nil { + return nil, err + } + + txCreator := func(tx *sql.Tx) TreeStore { + return sqlDB.WithTx(tx) + } + + treeDB := NewTransactionExecutor(sqlDB, txCreator) + + return NewTaprootAssetTreeStore(treeDB, namespace), nil + }, + } + if err := mssmt.RegisterTreeStore(&driver); err != nil { + panic(fmt.Errorf("failed to register db=%v): %v", + activeTestDB, err)) + } +} diff --git a/tapdb/postgres.go b/tapdb/postgres.go index 8dc3c01a0..e13998b2a 100644 --- a/tapdb/postgres.go +++ b/tapdb/postgres.go @@ -173,6 +173,25 @@ func NewTestPostgresDB(t *testing.T) *PostgresStore { return store } +// NewPostgresDB is a helper function that creates a Postgres database for +// testing. +func NewPostgresDB() (*PostgresStore, error) { + var t testing.T + sqlFixture := NewTestPgFixture(&t, DefaultPostgresFixtureLifetime, true) + if t.Failed() { + return nil, fmt.Errorf("unable to make postgres db") + } + + store, err := NewPostgresStore(sqlFixture.GetConfig()) + if err != nil { + return nil, err + } + + // sqlFixture.TearDown(t) + + return store, nil +} + // NewTestPostgresDBWithVersion is a helper function that creates a Postgres // database for testing and migrates it to the given version. func NewTestPostgresDBWithVersion(t *testing.T, version uint) *PostgresStore { diff --git a/tapdb/sqlite.go b/tapdb/sqlite.go index aa3fbdb3e..54a841a3e 100644 --- a/tapdb/sqlite.go +++ b/tapdb/sqlite.go @@ -196,6 +196,20 @@ func NewTestSqliteDbHandleFromPath(t *testing.T, dbPath string) *SqliteStore { return sqlDB } +// NewSqliteDbHandleFromPath is a helper function that creates a SQLite +// database handle given a database file path. +func NewSqliteDbHandleFromPath(dbPath string) (*SqliteStore, error) { + sqlDB, err := NewSqliteStore(&SqliteConfig{ + DatabaseFileName: dbPath, + SkipMigrations: false, + }) + if err != nil { + return nil, err + } + + return sqlDB, nil +} + // NewTestSqliteDBWithVersion is a helper function that creates an SQLite // database for testing and migrates it to the given version. func NewTestSqliteDBWithVersion(t *testing.T, version uint) *SqliteStore { diff --git a/tapdb/sqlutils_test.go b/tapdb/sqlutils_test.go index a4b51fcb5..ffdaf71ec 100644 --- a/tapdb/sqlutils_test.go +++ b/tapdb/sqlutils_test.go @@ -280,13 +280,6 @@ func newDbHandleFromDb(db *BaseDB) *DbHandler { } } -// NewDbHandleFromPath creates a new database store handle given a database file -// path. -func NewDbHandleFromPath(t *testing.T, dbPath string) *DbHandler { - db := NewTestDbHandleFromPath(t, dbPath) - return newDbHandleFromDb(db.BaseDB) -} - // NewDbHandle creates a new database store handle. func NewDbHandle(t *testing.T) *DbHandler { // Create a new test database with the default database file path. diff --git a/tapdb/test_postgres.go b/tapdb/test_postgres.go index 4e284e10b..8e16cbeeb 100644 --- a/tapdb/test_postgres.go +++ b/tapdb/test_postgres.go @@ -6,17 +6,26 @@ import ( "testing" ) +const activeTestDB = "postgres" + // NewTestDB is a helper function that creates a Postgres database for testing. func NewTestDB(t *testing.T) *PostgresStore { return NewTestPostgresDB(t) } // NewTestDbHandleFromPath is a helper function that creates a new handle to an -// existing SQLite database for testing. +// existing Postgres database for testing. func NewTestDbHandleFromPath(t *testing.T, dbPath string) *PostgresStore { return NewTestPostgresDB(t) } +// NewDbHandleFromPath is a helper function that creates a new handle to an +// existing Postgres database for testing. This version returns an error if an +// an issue is hit during init. +func NewDbHandleFromPath(dbPath string) (*PostgresStore, error) { + return NewPostgresDB() +} + // NewTestDBWithVersion is a helper function that creates a Postgres database // for testing and migrates it to the given version. func NewTestDBWithVersion(t *testing.T, version uint) *PostgresStore { diff --git a/tapdb/test_sqlite.go b/tapdb/test_sqlite.go index 40e2b3673..931f9f75c 100644 --- a/tapdb/test_sqlite.go +++ b/tapdb/test_sqlite.go @@ -6,6 +6,8 @@ import ( "testing" ) +const activeTestDB = "sqlite3" + // NewTestDB is a helper function that creates an SQLite database for testing. func NewTestDB(t *testing.T) *SqliteStore { return NewTestSqliteDB(t) @@ -17,6 +19,13 @@ func NewTestDbHandleFromPath(t *testing.T, dbPath string) *SqliteStore { return NewTestSqliteDbHandleFromPath(t, dbPath) } +// NewDbHandleFromPath is a helper function that creates a new handle to an +// existing SQLite database for testing. This version returns an error if an +// issue is encountered during init. +func NewDbHandleFromPath(dbPath string) (*SqliteStore, error) { + return NewSqliteDbHandleFromPath(dbPath) +} + // NewTestDBWithVersion is a helper function that creates an SQLite database for // testing and migrates it to the given version. func NewTestDBWithVersion(t *testing.T, version uint) *SqliteStore {