Use fs.FS

This commit is contained in:
Nise Void 2021-02-24 14:27:48 +01:00
parent f39c1c640a
commit 9d4454628f
Signed by: NiseVoid
GPG Key ID: FBA14AC83EA602F3
2 changed files with 21 additions and 13 deletions

3
go.mod Normal file
View File

@ -0,0 +1,3 @@
module git.fuyu.moe/Fuyu/migrate/v2
go 1.16

View File

@ -5,14 +5,14 @@ import (
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
"io/fs"
"strconv" "strconv"
) )
// Options contains all settings // Options contains all settings
type Options struct { type Options struct {
TableName string // Name used for version info table; defaults to DefaultTableName if not set TableName string // Name used for version info table; defaults to DefaultTableName if not set
Schema string // Schema used for version info table; In PostgreSQL the current schema is changed to the one specified here Schema string // Schema used for version info table; In PostgreSQL the current schema is changed to the one specified here
AssetPrefix string
} }
// DefaultTableName is the name used when no TableName is specified in Options // DefaultTableName is the name used when no TableName is specified in Options
@ -26,16 +26,15 @@ const fileFormat = `%04d.sql`
// AssetFunc is a function that returns the data for the given name // AssetFunc is a function that returns the data for the given name
type AssetFunc func(string) ([]byte, error) type AssetFunc func(string) ([]byte, error)
// Migrate executes migrations to get to the desired version // Migrate executes all migrations
// Filenames need to have incrementing numbers
// Downgrading is not supported as it could result in data loss // Downgrading is not supported as it could result in data loss
func Migrate(db *sql.DB, version int, o Options, asset AssetFunc) error { func Migrate(db *sql.DB, o Options, assets fs.FS) error {
dbType := getDbType(db) entries, err := fs.ReadDir(assets, `.`)
if o.TableName == `` { if err != nil {
o.TableName = DefaultTableName panic(`failed to read list of files`)
} }
version := len(entries)
var err error
searchPath := `public`
tx, err := db.Begin() tx, err := db.Begin()
if err != nil { if err != nil {
@ -43,6 +42,12 @@ func Migrate(db *sql.DB, version int, o Options, asset AssetFunc) error {
} }
defer tx.Rollback() defer tx.Rollback()
dbType := getDbType(db)
searchPath := `public`
if o.TableName == `` {
o.TableName = DefaultTableName
}
if o.Schema != `` { if o.Schema != `` {
err = createSchemaIfNotExists(tx, o.Schema) err = createSchemaIfNotExists(tx, o.Schema)
if err != nil { if err != nil {
@ -73,10 +78,10 @@ func Migrate(db *sql.DB, version int, o Options, asset AssetFunc) error {
} }
for i := v + 1; i <= version; i++ { for i := v + 1; i <= version; i++ {
fileName := fmt.Sprintf(o.AssetPrefix+fileFormat, i) fileName := fmt.Sprintf(fileFormat, i)
errorf := func(e error) error { return fmt.Errorf(`migration "%s" failed: %w`, fileName, e) } errorf := func(e error) error { return fmt.Errorf(`migration "%s" failed: %w`, fileName, e) }
script, err := asset(fileName) script, err := fs.ReadFile(assets, fileName)
if err != nil { if err != nil {
return errorf(ErrUpdatesMissing) return errorf(ErrUpdatesMissing)
} }