From f0c76742212d3fc992d49099e0a43edeb8456c53 Mon Sep 17 00:00:00 2001 From: Pavlo Golub Date: Tue, 11 Jun 2024 13:35:13 +0200 Subject: [PATCH] [-] fix unordered batch expectations, fixes #207 (#208) --- batch_test.go | 41 +++++++++++++++++++++++++++++++++++++++++ pgxmock.go | 3 +++ 2 files changed, 44 insertions(+) diff --git a/batch_test.go b/batch_test.go index a3f07cd..6dea88b 100644 --- a/batch_test.go +++ b/batch_test.go @@ -83,3 +83,44 @@ func TestExplicitBatch(t *testing.T) { a.NoError(mock.ExpectationsWereMet()) } + +func processBatch(db PgxPoolIface) error { + batch := &pgx.Batch{} + // Random order + batch.Queue("SELECT id FROM normalized_queries WHERE query = $1", "some query") + batch.Queue("INSERT INTO normalized_queries (query) VALUES ($1) RETURNING id", "some query") + + results := db.SendBatch(ctx, batch) + defer results.Close() + + for i := 0; i < batch.Len(); i++ { + var id int + err := results.QueryRow().Scan(&id) + if err != nil { + return err + } + } + + return nil +} + +func TestUnorderedBatchExpectations(t *testing.T) { + t.Parallel() + a := assert.New(t) + + mock, err := NewPool() + a.NoError(err) + defer mock.Close() + + mock.MatchExpectationsInOrder(false) + + expectedBatch := mock.ExpectBatch() + expectedBatch.ExpectQuery("INSERT INTO").WithArgs("some query"). + WillReturnRows(NewRows([]string{"id"}).AddRow(10)) + expectedBatch.ExpectQuery("SELECT id").WithArgs("some query"). + WillReturnRows(NewRows([]string{"id"}).AddRow(20)) + + err = processBatch(mock) + a.NoError(err) + a.NoError(mock.ExpectationsWereMet()) +} diff --git a/pgxmock.go b/pgxmock.go index c611a19..377c7a5 100644 --- a/pgxmock.go +++ b/pgxmock.go @@ -347,6 +347,9 @@ func (c *pgxmock) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults return fmt.Errorf("SendBatch: number of queries in batch '%d' was not expected, expected number of queries is '%d'", len(b.QueuedQueries), len(batchExp.expectedQueries)) } + if !c.ordered { // postpone the check of every query until/if it is called + return nil + } for i, query := range b.QueuedQueries { if err := c.queryMatcher.Match(batchExp.expectedQueries[i].expectSQL, query.SQL); err != nil { return err