diff --git a/migrate.go b/migrate.go index 790ac5d..91be080 100644 --- a/migrate.go +++ b/migrate.go @@ -36,33 +36,29 @@ func Migrate(db *sql.DB, version int, o Options, asset AssetFunc) error { o.TableName = DefaultTableName } + var err error + + if o.Schema != `` { + _, _ = db.Exec(`CREATE SCHEMA IF NOT EXISTS ` + o.Schema) + + _, err = db.Exec(`SET search_path TO ` + o.Schema) + if err != nil { + return err + } + } + tx, err := db.Begin() if err != nil { return err } defer tx.Rollback() - table := o.TableName - if o.Schema != `` { - table = o.Schema + `.` + table - - _, err = tx.Exec(`CREATE SCHEMA IF NOT EXISTS ` + o.Schema) - if err != nil { - return err - } - - _, err = tx.Exec(`SET search_path TO ` + o.Schema + `,public`) - if err != nil { - return err - } - } - - _, err = tx.Exec(`CREATE TABLE IF NOT EXISTS ` + table + ` (Version integer NOT NULL PRIMARY KEY)`) + _, err = tx.Exec(`CREATE TABLE IF NOT EXISTS ` + o.TableName + ` (Version integer NOT NULL PRIMARY KEY)`) if err != nil { return err } - row := tx.QueryRow(`SELECT Version FROM ` + table + ` ORDER BY Version DESC`) + row := tx.QueryRow(`SELECT Version FROM ` + o.TableName + ` ORDER BY Version DESC`) var v int err = row.Scan(&v) @@ -85,12 +81,17 @@ func Migrate(db *sql.DB, version int, o Options, asset AssetFunc) error { return err } - _, err = tx.Exec(`INSERT INTO ` + table + ` VALUES (` + strconv.Itoa(i) + `)`) + _, err = tx.Exec(`INSERT INTO ` + o.TableName + ` VALUES (` + strconv.Itoa(i) + `)`) if err != nil { return err } } + _, err = tx.Exec(`SET search_path TO ` + o.Schema + `,public`) + if err != nil { + return err + } + err = tx.Commit() if err != nil { return err