diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index c2f4459..559fd57 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -20,7 +20,7 @@ jobs: - name: Set up Golang uses: actions/setup-go@v4 with: - go-version: '1.20' + go-version: '1.21' - name: Get dependencies run: | diff --git a/README.md b/README.md index 633cdb5..69144ad 100644 --- a/README.md +++ b/README.md @@ -10,18 +10,18 @@ It's based on the well-known [sqlmock](https://github.com/DATA-DOG/go-sqlmock) l **pgxmock** has one and only purpose - to simulate **pgx** behavior in tests, without needing a real database connection. It helps to maintain correct **TDD** workflow. -- written based on **go1.15** version, however, should be compatible with **go1.11** and above; +- written based on **go1.21** version; - does not require any modifications to your source code; - has strict by default expectation order matching; - has no third party dependencies except **pgx** packages. ## Install - go get github.com/pashagolub/pgxmock/v2 + go get github.com/pashagolub/pgxmock/v3 ## Documentation and Examples -Visit [godoc](http://pkg.go.dev/github.com/pashagolub/pgxmock/v2) for general examples and public api reference. +Visit [godoc](http://pkg.go.dev/github.com/pashagolub/pgxmock/v3) for general examples and public api reference. See implementation examples: @@ -92,7 +92,7 @@ import ( "fmt" "testing" - "github.com/pashagolub/pgxmock/v2" + "github.com/pashagolub/pgxmock/v3" ) // a successful case @@ -175,7 +175,7 @@ provide a standard sql parsing matchers. ## Matching arguments like time.Time There may be arguments which are of `struct` type and cannot be compared easily by value like `time.Time`. In this case -**pgxmock** provides an [Argument](https://pkg.go.dev/github.com/pashagolub/pgxmock/v2#Argument) interface which +**pgxmock** provides an [Argument](https://pkg.go.dev/github.com/pashagolub/pgxmock/v3#Argument) interface which can be used in more sophisticated matching. Here is a simple example of time argument matching: ``` go diff --git a/argument.go b/argument.go index 6d84670..37a84a3 100644 --- a/argument.go +++ b/argument.go @@ -20,3 +20,4 @@ type anyArgument struct{} func (a anyArgument) Match(_ interface{}) bool { return true } + diff --git a/driver.go b/driver.go index 7ce1b71..d215f84 100644 --- a/driver.go +++ b/driver.go @@ -12,29 +12,19 @@ type pgxmockConn struct { } // NewConn creates PgxConnIface database connection and a mock to manage expectations. -// Accepts options, like ValueConverterOption, to use a ValueConverter from -// a specific driver. -// Pings db so that all expectations could be -// asserted. +// Accepts options, like QueryMatcherOption, to match SQL query strings in more sophisticated ways. func NewConn(options ...func(*pgxmock) error) (PgxConnIface, error) { smock := &pgxmockConn{} smock.ordered = true return smock, smock.open(options) } -func (c *pgxmockConn) Close(ctx context.Context) error { - return c.close(ctx) -} - type pgxmockPool struct { pgxmock } // NewPool creates PgxPoolIface pool of database connections and a mock to manage expectations. -// Accepts options, like ValueConverterOption, to use a ValueConverter from -// a specific driver. -// Pings db so that all expectations could be -// asserted. +// Accepts options, like QueryMatcherOption, to match SQL query strings in more sophisticated ways. func NewPool(options ...func(*pgxmock) error) (PgxPoolIface, error) { smock := &pgxmockPool{} smock.ordered = true @@ -42,7 +32,7 @@ func NewPool(options ...func(*pgxmock) error) (PgxPoolIface, error) { } func (p *pgxmockPool) Close() { - _ = p.close(context.Background()) + p.pgxmock.Close(context.Background()) } func (p *pgxmockPool) Acquire(context.Context) (*pgxpool.Conn, error) { diff --git a/driver_test.go b/driver_test.go index 419c9df..2b4206b 100644 --- a/driver_test.go +++ b/driver_test.go @@ -33,6 +33,10 @@ func TestPools(t *testing.T) { if mock == mock2 { t.Errorf("expected not the same mock instance, but it is the same") } + conn := mock.AsConn() + if conn == nil { + t.Error("expected connection strruct, but got nil") + } mock.Close() mock2.Close() } diff --git a/examples/basic/basic_test.go b/examples/basic/basic_test.go index 1cd6f95..728723c 100644 --- a/examples/basic/basic_test.go +++ b/examples/basic/basic_test.go @@ -4,7 +4,7 @@ import ( "fmt" "testing" - "github.com/pashagolub/pgxmock/v2" + "github.com/pashagolub/pgxmock/v3" ) // a successful case diff --git a/examples/blog/blog_test.go b/examples/blog/blog_test.go index 7d2a268..e2c7de9 100644 --- a/examples/blog/blog_test.go +++ b/examples/blog/blog_test.go @@ -9,7 +9,7 @@ import ( "net/http/httptest" "testing" - "github.com/pashagolub/pgxmock/v2" + "github.com/pashagolub/pgxmock/v3" ) func (a *api) assertJSON(actual []byte, data interface{}, t *testing.T) { diff --git a/expectations.go b/expectations.go index fe8eca8..884e3fc 100644 --- a/expectations.go +++ b/expectations.go @@ -1,6 +1,8 @@ package pgxmock import ( + "context" + "errors" "fmt" "reflect" "strings" @@ -11,189 +13,191 @@ import ( pgconn "github.com/jackc/pgx/v5/pgconn" ) -// an expectation interface -type expectation interface { +// an Expectation interface +type Expectation interface { + error() error + required() bool fulfilled() bool - Lock() - Unlock() - String() string + fulfill() + sync.Locker + fmt.Stringer +} + +type CallModifyer interface { + Maybe() CallModifyer + Times(n uint) CallModifyer + WillDelayFor(duration time.Duration) CallModifyer + WillReturnError(err error) + WillPanic(v any) } // common expectation struct // satisfies the expectation interface type commonExpectation struct { sync.Mutex - triggered bool - err error + triggered uint // how many times method was called + err error // should method return error + optional bool // can method be skipped + panicArgument any // panic value to return for recovery + plannedDelay time.Duration // should method delay before return + plannedCalls uint // how many sequentional calls should be made } -func (e *commonExpectation) fulfilled() bool { - return e.triggered +func (e *commonExpectation) error() error { + return e.err } -// ExpectedClose is used to manage pgx.Close expectation -// returned by pgxmock.ExpectClose. -type ExpectedClose struct { - commonExpectation +func (e *commonExpectation) fulfill() { + e.triggered++ } -// WillReturnError allows to set an error for pgx.Close action -func (e *ExpectedClose) WillReturnError(err error) *ExpectedClose { - e.err = err - return e +func (e *commonExpectation) fulfilled() bool { + return e.triggered >= max(e.plannedCalls, 1) } -// String returns string representation -func (e *ExpectedClose) String() string { - msg := "ExpectedClose => expecting database Close" - if e.err != nil { - msg += fmt.Sprintf(", which should return error: %s", e.err) - } - return msg +func (e *commonExpectation) required() bool { + return !e.optional } -// ExpectedBegin is used to manage *pgx.Begin expectation -// returned by pgxmock.ExpectBegin. -type ExpectedBegin struct { - commonExpectation - delay time.Duration - opts pgx.TxOptions +func (e *commonExpectation) waitForDelay(ctx context.Context) (err error) { + select { + case <-time.After(e.plannedDelay): + err = e.error() + case <-ctx.Done(): + err = ctx.Err() + } + if e.panicArgument != nil { + panic(e.panicArgument) + } + return err } -// WillReturnError allows to set an error for pgx.Begin action -func (e *ExpectedBegin) WillReturnError(err error) *ExpectedBegin { - e.err = err +// Maybe allows the expected method call to be optional. +// Not calling an optional method will not cause an error while asserting expectations +func (e *commonExpectation) Maybe() CallModifyer { + e.optional = true return e } -// String returns string representation -func (e *ExpectedBegin) String() string { - msg := "ExpectedBegin => expecting database transaction Begin" - if e.err != nil { - msg += fmt.Sprintf(", which should return error: %s", e.err) - } - return msg +// Times indicates that that the expected method should only fire the indicated number of times. +// Zero value is ignored and means the same as one. +func (e *commonExpectation) Times(n uint) CallModifyer { + e.plannedCalls = n + return e } // WillDelayFor allows to specify duration for which it will delay // result. May be used together with Context -func (e *ExpectedBegin) WillDelayFor(duration time.Duration) *ExpectedBegin { - e.delay = duration +func (e *commonExpectation) WillDelayFor(duration time.Duration) CallModifyer { + e.plannedDelay = duration return e } -// ExpectedCommit is used to manage pgx.Tx.Commit expectation -// returned by pgxmock.ExpectCommit. -type ExpectedCommit struct { - commonExpectation +// WillReturnError allows to set an error for the expected method +func (e *commonExpectation) WillReturnError(err error) { + e.err = err } -// WillReturnError allows to set an error for pgx.Tx.Close action -func (e *ExpectedCommit) WillReturnError(err error) *ExpectedCommit { - e.err = err - return e +var errPanic = errors.New("pgxmock panic") + +// WillPanic allows to force the expected method to panic +func (e *commonExpectation) WillPanic(v any) { + e.err = errPanic + e.panicArgument = v } // String returns string representation -func (e *ExpectedCommit) String() string { - msg := "ExpectedCommit => expecting transaction Commit" +func (e *commonExpectation) String() string { + w := new(strings.Builder) if e.err != nil { - msg += fmt.Sprintf(", which should return error: %s", e.err) + if e.err != errPanic { + fmt.Fprintf(w, "\t- returns error: %v\n", e.err) + } else { + fmt.Fprintf(w, "\t- panics with: %v\n", e.panicArgument) + } } - return msg + if e.plannedDelay > 0 { + fmt.Fprintf(w, "\t- delayed execution for: %v\n", e.plannedDelay) + } + if e.optional { + fmt.Fprint(w, "\t- execution is optional\n") + } + if e.plannedCalls > 0 { + fmt.Fprintf(w, "\t- execution calls awaited: %d\n", e.plannedCalls) + } + return w.String() } -// ExpectedRollback is used to manage pgx.Tx.Rollback expectation -// returned by pgxmock.ExpectRollback. -type ExpectedRollback struct { - commonExpectation +// queryBasedExpectation is a base class that adds a query matching logic +type queryBasedExpectation struct { + expectSQL string + args []interface{} } -// WillReturnError allows to set an error for pgx.Tx.Rollback action -func (e *ExpectedRollback) WillReturnError(err error) *ExpectedRollback { - e.err = err - return e -} +func (e *queryBasedExpectation) argsMatches(args []interface{}) error { + if len(args) != len(e.args) { + return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args)) + } + for k, v := range args { + // custom argument matcher + if matcher, ok := e.args[k].(Argument); ok { + if !matcher.Match(v) { + return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k]) + } + continue + } -// String returns string representation -func (e *ExpectedRollback) String() string { - msg := "ExpectedRollback => expecting transaction Rollback" - if e.err != nil { - msg += fmt.Sprintf(", which should return error: %s", e.err) + if darg := e.args[k]; !reflect.DeepEqual(darg, v) { + return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v, v) + } } - return msg + return nil } -// ExpectedQuery is used to manage *pgx.Conn.Query, *pgx.Conn.QueryRow, *pgx.Tx.Query, -// *pgx.Tx.QueryRow, *pgx.Stmt.Query or *pgx.Stmt.QueryRow expectations. -// Returned by pgxmock.ExpectQuery. -type ExpectedQuery struct { - queryBasedExpectation - rows pgx.Rows - delay time.Duration - rowsMustBeClosed bool - rowsWereClosed bool +// ExpectedClose is used to manage pgx.Close expectation +// returned by pgxmock.ExpectClose +type ExpectedClose struct { + commonExpectation } -// WithArgs will match given expected args to actual database query arguments. -// if at least one argument does not match, it will return an error. For specific -// arguments an pgxmock.Argument interface can be used to match an argument. -func (e *ExpectedQuery) WithArgs(args ...interface{}) *ExpectedQuery { - e.args = args - return e +// String returns string representation +func (e *ExpectedClose) String() string { + return "ExpectedClose => expecting call to Close()\n" + e.commonExpectation.String() } -// RowsWillBeClosed expects this query rows to be closed. -func (e *ExpectedQuery) RowsWillBeClosed() *ExpectedQuery { - e.rowsMustBeClosed = true - return e +// ExpectedBegin is used to manage *pgx.Begin expectation +// returned by pgxmock.ExpectBegin. +type ExpectedBegin struct { + commonExpectation + opts pgx.TxOptions } -// WillReturnError allows to set an error for expected database query -func (e *ExpectedQuery) WillReturnError(err error) *ExpectedQuery { - e.err = err - return e +// String returns string representation +func (e *ExpectedBegin) String() string { + msg := "ExpectedBegin => expecting call to Begin() or to BeginTx()\n" + if e.opts != (pgx.TxOptions{}) { + msg += fmt.Sprintf("\t- transaction options awaited: %+v\n", e.opts) + } + return msg + e.commonExpectation.String() } -// WillDelayFor allows to specify duration for which it will delay -// result. May be used together with Context -func (e *ExpectedQuery) WillDelayFor(duration time.Duration) *ExpectedQuery { - e.delay = duration - return e +// ExpectedCommit is used to manage pgx.Tx.Commit expectation +// returned by pgxmock.ExpectCommit. +type ExpectedCommit struct { + commonExpectation } // String returns string representation -func (e *ExpectedQuery) String() string { - msg := "ExpectedQuery => expecting Query, QueryContext or QueryRow which:" - msg += "\n - matches sql: '" + e.expectSQL + "'" - - if len(e.args) == 0 { - msg += "\n - is without arguments" - } else { - msg += "\n - is with arguments:\n" - for i, arg := range e.args { - msg += fmt.Sprintf(" %d - %+v\n", i, arg) - } - msg = strings.TrimSpace(msg) - } - - if e.rows != nil { - msg += fmt.Sprintf("\n - %s", e.rows) - } - - if e.err != nil { - msg += fmt.Sprintf("\n - should return error: %s", e.err) - } - - return msg +func (e *ExpectedCommit) String() string { + return "ExpectedCommit => expecting call to Tx.Commit()\n" + e.commonExpectation.String() } // ExpectedExec is used to manage pgx.Exec, pgx.Tx.Exec or pgx.Stmt.Exec expectations. // Returned by pgxmock.ExpectExec. type ExpectedExec struct { + commonExpectation queryBasedExpectation result pgconn.CommandTag - delay time.Duration } // WithArgs will match given expected args to actual database exec operation arguments. @@ -204,51 +208,29 @@ func (e *ExpectedExec) WithArgs(args ...interface{}) *ExpectedExec { return e } -// WillReturnError allows to set an error for expected database exec action -func (e *ExpectedExec) WillReturnError(err error) *ExpectedExec { - e.err = err - return e -} - -// WillDelayFor allows to specify duration for which it will delay -// result. May be used together with Context -func (e *ExpectedExec) WillDelayFor(duration time.Duration) *ExpectedExec { - e.delay = duration - return e -} - // String returns string representation func (e *ExpectedExec) String() string { - msg := "ExpectedExec => expecting Exec or ExecContext which:" - msg += "\n - matches sql: '" + e.expectSQL + "'" + msg := "ExpectedExec => expecting call to Exec():\n" + msg += fmt.Sprintf("\t- matches sql: '%s'\n", e.expectSQL) if len(e.args) == 0 { - msg += "\n - is without arguments" + msg += "\t- is without arguments\n" } else { - msg += "\n - is with arguments:\n" - var margs []string + msg += "\t- is with arguments:\n" for i, arg := range e.args { - margs = append(margs, fmt.Sprintf(" %d - %+v", i, arg)) + msg += fmt.Sprintf("\t\t%d - %+v\n", i, arg) } - msg += strings.Join(margs, "\n") } - if e.result.String() > "" { - msg += "\n - should return Result having:" - msg += fmt.Sprintf("\n RowsAffected: %d", e.result.RowsAffected()) - } - - if e.err != nil { - msg += fmt.Sprintf("\n - should return error: %s", e.err) + msg += fmt.Sprintf("\t- returns result: %s\n", e.result) } - return msg + return msg + e.commonExpectation.String() } // WillReturnResult arranges for an expected Exec() to return a particular -// result, there is pgxmock.NewResult(lastInsertID int64, affectedRows int64) method -// to build a corresponding result. Or if actions needs to be tested against errors -// pgxmock.NewErrorResult(err error) to return a given error. +// result, there is pgxmock.NewResult(op string, rowsAffected int64) method +// to build a corresponding result. func (e *ExpectedExec) WillReturnResult(result pgconn.CommandTag) *ExpectedExec { e.result = result return e @@ -261,34 +243,25 @@ type ExpectedPrepare struct { mock *pgxmock expectStmtName string expectSQL string - closeErr error + deallocateErr error mustBeClosed bool - wasClosed bool - delay time.Duration -} - -// WillReturnError allows to set an error for the expected pgx.Prepare or pgx.Tx.Prepare action. -func (e *ExpectedPrepare) WillReturnError(err error) *ExpectedPrepare { - e.err = err - return e + deallocated bool } // WillReturnCloseError allows to set an error for this prepared statement Close action func (e *ExpectedPrepare) WillReturnCloseError(err error) *ExpectedPrepare { - e.closeErr = err + e.deallocateErr = err return e } -// WillDelayFor allows to specify duration for which it will delay -// result. May be used together with Context -func (e *ExpectedPrepare) WillDelayFor(duration time.Duration) *ExpectedPrepare { - e.delay = duration - return e +// WillBeClosed is for backward compatibility only and will be removed soon. +// One should use WillBeDeallocated() instead +func (e *ExpectedPrepare) WillBeClosed() *ExpectedPrepare { + return e.WillBeDeallocated() } -// WillBeClosed expects this prepared statement to -// be closed. -func (e *ExpectedPrepare) WillBeClosed() *ExpectedPrepare { +// WillBeDeallocated expects this prepared statement to be deallocated +func (e *ExpectedPrepare) WillBeDeallocated() *ExpectedPrepare { e.mustBeClosed = true return e } @@ -298,8 +271,7 @@ func (e *ExpectedPrepare) WillBeClosed() *ExpectedPrepare { func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery { eq := &ExpectedQuery{} eq.expectSQL = e.expectStmtName - // eq.converter = e.mock.converter - e.mock.expected = append(e.mock.expected, eq) + e.mock.expectations = append(e.mock.expectations, eq) return eq } @@ -308,64 +280,73 @@ func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery { func (e *ExpectedPrepare) ExpectExec() *ExpectedExec { eq := &ExpectedExec{} eq.expectSQL = e.expectStmtName - // eq.converter = e.mock.converter - e.mock.expected = append(e.mock.expected, eq) + e.mock.expectations = append(e.mock.expectations, eq) return eq } // String returns string representation func (e *ExpectedPrepare) String() string { - msg := "ExpectedPrepare => expecting Prepare statement which:" - msg += "\n - matches statement name: '" + e.expectStmtName + "'" - msg += "\n - matches sql: '" + e.expectSQL + "'" - - if e.err != nil { - msg += fmt.Sprintf("\n - should return error: %s", e.err) - } - - if e.closeErr != nil { - msg += fmt.Sprintf("\n - should return error on Close: %s", e.closeErr) + msg := "ExpectedPrepare => expecting call to Prepare():" + msg += fmt.Sprintf("\t- matches statement name: '%s'", e.expectStmtName) + msg += fmt.Sprintf("\t- matches sql: '%s'\n", e.expectSQL) + if e.deallocateErr != nil { + msg += fmt.Sprintf("\t- returns error on Close: %s", e.deallocateErr) } - - return msg + return msg + e.commonExpectation.String() } -// query based expectation -// adds a query matching logic -type queryBasedExpectation struct { +// ExpectedPing is used to manage Ping() expectations +type ExpectedPing struct { commonExpectation - expectSQL string - // converter driver.ValueConverter - args []interface{} } -// ExpectedPing is used to manage pgx.Ping expectations. -// Returned by pgxmock.ExpectPing. -type ExpectedPing struct { +// String returns string representation +func (e *ExpectedPing) String() string { + msg := "ExpectedPing => expecting call to Ping()\n" + return msg + e.commonExpectation.String() +} + +// ExpectedQuery is used to manage *pgx.Conn.Query, *pgx.Conn.QueryRow, *pgx.Tx.Query, +// *pgx.Tx.QueryRow, *pgx.Stmt.Query or *pgx.Stmt.QueryRow expectations +type ExpectedQuery struct { commonExpectation - delay time.Duration + queryBasedExpectation + rows pgx.Rows + rowsMustBeClosed bool + rowsWereClosed bool } -// WillDelayFor allows to specify duration for which it will delay result. May -// be used together with Context. -func (e *ExpectedPing) WillDelayFor(duration time.Duration) *ExpectedPing { - e.delay = duration +// WithArgs will match given expected args to actual database query arguments. +// if at least one argument does not match, it will return an error. For specific +// arguments an pgxmock.Argument interface can be used to match an argument. +func (e *ExpectedQuery) WithArgs(args ...interface{}) *ExpectedQuery { + e.args = args return e } -// WillReturnError allows to set an error for expected database ping -func (e *ExpectedPing) WillReturnError(err error) *ExpectedPing { - e.err = err +// RowsWillBeClosed expects this query rows to be closed. +func (e *ExpectedQuery) RowsWillBeClosed() *ExpectedQuery { + e.rowsMustBeClosed = true return e } // String returns string representation -func (e *ExpectedPing) String() string { - msg := "ExpectedPing => expecting database Ping" - if e.err != nil { - msg += fmt.Sprintf(", which should return error: %s", e.err) +func (e *ExpectedQuery) String() string { + msg := "ExpectedQuery => expecting call to Query() or to QueryRow():\n" + msg += fmt.Sprintf("\t- matches sql: '%s'\n", e.expectSQL) + + if len(e.args) == 0 { + msg += "\t- is without arguments\n" + } else { + msg += "\t- is with arguments:\n" + for i, arg := range e.args { + msg += fmt.Sprintf("\t\t%d - %+v\n", i, arg) + } } - return msg + if e.rows != nil { + msg += fmt.Sprintf("%s\n", e.rows) + } + return msg + e.commonExpectation.String() } // WillReturnRows specifies the set of resulting rows that will be returned @@ -375,42 +356,6 @@ func (e *ExpectedQuery) WillReturnRows(rows ...*Rows) *ExpectedQuery { return e } -func (e *queryBasedExpectation) argsMatches(args []interface{}) error { - if len(args) != len(e.args) { - return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args)) - } - for k, v := range args { - // custom argument matcher - matcher, ok := e.args[k].(Argument) - if ok { - if !matcher.Match(v) { - return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k]) - } - continue - } - darg := e.args[k] - if !reflect.DeepEqual(darg, v) { - return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v, v) - } - } - return nil -} - -func (e *queryBasedExpectation) attemptArgMatch(args []interface{}) (err error) { - // catch panic - defer func() { - if e := recover(); e != nil { - _, ok := e.(error) - if !ok { - err = fmt.Errorf(e.(string)) - } - } - }() - - err = e.argsMatches(args) - return -} - // ExpectedCopyFrom is used to manage *pgx.Conn.CopyFrom expectations. // Returned by *Pgxmock.ExpectCopyFrom. type ExpectedCopyFrom struct { @@ -418,20 +363,6 @@ type ExpectedCopyFrom struct { expectedTableName pgx.Identifier expectedColumns []string rowsAffected int64 - delay time.Duration -} - -// WillReturnError allows to set an error for expected database exec action -func (e *ExpectedCopyFrom) WillReturnError(err error) *ExpectedCopyFrom { - e.err = err - return e -} - -// WillDelayFor allows to specify duration for which it will delay -// result. May be used together with Context -func (e *ExpectedCopyFrom) WillDelayFor(duration time.Duration) *ExpectedCopyFrom { - e.delay = duration - return e } // String returns string representation @@ -441,16 +372,13 @@ func (e *ExpectedCopyFrom) String() string { msg += fmt.Sprintf("\n - matches column names: '%+v'", e.expectedColumns) if e.err != nil { - msg += fmt.Sprintf("\n - should return error: %s", e.err) + msg += fmt.Sprintf("\n - should returns error: %s", e.err) } return msg } -// WillReturnResult arranges for an expected Exec() to return a particular -// result, there is pgxmock.NewResult(lastInsertID int64, affectedRows int64) method -// to build a corresponding result. Or if actions needs to be tested against errors -// pgxmock.NewErrorResult(err error) to return a given error. +// WillReturnResult arranges for an expected CopyFrom() to return a number of rows affected func (e *ExpectedCopyFrom) WillReturnResult(result int64) *ExpectedCopyFrom { e.rowsAffected = result return e @@ -464,3 +392,18 @@ type ExpectedReset struct { func (e *ExpectedReset) String() string { return "ExpectedReset => expecting database Reset" } + +// ExpectedRollback is used to manage pgx.Tx.Rollback expectation +// returned by pgxmock.ExpectRollback. +type ExpectedRollback struct { + commonExpectation +} + +// String returns string representation +func (e *ExpectedRollback) String() string { + msg := "ExpectedRollback => expecting transaction Rollback" + if e.err != nil { + msg += fmt.Sprintf(", which should return error: %s", e.err) + } + return msg +} diff --git a/expectations_test.go b/expectations_test.go index 5fbec73..e8e9eac 100644 --- a/expectations_test.go +++ b/expectations_test.go @@ -9,54 +9,129 @@ import ( "time" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/stretchr/testify/assert" ) -func TestCopyFromBug(t *testing.T) { +var ctx = context.Background() + +func TestTimes(t *testing.T) { + t.Parallel() + mock, _ := NewConn() + a := assert.New(t) + mock.ExpectPing().Times(2) + err := mock.Ping(ctx) + a.NoError(err) + a.Error(mock.ExpectationsWereMet()) // must be two Ping() calls + err = mock.Ping(ctx) + a.NoError(err) + a.NoError(mock.ExpectationsWereMet()) +} + +func TestMaybe(t *testing.T) { + t.Parallel() + mock, _ := NewConn() + a := assert.New(t) + mock.ExpectPing().Maybe() + mock.ExpectBegin().Maybe() + mock.ExpectQuery("SET TIME ZONE 'Europe/Rome'").Maybe() //only if we're in Italy + cmdtag := pgconn.NewCommandTag("SELECT 1") + mock.ExpectExec("select").WillReturnResult(cmdtag) + mock.ExpectCommit().Maybe() + + res, err := mock.Exec(ctx, "select version()") + a.Equal(cmdtag, res) + a.NoError(err) + a.NoError(mock.ExpectationsWereMet()) +} + +func TestPanic(t *testing.T) { + t.Parallel() mock, _ := NewConn() + a := assert.New(t) defer func() { - err := mock.ExpectationsWereMet() - if err != nil { - t.Errorf("expectation were not met: %s", err) - } + a.NotNil(recover(), "The code did not panic") + a.NoError(mock.ExpectationsWereMet()) }() + ex := mock.ExpectPing() + ex.WillPanic("i'm tired") + fmt.Println(ex) + a.NoError(mock.Ping(ctx)) +} + +func TestCallModifier(t *testing.T) { + t.Parallel() + mock, _ := NewConn() + a := assert.New(t) + + mock.ExpectPing().WillDelayFor(time.Second).Maybe().Times(4) + + c, f := context.WithCancel(ctx) + f() + a.Error(mock.Ping(c), "should raise error for cancelled context") + + a.NoError(mock.ExpectationsWereMet()) //should produce no error since Ping() call is optional + + a.NoError(mock.Ping(ctx)) + a.NoError(mock.ExpectationsWereMet()) //should produce no error since Ping() was called actually +} + +func TestCopyFromBug(t *testing.T) { + mock, _ := NewConn() + a := assert.New(t) + mock.ExpectCopyFrom(pgx.Identifier{"foo"}, []string{"bar"}).WillReturnResult(1) var rows [][]any rows = append(rows, []any{"baz"}) - _, err := mock.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"bar"}, pgx.CopyFromRows(rows)) - if err != nil { - t.Errorf("unexpected error: %s", err) - } + r, err := mock.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"bar"}, pgx.CopyFromRows(rows)) + a.EqualValues(len(rows), r) + a.NoError(err) + a.NoError(mock.ExpectationsWereMet()) } func ExampleExpectedExec() { mock, _ := NewConn() - result := NewErrorResult(fmt.Errorf("some error")) - mock.ExpectExec("^INSERT (.+)").WillReturnResult(result) - res, _ := mock.Exec(context.Background(), "INSERT something") - s := res.String() - fmt.Println(s) - // Output: some error -} - -func TestUnmonitoredPing(t *testing.T) { - mock, _ := NewConn() - p := mock.ExpectPing() - if p != nil { - t.Error("ExpectPing should return nil since MonitorPingsOption = false ") - } + ex := mock.ExpectExec("^INSERT (.+)").WillReturnResult(NewResult("INSERT", 15)) + ex.WillDelayFor(time.Second).Maybe().Times(2) + + fmt.Print(ex) + res, _ := mock.Exec(ctx, "INSERT something") + fmt.Println(res) + ex.WithArgs(42) + fmt.Print(ex) + res, _ = mock.Exec(ctx, "INSERT something", 42) + fmt.Print(res) + // Output: + // ExpectedExec => expecting call to Exec(): + // - matches sql: '^INSERT (.+)' + // - is without arguments + // - returns result: INSERT 15 + // - delayed execution for: 1s + // - execution is optional + // - execution calls awaited: 2 + // INSERT 15 + // ExpectedExec => expecting call to Exec(): + // - matches sql: '^INSERT (.+)' + // - is with arguments: + // 0 - 42 + // - returns result: INSERT 15 + // - delayed execution for: 1s + // - execution is optional + // - execution calls awaited: 2 + // INSERT 15 } func TestUnexpectedPing(t *testing.T) { - mock, _ := NewConn(MonitorPingsOption(true)) - err := mock.Ping(context.Background()) + mock, _ := NewConn() + err := mock.Ping(ctx) if err == nil { t.Error("Ping should return error for unexpected call") } mock.ExpectExec("foo") - err = mock.Ping(context.Background()) + err = mock.Ping(ctx) if err == nil { t.Error("Ping should return error for unexpected call") } @@ -64,12 +139,12 @@ func TestUnexpectedPing(t *testing.T) { func TestUnexpectedPrepare(t *testing.T) { mock, _ := NewConn() - _, err := mock.Prepare(context.Background(), "foo", "bar") + _, err := mock.Prepare(ctx, "foo", "bar") if err == nil { t.Error("Prepare should return error for unexpected call") } mock.ExpectExec("foo") - _, err = mock.Prepare(context.Background(), "foo", "bar") + _, err = mock.Prepare(ctx, "foo", "bar") if err == nil { t.Error("Prepare should return error for unexpected call") } @@ -77,19 +152,20 @@ func TestUnexpectedPrepare(t *testing.T) { func TestUnexpectedCopyFrom(t *testing.T) { mock, _ := NewConn() - _, err := mock.CopyFrom(context.Background(), pgx.Identifier{"schema", "table"}, []string{"foo", "bar"}, nil) + _, err := mock.CopyFrom(ctx, pgx.Identifier{"schema", "table"}, []string{"foo", "bar"}, nil) if err == nil { t.Error("CopyFrom should return error for unexpected call") } mock.ExpectExec("foo") - _, err = mock.CopyFrom(context.Background(), pgx.Identifier{"schema", "table"}, []string{"foo", "bar"}, nil) + _, err = mock.CopyFrom(ctx, pgx.Identifier{"schema", "table"}, []string{"foo", "bar"}, nil) if err == nil { t.Error("CopyFrom should return error for unexpected call") } } func TestBuildQuery(t *testing.T) { - mock, _ := NewConn(MonitorPingsOption(true)) + mock, _ := NewConn() + a := assert.New(t) query := ` SELECT name, @@ -105,18 +181,19 @@ func TestBuildQuery(t *testing.T) { ` mock.ExpectPing().WillDelayFor(1 * time.Second).WillReturnError(errors.New("no ping please")) - mock.ExpectQuery(query) - mock.ExpectExec(query) + mock.ExpectQuery(query).WillReturnError(errors.New("oops")) + mock.ExpectExec(query).WillReturnResult(NewResult("SELECT", 1)) mock.ExpectPrepare("foo", query) - _ = mock.Ping(context.Background()) - mock.QueryRow(context.Background(), query) - _, _ = mock.Exec(context.Background(), query) - _, _ = mock.Prepare(context.Background(), "foo", query) + err := mock.Ping(ctx) + a.Error(err) + mock.QueryRow(ctx, query) + _, err = mock.Exec(ctx, query) + a.NoError(err) + _, err = mock.Prepare(ctx, "foo", query) + a.NoError(err) - if err := mock.ExpectationsWereMet(); err != nil { - t.Error(err) - } + a.NoError(mock.ExpectationsWereMet()) } func TestQueryRowScan(t *testing.T) { @@ -138,7 +215,7 @@ func TestQueryRowScan(t *testing.T) { expectedIntValue := 2 expectedArrayValue := []string{"Three", "Four"} mock.ExpectQuery(query).WillReturnRows(mock.NewRows([]string{"One", "Two", "Three"}).AddRow(expectedStringValue, expectedIntValue, []string{"Three", "Four"})) - row := mock.QueryRow(context.Background(), query) + row := mock.QueryRow(ctx, query) var stringValue string var intValue int var arrayValue []string @@ -164,7 +241,7 @@ func TestMissingWithArgs(t *testing.T) { // No arguments expected mock.ExpectExec("INSERT something") // Receiving argument - _, err := mock.Exec(context.Background(), "INSERT something", "something") + _, err := mock.Exec(ctx, "INSERT something", "something") if err == nil { t.Error("arguments do not match error was expected") } diff --git a/go.mod b/go.mod index 996be9b..071a6e9 100644 --- a/go.mod +++ b/go.mod @@ -1,14 +1,22 @@ -module github.com/pashagolub/pgxmock/v2 +module github.com/pashagolub/pgxmock/v3 -go 1.20 +go 1.21 -require github.com/jackc/pgx/v5 v5.4.3 +require ( + github.com/jackc/pgx/v5 v5.4.3 + github.com/stretchr/testify v1.8.1 +) require ( + github.com/davecgh/go-spew v1.1.1 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rogpeppe/go-internal v1.11.0 // indirect golang.org/x/crypto v0.9.0 // indirect golang.org/x/sync v0.1.0 // indirect golang.org/x/text v0.9.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 9f5dac9..c75d161 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= @@ -8,12 +10,22 @@ github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY= github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= +github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= @@ -21,5 +33,7 @@ golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/options.go b/options.go index c33b121..5b24fbb 100644 --- a/options.go +++ b/options.go @@ -18,19 +18,3 @@ func QueryMatcherOption(queryMatcher QueryMatcher) func(*pgxmock) error { return nil } } - -// MonitorPingsOption determines whether calls to Ping on the driver should be -// observed and mocked. -// -// If true is passed, we will check these calls were expected. Expectations can -// be registered using the ExpectPing() method on the mock. -// -// If false is passed or this option is omitted, calls to Ping will not be -// considered when determining expectations and calls to ExpectPing will have -// no effect. -func MonitorPingsOption(monitorPings bool) func(*pgxmock) error { - return func(s *pgxmock) error { - s.monitorPings = monitorPings - return nil - } -} diff --git a/pgxmock.go b/pgxmock.go index 733908c..6051301 100644 --- a/pgxmock.go +++ b/pgxmock.go @@ -12,11 +12,8 @@ package pgxmock import ( "context" - "errors" "fmt" - "log" "reflect" - "time" pgx "github.com/jackc/pgx/v5" pgconn "github.com/jackc/pgx/v5/pgconn" @@ -150,9 +147,7 @@ type PgxPoolIface interface { type pgxmock struct { ordered bool queryMatcher QueryMatcher - monitorPings bool - - expected []expectation + expectations []Expectation } func (c *pgxmock) Config() *pgxpool.Config { @@ -170,7 +165,7 @@ func (c *pgxmock) AcquireFunc(_ context.Context, _ func(*pgxpool.Conn) error) er // region Expectations func (c *pgxmock) ExpectClose() *ExpectedClose { e := &ExpectedClose{} - c.expected = append(c.expected, e) + c.expectations = append(c.expectations, e) return e } @@ -179,9 +174,9 @@ func (c *pgxmock) MatchExpectationsInOrder(b bool) { } func (c *pgxmock) ExpectationsWereMet() error { - for _, e := range c.expected { + for _, e := range c.expectations { e.Lock() - fulfilled := e.fulfilled() + fulfilled := e.fulfilled() || !e.required() e.Unlock() if !fulfilled { @@ -190,7 +185,7 @@ func (c *pgxmock) ExpectationsWereMet() error { // for expected prepared statement check whether it was closed if expected if prep, ok := e.(*ExpectedPrepare); ok { - if prep.mustBeClosed && !prep.wasClosed { + if prep.mustBeClosed && !prep.deallocated { return fmt.Errorf("expected prepared statement to be closed, but it was not: %s", prep) } } @@ -208,70 +203,63 @@ func (c *pgxmock) ExpectationsWereMet() error { func (c *pgxmock) ExpectQuery(expectedSQL string) *ExpectedQuery { e := &ExpectedQuery{} e.expectSQL = expectedSQL - c.expected = append(c.expected, e) + c.expectations = append(c.expectations, e) return e } func (c *pgxmock) ExpectCommit() *ExpectedCommit { e := &ExpectedCommit{} - c.expected = append(c.expected, e) + c.expectations = append(c.expectations, e) return e } func (c *pgxmock) ExpectRollback() *ExpectedRollback { e := &ExpectedRollback{} - c.expected = append(c.expected, e) + c.expectations = append(c.expectations, e) return e } func (c *pgxmock) ExpectBegin() *ExpectedBegin { e := &ExpectedBegin{} - c.expected = append(c.expected, e) + c.expectations = append(c.expectations, e) return e } func (c *pgxmock) ExpectBeginTx(txOptions pgx.TxOptions) *ExpectedBegin { e := &ExpectedBegin{opts: txOptions} - c.expected = append(c.expected, e) + c.expectations = append(c.expectations, e) return e } func (c *pgxmock) ExpectExec(expectedSQL string) *ExpectedExec { e := &ExpectedExec{} e.expectSQL = expectedSQL - // e.converter = c.converter - c.expected = append(c.expected, e) + c.expectations = append(c.expectations, e) return e } func (c *pgxmock) ExpectCopyFrom(expectedTableName pgx.Identifier, expectedColumns []string) *ExpectedCopyFrom { - e := &ExpectedCopyFrom{} - e.expectedTableName = expectedTableName - e.expectedColumns = expectedColumns - c.expected = append(c.expected, e) + e := &ExpectedCopyFrom{expectedTableName: expectedTableName, expectedColumns: expectedColumns} + c.expectations = append(c.expectations, e) return e } // ExpectReset expects Reset to be called. func (c *pgxmock) ExpectReset() *ExpectedReset { e := &ExpectedReset{} - c.expected = append(c.expected, e) + c.expectations = append(c.expectations, e) return e } func (c *pgxmock) ExpectPing() *ExpectedPing { - if !c.monitorPings { - log.Println("ExpectPing will have no effect as monitoring pings is disabled. Use MonitorPingsOption to enable.") - return nil - } e := &ExpectedPing{} - c.expected = append(c.expected, e) + c.expectations = append(c.expectations, e) return e } func (c *pgxmock) ExpectPrepare(expectedStmtName, expectedSQL string) *ExpectedPrepare { e := &ExpectedPrepare{expectSQL: expectedSQL, expectStmtName: expectedStmtName, mock: c} - c.expected = append(c.expected, e) + c.expectations = append(c.expectations, e) return e } @@ -313,60 +301,23 @@ func (c *pgxmock) open(options []func(*pgxmock) error) error { return err } } - // if c.converter == nil { - // c.converter = driver.DefaultParameterConverter - // } + if c.queryMatcher == nil { c.queryMatcher = QueryMatcherRegexp } - if c.monitorPings { - // We call Ping on the driver shortly to verify startup assertions by - // driving internal behaviour of the sql standard library. We don't - // want this call to ping to be monitored for expectation purposes so - // temporarily disable. - c.monitorPings = false - defer func() { c.monitorPings = true }() - } - return c.Ping(context.TODO()) + return nil } // Close a mock database driver connection. It may or may not // be called depending on the circumstances, but if it is called // there must be an *ExpectedClose expectation satisfied. -func (c *pgxmock) close(context.Context) error { - var expected *ExpectedClose - var fulfilled int - var ok bool - for _, next := range c.expected { - next.Lock() - if next.fulfilled() { - next.Unlock() - fulfilled++ - continue - } - - if expected, ok = next.(*ExpectedClose); ok { - break - } - - next.Unlock() - if c.ordered { - return fmt.Errorf("call to database Close, was not expected, next expectation is: %s", next) - } - } - - if expected == nil { - msg := "call to database Close was not expected" - if fulfilled == len(c.expected) { - msg = "all expectations were already fulfilled, " + msg - } - return fmt.Errorf(msg) +func (c *pgxmock) Close(ctx context.Context) error { + ex, err := findExpectation[*ExpectedClose](c, "Close()") + if err != nil { + return err } - - expected.triggered = true - expected.Unlock() - return expected.err + return ex.waitForDelay(ctx) } func (c *pgxmock) Conn() *pgx.Conn { @@ -374,69 +325,19 @@ func (c *pgxmock) Conn() *pgx.Conn { } func (c *pgxmock) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, _ pgx.CopyFromSource) (int64, error) { - ex, err := c.copyFrom(tableName, columnNames) - if ex != nil { - select { - case <-time.After(ex.delay): - if err != nil { - return ex.rowsAffected, err - } - return ex.rowsAffected, nil - case <-ctx.Done(): - return -1, ErrCancelled - } - } - return -1, err -} - -func (c *pgxmock) copyFrom(tableName pgx.Identifier, columnNames []string) (*ExpectedCopyFrom, error) { - var expected *ExpectedCopyFrom - var fulfilled int - var ok bool - - for _, next := range c.expected { - next.Lock() - if next.fulfilled() { - next.Unlock() - fulfilled++ - continue - } - - if c.ordered { - if expected, ok = next.(*ExpectedCopyFrom); ok { - break - } - - next.Unlock() - return nil, fmt.Errorf("call to CopyFrom statement with table name '%s', was not expected, next expectation is: %s", tableName, next) - } - - if pr, ok := next.(*ExpectedCopyFrom); ok { - if reflect.DeepEqual(pr.expectedTableName, tableName) && reflect.DeepEqual(pr.expectedColumns, columnNames) { - expected = pr - break - } + ex, err := findExpectationFunc[*ExpectedCopyFrom](c, "BeginTx()", func(copyExp *ExpectedCopyFrom) error { + if !reflect.DeepEqual(copyExp.expectedTableName, tableName) { + return fmt.Errorf("CopyFrom: table name '%s' was not expected, expected table name is '%s'", tableName, copyExp.expectedTableName) } - next.Unlock() - } - - if expected == nil { - msg := "call to CopyFrom table name '%s' was not expected" - if fulfilled == len(c.expected) { - msg = "all expectations were already fulfilled, " + msg + if !reflect.DeepEqual(copyExp.expectedColumns, columnNames) { + return fmt.Errorf("CopyFrom: column names '%v' were not expected, expected column names are '%v'", columnNames, copyExp.expectedColumns) } - return nil, fmt.Errorf(msg, tableName) - } - defer expected.Unlock() - if !reflect.DeepEqual(expected.expectedTableName, tableName) { - return nil, fmt.Errorf("CopyFrom: table name '%s' was not expected, expected table name is '%s'", tableName, expected.expectedTableName) - } - if !reflect.DeepEqual(expected.expectedColumns, columnNames) { - return nil, fmt.Errorf("CopyFrom: column names '%v' were not expected, expected column names are '%v'", columnNames, expected.expectedColumns) + return nil + }) + if err != nil { + return -1, err } - - expected.triggered = true - return expected, expected.err + return ex.rowsAffected, ex.waitForDelay(ctx) } func (c *pgxmock) SendBatch(context.Context, *pgx.Batch) pgx.BatchResults { @@ -447,294 +348,100 @@ func (c *pgxmock) LargeObjects() pgx.LargeObjects { return pgx.LargeObjects{} } -func (c *pgxmock) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) { - ex, err := c.begin(txOptions) - if ex != nil { - time.Sleep(ex.delay) - } - if err != nil { - return nil, err - } - - return c, ctx.Err() -} - func (c *pgxmock) Begin(ctx context.Context) (pgx.Tx, error) { return c.BeginTx(ctx, pgx.TxOptions{}) } -func (c *pgxmock) begin(txOptions pgx.TxOptions) (*ExpectedBegin, error) { - var expected *ExpectedBegin - var ok bool - var fulfilled int - for _, next := range c.expected { - next.Lock() - if next.fulfilled() { - next.Unlock() - fulfilled++ - continue - } - - if expected, ok = next.(*ExpectedBegin); ok { - break - } - - next.Unlock() - if c.ordered { - return nil, fmt.Errorf("call to database transaction Begin, was not expected, next expectation is: %s", next) - } - } - if expected == nil { - msg := "call to database transaction Begin was not expected" - if fulfilled == len(c.expected) { - msg = "all expectations were already fulfilled, " + msg +func (c *pgxmock) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) { + ex, err := findExpectationFunc[*ExpectedBegin](c, "BeginTx()", func(beginExp *ExpectedBegin) error { + if beginExp.opts != txOptions { + return fmt.Errorf("BeginTx: call with transaction options '%v' was not expected: %s", txOptions, beginExp) } - return nil, fmt.Errorf(msg) - } - defer expected.Unlock() - if expected.opts != txOptions { - return nil, fmt.Errorf("Begin: call with transaction options '%v' was not expected, expected name is '%v'", txOptions, expected.opts) - } - expected.triggered = true - - return expected, expected.err -} - -func (c *pgxmock) Prepare(ctx context.Context, name, query string) (*pgconn.StatementDescription, error) { - ex, err := c.prepare(name, query) - if ex != nil { - time.Sleep(ex.delay) - } + return nil + }) if err != nil { return nil, err } - - return &pgconn.StatementDescription{Name: name, SQL: query}, ctx.Err() + if err = ex.waitForDelay(ctx); err != nil { + return nil, err + } + return c, nil } -func (c *pgxmock) prepare(name string, query string) (*ExpectedPrepare, error) { - var expected *ExpectedPrepare - var fulfilled int - var ok bool - - for _, next := range c.expected { - next.Lock() - if next.fulfilled() { - next.Unlock() - fulfilled++ - continue - } - - if c.ordered { - if expected, ok = next.(*ExpectedPrepare); ok { - break - } - - next.Unlock() - return nil, fmt.Errorf("call to Prepare statement with query '%s', was not expected, next expectation is: %s", query, next) - } - - if pr, ok := next.(*ExpectedPrepare); ok { - if err := c.queryMatcher.Match(pr.expectSQL, query); err == nil { - expected = pr - break - } +func (c *pgxmock) Prepare(ctx context.Context, name, query string) (*pgconn.StatementDescription, error) { + ex, err := findExpectationFunc[*ExpectedPrepare](c, "Exec()", func(prepareExp *ExpectedPrepare) error { + if err := c.queryMatcher.Match(prepareExp.expectSQL, query); err != nil { + return err } - next.Unlock() - } - - if expected == nil { - msg := "call to Prepare '%s' query was not expected" - if fulfilled == len(c.expected) { - msg = "all expectations were already fulfilled, " + msg + if prepareExp.expectStmtName != name { + return fmt.Errorf("Prepare: prepared statement name '%s' was not expected, expected name is '%s'", name, prepareExp.expectStmtName) } - return nil, fmt.Errorf(msg, query) - } - defer expected.Unlock() - if expected.expectStmtName != name { - return nil, fmt.Errorf("Prepare: prepared statement name '%s' was not expected, expected name is '%s'", name, expected.expectStmtName) + return nil + }) + if err != nil { + return nil, err } - if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil { - return nil, fmt.Errorf("Prepare: %v", err) + if err = ex.waitForDelay(ctx); err != nil { + return nil, err } - - expected.triggered = true - return expected, expected.err + return &pgconn.StatementDescription{Name: name, SQL: query}, nil } func (c *pgxmock) Deallocate(ctx context.Context, name string) error { - var expected *ExpectedPrepare - for _, next := range c.expected { + var ( + expected *ExpectedPrepare + ok bool + ) + for _, next := range c.expectations { next.Lock() - if pr, ok := next.(*ExpectedPrepare); ok && pr.expectStmtName == name { - expected = pr - next.Unlock() + expected, ok = next.(*ExpectedPrepare) + ok = ok && expected.expectStmtName == name + next.Unlock() + if ok { break } - next.Unlock() } if expected == nil { return fmt.Errorf("Deallocate: prepared statement name '%s' doesn't exist", name) } - expected.wasClosed = true - return ctx.Err() + expected.deallocated = true + return expected.waitForDelay(ctx) } func (c *pgxmock) Commit(ctx context.Context) error { - var expected *ExpectedCommit - var fulfilled int - var ok bool - for _, next := range c.expected { - next.Lock() - if next.fulfilled() { - next.Unlock() - fulfilled++ - continue - } - - if expected, ok = next.(*ExpectedCommit); ok { - break - } - - next.Unlock() - if c.ordered { - return fmt.Errorf("call to Commit transaction, was not expected, next expectation is: %s", next) - } - } - if expected == nil { - msg := "call to Commit transaction was not expected" - if fulfilled == len(c.expected) { - msg = "all expectations were already fulfilled, " + msg - } - return fmt.Errorf(msg) - } - - expected.triggered = true - expected.Unlock() - if expected.err != nil { - return expected.err + ex, err := findExpectation[*ExpectedCommit](c, "Commit()") + if err != nil { + return err } - return ctx.Err() + return ex.waitForDelay(ctx) } func (c *pgxmock) Rollback(ctx context.Context) error { - var expected *ExpectedRollback - var fulfilled int - var ok bool - for _, next := range c.expected { - next.Lock() - if next.fulfilled() { - next.Unlock() - fulfilled++ - continue - } - - if expected, ok = next.(*ExpectedRollback); ok { - break - } - - next.Unlock() - if c.ordered { - return fmt.Errorf("call to Rollback transaction, was not expected, next expectation is: %s", next) - } - } - if expected == nil { - msg := "call to Rollback transaction was not expected" - if fulfilled == len(c.expected) { - msg = "all expectations were already fulfilled, " + msg - } - return fmt.Errorf(msg) - } - - expected.triggered = true - expected.Unlock() - if expected.err != nil { - return expected.err + ex, err := findExpectation[*ExpectedRollback](c, "Rollback()") + if err != nil { + return err } - return ctx.Err() + return ex.waitForDelay(ctx) } -// ErrCancelled defines an error value, which can be expected in case of -// such cancellation error. -var ErrCancelled = errors.New("canceling query due to user request") - // Implement the "QueryerContext" interface func (c *pgxmock) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) { - ex, err := c.query(sql, args) - if ex != nil { - select { - case <-time.After(ex.delay): - if err != nil { - return nil, err - } - return ex.rows, nil - case <-ctx.Done(): - return nil, ErrCancelled - } - } - - return nil, err -} - -func (c *pgxmock) query(query string, args []interface{}) (*ExpectedQuery, error) { - var expected *ExpectedQuery - var fulfilled int - var ok bool - for _, next := range c.expected { - next.Lock() - if next.fulfilled() { - next.Unlock() - fulfilled++ - continue - } - - if c.ordered { - if expected, ok = next.(*ExpectedQuery); ok { - break - } - next.Unlock() - return nil, fmt.Errorf("call to Query '%s' with args %+v, was not expected, next expectation is: %s", query, args, next) + ex, err := findExpectationFunc[*ExpectedQuery](c, "Query()", func(queryExp *ExpectedQuery) error { + if err := c.queryMatcher.Match(queryExp.expectSQL, sql); err != nil { + return err } - if qr, ok := next.(*ExpectedQuery); ok { - if err := c.queryMatcher.Match(qr.expectSQL, query); err != nil { - next.Unlock() - continue - } - if err := qr.attemptArgMatch(args); err == nil { - expected = qr - break - } + if err := queryExp.argsMatches(args); err != nil { + return err } - next.Unlock() - } - - if expected == nil { - msg := "call to Query '%s' with args %+v was not expected" - if fulfilled == len(c.expected) { - msg = "all expectations were already fulfilled, " + msg + if queryExp.err == nil && queryExp.rows == nil { + return fmt.Errorf("Query must return a result rows or raise an error: %v", queryExp) } - return nil, fmt.Errorf(msg, query, args) - } - - defer expected.Unlock() - - if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil { - return nil, fmt.Errorf("Query: %v", err) - } - - if err := expected.argsMatches(args); err != nil { - return nil, fmt.Errorf("Query '%s', arguments do not match: %s", query, err) - } - - expected.triggered = true - if expected.err != nil { - return expected, expected.err // mocked to return error - } - - if expected.rows == nil { - return nil, fmt.Errorf("Query '%s' with args %+v, must return a pgx.Rows, but it was not set for expectation %T as %+v", query, args, expected, expected) + return nil + }) + if err != nil { + return nil, err } - return expected, nil + return ex.rows, ex.waitForDelay(ctx) } type errRow struct { @@ -746,123 +453,60 @@ func (er errRow) Scan(...interface{}) error { } func (c *pgxmock) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row { - ex, err := c.query(sql, args) - if ex != nil { - select { - case <-time.After(ex.delay): - if (err != nil) || (ex.rows == nil) { - return errRow{err} - } - _ = ex.rows.Next() - return ex.rows - case <-ctx.Done(): - return errRow{ctx.Err()} - } + rows, err := c.Query(ctx, sql, args...) + if err != nil { + return errRow{err} } - return errRow{err} + _ = rows.Next() + return rows } -// Implement the "ExecerContext" interface func (c *pgxmock) Exec(ctx context.Context, query string, args ...interface{}) (pgconn.CommandTag, error) { - ex, err := c.exec(query, args) - if ex != nil { - select { - case <-time.After(ex.delay): - if err != nil { - return pgconn.NewCommandTag(""), err - } - return ex.result, nil - case <-ctx.Done(): - return pgconn.NewCommandTag(""), ErrCancelled - } - } - return pgconn.NewCommandTag(""), err -} - -func (c *pgxmock) exec(query string, args []interface{}) (*ExpectedExec, error) { - var expected *ExpectedExec - var fulfilled int - var ok bool - for _, next := range c.expected { - next.Lock() - if next.fulfilled() { - next.Unlock() - fulfilled++ - continue - } - - if c.ordered { - if expected, ok = next.(*ExpectedExec); ok { - break - } - next.Unlock() - return nil, fmt.Errorf("call to ExecQuery '%s' with args %+v, was not expected, next expectation is: %s", query, args, next) + ex, err := findExpectationFunc[*ExpectedExec](c, "Exec()", func(execExp *ExpectedExec) error { + if err := c.queryMatcher.Match(execExp.expectSQL, query); err != nil { + return err } - if exec, ok := next.(*ExpectedExec); ok { - if err := c.queryMatcher.Match(exec.expectSQL, query); err != nil { - next.Unlock() - continue - } - - if err := exec.attemptArgMatch(args); err == nil { - expected = exec - break - } + if err := execExp.argsMatches(args); err != nil { + return err } - next.Unlock() - } - if expected == nil { - msg := "call to ExecQuery '%s' with args %+v was not expected" - if fulfilled == len(c.expected) { - msg = "all expectations were already fulfilled, " + msg + if execExp.result.String() == "" && execExp.err == nil { + return fmt.Errorf("Exec must return a result or raise an error: %s", execExp) } - return nil, fmt.Errorf(msg, query, args) - } - defer expected.Unlock() - - if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil { - return nil, fmt.Errorf("ExecQuery: %v", err) - } - - if err := expected.argsMatches(args); err != nil { - return nil, fmt.Errorf("ExecQuery '%s', arguments do not match: %s", query, err) - } - - expected.triggered = true - if expected.err != nil { - return expected, expected.err // mocked to return error - } - - if expected.result.String() == "" { - return nil, fmt.Errorf("ExecQuery '%s' with args %+v, must return a pgconn.CommandTag, but it was not set for expectation %T as %+v", query, args, expected, expected) + return nil + }) + if err != nil { + return pgconn.NewCommandTag(""), err } - - return expected, nil + return ex.result, ex.waitForDelay(ctx) } -// Implement the "Pinger" interface - the explicit DB driver ping was only added to database/sql in Go 1.8 -func (c *pgxmock) Ping(ctx context.Context) error { - if !c.monitorPings { - return nil +func (c *pgxmock) Ping(ctx context.Context) (err error) { + ex, err := findExpectation[*ExpectedPing](c, "Ping()") + if err != nil { + return err } + return ex.waitForDelay(ctx) +} - ex, err := c.ping() - if ex != nil { - select { - case <-ctx.Done(): - return ErrCancelled - case <-time.After(ex.delay): - } +func (c *pgxmock) Reset() { + ex, err := findExpectation[*ExpectedReset](c, "Reset()") + if err != nil { + return } + _ = ex.waitForDelay(context.Background()) +} - return err +type ExpectationType[t any] interface { + *t + Expectation } -func (c *pgxmock) ping() (*ExpectedPing, error) { - var expected *ExpectedPing +func findExpectationFunc[ET ExpectationType[t], t any](c *pgxmock, method string, cmp func(ET) error) (ET, error) { + var expected ET var fulfilled int var ok bool - for _, next := range c.expected { + var err error + for _, next := range c.expectations { next.Lock() if next.fulfilled() { next.Unlock() @@ -870,47 +514,39 @@ func (c *pgxmock) ping() (*ExpectedPing, error) { continue } - if expected, ok = next.(*ExpectedPing); ok { - break + if expected, ok = next.(ET); ok { + err = cmp(expected) + if err == nil { + break + } } - - next.Unlock() if c.ordered { - return nil, fmt.Errorf("call to database Ping, was not expected, next expectation is: %s", next) + if (!ok || err != nil) && !next.required() { + next.Unlock() + continue + } + next.Unlock() + if err != nil { + return nil, err + } + return nil, fmt.Errorf("call to method %s, was not expected, next expectation is: %s", method, next) } + next.Unlock() } if expected == nil { - msg := "call to database Ping was not expected" - if fulfilled == len(c.expected) { + msg := fmt.Sprintf("call to method %s was not expected", method) + if fulfilled == len(c.expectations) { msg = "all expectations were already fulfilled, " + msg } return nil, fmt.Errorf(msg) } + defer expected.Unlock() - expected.triggered = true - expected.Unlock() - return expected, expected.err + expected.fulfill() + return expected, nil } -func (c *pgxmock) Reset() { - var expected *ExpectedReset - var ok bool - for _, next := range c.expected { - next.Lock() - if next.fulfilled() { - next.Unlock() - continue - } - - if expected, ok = next.(*ExpectedReset); ok { - break - } - next.Unlock() - } - if expected == nil { - return - } - expected.triggered = true - expected.Unlock() +func findExpectation[ET ExpectationType[t], t any](c *pgxmock, method string) (ET, error) { + return findExpectationFunc[ET, t](c, method, func(_ ET) error { return nil }) } diff --git a/pgxmock_test.go b/pgxmock_test.go index 7ebcf8c..3ab286c 100644 --- a/pgxmock_test.go +++ b/pgxmock_test.go @@ -10,6 +10,8 @@ import ( "time" pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/stretchr/testify/assert" ) func cancelOrder(db pgxIface, orderID int) error { @@ -109,43 +111,32 @@ func TestMockQuery(t *testing.T) { func TestMockCopyFrom(t *testing.T) { t.Parallel() - mock, err := NewConn() - if err != nil { - t.Errorf("an error '%s' was not expected when opening a stub database connection", err) - } - defer mock.Close(context.Background()) - + mock, _ := NewConn() + a := assert.New(t) mock.ExpectCopyFrom(pgx.Identifier{"fooschema", "baztable"}, []string{"col1"}). WillReturnResult(2).WillDelayFor(1 * time.Second) - _, err = mock.CopyFrom(context.Background(), pgx.Identifier{"error", "error"}, []string{"error"}, nil) - if err == nil { - t.Error("error is expected while executing CopyFrom") - } - if mock.ExpectationsWereMet() == nil { - t.Error("there must be unfulfilled expectations") - } + res, err := mock.CopyFrom(context.Background(), pgx.Identifier{"error", "error"}, []string{"error"}, nil) + a.Error(err, "incorrect table should raise an error") + a.EqualValues(res, -1) + a.Error(mock.ExpectationsWereMet(), "there must be unfulfilled expectations") - rows, err := mock.CopyFrom(context.Background(), pgx.Identifier{"fooschema", "baztable"}, []string{"col1"}, nil) - if err != nil { - t.Errorf("error '%s' was not expected while executing CopyFrom", err) - } + res, err = mock.CopyFrom(context.Background(), pgx.Identifier{"fooschema", "baztable"}, []string{"error"}, nil) + a.Error(err, "incorrect columns should raise an error") + a.EqualValues(res, -1) + a.Error(mock.ExpectationsWereMet(), "there must be unfulfilled expectations") - if rows != 2 { - t.Errorf("expected RowsAffected to be 2, but got %d instead", rows) - } + res, err = mock.CopyFrom(context.Background(), pgx.Identifier{"fooschema", "baztable"}, []string{"col1"}, nil) + a.NoError(err) + a.EqualValues(res, 2) mock.ExpectCopyFrom(pgx.Identifier{"fooschema", "baztable"}, []string{"col1"}). WillReturnError(errors.New("error is here")) _, err = mock.CopyFrom(context.Background(), pgx.Identifier{"fooschema", "baztable"}, []string{"col1"}, nil) - if err == nil { - t.Error("error is expected while executing CopyFrom") - } + a.Error(err) - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("there were unfulfilled expectations: %s", err) - } + a.NoError(mock.ExpectationsWereMet()) } func TestMockQueryTypes(t *testing.T) { @@ -203,86 +194,64 @@ func TestMockQueryTypes(t *testing.T) { func TestTransactionExpectations(t *testing.T) { t.Parallel() - mock, err := NewConn() - if err != nil { - t.Errorf("an error '%s' was not expected when opening a stub database connection", err) - } - defer mock.Close(context.Background()) + mock, _ := NewConn() + a := assert.New(t) // begin and commit mock.ExpectBegin() mock.ExpectCommit() - tx, err := mock.Begin(context.Background()) - if err != nil { - t.Errorf("an error '%s' was not expected when beginning a transaction", err) - } - - err = tx.Commit(context.Background()) - if err != nil { - t.Errorf("an error '%s' was not expected when committing a transaction", err) - } + tx, err := mock.Begin(ctx) + a.NoError(err) + err = tx.Commit(ctx) + a.NoError(err) // beginTx and commit - mock.ExpectBeginTx(pgx.TxOptions{}) + mock.ExpectBeginTx(pgx.TxOptions{AccessMode: pgx.ReadOnly}) mock.ExpectCommit() - tx, err = mock.BeginTx(context.Background(), pgx.TxOptions{}) - if err != nil { - t.Errorf("an error '%s' was not expected when beginning a transaction", err) - } + _, err = mock.BeginTx(ctx, pgx.TxOptions{}) + a.Error(err, "wrong tx access mode should raise error") - err = tx.Commit(context.Background()) - if err != nil { - t.Errorf("an error '%s' was not expected when committing a transaction", err) - } + tx, err = mock.BeginTx(ctx, pgx.TxOptions{AccessMode: pgx.ReadOnly}) + a.NoError(err) + err = tx.Commit(ctx) + a.NoError(err) // begin and rollback mock.ExpectBegin() mock.ExpectRollback() - tx, err = mock.Begin(context.Background()) - if err != nil { - t.Errorf("an error '%s' was not expected when beginning a transaction", err) - } - - err = tx.Rollback(context.Background()) - if err != nil { - t.Errorf("an error '%s' was not expected when rolling back a transaction", err) - } + tx, err = mock.Begin(ctx) + a.NoError(err) + err = tx.Rollback(ctx) + a.NoError(err) // begin with an error - mock.ExpectBegin().WillReturnError(fmt.Errorf("some err")) + mock.ExpectBegin().WillReturnError(errors.New("some err")) - _, err = mock.Begin(context.Background()) - if err == nil { - t.Error("an error was expected when beginning a transaction, but got none") - } + _, err = mock.Begin(ctx) + a.Error(err) - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("there were unfulfilled expectations: %s", err) - } + a.NoError(mock.ExpectationsWereMet()) } func TestPrepareExpectations(t *testing.T) { t.Parallel() - mock, err := NewConn() - if err != nil { - t.Errorf("an error '%s' was not expected when opening a stub database connection", err) - } - defer mock.Close(context.Background()) + mock, _ := NewConn() + a := assert.New(t) mock.ExpectPrepare("foo", "SELECT (.+) FROM articles WHERE id = ?"). - WillDelayFor(1 * time.Second). - WillReturnCloseError(errors.New("invaders must die")) + WillReturnCloseError(errors.New("invaders must die")). + WillDelayFor(1 * time.Second) - stmt, err := mock.Prepare(context.Background(), "foo", "SELECT (.+) FROM articles WHERE id = $1") - if err != nil { - t.Errorf("error '%s' was not expected while creating a prepared statement", err) - } - if stmt == nil { - t.Errorf("stmt was expected while creating a prepared statement") - } + stmt, err := mock.Prepare(context.Background(), "baz", "SELECT (.+) FROM articles WHERE id = ?") + a.Error(err, "wrong prepare stmt name should raise an error") + a.Nil(stmt) + + stmt, err = mock.Prepare(context.Background(), "foo", "SELECT (.+) FROM articles WHERE id = $1") + a.NoError(err) + a.NotNil(stmt) // expect something else, w/o ExpectPrepare() var id int @@ -294,24 +263,15 @@ func TestPrepareExpectations(t *testing.T) { WillReturnRows(rs) err = mock.QueryRow(context.Background(), "foo", 5).Scan(&id, &title) - if err != nil { - t.Errorf("error '%s' was not expected while retrieving mock rows", err) - } + a.NoError(err) mock.ExpectPrepare("foo", "SELECT (.+) FROM articles WHERE id = ?"). WillReturnError(fmt.Errorf("Some DB error occurred")) stmt, err = mock.Prepare(context.Background(), "foo", "SELECT id FROM articles WHERE id = $1") - if err == nil { - t.Error("error was expected while creating a prepared statement") - } - if stmt != nil { - t.Errorf("stmt was not expected while creating a prepared statement returning error") - } - - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("there were unfulfilled expectations: %s", err) - } + a.Error(err) + a.Nil(stmt) + a.NoError(mock.ExpectationsWereMet()) } func TestPreparedQueryExecutions(t *testing.T) { @@ -1034,6 +994,7 @@ func TestExpectedCloseOrder(t *testing.T) { } defer mock.Close(context.Background()) mock.ExpectClose().WillReturnError(fmt.Errorf("Close failed")) + t.Log() _, _ = mock.Begin(context.Background()) if err := mock.ExpectationsWereMet(); err == nil { t.Error("expected error on ExpectationsWereMet") @@ -1054,29 +1015,25 @@ func TestExpectedBeginOrder(t *testing.T) { } func TestPreparedStatementCloseExpectation(t *testing.T) { - // Open new mock database - mock, err := NewConn() - if err != nil { - fmt.Println("error creating mock database") - return - } - defer mock.Close(context.Background()) + t.Parallel() + mock, _ := NewConn() + a := assert.New(t) - ep := mock.ExpectPrepare("foo", "INSERT INTO ORDERS").WillBeClosed() + ep := mock.ExpectPrepare("foo", "INSERT INTO ORDERS").WillBeDeallocated() ep.ExpectExec().WithArgs(AnyArg(), AnyArg()).WillReturnResult(NewResult("UPDATE", 1)) - _, err = mock.Prepare(context.Background(), "foo", "INSERT INTO ORDERS(ID, STATUS) VALUES (?, ?)") - if err != nil { - t.Fatal(err) - } + stmt, err := mock.Prepare(context.Background(), "foo", "INSERT INTO ORDERS(ID, STATUS) VALUES (?, ?)") + a.NoError(err) + a.NotNil(stmt) - if _, err := mock.Exec(context.Background(), "foo", 1, "Hello"); err != nil { - t.Fatal(err) - } + _, err = mock.Exec(context.Background(), "foo", 1, "Hello") + a.NoError(err) - if err := mock.Deallocate(context.Background(), "foo"); err != nil { - t.Fatal(err) - } + err = mock.Deallocate(context.Background(), "baz") + a.Error(err, "wrong prepares stmt name should raise an error") + + err = mock.Deallocate(context.Background(), "foo") + a.NoError(err) if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) @@ -1094,8 +1051,8 @@ func TestExecExpectationErrorDelay(t *testing.T) { // test that return of error is delayed delay := time.Millisecond * 100 mock.ExpectExec("^INSERT INTO articles").WithArgs(AnyArg()). - WillReturnError(errors.New("slow fail")). - WillDelayFor(delay) + WillDelayFor(delay). + WillReturnError(errors.New("slow fail")) start := time.Now() res, err := mock.Exec(context.Background(), "INSERT INTO articles (title) VALUES (?)", "hello") @@ -1172,9 +1129,9 @@ func TestQueryWithTimeout(t *testing.T) { rs := NewRows([]string{"id", "title"}).FromCSVString("5,hello world") mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). - WillDelayFor(50 * time.Millisecond). // Query will take longer than timeout WithArgs(5). - WillReturnRows(rs) + WillReturnRows(rs). + WillDelayFor(50 * time.Millisecond) // Query will take longer than timeout _, err = queryWithTimeout(10*time.Millisecond, mock, "SELECT (.+) FROM articles WHERE id = ?", 5) if err == nil { @@ -1209,71 +1166,33 @@ func queryWithTimeout(t time.Duration, db pgxIface, query string, args ...interf } } -func TestCon(t *testing.T) { - mock, err := NewConn() - if err != nil { - t.Errorf("an error '%s' was not expected when opening a stub database connection", err) - } - defer mock.Close(context.Background()) - defer func() { - if r := recover(); r == nil { - t.Errorf("The Conn() did not panic") - } - }() - _ = mock.Conn() -} - -func TestConnInfo(t *testing.T) { - mock, err := NewConn() - if err != nil { - t.Errorf("an error '%s' was not expected when opening a stub database connection", err) - } - defer mock.Close(context.Background()) - - _ = mock.Config() -} - -func TestPgConn(t *testing.T) { - mock, err := NewConn() - if err != nil { - t.Errorf("an error '%s' was not expected when opening a stub database connection", err) - } - defer mock.Close(context.Background()) - - _ = mock.PgConn() +func TestUnmockedMethods(t *testing.T) { + mock, _ := NewPool() + a := assert.New(t) + a.NotNil(mock.Config()) + a.NotNil(mock.PgConn()) + a.NotNil(mock.AcquireAllIdle(ctx)) + a.Nil(mock.AcquireFunc(ctx, func(*pgxpool.Conn) error { return nil })) + a.Nil(mock.SendBatch(ctx, nil)) + a.Zero(mock.LargeObjects()) + a.Panics(func() { _ = mock.Conn() }) } func TestNewRowsWithColumnDefinition(t *testing.T) { - mock, err := NewConn() - if err != nil { - t.Errorf("an error '%s' was not expected when opening a stub database connection", err) - } - defer mock.Close(context.Background()) + mock, _ := NewConn() r := mock.NewRowsWithColumnDefinition(*mock.NewColumn("foo")) - if len(r.defs) != 1 { - t.Error("NewRows failed") - } + assert.Equal(t, 1, len(r.defs)) } func TestExpectReset(t *testing.T) { - mock, err := NewPool() - if err != nil { - t.Errorf("an error '%s' was not expected when opening a stub database connection", err) - } - defer mock.Close() - + mock, _ := NewPool() + a := assert.New(t) // Successful scenario - _ = mock.ExpectReset() + mock.ExpectReset() mock.Reset() - err = mock.ExpectationsWereMet() - if err != nil { - t.Errorf("there were unfulfilled expectations: %s", err) - } + a.NoError(mock.ExpectationsWereMet()) // Unsuccessful scenario mock.ExpectReset() - err = mock.ExpectationsWereMet() - if err == nil { - t.Error("was expecting an error, but there was none") - } + a.Error(mock.ExpectationsWereMet()) } diff --git a/query_test.go b/query_test.go index bc0718f..9b7bba1 100644 --- a/query_test.go +++ b/query_test.go @@ -9,7 +9,7 @@ import ( func ExampleQueryMatcher() { // configure to use case sensitive SQL query matcher // instead of default regular expression matcher - mock, err := NewConn(QueryMatcherOption(QueryMatcherEqual), MonitorPingsOption(true)) + mock, err := NewConn(QueryMatcherOption(QueryMatcherEqual)) if err != nil { fmt.Println("failed to open pgxmock database:", err) } diff --git a/result.go b/result.go index 94921f3..e2f0dcb 100644 --- a/result.go +++ b/result.go @@ -6,14 +6,8 @@ import ( "github.com/jackc/pgx/v5/pgconn" ) -// NewResult creates a new sql driver Result +// NewResult creates a new pgconn.CommandTag result // for Exec based query mocks. func NewResult(op string, rowsAffected int64) pgconn.CommandTag { - return pgconn.NewCommandTag(fmt.Sprint(op, rowsAffected)) -} - -// NewErrorResult creates a new sql driver Result -// which returns an error given for both interface methods -func NewErrorResult(err error) pgconn.CommandTag { - return pgconn.NewCommandTag(err.Error()) + return pgconn.NewCommandTag(fmt.Sprintf("%s %d", op, rowsAffected)) } diff --git a/rows.go b/rows.go index 8997ad6..6483bcb 100644 --- a/rows.go +++ b/rows.go @@ -23,9 +23,9 @@ var CSVColumnParser = func(s string) interface{} { } type rowSets struct { - sets []*Rows - pos int - ex *ExpectedQuery + sets []*Rows + RowSetNo int + ex *ExpectedQuery } func (rs *rowSets) Conn() *pgx.Conn { @@ -33,16 +33,16 @@ func (rs *rowSets) Conn() *pgx.Conn { } func (rs *rowSets) Err() error { - r := rs.sets[rs.pos] - return r.nextErr[r.pos-1] + r := rs.sets[rs.RowSetNo] + return r.nextErr[r.recNo-1] } func (rs *rowSets) CommandTag() pgconn.CommandTag { - return rs.sets[rs.pos].commandTag + return rs.sets[rs.RowSetNo].commandTag } func (rs *rowSets) FieldDescriptions() []pgconn.FieldDescription { - return rs.sets[rs.pos].defs + return rs.sets[rs.RowSetNo].defs } // func (rs *rowSets) Columns() []string { @@ -56,21 +56,21 @@ func (rs *rowSets) Close() { // advances to next row func (rs *rowSets) Next() bool { - r := rs.sets[rs.pos] - r.pos++ - return r.pos <= len(r.rows) + r := rs.sets[rs.RowSetNo] + r.recNo++ + return r.recNo <= len(r.rows) } // Values returns the decoded row values. As with Scan(), it is an error to // call Values without first calling Next() and checking that it returned // true. func (rs *rowSets) Values() ([]interface{}, error) { - r := rs.sets[rs.pos] - return r.rows[r.pos-1], r.nextErr[r.pos-1] + r := rs.sets[rs.RowSetNo] + return r.rows[r.recNo-1], r.nextErr[r.recNo-1] } func (rs *rowSets) Scan(dest ...interface{}) error { - r := rs.sets[rs.pos] + r := rs.sets[rs.RowSetNo] if len(dest) == 1 { if rc, ok := dest[0].(pgx.RowScanner); ok { return rc.ScanRow(rs) @@ -82,7 +82,7 @@ func (rs *rowSets) Scan(dest ...interface{}) error { if len(r.rows) == 0 { return pgx.ErrNoRows } - for i, col := range r.rows[r.pos-1] { + for i, col := range r.rows[r.recNo-1] { if dest[i] == nil { //behave compatible with pgx continue @@ -116,14 +116,14 @@ func (rs *rowSets) Scan(dest ...interface{}) error { } } - return r.nextErr[r.pos-1] + return r.nextErr[r.recNo-1] } func (rs *rowSets) RawValues() [][]byte { - r := rs.sets[rs.pos] + r := rs.sets[rs.RowSetNo] dest := make([][]byte, len(r.defs)) - for i, col := range r.rows[r.pos-1] { + for i, col := range r.rows[r.recNo-1] { if b, ok := rawBytes(col); ok { dest[i] = b continue @@ -137,23 +137,23 @@ func (rs *rowSets) RawValues() [][]byte { // transforms to debuggable printable string func (rs *rowSets) String() string { if rs.empty() { - return "with empty rows" + return "\t- returns no data" } - msg := "should return rows:\n" + msg := "\t- returns data:\n" if len(rs.sets) == 1 { for n, row := range rs.sets[0].rows { - msg += fmt.Sprintf(" row %d - %+v\n", n, row) + msg += fmt.Sprintf("\t\trow %d - %+v\n", n, row) } - return strings.TrimSpace(msg) + return msg } for i, set := range rs.sets { - msg += fmt.Sprintf(" result set: %d\n", i) + msg += fmt.Sprintf("\t\tresult set: %d\n", i) for n, row := range set.rows { - msg += fmt.Sprintf(" row %d - %+v\n", n, row) + msg += fmt.Sprintf("\t\t\trow %d: %+v\n", n, row) } } - return strings.TrimSpace(msg) + return msg } func (rs *rowSets) empty() bool { @@ -182,7 +182,7 @@ type Rows struct { commandTag pgconn.CommandTag defs []pgconn.FieldDescription rows [][]interface{} - pos int + recNo int nextErr map[int]error closeErr error } diff --git a/rows_test.go b/rows_test.go index 61a83a0..0d8cb15 100644 --- a/rows_test.go +++ b/rows_test.go @@ -10,6 +10,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" + "github.com/stretchr/testify/assert" ) func TestPointerToInterfaceArgument(t *testing.T) { @@ -225,15 +226,16 @@ func ExampleRows_expectToBeClosed() { fmt.Println("got error:", err) } - // Output: got error: expected query rows to be closed, but it was not: ExpectedQuery => expecting Query, QueryContext or QueryRow which: - // - matches sql: 'SELECT' - // - is without arguments - // - should return rows: - // result set: 0 - // row 0 - [1 john] - // result set: 1 - // row 0 - [1 john] - // row 1 - [2 anna] + /*Output: got error: expected query rows to be closed, but it was not: ExpectedQuery => expecting call to Query() or to QueryRow(): + - matches sql: 'SELECT' + - is without arguments + - returns data: + result set: 0 + row 0: [1 john] + result set: 1 + row 0: [1 john] + row 1: [2 anna] + */ } func ExampleRows_customDriverValue() { @@ -436,70 +438,6 @@ func ExampleRows_rawValues() { // } -// func TestQueryRowBytesNotInvalidatedByNext_bytesIntoBytes(t *testing.T) { -// t.Parallel() -// rows := NewRows([]string{"raw"}). -// AddRow([]byte(`one binary value with some text!`)). -// AddRow([]byte(`two binary value with even more text than the first one`)) -// scan := func(rs *sql.Rows) ([]byte, error) { -// var b []byte -// return b, rs.Scan(&b) -// } -// want := [][]byte{[]byte(`one binary value with some text!`), []byte(`two binary value with even more text than the first one`)} -// queryRowBytesNotInvalidatedByNext(t, rows, scan, want) -// } - -// func TestQueryRowBytesNotInvalidatedByNext_stringIntoBytes(t *testing.T) { -// t.Parallel() -// rows := NewRows([]string{"raw"}). -// AddRow(`one binary value with some text!`). -// AddRow(`two binary value with even more text than the first one`) -// scan := func(rs *sql.Rows) ([]byte, error) { -// var b []byte -// return b, rs.Scan(&b) -// } -// want := [][]byte{[]byte(`one binary value with some text!`), []byte(`two binary value with even more text than the first one`)} -// queryRowBytesNotInvalidatedByNext(t, rows, scan, want) -// } - -// func TestQueryRowBytesInvalidatedByClose_bytesIntoRawBytes(t *testing.T) { -// t.Parallel() -// replace := []byte(invalid) -// rows := NewRows([]string{"raw"}).AddRow([]byte(`one binary value with some text!`)) -// scan := func(rs *sql.Rows) ([]byte, error) { -// var raw sql.RawBytes -// return raw, rs.Scan(&raw) -// } -// want := struct { -// Initial []byte -// Replaced []byte -// }{ -// Initial: []byte(`one binary value with some text!`), -// Replaced: replace[:len(replace)-7], -// } -// queryRowBytesInvalidatedByClose(t, rows, scan, want) -// } - -// func TestQueryRowBytesNotInvalidatedByClose_bytesIntoBytes(t *testing.T) { -// t.Parallel() -// rows := NewRows([]string{"raw"}).AddRow([]byte(`one binary value with some text!`)) -// scan := func(rs *sql.Rows) ([]byte, error) { -// var b []byte -// return b, rs.Scan(&b) -// } -// queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`one binary value with some text!`)) -// } - -// func TestQueryRowBytesNotInvalidatedByClose_stringIntoBytes(t *testing.T) { -// t.Parallel() -// rows := NewRows([]string{"raw"}).AddRow(`one binary value with some text!`) -// scan := func(rs *sql.Rows) ([]byte, error) { -// var b []byte -// return b, rs.Scan(&b) -// } -// queryRowBytesNotInvalidatedByClose(t, rows, scan, []byte(`one binary value with some text!`)) -// } - func TestRowsScanError(t *testing.T) { t.Parallel() mock, err := NewConn() @@ -681,179 +619,6 @@ func TestEmptyRowSets(t *testing.T) { } } -// func queryRowBytesInvalidatedByNext(t *testing.T, rows *Rows, scan func(*sql.Rows) ([]byte, error), want []struct { -// Initial []byte -// Replaced []byte -// }) { -// mock, err := New() -// if err != nil { -// t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) -// } -// defer mock.Close(context.Background()) -// mock.ExpectQuery("SELECT").WillReturnRows(rows) - -// rs, err := mock.Query(context.Background(), "SELECT") -// if err != nil { -// t.Fatalf("failed to query rows: %s", err) -// } - -// if !rs.Next() || rs.Err() != nil { -// t.Fatal("unexpected error on first row retrieval") -// } -// var count int -// for i := 0; ; i++ { -// count++ -// b, err := scan(rs) -// if err != nil { -// t.Fatalf("unexpected error scanning row: %s", err) -// } -// if exp := want[i].Initial; !bytes.Equal(b, exp) { -// t.Fatalf("expected raw value to be '%s' (len:%d), but got [%T]:%s (len:%d)", exp, len(exp), b, b, len(b)) -// } -// next := rs.Next() -// if exp := want[i].Replaced; !bytes.Equal(b, exp) { -// t.Fatalf("expected raw value to be replaced with '%s' (len:%d) after calling Next(), but got [%T]:%s (len:%d)", exp, len(exp), b, b, len(b)) -// } -// if !next { -// break -// } -// } -// if err := rs.Err(); err != nil { -// t.Fatalf("row iteration failed: %s", err) -// } -// if exp := len(want); count != exp { -// t.Fatalf("incorrect number of rows exp: %d, but got %d", exp, count) -// } - -// if err := mock.ExpectationsWereMet(); err != nil { -// t.Fatal(err) -// } -// } - -// func queryRowBytesNotInvalidatedByNext(t *testing.T, rows *Rows, scan func(*sql.Rows) ([]byte, error), want [][]byte) { -// mock, err := New() -// if err != nil { -// t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) -// } -// defer mock.Close(context.Background()) -// mock.ExpectQuery("SELECT").WillReturnRows(rows) - -// rs, err := mock.Query(context.Background(), "SELECT") -// if err != nil { -// t.Fatalf("failed to query rows: %s", err) -// } - -// if !rs.Next() || rs.Err() != nil { -// t.Fatal("unexpected error on first row retrieval") -// } -// var count int -// for i := 0; ; i++ { -// count++ -// b, err := scan(rs) -// if err != nil { -// t.Fatalf("unexpected error scanning row: %s", err) -// } -// if exp := want[i]; !bytes.Equal(b, exp) { -// t.Fatalf("expected raw value to be '%s' (len:%d), but got [%T]:%s (len:%d)", exp, len(exp), b, b, len(b)) -// } -// next := rs.Next() -// if exp := want[i]; !bytes.Equal(b, exp) { -// t.Fatalf("expected raw value to be replaced with '%s' (len:%d) after calling Next(), but got [%T]:%s (len:%d)", exp, len(exp), b, b, len(b)) -// } -// if !next { -// break -// } -// } -// if err := rs.Err(); err != nil { -// t.Fatalf("row iteration failed: %s", err) -// } -// if exp := len(want); count != exp { -// t.Fatalf("incorrect number of rows exp: %d, but got %d", exp, count) -// } - -// if err := mock.ExpectationsWereMet(); err != nil { -// t.Fatal(err) -// } -// } - -// func queryRowBytesInvalidatedByClose(t *testing.T, rows *Rows, scan func(*sql.Rows) ([]byte, error), want struct { -// Initial []byte -// Replaced []byte -// }) { -// mock, err := New() -// if err != nil { -// t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) -// } -// defer mock.Close(context.Background()) -// mock.ExpectQuery("SELECT").WillReturnRows(rows) - -// rs, err := mock.Query(context.Background(), "SELECT") -// if err != nil { -// t.Fatalf("failed to query rows: %s", err) -// } - -// if !rs.Next() || rs.Err() != nil { -// t.Fatal("unexpected error on first row retrieval") -// } -// b, err := scan(rs) -// if err != nil { -// t.Fatalf("unexpected error scanning row: %s", err) -// } -// if !bytes.Equal(b, want.Initial) { -// t.Fatalf("expected raw value to be '%s' (len:%d), but got [%T]:%s (len:%d)", want.Initial, len(want.Initial), b, b, len(b)) -// } -// rs.Close() - -// if !bytes.Equal(b, want.Replaced) { -// t.Fatalf("expected raw value to be replaced with '%s' (len:%d) after calling Next(), but got [%T]:%s (len:%d)", want.Replaced, len(want.Replaced), b, b, len(b)) -// } -// if err := rs.Err(); err != nil { -// t.Fatalf("row iteration failed: %s", err) -// } - -// if err := mock.ExpectationsWereMet(); err != nil { -// t.Fatal(err) -// } -// } - -// func queryRowBytesNotInvalidatedByClose(t *testing.T, rows *Rows, scan func(*sql.Rows) ([]byte, error), want []byte) { -// mock, err := New() -// if err != nil { -// t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) -// } -// defer mock.Close(context.Background()) -// mock.ExpectQuery("SELECT").WillReturnRows(rows) - -// rs, err := mock.Query(context.Background(), "SELECT") -// if err != nil { -// t.Fatalf("failed to query rows: %s", err) -// } - -// if !rs.Next() || rs.Err() != nil { -// t.Fatal("unexpected error on first row retrieval") -// } -// b, err := scan(rs) -// if err != nil { -// t.Fatalf("unexpected error scanning row: %s", err) -// } -// if !bytes.Equal(b, want) { -// t.Fatalf("expected raw value to be '%s' (len:%d), but got [%T]:%s (len:%d)", want, len(want), b, b, len(b)) -// } -// if err := rs.Close(); err != nil { -// t.Fatalf("unexpected error closing rows: %s", err) -// } -// if !bytes.Equal(b, want) { -// t.Fatalf("expected raw value to be replaced with '%s' (len:%d) after calling Next(), but got [%T]:%s (len:%d)", want, len(want), b, b, len(b)) -// } -// if err := rs.Err(); err != nil { -// t.Fatalf("row iteration failed: %s", err) -// } - -// if err := mock.ExpectationsWereMet(); err != nil { -// t.Fatal(err) -// } -// } - func TestMockQueryWithCollect(t *testing.T) { t.Parallel() mock, err := NewConn() @@ -878,10 +643,6 @@ func TestMockQueryWithCollect(t *testing.T) { defer rows.Close() - //if !rows.Next() { - // t.Error("it must have had one row as result, but got empty result set instead") - //} - rawMap, err := pgx.CollectRows(rows, pgx.RowToAddrOfStructByPos[rowStructType]) if err != nil { t.Errorf("error '%s' was not expected while trying to collect rows", err) @@ -906,3 +667,7 @@ func TestMockQueryWithCollect(t *testing.T) { t.Errorf("there were unfulfilled expectations: %s", err) } } + +func TestRowsConn(t *testing.T) { + assert.Nil(t, (&rowSets{}).Conn()) +} diff --git a/sql_test.go b/sql_test.go index bf9d382..d91081c 100644 --- a/sql_test.go +++ b/sql_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - pgxmock "github.com/pashagolub/pgxmock/v2" + pgxmock "github.com/pashagolub/pgxmock/v3" ) func TestScanTime(t *testing.T) {