// Package migrate allows you to update your database from your application package migrate import ( "database/sql" "errors" "fmt" "io/fs" "strconv" ) // Options contains all settings type Options struct { TableName string // Name used for version info table; defaults to DefaultTableName if not set Schema string // Schema used for version info table; In PostgreSQL the current schema is changed to the one specified here } // DefaultTableName is the name used when no TableName is specified in Options const DefaultTableName = `version` // ErrUpdatesMissing indicates an update is missing, making it impossible to execute the migration var ErrUpdatesMissing = errors.New(`Missing migration files`) const fileFormat = `%04d.sql` // AssetFunc is a function that returns the data for the given name type AssetFunc func(string) ([]byte, error) // Migrate executes all migrations // Filenames need to have incrementing numbers // Downgrading is not supported as it could result in data loss func Migrate(db *sql.DB, o Options, assets fs.FS) error { entries, err := fs.ReadDir(assets, `.`) if err != nil { panic(`failed to read list of files`) } version := len(entries) tx, err := db.Begin() if err != nil { return err } defer tx.Rollback() dbType := getDbType(db) searchPath := `public` if o.TableName == `` { o.TableName = DefaultTableName } if o.Schema != `` { err = createSchemaIfNotExists(tx, o.Schema) if err != nil { return err } if dbType == `pq` { _ = tx.QueryRow(`SHOW search_path`).Scan(&searchPath) _, err = tx.Exec(`SET search_path TO ` + o.Schema) if err != nil { return err } } } versionTable, err := createTableIfNotExists(tx, o.Schema, o.TableName) if err != nil { return err } row := tx.QueryRow(`SELECT version FROM ` + versionTable + ` ORDER BY Version DESC`) var v int err = row.Scan(&v) if err != sql.ErrNoRows && err != nil { return err } for i := v + 1; i <= version; i++ { fileName := fmt.Sprintf(fileFormat, i) errorf := func(e error) error { return fmt.Errorf(`migration "%s" failed: %w`, fileName, e) } script, err := fs.ReadFile(assets, fileName) if err != nil { return errorf(ErrUpdatesMissing) } _, err = tx.Exec(string(script)) if err != nil { return errorf(err) } _, err = tx.Exec(`INSERT INTO ` + versionTable + ` (version) VALUES (` + strconv.Itoa(i) + `)`) if err != nil { return errorf(err) } } if dbType == `pq` && o.Schema != `` { _, err = tx.Exec(`SET search_path TO ` + searchPath) if err != nil { return err } } err = tx.Commit() if err != nil { return err } return nil }