migrate/migrate.go

114 lines
2.6 KiB
Go
Raw Permalink Normal View History

2016-10-21 20:01:00 +02:00
// Package migrate allows you to update your database from your application
package migrate
import (
"database/sql"
"errors"
"fmt"
2021-02-24 14:27:48 +01:00
"io/fs"
2016-10-21 20:01:00 +02:00
"strconv"
)
// Options contains all settings
type Options struct {
2021-02-24 14:27:48 +01:00
TableName string // Name used for version info table; defaults to DefaultTableName if not set
Schema string // Schema used for version info table; In PostgreSQL the current schema is changed to the one specified here
2016-10-21 20:01:00 +02:00
}
// DefaultTableName is the name used when no TableName is specified in Options
const DefaultTableName = `version`
// ErrUpdatesMissing indicates an update is missing, making it impossible to execute the migration
var ErrUpdatesMissing = errors.New(`Missing migration files`)
const fileFormat = `%04d.sql`
// AssetFunc is a function that returns the data for the given name
type AssetFunc func(string) ([]byte, error)
2021-02-24 14:27:48 +01:00
// Migrate executes all migrations
// Filenames need to have incrementing numbers
// Downgrading is not supported as it could result in data loss
2021-02-24 14:27:48 +01:00
func Migrate(db *sql.DB, o Options, assets fs.FS) error {
entries, err := fs.ReadDir(assets, `.`)
if err != nil {
panic(`failed to read list of files`)
2016-10-21 20:01:00 +02:00
}
2021-02-24 14:27:48 +01:00
version := len(entries)
2019-11-12 13:01:36 +01:00
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
2017-04-19 14:45:11 +02:00
2021-02-24 14:27:48 +01:00
dbType := getDbType(db)
searchPath := `public`
if o.TableName == `` {
o.TableName = DefaultTableName
}
2019-11-12 13:01:36 +01:00
if o.Schema != `` {
err = createSchemaIfNotExists(tx, o.Schema)
2017-04-19 14:45:11 +02:00
if err != nil {
return err
}
2019-11-12 13:01:36 +01:00
if dbType == `pq` {
_ = tx.QueryRow(`SHOW search_path`).Scan(&searchPath)
_, err = tx.Exec(`SET search_path TO ` + o.Schema)
if err != nil {
return err
}
}
2016-10-21 20:01:00 +02:00
}
2019-11-12 13:01:36 +01:00
versionTable, err := createTableIfNotExists(tx, o.Schema, o.TableName)
2016-10-21 20:01:00 +02:00
if err != nil {
return err
}
2019-11-12 13:01:36 +01:00
row := tx.QueryRow(`SELECT version FROM ` + versionTable + ` ORDER BY Version DESC`)
2016-10-21 20:01:00 +02:00
var v int
err = row.Scan(&v)
if err != sql.ErrNoRows && err != nil {
return err
}
for i := v + 1; i <= version; i++ {
2021-02-24 14:27:48 +01:00
fileName := fmt.Sprintf(fileFormat, i)
2019-10-28 13:01:29 +01:00
errorf := func(e error) error { return fmt.Errorf(`migration "%s" failed: %w`, fileName, e) }
2021-02-24 14:27:48 +01:00
script, err := fs.ReadFile(assets, fileName)
2016-10-21 20:01:00 +02:00
if err != nil {
2019-10-28 13:01:29 +01:00
return errorf(ErrUpdatesMissing)
2016-10-21 20:01:00 +02:00
}
2017-04-06 14:03:54 +02:00
_, err = tx.Exec(string(script))
2016-10-21 20:01:00 +02:00
if err != nil {
2019-10-28 13:01:29 +01:00
return errorf(err)
2016-10-21 20:01:00 +02:00
}
2019-11-12 13:01:36 +01:00
_, err = tx.Exec(`INSERT INTO ` + versionTable + ` (version) VALUES (` + strconv.Itoa(i) + `)`)
2016-10-21 20:01:00 +02:00
if err != nil {
2019-10-28 13:01:29 +01:00
return errorf(err)
2016-10-21 20:01:00 +02:00
}
}
2019-11-12 13:01:36 +01:00
if dbType == `pq` && o.Schema != `` {
_, err = tx.Exec(`SET search_path TO ` + searchPath)
if err != nil {
return err
}
}
2016-10-21 20:01:00 +02:00
err = tx.Commit()
if err != nil {
return err
}
return nil
}