diff --git a/migrate.go b/migrate.go index b2c377a..9e902a8 100644 --- a/migrate.go +++ b/migrate.go @@ -12,7 +12,7 @@ import ( // Options contains all settings type Options struct { TableName string // Name used for version info table; defaults to DefaultTableName if not set - Schema string // Schema used for version info table; In PostgreSQL the current schema is changed to the one specified here + Schema string // Schema used for version info table; For PostgreSQL, ignored if not set } // DefaultTableName is the name used when no TableName is specified in Options @@ -21,6 +21,9 @@ const DefaultTableName = `version` // ErrUpdatesMissing indicates an update is missing, making it impossible to execute the migration var ErrUpdatesMissing = errors.New(`Missing migration files`) +// ErrDatabaseNewer indicates that the database version is newer than the requested version. We throw an error because downgrades might cause dataloss +var ErrDatabaseNewer = errors.New(`Current version is newer than the requested version`) + const fileFormat = `%04d.sql` // AssetFunc is a function that returns the data for the given name @@ -36,40 +39,34 @@ func Migrate(db *sql.DB, o Options, assets fs.FS) error { } version := len(entries) + if o.TableName == `` { + o.TableName = DefaultTableName + } + + searchPath := `public` + _ = db.QueryRow(`SHOW search_path`).Scan(&searchPath) + + if o.Schema != `` { + _, _ = db.Exec(`CREATE SCHEMA IF NOT EXISTS ` + o.Schema) + + _, err = db.Exec(`SET search_path TO ` + o.Schema) + if err != nil { + return err + } + } + tx, err := db.Begin() if err != nil { return err } defer tx.Rollback() - dbType := getDbType(db) - searchPath := `public` - if o.TableName == `` { - o.TableName = DefaultTableName - } - - if o.Schema != `` { - err = createSchemaIfNotExists(tx, o.Schema) - if err != nil { - return err - } - - if dbType == `pq` { - _ = tx.QueryRow(`SHOW search_path`).Scan(&searchPath) - - _, err = tx.Exec(`SET search_path TO ` + o.Schema) - if err != nil { - return err - } - } - } - - versionTable, err := createTableIfNotExists(tx, o.Schema, o.TableName) + _, err = tx.Exec(`CREATE TABLE IF NOT EXISTS ` + o.TableName + ` (Version integer NOT NULL PRIMARY KEY)`) if err != nil { return err } - row := tx.QueryRow(`SELECT version FROM ` + versionTable + ` ORDER BY Version DESC`) + row := tx.QueryRow(`SELECT Version FROM ` + o.TableName + ` ORDER BY Version DESC`) var v int err = row.Scan(&v) @@ -77,27 +74,28 @@ func Migrate(db *sql.DB, o Options, assets fs.FS) error { return err } - for i := v + 1; i <= version; i++ { - fileName := fmt.Sprintf(fileFormat, i) - errorf := func(e error) error { return fmt.Errorf(`migration "%s" failed: %w`, fileName, e) } + if v > version { + return ErrDatabaseNewer + } - script, err := fs.ReadFile(assets, fileName) + for i := v + 1; i <= version; i++ { + script, err := fs.ReadFile(assets, fmt.Sprintf(fileFormat, i)) if err != nil { - return errorf(ErrUpdatesMissing) + return ErrUpdatesMissing } _, err = tx.Exec(string(script)) if err != nil { - return errorf(err) + return err } - _, err = tx.Exec(`INSERT INTO ` + versionTable + ` (version) VALUES (` + strconv.Itoa(i) + `)`) + _, err = tx.Exec(`INSERT INTO ` + o.TableName + ` VALUES (` + strconv.Itoa(i) + `)`) if err != nil { - return errorf(err) + return err } } - if dbType == `pq` && o.Schema != `` { + if o.Schema != `` { _, err = tx.Exec(`SET search_path TO ` + searchPath) if err != nil { return err diff --git a/query.go b/query.go deleted file mode 100644 index b95419f..0000000 --- a/query.go +++ /dev/null @@ -1,35 +0,0 @@ -package migrate - -import "database/sql" - -func createSchemaIfNotExists(tx *sql.Tx, schema string) error { - row := tx.QueryRow(`SELECT 1 FROM information_schema.schemata WHERE schema_name = '` + schema + `'`) - err := row.Scan(new(int)) - if err == sql.ErrNoRows { - _, err = tx.Exec(`CREATE SCHEMA ` + schema) - if err != nil { - return err - } - } - return err -} - -func createTableIfNotExists(tx *sql.Tx, schema, table string) (string, error) { - versionTable := table - var schemaCond string - - if schema != `` { - versionTable = schema + `.` + table - schemaCond = ` AND table_schema = '` + schema + `'` - } - - row := tx.QueryRow(`SELECT 1 FROM information_schema.tables WHERE table_name = '` + table + `'` + schemaCond) - err := row.Scan(new(int)) - if err == sql.ErrNoRows { - _, err = tx.Exec(`CREATE TABLE ` + versionTable + ` (version integer NOT NULL PRIMARY KEY)`) - if err != nil { - return ``, err - } - } - return versionTable, err -} diff --git a/util.go b/util.go deleted file mode 100644 index c4bc8b0..0000000 --- a/util.go +++ /dev/null @@ -1,29 +0,0 @@ -package migrate - -import ( - "database/sql" - "reflect" - "strings" -) - -func getPkgName(db *sql.DB) string { - t := reflect.TypeOf(db.Driver()) - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - path := t.PkgPath() - parts := strings.Split(path, `/vendor/`) - return parts[len(parts)-1] -} - -func getDbType(db *sql.DB) string { - switch getPkgName(db) { - case `github.com/lib/pq`, `github.com/jackc/pgx/stdlib`: - return `pq` - case `github.com/go-sql-driver/mysql`, `github.com/ziutek/mymysql/godrv`: - return `my` - case `github.com/denisenkom/go-mssqldb`: - return `ms` - } - return `` -}