migrate/migrate.go

104 lines
2.5 KiB
Go
Raw Normal View History

2016-10-21 20:01:00 +02:00
// Package migrate allows you to update your database from your application
package migrate
import (
"database/sql"
"errors"
"fmt"
"strconv"
)
// Options contains all settings
type Options struct {
TableName string // Name used for version info table; defaults to DefaultTableName if not set
Schema string // Schema used for version info table; For PostgreSQL, ignored if not set
AssetPrefix string
}
// DefaultTableName is the name used when no TableName is specified in Options
const DefaultTableName = `version`
// ErrUpdatesMissing indicates an update is missing, making it impossible to execute the migration
var ErrUpdatesMissing = errors.New(`Missing migration files`)
// ErrDatabaseNewer indicates that the database version is newer than the requested version. We throw an error because downgrades might cause dataloss
var ErrDatabaseNewer = errors.New(`Current version is newer than the requested version`)
const fileFormat = `%04d.sql`
// AssetFunc is a function that returns the data for the given name
type AssetFunc func(string) ([]byte, error)
// Migrate executes migrations to get to the desired version
// Downgrading is not supported as it could result in data loss
2016-10-21 20:01:00 +02:00
func Migrate(db *sql.DB, version int, o Options, asset AssetFunc) error {
if o.TableName == `` {
o.TableName = DefaultTableName
}
var err error
2017-04-19 15:39:48 +02:00
2016-10-21 20:01:00 +02:00
if o.Schema != `` {
_, _ = db.Exec(`CREATE SCHEMA IF NOT EXISTS ` + o.Schema)
2017-04-19 14:45:11 +02:00
_, err = db.Exec(`SET search_path TO ` + o.Schema)
2017-04-19 14:45:11 +02:00
if err != nil {
return err
}
}
2017-04-19 14:45:11 +02:00
tx, err := db.Begin()
if err != nil {
return err
2016-10-21 20:01:00 +02:00
}
defer tx.Rollback()
2016-10-21 20:01:00 +02:00
_, err = tx.Exec(`CREATE TABLE IF NOT EXISTS ` + o.TableName + ` (Version integer NOT NULL PRIMARY KEY)`)
2016-10-21 20:01:00 +02:00
if err != nil {
return err
}
row := tx.QueryRow(`SELECT Version FROM ` + o.TableName + ` ORDER BY Version DESC`)
2016-10-21 20:01:00 +02:00
var v int
err = row.Scan(&v)
if err != sql.ErrNoRows && err != nil {
return err
}
if v > version {
return ErrDatabaseNewer
}
for i := v + 1; i <= version; i++ {
2017-04-06 14:03:54 +02:00
script, err := asset(fmt.Sprintf(o.AssetPrefix+fileFormat, i))
2016-10-21 20:01:00 +02:00
if err != nil {
return ErrUpdatesMissing
}
2017-04-06 14:03:54 +02:00
_, err = tx.Exec(string(script))
2016-10-21 20:01:00 +02:00
if err != nil {
return err
}
_, err = tx.Exec(`INSERT INTO ` + o.TableName + ` VALUES (` + strconv.Itoa(i) + `)`)
2016-10-21 20:01:00 +02:00
if err != nil {
return err
}
}
if o.Schema != `` {
_, err = tx.Exec(`SET search_path TO ` + o.Schema + `,public`)
if err != nil {
return err
}
}
2016-10-21 20:01:00 +02:00
err = tx.Commit()
if err != nil {
return err
}
return nil
}