Support non-PostgreSQL databases
This commit is contained in:
		
							parent
							
								
									de9e7d9a41
								
							
						
					
					
						commit
						f39c1c640a
					
				
					 3 changed files with 86 additions and 16 deletions
				
			
		
							
								
								
									
										38
									
								
								migrate.go
									
										
									
									
									
								
							
							
						
						
									
										38
									
								
								migrate.go
									
										
									
									
									
								
							|  | @ -11,7 +11,7 @@ import ( | |||
| // Options contains all settings | ||||
| type Options struct { | ||||
| 	TableName   string // Name used for version info table; defaults to DefaultTableName if not set | ||||
| 	Schema      string // Schema used for version info table; For PostgreSQL, ignored if not set | ||||
| 	Schema      string // Schema used for version info table; In PostgreSQL the current schema is changed to the one specified here | ||||
| 	AssetPrefix string | ||||
| } | ||||
| 
 | ||||
|  | @ -29,23 +29,13 @@ type AssetFunc func(string) ([]byte, error) | |||
| // Migrate executes migrations to get to the desired version | ||||
| // Downgrading is not supported as it could result in data loss | ||||
| func Migrate(db *sql.DB, version int, o Options, asset AssetFunc) error { | ||||
| 	dbType := getDbType(db) | ||||
| 	if o.TableName == `` { | ||||
| 		o.TableName = DefaultTableName | ||||
| 	} | ||||
| 
 | ||||
| 	var err error | ||||
| 
 | ||||
| 	searchPath := `public` | ||||
| 	_ = db.QueryRow(`SHOW search_path`).Scan(&searchPath) | ||||
| 
 | ||||
| 	if o.Schema != `` { | ||||
| 		_, _ = db.Exec(`CREATE SCHEMA IF NOT EXISTS ` + o.Schema) | ||||
| 
 | ||||
| 		_, err = db.Exec(`SET search_path TO ` + o.Schema) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	tx, err := db.Begin() | ||||
| 	if err != nil { | ||||
|  | @ -53,12 +43,28 @@ func Migrate(db *sql.DB, version int, o Options, asset AssetFunc) error { | |||
| 	} | ||||
| 	defer tx.Rollback() | ||||
| 
 | ||||
| 	_, err = tx.Exec(`CREATE TABLE IF NOT EXISTS ` + o.TableName + ` (Version integer NOT NULL PRIMARY KEY)`) | ||||
| 	if o.Schema != `` { | ||||
| 		err = createSchemaIfNotExists(tx, o.Schema) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 
 | ||||
| 	row := tx.QueryRow(`SELECT Version FROM ` + o.TableName + ` ORDER BY Version DESC`) | ||||
| 		if dbType == `pq` { | ||||
| 			_ = tx.QueryRow(`SHOW search_path`).Scan(&searchPath) | ||||
| 
 | ||||
| 			_, err = tx.Exec(`SET search_path TO ` + o.Schema) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	versionTable, err := createTableIfNotExists(tx, o.Schema, o.TableName) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	row := tx.QueryRow(`SELECT version FROM ` + versionTable + ` ORDER BY Version DESC`) | ||||
| 
 | ||||
| 	var v int | ||||
| 	err = row.Scan(&v) | ||||
|  | @ -80,13 +86,13 @@ func Migrate(db *sql.DB, version int, o Options, asset AssetFunc) error { | |||
| 			return errorf(err) | ||||
| 		} | ||||
| 
 | ||||
| 		_, err = tx.Exec(`INSERT INTO ` + o.TableName + ` VALUES (` + strconv.Itoa(i) + `)`) | ||||
| 		_, err = tx.Exec(`INSERT INTO ` + versionTable + ` (version) VALUES (` + strconv.Itoa(i) + `)`) | ||||
| 		if err != nil { | ||||
| 			return errorf(err) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if o.Schema != `` { | ||||
| 	if dbType == `pq` && o.Schema != `` { | ||||
| 		_, err = tx.Exec(`SET search_path TO ` + searchPath) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
|  |  | |||
							
								
								
									
										35
									
								
								query.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								query.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,35 @@ | |||
| package migrate | ||||
| 
 | ||||
| import "database/sql" | ||||
| 
 | ||||
| func createSchemaIfNotExists(tx *sql.Tx, schema string) error { | ||||
| 	row := tx.QueryRow(`SELECT 1 FROM information_schema.schemata WHERE schema_name = '` + schema + `'`) | ||||
| 	err := row.Scan(new(int)) | ||||
| 	if err == sql.ErrNoRows { | ||||
| 		_, err = tx.Exec(`CREATE SCHEMA ` + schema) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 	return err | ||||
| } | ||||
| 
 | ||||
| func createTableIfNotExists(tx *sql.Tx, schema, table string) (string, error) { | ||||
| 	versionTable := table | ||||
| 	var schemaCond string | ||||
| 
 | ||||
| 	if schema != `` { | ||||
| 		versionTable = schema + `.` + table | ||||
| 		schemaCond = ` AND table_schema = '` + schema + `'` | ||||
| 	} | ||||
| 
 | ||||
| 	row := tx.QueryRow(`SELECT 1 FROM information_schema.tables WHERE table_name = '` + table + `'` + schemaCond) | ||||
| 	err := row.Scan(new(int)) | ||||
| 	if err == sql.ErrNoRows { | ||||
| 		_, err = tx.Exec(`CREATE TABLE ` + versionTable + ` (version integer NOT NULL PRIMARY KEY)`) | ||||
| 		if err != nil { | ||||
| 			return ``, err | ||||
| 		} | ||||
| 	} | ||||
| 	return versionTable, err | ||||
| } | ||||
							
								
								
									
										29
									
								
								util.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								util.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,29 @@ | |||
| package migrate | ||||
| 
 | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
| ) | ||||
| 
 | ||||
| func getPkgName(db *sql.DB) string { | ||||
| 	t := reflect.TypeOf(db.Driver()) | ||||
| 	if t.Kind() == reflect.Ptr { | ||||
| 		t = t.Elem() | ||||
| 	} | ||||
| 	path := t.PkgPath() | ||||
| 	parts := strings.Split(path, `/vendor/`) | ||||
| 	return parts[len(parts)-1] | ||||
| } | ||||
| 
 | ||||
| func getDbType(db *sql.DB) string { | ||||
| 	switch getPkgName(db) { | ||||
| 	case `github.com/lib/pq`, `github.com/jackc/pgx/stdlib`: | ||||
| 		return `pq` | ||||
| 	case `github.com/go-sql-driver/mysql`, `github.com/ziutek/mymysql/godrv`: | ||||
| 		return `my` | ||||
| 	case `github.com/denisenkom/go-mssqldb`: | ||||
| 		return `ms` | ||||
| 	} | ||||
| 	return `` | ||||
| } | ||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue