Support non-PostgreSQL databases
This commit is contained in:
parent
de9e7d9a41
commit
f39c1c640a
38
migrate.go
38
migrate.go
@ -11,7 +11,7 @@ import (
|
|||||||
// Options contains all settings
|
// Options contains all settings
|
||||||
type Options struct {
|
type Options struct {
|
||||||
TableName string // Name used for version info table; defaults to DefaultTableName if not set
|
TableName string // Name used for version info table; defaults to DefaultTableName if not set
|
||||||
Schema string // Schema used for version info table; For PostgreSQL, ignored if not set
|
Schema string // Schema used for version info table; In PostgreSQL the current schema is changed to the one specified here
|
||||||
AssetPrefix string
|
AssetPrefix string
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -29,23 +29,13 @@ type AssetFunc func(string) ([]byte, error)
|
|||||||
// Migrate executes migrations to get to the desired version
|
// Migrate executes migrations to get to the desired version
|
||||||
// Downgrading is not supported as it could result in data loss
|
// Downgrading is not supported as it could result in data loss
|
||||||
func Migrate(db *sql.DB, version int, o Options, asset AssetFunc) error {
|
func Migrate(db *sql.DB, version int, o Options, asset AssetFunc) error {
|
||||||
|
dbType := getDbType(db)
|
||||||
if o.TableName == `` {
|
if o.TableName == `` {
|
||||||
o.TableName = DefaultTableName
|
o.TableName = DefaultTableName
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
searchPath := `public`
|
searchPath := `public`
|
||||||
_ = db.QueryRow(`SHOW search_path`).Scan(&searchPath)
|
|
||||||
|
|
||||||
if o.Schema != `` {
|
|
||||||
_, _ = db.Exec(`CREATE SCHEMA IF NOT EXISTS ` + o.Schema)
|
|
||||||
|
|
||||||
_, err = db.Exec(`SET search_path TO ` + o.Schema)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
tx, err := db.Begin()
|
tx, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -53,12 +43,28 @@ func Migrate(db *sql.DB, version int, o Options, asset AssetFunc) error {
|
|||||||
}
|
}
|
||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
|
|
||||||
_, err = tx.Exec(`CREATE TABLE IF NOT EXISTS ` + o.TableName + ` (Version integer NOT NULL PRIMARY KEY)`)
|
if o.Schema != `` {
|
||||||
|
err = createSchemaIfNotExists(tx, o.Schema)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if dbType == `pq` {
|
||||||
|
_ = tx.QueryRow(`SHOW search_path`).Scan(&searchPath)
|
||||||
|
|
||||||
|
_, err = tx.Exec(`SET search_path TO ` + o.Schema)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
versionTable, err := createTableIfNotExists(tx, o.Schema, o.TableName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
row := tx.QueryRow(`SELECT Version FROM ` + o.TableName + ` ORDER BY Version DESC`)
|
row := tx.QueryRow(`SELECT version FROM ` + versionTable + ` ORDER BY Version DESC`)
|
||||||
|
|
||||||
var v int
|
var v int
|
||||||
err = row.Scan(&v)
|
err = row.Scan(&v)
|
||||||
@ -80,13 +86,13 @@ func Migrate(db *sql.DB, version int, o Options, asset AssetFunc) error {
|
|||||||
return errorf(err)
|
return errorf(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = tx.Exec(`INSERT INTO ` + o.TableName + ` VALUES (` + strconv.Itoa(i) + `)`)
|
_, err = tx.Exec(`INSERT INTO ` + versionTable + ` (version) VALUES (` + strconv.Itoa(i) + `)`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorf(err)
|
return errorf(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if o.Schema != `` {
|
if dbType == `pq` && o.Schema != `` {
|
||||||
_, err = tx.Exec(`SET search_path TO ` + searchPath)
|
_, err = tx.Exec(`SET search_path TO ` + searchPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
35
query.go
Normal file
35
query.go
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
package migrate
|
||||||
|
|
||||||
|
import "database/sql"
|
||||||
|
|
||||||
|
func createSchemaIfNotExists(tx *sql.Tx, schema string) error {
|
||||||
|
row := tx.QueryRow(`SELECT 1 FROM information_schema.schemata WHERE schema_name = '` + schema + `'`)
|
||||||
|
err := row.Scan(new(int))
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
_, err = tx.Exec(`CREATE SCHEMA ` + schema)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTableIfNotExists(tx *sql.Tx, schema, table string) (string, error) {
|
||||||
|
versionTable := table
|
||||||
|
var schemaCond string
|
||||||
|
|
||||||
|
if schema != `` {
|
||||||
|
versionTable = schema + `.` + table
|
||||||
|
schemaCond = ` AND table_schema = '` + schema + `'`
|
||||||
|
}
|
||||||
|
|
||||||
|
row := tx.QueryRow(`SELECT 1 FROM information_schema.tables WHERE table_name = '` + table + `'` + schemaCond)
|
||||||
|
err := row.Scan(new(int))
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
_, err = tx.Exec(`CREATE TABLE ` + versionTable + ` (version integer NOT NULL PRIMARY KEY)`)
|
||||||
|
if err != nil {
|
||||||
|
return ``, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return versionTable, err
|
||||||
|
}
|
29
util.go
Normal file
29
util.go
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
package migrate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getPkgName(db *sql.DB) string {
|
||||||
|
t := reflect.TypeOf(db.Driver())
|
||||||
|
if t.Kind() == reflect.Ptr {
|
||||||
|
t = t.Elem()
|
||||||
|
}
|
||||||
|
path := t.PkgPath()
|
||||||
|
parts := strings.Split(path, `/vendor/`)
|
||||||
|
return parts[len(parts)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
func getDbType(db *sql.DB) string {
|
||||||
|
switch getPkgName(db) {
|
||||||
|
case `github.com/lib/pq`, `github.com/jackc/pgx/stdlib`:
|
||||||
|
return `pq`
|
||||||
|
case `github.com/go-sql-driver/mysql`, `github.com/ziutek/mymysql/godrv`:
|
||||||
|
return `my`
|
||||||
|
case `github.com/denisenkom/go-mssqldb`:
|
||||||
|
return `ms`
|
||||||
|
}
|
||||||
|
return ``
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user