108 lines
		
	
	
	
		
			2.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			108 lines
		
	
	
	
		
			2.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Package migrate allows you to update your database from your application
 | |
| package migrate
 | |
| 
 | |
| import (
 | |
| 	"database/sql"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"strconv"
 | |
| )
 | |
| 
 | |
| // Options contains all settings
 | |
| type Options struct {
 | |
| 	TableName   string // Name used for version info table; defaults to DefaultTableName if not set
 | |
| 	Schema      string // Schema used for version info table; In PostgreSQL the current schema is changed to the one specified here
 | |
| 	AssetPrefix string
 | |
| }
 | |
| 
 | |
| // DefaultTableName is the name used when no TableName is specified in Options
 | |
| const DefaultTableName = `version`
 | |
| 
 | |
| // ErrUpdatesMissing indicates an update is missing, making it impossible to execute the migration
 | |
| var ErrUpdatesMissing = errors.New(`Missing migration files`)
 | |
| 
 | |
| const fileFormat = `%04d.sql`
 | |
| 
 | |
| // AssetFunc is a function that returns the data for the given name
 | |
| type AssetFunc func(string) ([]byte, error)
 | |
| 
 | |
| // Migrate executes migrations to get to the desired version
 | |
| // Downgrading is not supported as it could result in data loss
 | |
| func Migrate(db *sql.DB, version int, o Options, asset AssetFunc) error {
 | |
| 	dbType := getDbType(db)
 | |
| 	if o.TableName == `` {
 | |
| 		o.TableName = DefaultTableName
 | |
| 	}
 | |
| 
 | |
| 	var err error
 | |
| 	searchPath := `public`
 | |
| 
 | |
| 	tx, err := db.Begin()
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	defer tx.Rollback()
 | |
| 
 | |
| 	if o.Schema != `` {
 | |
| 		err = createSchemaIfNotExists(tx, o.Schema)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		if dbType == `pq` {
 | |
| 			_ = tx.QueryRow(`SHOW search_path`).Scan(&searchPath)
 | |
| 
 | |
| 			_, err = tx.Exec(`SET search_path TO ` + o.Schema)
 | |
| 			if err != nil {
 | |
| 				return err
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	versionTable, err := createTableIfNotExists(tx, o.Schema, o.TableName)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	row := tx.QueryRow(`SELECT version FROM ` + versionTable + ` ORDER BY Version DESC`)
 | |
| 
 | |
| 	var v int
 | |
| 	err = row.Scan(&v)
 | |
| 	if err != sql.ErrNoRows && err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	for i := v + 1; i <= version; i++ {
 | |
| 		fileName := fmt.Sprintf(o.AssetPrefix+fileFormat, i)
 | |
| 		errorf := func(e error) error { return fmt.Errorf(`migration "%s" failed: %w`, fileName, e) }
 | |
| 
 | |
| 		script, err := asset(fileName)
 | |
| 		if err != nil {
 | |
| 			return errorf(ErrUpdatesMissing)
 | |
| 		}
 | |
| 
 | |
| 		_, err = tx.Exec(string(script))
 | |
| 		if err != nil {
 | |
| 			return errorf(err)
 | |
| 		}
 | |
| 
 | |
| 		_, err = tx.Exec(`INSERT INTO ` + versionTable + ` (version) VALUES (` + strconv.Itoa(i) + `)`)
 | |
| 		if err != nil {
 | |
| 			return errorf(err)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if dbType == `pq` && o.Schema != `` {
 | |
| 		_, err = tx.Exec(`SET search_path TO ` + searchPath)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	err = tx.Commit()
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 |