Skip to content

Commit

Permalink
[!] add support for pgx.Batch, closes #199 (#200)
Browse files Browse the repository at this point in the history
* [!] add support for `pgx.Batch`, closes #199
* [+] add awaited pgx `Fn()` functionality
* [+] bump jackc/pgx/v5 from 5.5.5 to 5.6.0
  • Loading branch information
pashagolub authored May 28, 2024
1 parent e0fce08 commit 3f620d4
Show file tree
Hide file tree
Showing 7 changed files with 255 additions and 8 deletions.
80 changes: 80 additions & 0 deletions batch.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package pgxmock

import (
"context"
"errors"

pgx "github.com/jackc/pgx/v5"
pgconn "github.com/jackc/pgx/v5/pgconn"
)

type batchResults struct {
mock *pgxmock
batch *pgx.Batch
expectedBatch *ExpectedBatch
qqIdx int
err error
}

func (br *batchResults) nextQueryAndArgs() (query string, args []any, err error) {
if br.err != nil {
return "", nil, br.err
}
if br.batch == nil {
return "", nil, errors.New("no batch expectations set")
}
if br.qqIdx >= len(br.batch.QueuedQueries) {
return "", nil, errors.New("no more queries in batch")
}
bi := br.batch.QueuedQueries[br.qqIdx]
query = bi.SQL
args = bi.Arguments
br.qqIdx++
return
}

func (br *batchResults) Exec() (pgconn.CommandTag, error) {
query, arguments, err := br.nextQueryAndArgs()
if err != nil {
return pgconn.NewCommandTag(""), err
}
return br.mock.Exec(context.Background(), query, arguments...)
}

func (br *batchResults) Query() (pgx.Rows, error) {
query, arguments, err := br.nextQueryAndArgs()
if err != nil {
return nil, err
}
return br.mock.Query(context.Background(), query, arguments...)
}

func (br *batchResults) QueryRow() pgx.Row {
query, arguments, err := br.nextQueryAndArgs()
if err != nil {
return errRow{err: err}
}
return br.mock.QueryRow(context.Background(), query, arguments...)
}

func (br *batchResults) Close() error {
if br.err != nil {
return br.err
}
// Read and run fn for all remaining items
for br.err == nil && br.expectedBatch != nil && !br.expectedBatch.closed && br.qqIdx < len(br.batch.QueuedQueries) {
if qq := br.batch.QueuedQueries[br.qqIdx]; qq != nil {
br.err = errors.Join(br.err, br.callQuedQueryFn(qq))
}
}
br.expectedBatch.closed = true
return br.err
}

func (br *batchResults) callQuedQueryFn(qq *pgx.QueuedQuery) error {
if qq.Fn != nil {
return qq.Fn(br)
}
_, err := br.Exec()
return err
}
85 changes: 85 additions & 0 deletions batch_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package pgxmock

import (
"errors"
"testing"

pgx "github.com/jackc/pgx/v5"
pgconn "github.com/jackc/pgx/v5/pgconn"
"github.com/stretchr/testify/assert"
)

func TestBatch(t *testing.T) {
t.Parallel()
mock, _ := NewConn()
a := assert.New(t)

// define our expectations
eb := mock.ExpectBatch()
eb.ExpectQuery("select").WillReturnRows(NewRows([]string{"sum"}).AddRow(2))
eb.ExpectExec("update").WithArgs(true, 1).WillReturnResult(NewResult("UPDATE", 1))

// run the test
batch := &pgx.Batch{}
batch.Queue("select 1 + 1").QueryRow(func(row pgx.Row) error {
var n int32
return row.Scan(&n)
})
batch.Queue("update users set active = $1 where id = $2", true, 1).Exec(func(ct pgconn.CommandTag) (err error) {
if ct.RowsAffected() != 1 {
err = errors.New("expected 1 row to be affected")
}
return
})

err := mock.SendBatch(ctx, batch).Close()
a.NoError(err)
a.NoError(mock.ExpectationsWereMet())
}

func TestExplicitBatch(t *testing.T) {
t.Parallel()
mock, _ := NewConn()
a := assert.New(t)

// define our expectations
eb := mock.ExpectBatch()
eb.ExpectQuery("select").WillReturnRows(NewRows([]string{"sum"}).AddRow(2))
eb.ExpectQuery("select").WillReturnRows(NewRows([]string{"answer"}).AddRow(42))
eb.ExpectExec("update").WithArgs(true, 1).WillReturnResult(NewResult("UPDATE", 1))

// run the test
batch := &pgx.Batch{}
batch.Queue("select 1 + 1")
batch.Queue("select 42")
batch.Queue("update users set active = $1 where id = $2", true, 1)

var sum int
br := mock.SendBatch(ctx, batch)
err := br.QueryRow().Scan(&sum)
a.NoError(err)
a.Equal(2, sum)

var answer int
rows, err := br.Query()
a.NoError(err)
rows.Next()
err = rows.Scan(&answer)
a.NoError(err)
a.Equal(42, answer)

ct, err := br.Exec()
a.NoError(err)
a.True(ct.Update())
a.EqualValues(1, ct.RowsAffected())

// no more queries
_, err = br.Exec()
a.Error(err)
_, err = br.Query()
a.Error(err)
err = br.QueryRow().Scan(&sum)
a.Error(err)

a.NoError(mock.ExpectationsWereMet())
}
37 changes: 37 additions & 0 deletions expectations.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,43 @@ func (e *ExpectedExec) WillReturnResult(result pgconn.CommandTag) *ExpectedExec
return e
}

