Support non-PostgreSQL databases
This commit is contained in:
		
							parent
							
								
									de9e7d9a41
								
							
						
					
					
						commit
						f39c1c640a
					
				
					 3 changed files with 86 additions and 16 deletions
				
			
		
							
								
								
									
										38
									
								
								migrate.go
									
										
									
									
									
								
							
							
						
						
									
										38
									
								
								migrate.go
									
										
									
									
									
								
							|  | @ -11,7 +11,7 @@ import ( | ||||||
| // Options contains all settings | // Options contains all settings | ||||||
| type Options struct { | type Options struct { | ||||||
| 	TableName   string // Name used for version info table; defaults to DefaultTableName if not set | 	TableName   string // Name used for version info table; defaults to DefaultTableName if not set | ||||||
| 	Schema      string // Schema used for version info table; For PostgreSQL, ignored if not set | 	Schema      string // Schema used for version info table; In PostgreSQL the current schema is changed to the one specified here | ||||||
| 	AssetPrefix string | 	AssetPrefix string | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -29,23 +29,13 @@ type AssetFunc func(string) ([]byte, error) | ||||||
| // Migrate executes migrations to get to the desired version | // Migrate executes migrations to get to the desired version | ||||||
| // Downgrading is not supported as it could result in data loss | // Downgrading is not supported as it could result in data loss | ||||||
| func Migrate(db *sql.DB, version int, o Options, asset AssetFunc) error { | func Migrate(db *sql.DB, version int, o Options, asset AssetFunc) error { | ||||||
|  | 	dbType := getDbType(db) | ||||||
| 	if o.TableName == `` { | 	if o.TableName == `` { | ||||||
| 		o.TableName = DefaultTableName | 		o.TableName = DefaultTableName | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	var err error | 	var err error | ||||||
| 
 |  | ||||||
| 	searchPath := `public` | 	searchPath := `public` | ||||||
| 	_ = db.QueryRow(`SHOW search_path`).Scan(&searchPath) |  | ||||||
| 
 |  | ||||||
| 	if o.Schema != `` { |  | ||||||
| 		_, _ = db.Exec(`CREATE SCHEMA IF NOT EXISTS ` + o.Schema) |  | ||||||
| 
 |  | ||||||
| 		_, err = db.Exec(`SET search_path TO ` + o.Schema) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 | 
 | ||||||
| 	tx, err := db.Begin() | 	tx, err := db.Begin() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | @ -53,12 +43,28 @@ func Migrate(db *sql.DB, version int, o Options, asset AssetFunc) error { | ||||||
| 	} | 	} | ||||||
| 	defer tx.Rollback() | 	defer tx.Rollback() | ||||||
| 
 | 
 | ||||||
| 	_, err = tx.Exec(`CREATE TABLE IF NOT EXISTS ` + o.TableName + ` (Version integer NOT NULL PRIMARY KEY)`) | 	if o.Schema != `` { | ||||||
|  | 		err = createSchemaIfNotExists(tx, o.Schema) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 	row := tx.QueryRow(`SELECT Version FROM ` + o.TableName + ` ORDER BY Version DESC`) | 		if dbType == `pq` { | ||||||
|  | 			_ = tx.QueryRow(`SHOW search_path`).Scan(&searchPath) | ||||||
|  | 
 | ||||||
|  | 			_, err = tx.Exec(`SET search_path TO ` + o.Schema) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	versionTable, err := createTableIfNotExists(tx, o.Schema, o.TableName) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	row := tx.QueryRow(`SELECT version FROM ` + versionTable + ` ORDER BY Version DESC`) | ||||||
| 
 | 
 | ||||||
| 	var v int | 	var v int | ||||||
| 	err = row.Scan(&v) | 	err = row.Scan(&v) | ||||||
|  | @ -80,13 +86,13 @@ func Migrate(db *sql.DB, version int, o Options, asset AssetFunc) error { | ||||||
| 			return errorf(err) | 			return errorf(err) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		_, err = tx.Exec(`INSERT INTO ` + o.TableName + ` VALUES (` + strconv.Itoa(i) + `)`) | 		_, err = tx.Exec(`INSERT INTO ` + versionTable + ` (version) VALUES (` + strconv.Itoa(i) + `)`) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return errorf(err) | 			return errorf(err) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if o.Schema != `` { | 	if dbType == `pq` && o.Schema != `` { | ||||||
| 		_, err = tx.Exec(`SET search_path TO ` + searchPath) | 		_, err = tx.Exec(`SET search_path TO ` + searchPath) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return err | ||||||
|  |  | ||||||
							
								
								
									
										35
									
								
								query.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								query.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,35 @@ | ||||||
|  | package migrate | ||||||
|  | 
 | ||||||
|  | import "database/sql" | ||||||
|  | 
 | ||||||
|  | func createSchemaIfNotExists(tx *sql.Tx, schema string) error { | ||||||
|  | 	row := tx.QueryRow(`SELECT 1 FROM information_schema.schemata WHERE schema_name = '` + schema + `'`) | ||||||
|  | 	err := row.Scan(new(int)) | ||||||
|  | 	if err == sql.ErrNoRows { | ||||||
|  | 		_, err = tx.Exec(`CREATE SCHEMA ` + schema) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func createTableIfNotExists(tx *sql.Tx, schema, table string) (string, error) { | ||||||
|  | 	versionTable := table | ||||||
|  | 	var schemaCond string | ||||||
|  | 
 | ||||||
|  | 	if schema != `` { | ||||||
|  | 		versionTable = schema + `.` + table | ||||||
|  | 		schemaCond = ` AND table_schema = '` + schema + `'` | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	row := tx.QueryRow(`SELECT 1 FROM information_schema.tables WHERE table_name = '` + table + `'` + schemaCond) | ||||||
|  | 	err := row.Scan(new(int)) | ||||||
|  | 	if err == sql.ErrNoRows { | ||||||
|  | 		_, err = tx.Exec(`CREATE TABLE ` + versionTable + ` (version integer NOT NULL PRIMARY KEY)`) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return ``, err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return versionTable, err | ||||||
|  | } | ||||||
							
								
								
									
										29
									
								
								util.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								util.go
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,29 @@ | ||||||
|  | package migrate | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"database/sql" | ||||||
|  | 	"reflect" | ||||||
|  | 	"strings" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func getPkgName(db *sql.DB) string { | ||||||
|  | 	t := reflect.TypeOf(db.Driver()) | ||||||
|  | 	if t.Kind() == reflect.Ptr { | ||||||
|  | 		t = t.Elem() | ||||||
|  | 	} | ||||||
|  | 	path := t.PkgPath() | ||||||
|  | 	parts := strings.Split(path, `/vendor/`) | ||||||
|  | 	return parts[len(parts)-1] | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func getDbType(db *sql.DB) string { | ||||||
|  | 	switch getPkgName(db) { | ||||||
|  | 	case `github.com/lib/pq`, `github.com/jackc/pgx/stdlib`: | ||||||
|  | 		return `pq` | ||||||
|  | 	case `github.com/go-sql-driver/mysql`, `github.com/ziutek/mymysql/godrv`: | ||||||
|  | 		return `my` | ||||||
|  | 	case `github.com/denisenkom/go-mssqldb`: | ||||||
|  | 		return `ms` | ||||||
|  | 	} | ||||||
|  | 	return `` | ||||||
|  | } | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue