Skip to content

Commit

Permalink
Merge pull request #34 from j16r/add_test_and_fix_transactional_migra…
Browse files Browse the repository at this point in the history
…tions

Make sure table creation and schema check happens in tx, add tx tests
  • Loading branch information
andreynering authored Apr 30, 2019
2 parents f45874c + d6b4215 commit 8fd2dac
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 18 deletions.
24 changes: 13 additions & 11 deletions gormigrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,23 +157,22 @@ func (g *Gormigrate) migrate(migrationID string) error {
return err
}

g.begin()
defer g.rollback()

if err := g.createMigrationTableIfNotExists(); err != nil {
return err
}

g.begin()

if g.initSchema != nil && g.canInitializeSchema() {
if err := g.runInitSchema(); err != nil {
g.rollback()
return err
}
return g.commit()
}

for _, migration := range g.migrations {
if err := g.runMigration(migration); err != nil {
g.rollback()
return err
}
if migrationID != "" && migration.ID == migrationID {
Expand Down Expand Up @@ -227,15 +226,18 @@ func (g *Gormigrate) RollbackLast() error {
return ErrNoMigrationDefined
}

g.begin()
defer g.rollback()

lastRunMigration, err := g.getLastRunMigration()
if err != nil {
return err
}

if err := g.RollbackMigration(lastRunMigration); err != nil {
if err := g.rollbackMigration(lastRunMigration); err != nil {
return err
}
return nil
return g.commit()
}

// RollbackTo undoes migrations up to the given migration that matches the `migrationID`.
Expand Down Expand Up @@ -297,7 +299,7 @@ func (g *Gormigrate) rollbackMigration(m *Migration) error {
}

sql := fmt.Sprintf("DELETE FROM %s WHERE %s = ?", g.options.TableName, g.options.IDColumnName)
if err := g.db.Exec(sql, m.ID).Error; err != nil {
if err := g.tx.Exec(sql, m.ID).Error; err != nil {
return err
}
return nil
Expand Down Expand Up @@ -338,20 +340,20 @@ func (g *Gormigrate) runMigration(migration *Migration) error {
}

func (g *Gormigrate) createMigrationTableIfNotExists() error {
if g.db.HasTable(g.options.TableName) {
if g.tx.HasTable(g.options.TableName) {
return nil
}

sql := fmt.Sprintf("CREATE TABLE %s (%s VARCHAR(%d) PRIMARY KEY)", g.options.TableName, g.options.IDColumnName, g.options.IDColumnSize)
if err := g.db.Exec(sql).Error; err != nil {
if err := g.tx.Exec(sql).Error; err != nil {
return err
}
return nil
}

func (g *Gormigrate) migrationDidRun(m *Migration) bool {
var count int
g.db.
g.tx.
Table(g.options.TableName).
Where(fmt.Sprintf("%s = ?", g.options.IDColumnName), m.ID).
Count(&count)
Expand All @@ -367,7 +369,7 @@ func (g *Gormigrate) canInitializeSchema() bool {

// If the ID doesn't exist, we also want the list of migrations to be empty
var count int
g.db.
g.tx.
Table(g.options.TableName).
Count(&count)
return count == 0
Expand Down
90 changes: 83 additions & 7 deletions gormigrate_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package gormigrate

import (
"errors"
"fmt"
"os"
"testing"

"github.com/jinzhu/gorm"
_ "github.com/joho/godotenv/autoload"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

var databases []database
Expand Down Expand Up @@ -47,6 +50,21 @@ var extendedMigrations = append(migrations, &Migration{
},
})

var failingMigration = []*Migration{
{
ID: "201904231300",
Migrate: func(tx *gorm.DB) error {
if err := tx.AutoMigrate(&Book{}).Error; err != nil {
return err
}
return errors.New("this transaction should be rolled back")
},
Rollback: func(tx *gorm.DB) error {
return nil
},
},
}

type Person struct {
gorm.Model
Name string
Expand Down Expand Up @@ -306,25 +324,83 @@ func TestEmptyMigrationList(t *testing.T) {
})
}

func TestMigration_WithUseTransactions(t *testing.T) {
options := DefaultOptions
options.UseTransaction = true

forEachDatabase(t, func(db *gorm.DB) {
m := New(db, options, migrations)

err := m.Migrate()
require.NoError(t, err)
assert.True(t, db.HasTable(&Person{}))
assert.True(t, db.HasTable(&Pet{}))
assert.Equal(t, 2, tableCount(t, db, "migrations"))

err = m.RollbackLast()
require.NoError(t, err)
assert.True(t, db.HasTable(&Person{}))
assert.False(t, db.HasTable(&Pet{}))
assert.Equal(t, 1, tableCount(t, db, "migrations"))

err = m.RollbackLast()
require.NoError(t, err)
assert.False(t, db.HasTable(&Person{}))
assert.False(t, db.HasTable(&Pet{}))
assert.Equal(t, 0, tableCount(t, db, "migrations"))
}, "postgres", "sqlite3", "mssql")
}

func TestMigration_WithUseTransactionsShouldRollback(t *testing.T) {
options := DefaultOptions
options.UseTransaction = true

forEachDatabase(t, func(db *gorm.DB) {
assert.True(t, true)
m := New(db, options, failingMigration)

// Migration should return an error and not leave around a Book table
err := m.Migrate()
assert.Error(t, err)
assert.False(t, db.HasTable(&Book{}))
}, "postgres", "sqlite3", "mssql")
}

func tableCount(t *testing.T, db *gorm.DB, tableName string) (count int) {
assert.NoError(t, db.Table(tableName).Count(&count).Error)
return
}

func forEachDatabase(t *testing.T, fn func(database *gorm.DB)) {
func forEachDatabase(t *testing.T, fn func(database *gorm.DB), dialects ...string) {
if len(databases) == 0 {
panic("No database choosen for testing!")
}

for _, database := range databases {
db, err := gorm.Open(database.name, os.Getenv(database.connEnv))
assert.NoError(t, err, "Could not connect to database %s, %v", database.name, err)
if len(dialects) > 0 && !contains(dialects, database.name) {
t.Skip(fmt.Sprintf("test is not supported by [%s] dialect", database.name))
}

defer db.Close()
// Ensure defers are not stacked up for each DB
func() {
db, err := gorm.Open(database.name, os.Getenv(database.connEnv))
require.NoError(t, err, "Could not connect to database %s, %v", database.name, err)

// ensure tables do not exists
assert.NoError(t, db.DropTableIfExists("migrations", "people", "pets").Error)
defer db.Close()

fn(db)
// ensure tables do not exists
assert.NoError(t, db.DropTableIfExists("migrations", "people", "pets").Error)

fn(db)
}()
}
}

func contains(haystack []string, needle string) bool {
for _, straw := range haystack {
if straw == needle {
return true
}
}
return false
}

0 comments on commit 8fd2dac

Please sign in to comment.