// ExpectedBatch is used to manage pgx.Batch expectations.
// Returned by pgxmock.ExpectBatch.
type ExpectedBatch struct {
commonExpectation
mock *pgxmock
expectedQueries []*queryBasedExpectation
closed bool
mustBeClosed bool
}

// ExpectExec allows to expect Queue().Exec() on this batch.
func (e *ExpectedBatch) ExpectExec(query string) *ExpectedExec {
ee := &ExpectedExec{}
ee.expectSQL = query
e.expectedQueries = append(e.expectedQueries, &ee.queryBasedExpectation)
e.mock.expectations = append(e.mock.expectations, ee)
return ee
}

// ExpectQuery allows to expect Queue().Query() or Queue().QueryRow() on this batch.
func (e *ExpectedBatch) ExpectQuery(query string) *ExpectedQuery {
eq := &ExpectedQuery{}
eq.expectSQL = query
e.expectedQueries = append(e.expectedQueries, &eq.queryBasedExpectation)
e.mock.expectations = append(e.mock.expectations, eq)
return eq
}

// String returns string representation
func (e *ExpectedBatch) String() string {
msg := "ExpectedBatch => expecting call to SendBatch()\n"
if e.mustBeClosed {
msg += "\t- batch must be closed\n"
}
return msg + e.commonExpectation.String()
}

// ExpectedPrepare is used to manage pgx.Prepare or pgx.Tx.Prepare expectations.
// Returned by pgxmock.ExpectPrepare.
type ExpectedPrepare struct {
Expand Down
10 changes: 5 additions & 5 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@ module github.com/pashagolub/pgxmock/v3
go 1.21

require (
github.com/jackc/pgx/v5 v5.5.5
github.com/jackc/pgx/v5 v5.6.0
github.com/stretchr/testify v1.9.0
)

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.11.0 // indirect
golang.org/x/crypto v0.17.0 // indirect
golang.org/x/sync v0.1.0 // indirect
golang.org/x/text v0.14.0 // indirect
golang.org/x/crypto v0.23.0 // indirect
golang.org/x/sync v0.7.0 // indirect
golang.org/x/text v0.15.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
12 changes: 12 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,14 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA=
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw=
github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
github.com/jackc/pgx/v5 v5.5.6-0.20240512140347-523411a3fbcb h1:lFE9u8joHPJgkA0qSn/nTHdB98KQDTvGgSJ+tGUVey8=
github.com/jackc/pgx/v5 v5.5.6-0.20240512140347-523411a3fbcb/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY=
github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
Expand All @@ -25,10 +31,16 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
Expand Down
38 changes: 36 additions & 2 deletions pgxmock.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ type Expecter interface {
// If any of them was not met - an error is returned.
ExpectationsWereMet() error

// ExpectBatch expects pgx.Batch to be called. The *ExpectedBatch
// allows to mock database response
ExpectBatch() *ExpectedBatch

// ExpectClose queues an expectation for this database
// action to be triggered. The *ExpectedClose allows
// to mock database response
Expand Down Expand Up @@ -149,6 +153,12 @@ func (c *pgxmock) AcquireFunc(_ context.Context, _ func(*pgxpool.Conn) error) er
}

// region Expectations
func (c *pgxmock) ExpectBatch() *ExpectedBatch {
e := &ExpectedBatch{mock: c}
c.expectations = append(c.expectations, e)
return e
}

func (c *pgxmock) ExpectClose() *ExpectedClose {
e := &ExpectedClose{}
c.expectations = append(c.expectations, e)
Expand Down Expand Up @@ -331,8 +341,32 @@ func (c *pgxmock) CopyFrom(ctx context.Context, tableName pgx.Identifier, column
return ex.rowsAffected, ex.waitForDelay(ctx)
}

func (c *pgxmock) SendBatch(context.Context, *pgx.Batch) pgx.BatchResults {
return nil
func (c *pgxmock) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults {
ex, err := findExpectationFunc[*ExpectedBatch](c, "Batch()", func(batchExp *ExpectedBatch) error {
if len(batchExp.expectedQueries) != len(b.QueuedQueries) {
return fmt.Errorf("SendBatch: number of queries in batch '%d' was not expected, expected number of queries is '%d'",
len(b.QueuedQueries), len(batchExp.expectedQueries))
}
for i, query := range b.QueuedQueries {
if err := c.queryMatcher.Match(batchExp.expectedQueries[i].expectSQL, query.SQL); err != nil {
return err
}
if rewrittenSQL, err := batchExp.expectedQueries[i].argsMatches(query.SQL, query.Arguments); err != nil {
return err
} else if rewrittenSQL != "" && batchExp.expectedQueries[i].expectRewrittenSQL != "" {
if err := c.queryMatcher.Match(batchExp.expectedQueries[i].expectRewrittenSQL, rewrittenSQL); err != nil {
return err
}
}
}
return nil
})
br := &batchResults{mock: c, batch: b, expectedBatch: ex, err: err}
if err != nil {
return br
}
br.err = ex.waitForDelay(ctx)
return br
}

func (c *pgxmock) LargeObjects() pgx.LargeObjects {
Expand Down
1 change: 0 additions & 1 deletion pgxmock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1192,7 +1192,6 @@ func TestUnmockedMethods(t *testing.T) {
a.NotNil(mock.AsConn().Config())
a.NotNil(mock.AcquireAllIdle(ctx))
a.Nil(mock.AcquireFunc(ctx, func(*pgxpool.Conn) error { return nil }))
a.Nil(mock.SendBatch(ctx, nil))
a.Zero(mock.LargeObjects())
a.Panics(func() { _ = mock.Conn() })
}
Expand Down

0 comments on commit 3f620d4

Please sign in to comment.