diff --git a/gormigrate.go b/gormigrate.go index 900bbc4..e54c7eb 100644 --- a/gormigrate.go +++ b/gormigrate.go @@ -8,7 +8,7 @@ import ( ) const ( - initSchemaMigrationId = "SCHEMA_INIT" + initSchemaMigrationID = "SCHEMA_INIT" ) // MigrateFunc is the func signature for migrating. @@ -164,11 +164,17 @@ func (g *Gormigrate) migrate(migrationID string) error { return err } - if g.initSchema != nil && g.canInitializeSchema() { - if err := g.runInitSchema(); err != nil { + if g.initSchema != nil { + canInitializeSchema, err := g.canInitializeSchema() + if err != nil { return err } - return g.commit() + if canInitializeSchema { + if err := g.runInitSchema(); err != nil { + return err + } + return g.commit() + } } for _, migration := range g.migrations { @@ -179,7 +185,6 @@ func (g *Gormigrate) migrate(migrationID string) error { break } } - return g.commit() } @@ -193,7 +198,7 @@ func (g *Gormigrate) hasMigrations() bool { // For now there's only have one reserved ID, but there may be more in the future. func (g *Gormigrate) checkReservedID() error { for _, m := range g.migrations { - if m.ID == initSchemaMigrationId { + if m.ID == initSchemaMigrationID { return &ReservedIDError{ID: m.ID} } } @@ -252,27 +257,36 @@ func (g *Gormigrate) RollbackTo(migrationID string) error { } g.begin() + defer g.rollback() for i := len(g.migrations) - 1; i >= 0; i-- { migration := g.migrations[i] if migration.ID == migrationID { break } - if g.migrationDidRun(migration) { + migrationRan, err := g.migrationRan(migration) + if err != nil { + return err + } + if migrationRan { if err := g.rollbackMigration(migration); err != nil { - g.rollback() return err } } } - return g.commit() } func (g *Gormigrate) getLastRunMigration() (*Migration, error) { for i := len(g.migrations) - 1; i >= 0; i-- { migration := g.migrations[i] - if g.migrationDidRun(migration) { + + migrationRan, err := g.migrationRan(migration) + if err != nil { + return nil, err + } + + if migrationRan { return migration, nil } } @@ -282,8 +296,9 @@ func (g *Gormigrate) getLastRunMigration() (*Migration, error) { // RollbackMigration undo a migration. func (g *Gormigrate) RollbackMigration(m *Migration) error { g.begin() + defer g.rollback() + if err := g.rollbackMigration(m); err != nil { - g.rollback() return err } return g.commit() @@ -299,17 +314,14 @@ func (g *Gormigrate) rollbackMigration(m *Migration) error { } sql := fmt.Sprintf("DELETE FROM %s WHERE %s = ?", g.options.TableName, g.options.IDColumnName) - if err := g.tx.Exec(sql, m.ID).Error; err != nil { - return err - } - return nil + return g.tx.Exec(sql, m.ID).Error } func (g *Gormigrate) runInitSchema() error { if err := g.initSchema(g.tx); err != nil { return err } - if err := g.insertMigration(initSchemaMigrationId); err != nil { + if err := g.insertMigration(initSchemaMigrationID); err != nil { return err } @@ -327,7 +339,11 @@ func (g *Gormigrate) runMigration(migration *Migration) error { return ErrMissingID } - if !g.migrationDidRun(migration) { + migrationRan, err := g.migrationRan(migration) + if err != nil { + return err + } + if !migrationRan { if err := migration.Migrate(g.tx); err != nil { return err } @@ -345,34 +361,37 @@ func (g *Gormigrate) createMigrationTableIfNotExists() error { } sql := fmt.Sprintf("CREATE TABLE %s (%s VARCHAR(%d) PRIMARY KEY)", g.options.TableName, g.options.IDColumnName, g.options.IDColumnSize) - if err := g.tx.Exec(sql).Error; err != nil { - return err - } - return nil + return g.tx.Exec(sql).Error } -func (g *Gormigrate) migrationDidRun(m *Migration) bool { +func (g *Gormigrate) migrationRan(m *Migration) (bool, error) { var count int - g.tx. + err := g.tx. Table(g.options.TableName). Where(fmt.Sprintf("%s = ?", g.options.IDColumnName), m.ID). - Count(&count) - return count > 0 + Count(&count). + Error + return count > 0, err } // The schema can be initialised only if it hasn't been initialised yet // and no other migration has been applied already. -func (g *Gormigrate) canInitializeSchema() bool { - if g.migrationDidRun(&Migration{ID: initSchemaMigrationId}) { - return false +func (g *Gormigrate) canInitializeSchema() (bool, error) { + migrationRan, err := g.migrationRan(&Migration{ID: initSchemaMigrationID}) + if err != nil { + return false, err + } + if migrationRan { + return false, nil } // If the ID doesn't exist, we also want the list of migrations to be empty var count int - g.tx. + err = g.tx. Table(g.options.TableName). - Count(&count) - return count == 0 + Count(&count). + Error + return count == 0, err } func (g *Gormigrate) insertMigration(id string) error { @@ -390,9 +409,7 @@ func (g *Gormigrate) begin() { func (g *Gormigrate) commit() error { if g.options.UseTransaction { - if err := g.tx.Commit().Error; err != nil { - return err - } + return g.tx.Commit().Error } return nil }