diff --git a/migrate.go b/migrate.go index dea2aca..790ac5d 100644 --- a/migrate.go +++ b/migrate.go @@ -36,31 +36,33 @@ func Migrate(db *sql.DB, version int, o Options, asset AssetFunc) error { o.TableName = DefaultTableName } + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + table := o.TableName if o.Schema != `` { table = o.Schema + `.` + table - fmt.Println(`Switching to schema:`, o.Schema) - _, err := db.Exec(`CREATE SCHEMA IF NOT EXISTS ` + o.Schema) + _, err = tx.Exec(`CREATE SCHEMA IF NOT EXISTS ` + o.Schema) if err != nil { return err } - _, err = db.Exec(`SET search_path TO ` + o.Schema + `,public`) + _, err = tx.Exec(`SET search_path TO ` + o.Schema + `,public`) if err != nil { return err } } - _, err := db.Exec(`CREATE TABLE IF NOT EXISTS ` + table + ` (Version integer NOT NULL PRIMARY KEY)`) + _, err = tx.Exec(`CREATE TABLE IF NOT EXISTS ` + table + ` (Version integer NOT NULL PRIMARY KEY)`) if err != nil { return err } - row := db.QueryRow(`SELECT Version FROM ` + table + ` ORDER BY Version DESC`) - if err != nil { - return err - } + row := tx.QueryRow(`SELECT Version FROM ` + table + ` ORDER BY Version DESC`) var v int err = row.Scan(&v) @@ -72,12 +74,6 @@ func Migrate(db *sql.DB, version int, o Options, asset AssetFunc) error { return ErrDatabaseNewer } - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - for i := v + 1; i <= version; i++ { script, err := asset(fmt.Sprintf(o.AssetPrefix+fileFormat, i)) if err != nil {