Skip to content

Commit

Permalink
[+] add pgx.QueryRewriter support (#166)
Browse files Browse the repository at this point in the history
Co-authored-by: elij <[email protected]>
Co-authored-by: eli <[email protected]>
  • Loading branch information
3 people authored Oct 18, 2023
1 parent 1af56ff commit 9d28bb3
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 27 deletions.
113 changes: 113 additions & 0 deletions argument_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@ package pgxmock

import (
"context"
"errors"
"testing"
"time"

pgx "github.com/jackc/pgx/v5"
"github.com/stretchr/testify/assert"
)

type AnyTime struct{}
Expand Down Expand Up @@ -35,6 +39,30 @@ func TestAnyTimeArgument(t *testing.T) {
}
}

func TestAnyTimeNamedArgument(t *testing.T) {
t.Parallel()
mock, err := NewConn()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}

mock.ExpectExec("INSERT INTO users").
WithArgs(pgx.NamedArgs{"name": "john", "time": AnyTime{}}).
WillReturnResult(NewResult("INSERT", 1))

_, err = mock.Exec(context.Background(),
"INSERT INTO users(name, created_at) VALUES (@name, @time)",
pgx.NamedArgs{"name": "john", "time": time.Now()},
)
if err != nil {
t.Errorf("error '%s' was not expected, while inserting a row", err)
}

if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}

func TestByteSliceArgument(t *testing.T) {
t.Parallel()
mock, err := NewConn()
Expand All @@ -55,6 +83,68 @@ func TestByteSliceArgument(t *testing.T) {
}
}

type failQryRW struct {
pgx.QueryRewriter
}

func (fqrw failQryRW) RewriteQuery(_ context.Context, _ *pgx.Conn, sql string, _ []any) (newSQL string, newArgs []any, err error) {
return "", nil, errors.New("cannot rewrite query " + sql)
}

func TestExpectQueryRewriterFail(t *testing.T) {

t.Parallel()
mock, err := NewConn()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}

mock.ExpectQuery(`INSERT INTO users\(username\) VALUES \(\@user\)`).
WithRewrittenSQL(`INSERT INTO users\(username\) VALUES \(\$1\)`).
WithArgs(failQryRW{})
_, err = mock.Query(context.Background(), "INSERT INTO users(username) VALUES (@user)", "baz")
assert.Error(t, err)
}

func TestQueryRewriterFail(t *testing.T) {

t.Parallel()
mock, err := NewConn()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
mock.ExpectExec(`INSERT INTO .+`).WithArgs("foo")
_, err = mock.Exec(context.Background(), "INSERT INTO users(username) VALUES (@user)", failQryRW{})
assert.Error(t, err)

}

func TestByteSliceNamedArgument(t *testing.T) {
t.Parallel()
mock, err := NewConn()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}

username := []byte("user")
mock.ExpectExec(`INSERT INTO users\(username\) VALUES \(\@user\)`).
WithArgs(pgx.NamedArgs{"user": username}).
WithRewrittenSQL(`INSERT INTO users\(username\) VALUES \(\$1\)`).
WillReturnResult(NewResult("INSERT", 1))

_, err = mock.Exec(context.Background(),
"INSERT INTO users(username) VALUES (@user)",
pgx.NamedArgs{"user": username},
)
if err != nil {
t.Errorf("error '%s' was not expected, while inserting a row", err)
}

if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}

func TestAnyArgument(t *testing.T) {
t.Parallel()
mock, err := NewConn()
Expand All @@ -75,3 +165,26 @@ func TestAnyArgument(t *testing.T) {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}

func TestAnyNamedArgument(t *testing.T) {
t.Parallel()
mock, err := NewConn()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}

mock.ExpectExec("INSERT INTO users").
WithArgs("john", AnyArg()).
WillReturnResult(NewResult("INSERT", 1))

_, err = mock.Exec(context.Background(), "INSERT INTO users(name, created_at) VALUES (@name, @created)",
pgx.NamedArgs{"name": "john", "created": time.Now()},
)
if err != nil {
t.Errorf("error '%s' was not expected, while inserting a row", err)
}

if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}
63 changes: 48 additions & 15 deletions expectations.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,28 +130,46 @@ func (e *commonExpectation) String() string {

// queryBasedExpectation is a base class that adds a query matching logic
type queryBasedExpectation struct {
expectSQL string
args []interface{}
}

func (e *queryBasedExpectation) argsMatches(args []interface{}) error {
if len(args) != len(e.args) {
return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args))
expectSQL string
expectRewrittenSQL string
args []interface{}
}

