Compare commits

...

4 Commits

Author SHA1 Message Date
Nise Void 9d4454628f
Use fs.FS 2021-02-24 14:34:14 +01:00
Nise Void f39c1c640a
Support non-PostgreSQL databases 2020-03-11 10:42:00 +01:00
Nise Void de9e7d9a41
Add filename to error message 2019-10-28 13:01:29 +01:00
Nise Void 0e43d5ed69
Remove newer database check 2018-07-04 13:41:59 +02:00
4 changed files with 110 additions and 36 deletions

3
go.mod Normal file
View File

@ -0,0 +1,3 @@
module git.fuyu.moe/Fuyu/migrate/v2
go 1.16

View File

@ -5,14 +5,14 @@ import (
"database/sql"
"errors"
"fmt"
"io/fs"
"strconv"
)
// Options contains all settings
type Options struct {
TableName string // Name used for version info table; defaults to DefaultTableName if not set
Schema string // Schema used for version info table; For PostgreSQL, ignored if not set
AssetPrefix string
TableName string // Name used for version info table; defaults to DefaultTableName if not set
Schema string // Schema used for version info table; In PostgreSQL the current schema is changed to the one specified here
}
// DefaultTableName is the name used when no TableName is specified in Options
@ -21,34 +21,20 @@ const DefaultTableName = `version`
// ErrUpdatesMissing indicates an update is missing, making it impossible to execute the migration
var ErrUpdatesMissing = errors.New(`Missing migration files`)
// ErrDatabaseNewer indicates that the database version is newer than the requested version. We throw an error because downgrades might cause dataloss
var ErrDatabaseNewer = errors.New(`Current version is newer than the requested version`)
const fileFormat = `%04d.sql`
// AssetFunc is a function that returns the data for the given name
type AssetFunc func(string) ([]byte, error)
// Migrate executes migrations to get to the desired version
// Migrate executes all migrations
// Filenames need to have incrementing numbers
// Downgrading is not supported as it could result in data loss
func Migrate(db *sql.DB, version int, o Options, asset AssetFunc) error {
if o.TableName == `` {
o.TableName = DefaultTableName
}
var err error
searchPath := `public`
_ = db.QueryRow(`SHOW search_path`).Scan(&searchPath)
if o.Schema != `` {
_, _ = db.Exec(`CREATE SCHEMA IF NOT EXISTS ` + o.Schema)
_, err = db.Exec(`SET search_path TO ` + o.Schema)
if err != nil {
return err
}
func Migrate(db *sql.DB, o Options, assets fs.FS) error {
entries, err := fs.ReadDir(assets, `.`)
if err != nil {
panic(`failed to read list of files`)
}
version := len(entries)
tx, err := db.Begin()
if err != nil {
@ -56,12 +42,34 @@ func Migrate(db *sql.DB, version int, o Options, asset AssetFunc) error {
}
defer tx.Rollback()
_, err = tx.Exec(`CREATE TABLE IF NOT EXISTS ` + o.TableName + ` (Version integer NOT NULL PRIMARY KEY)`)
dbType := getDbType(db)
searchPath := `public`
if o.TableName == `` {
o.TableName = DefaultTableName
}
if o.Schema != `` {
err = createSchemaIfNotExists(tx, o.Schema)
if err != nil {
return err
}
if dbType == `pq` {
_ = tx.QueryRow(`SHOW search_path`).Scan(&searchPath)
_, err = tx.Exec(`SET search_path TO ` + o.Schema)
if err != nil {
return err
}
}
}
versionTable, err := createTableIfNotExists(tx, o.Schema, o.TableName)
if err != nil {
return err
}
row := tx.QueryRow(`SELECT Version FROM ` + o.TableName + ` ORDER BY Version DESC`)
row := tx.QueryRow(`SELECT version FROM ` + versionTable + ` ORDER BY Version DESC`)
var v int
err = row.Scan(&v)
@ -69,28 +77,27 @@ func Migrate(db *sql.DB, version int, o Options, asset AssetFunc) error {
return err
}
if v > version {
return ErrDatabaseNewer
}
for i := v + 1; i <= version; i++ {
script, err := asset(fmt.Sprintf(o.AssetPrefix+fileFormat, i))
fileName := fmt.Sprintf(fileFormat, i)
errorf := func(e error) error { return fmt.Errorf(`migration "%s" failed: %w`, fileName, e) }
script, err := fs.ReadFile(assets, fileName)
if err != nil {
return ErrUpdatesMissing
return errorf(ErrUpdatesMissing)
}
_, err = tx.Exec(string(script))
if err != nil {
return err
return errorf(err)
}
_, err = tx.Exec(`INSERT INTO ` + o.TableName + ` VALUES (` + strconv.Itoa(i) + `)`)
_, err = tx.Exec(`INSERT INTO ` + versionTable + ` (version) VALUES (` + strconv.Itoa(i) + `)`)
if err != nil {
return err
return errorf(err)
}
}
if o.Schema != `` {
if dbType == `pq` && o.Schema != `` {
_, err = tx.Exec(`SET search_path TO ` + searchPath)
if err != nil {
return err

35
query.go Normal file
View File

@ -0,0 +1,35 @@
package migrate
import "database/sql"
func createSchemaIfNotExists(tx *sql.Tx, schema string) error {
row := tx.QueryRow(`SELECT 1 FROM information_schema.schemata WHERE schema_name = '` + schema + `'`)
err := row.Scan(new(int))
if err == sql.ErrNoRows {
_, err = tx.Exec(`CREATE SCHEMA ` + schema)
if err != nil {
return err
}
}
return err
}
func createTableIfNotExists(tx *sql.Tx, schema, table string) (string, error) {
versionTable := table
var schemaCond string
if schema != `` {
versionTable = schema + `.` + table
schemaCond = ` AND table_schema = '` + schema + `'`
}
row := tx.QueryRow(`SELECT 1 FROM information_schema.tables WHERE table_name = '` + table + `'` + schemaCond)
err := row.Scan(new(int))
if err == sql.ErrNoRows {
_, err = tx.Exec(`CREATE TABLE ` + versionTable + ` (version integer NOT NULL PRIMARY KEY)`)
if err != nil {
return ``, err
}
}
return versionTable, err
}

29
util.go Normal file
View File

@ -0,0 +1,29 @@
package migrate
import (
"database/sql"
"reflect"
"strings"
)
func getPkgName(db *sql.DB) string {
t := reflect.TypeOf(db.Driver())
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
path := t.PkgPath()
parts := strings.Split(path, `/vendor/`)
return parts[len(parts)-1]
}
func getDbType(db *sql.DB) string {
switch getPkgName(db) {
case `github.com/lib/pq`, `github.com/jackc/pgx/stdlib`:
return `pq`
case `github.com/go-sql-driver/mysql`, `github.com/ziutek/mymysql/godrv`:
return `my`
case `github.com/denisenkom/go-mssqldb`:
return `ms`
}
return ``
}