diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..909249c --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module git.fuyu.moe/Fuyu/migrate/v2 + +go 1.16 diff --git a/migrate.go b/migrate.go index ab7984e..b2c377a 100644 --- a/migrate.go +++ b/migrate.go @@ -5,14 +5,14 @@ import ( "database/sql" "errors" "fmt" + "io/fs" "strconv" ) // Options contains all settings type Options struct { - TableName string // Name used for version info table; defaults to DefaultTableName if not set - Schema string // Schema used for version info table; In PostgreSQL the current schema is changed to the one specified here - AssetPrefix string + TableName string // Name used for version info table; defaults to DefaultTableName if not set + Schema string // Schema used for version info table; In PostgreSQL the current schema is changed to the one specified here } // DefaultTableName is the name used when no TableName is specified in Options @@ -26,16 +26,15 @@ const fileFormat = `%04d.sql` // AssetFunc is a function that returns the data for the given name type AssetFunc func(string) ([]byte, error) -// Migrate executes migrations to get to the desired version +// Migrate executes all migrations +// Filenames need to have incrementing numbers // Downgrading is not supported as it could result in data loss -func Migrate(db *sql.DB, version int, o Options, asset AssetFunc) error { - dbType := getDbType(db) - if o.TableName == `` { - o.TableName = DefaultTableName +func Migrate(db *sql.DB, o Options, assets fs.FS) error { + entries, err := fs.ReadDir(assets, `.`) + if err != nil { + panic(`failed to read list of files`) } - - var err error - searchPath := `public` + version := len(entries) tx, err := db.Begin() if err != nil { @@ -43,6 +42,12 @@ func Migrate(db *sql.DB, version int, o Options, asset AssetFunc) error { } defer tx.Rollback() + dbType := getDbType(db) + searchPath := `public` + if o.TableName == `` { + o.TableName = DefaultTableName + } + if o.Schema != `` { err = createSchemaIfNotExists(tx, o.Schema) if err != nil { @@ -73,10 +78,10 @@ func Migrate(db *sql.DB, version int, o Options, asset AssetFunc) error { } for i := v + 1; i <= version; i++ { - fileName := fmt.Sprintf(o.AssetPrefix+fileFormat, i) + fileName := fmt.Sprintf(fileFormat, i) errorf := func(e error) error { return fmt.Errorf(`migration "%s" failed: %w`, fileName, e) } - script, err := asset(fileName) + script, err := fs.ReadFile(assets, fileName) if err != nil { return errorf(ErrUpdatesMissing) }