func (e *queryBasedExpectation) argsMatches(sql string, args []interface{}) (rewrittenSQL string, err error) {
eargs := e.args
// check for any QueryRewriter arguments: only supported as the first argument
if len(args) == 1 {
if qrw, ok := args[0].(pgx.QueryRewriter); ok {
// note: pgx.Conn is not currently used by the query rewriter
if rewrittenSQL, args, err = qrw.RewriteQuery(context.Background(), nil, sql, args); err != nil {
return rewrittenSQL, fmt.Errorf("error rewriting query: %w", err)
}
}
// also do rewriting on the expected args if a QueryRewriter is present
if len(eargs) == 1 {
if qrw, ok := eargs[0].(pgx.QueryRewriter); ok {
if _, eargs, err = qrw.RewriteQuery(context.Background(), nil, sql, eargs); err != nil {
return "", fmt.Errorf("error rewriting query expectation: %w", err)
}
}
}
}
if len(args) != len(eargs) {
return rewrittenSQL, fmt.Errorf("expected %d, but got %d arguments", len(eargs), len(args))
}
for k, v := range args {
// custom argument matcher
if matcher, ok := e.args[k].(Argument); ok {
if matcher, ok := eargs[k].(Argument); ok {
if !matcher.Match(v) {
return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k])
return rewrittenSQL, fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k])
}
continue
}

if darg := e.args[k]; !reflect.DeepEqual(darg, v) {
return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v, v)
if darg := eargs[k]; !reflect.DeepEqual(darg, v) {
return rewrittenSQL, fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v, v)
}
}
return nil
return
}

// ExpectedClose is used to manage pgx.Close expectation
Expand Down Expand Up @@ -208,6 +226,13 @@ func (e *ExpectedExec) WithArgs(args ...interface{}) *ExpectedExec {
return e
}

// WithRewrittenSQL will match given expected expression to a rewritten SQL statement by
// an pgx.QueryRewriter argument
func (e *ExpectedExec) WithRewrittenSQL(sql string) *ExpectedExec {
e.expectRewrittenSQL = sql
return e
}

// String returns string representation
func (e *ExpectedExec) String() string {
msg := "ExpectedExec => expecting call to Exec():\n"
Expand All @@ -221,7 +246,7 @@ func (e *ExpectedExec) String() string {
msg += fmt.Sprintf("\t\t%d - %+v\n", i, arg)
}
}
if e.result.String() > "" {
if e.result.String() != "" {
msg += fmt.Sprintf("\t- returns result: %s\n", e.result)
}

Expand Down Expand Up @@ -255,7 +280,8 @@ func (e *ExpectedPrepare) WillReturnCloseError(err error) *ExpectedPrepare {
}

// WillBeClosed is for backward compatibility only and will be removed soon.
// One should use WillBeDeallocated() instead
//
// Deprecated: One should use WillBeDeallocated() instead.
func (e *ExpectedPrepare) WillBeClosed() *ExpectedPrepare {
return e.WillBeDeallocated()
}
Expand Down Expand Up @@ -324,6 +350,13 @@ func (e *ExpectedQuery) WithArgs(args ...interface{}) *ExpectedQuery {
return e
}

// WithRewrittenSQL will match given expected expression to a rewritten SQL statement by
// an pgx.QueryRewriter argument
func (e *ExpectedQuery) WithRewrittenSQL(sql string) *ExpectedQuery {
e.expectRewrittenSQL = sql
return e
}

// RowsWillBeClosed expects this query rows to be closed.
func (e *ExpectedQuery) RowsWillBeClosed() *ExpectedQuery {
e.rowsMustBeClosed = true
Expand Down
31 changes: 31 additions & 0 deletions expectations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,34 @@ func TestMissingWithArgs(t *testing.T) {
t.Error("expectation was not matched error was expected")
}
}

func TestWithRewrittenSQL(t *testing.T) {
t.Parallel()
mock, err := NewConn(QueryMatcherOption(QueryMatcherEqual))
a := assert.New(t)
a.NoError(err)

mock.ExpectQuery(`INSERT INTO users(username) VALUES (@user)`).
WithArgs(pgx.NamedArgs{"user": "John"}).
WithRewrittenSQL(`INSERT INTO users(username) VALUES ($1)`).
WillReturnRows()

_, err = mock.Query(context.Background(),
"INSERT INTO users(username) VALUES (@user)",
pgx.NamedArgs{"user": "John"},
)
a.NoError(err)
a.NoError(mock.ExpectationsWereMet())

mock.ExpectQuery(`INSERT INTO users(username, password) VALUES (@user, @password)`).
WithArgs(pgx.NamedArgs{"user": "John", "password": "strong"}).
WithRewrittenSQL(`INSERT INTO users(username, password) VALUES ($1)`).
WillReturnRows()

_, err = mock.Query(context.Background(),
"INSERT INTO users(username) VALUES (@user)",
pgx.NamedArgs{"user": "John", "password": "strong"},
)
a.Error(err)
a.Error(mock.ExpectationsWereMet())
}
9 changes: 0 additions & 9 deletions options.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,5 @@
package pgxmock

// ValueConverterOption allows to create a pgxmock connection
// with a custom ValueConverter to support drivers with special data types.
// func ValueConverterOption(converter driver.ValueConverter) func(*pgxmock) error {
// return func(s *pgxmock) error {
// s.converter = converter
// return nil
// }
// }

// QueryMatcherOption allows to customize SQL query matcher
// and match SQL query strings in more sophisticated ways.
// The default QueryMatcher is QueryMatcherRegexp.
Expand Down
15 changes: 13 additions & 2 deletions pgxmock.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ func (c *pgxmock) ExpectationsWereMet() error {
func (c *pgxmock) ExpectQuery(expectedSQL string) *ExpectedQuery {
e := &ExpectedQuery{}
e.expectSQL = expectedSQL
e.expectRewrittenSQL = expectedSQL
c.expectations = append(c.expectations, e)
return e
}
Expand Down Expand Up @@ -234,6 +235,7 @@ func (c *pgxmock) ExpectBeginTx(txOptions pgx.TxOptions) *ExpectedBegin {
func (c *pgxmock) ExpectExec(expectedSQL string) *ExpectedExec {
e := &ExpectedExec{}
e.expectSQL = expectedSQL
e.expectRewrittenSQL = expectedSQL
c.expectations = append(c.expectations, e)
return e
}
Expand Down Expand Up @@ -430,8 +432,12 @@ func (c *pgxmock) Query(ctx context.Context, sql string, args ...interface{}) (p
if err := c.queryMatcher.Match(queryExp.expectSQL, sql); err != nil {
return err
}
if err := queryExp.argsMatches(args); err != nil {
if rewrittenSQL, err := queryExp.argsMatches(sql, args); err != nil {
return err
} else if rewrittenSQL != "" {
if err := c.queryMatcher.Match(queryExp.expectRewrittenSQL, rewrittenSQL); err != nil {
return err
}
}
if queryExp.err == nil && queryExp.rows == nil {
return fmt.Errorf("Query must return a result rows or raise an error: %v", queryExp)
Expand Down Expand Up @@ -466,8 +472,13 @@ func (c *pgxmock) Exec(ctx context.Context, query string, args ...interface{}) (
if err := c.queryMatcher.Match(execExp.expectSQL, query); err != nil {
return err
}
if err := execExp.argsMatches(args); err != nil {
if rewrittenSQL, err := execExp.argsMatches(query, args); err != nil {
return err
} else if rewrittenSQL != "" {
if err := c.queryMatcher.Match(execExp.expectRewrittenSQL, rewrittenSQL); err != nil {
//pgx support QueryRewriter for arguments, now we can check if the query was actually rewriten
return err
}
}
if execExp.result.String() == "" && execExp.err == nil {
return fmt.Errorf("Exec must return a result or raise an error: %s", execExp)
Expand Down
2 changes: 1 addition & 1 deletion pgxmock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1019,7 +1019,7 @@ func TestPreparedStatementCloseExpectation(t *testing.T) {
mock, _ := NewConn()
a := assert.New(t)

ep := mock.ExpectPrepare("foo", "INSERT INTO ORDERS").WillBeDeallocated()
ep := mock.ExpectPrepare("foo", "INSERT INTO ORDERS").WillBeClosed()
ep.ExpectExec().WithArgs(AnyArg(), AnyArg()).WillReturnResult(NewResult("UPDATE", 1))

stmt, err := mock.Prepare(context.Background(), "foo", "INSERT INTO ORDERS(ID, STATUS) VALUES (?, ?)")
Expand Down

0 comments on commit 9d28bb3

Please sign in to comment.