diff --git a/go.mod b/go.mod deleted file mode 100644 index 909249c..0000000 --- a/go.mod +++ /dev/null @@ -1,3 +0,0 @@ -module git.fuyu.moe/Fuyu/migrate/v2 - -go 1.16 diff --git a/migrate.go b/migrate.go index b2c377a..ab7984e 100644 --- a/migrate.go +++ b/migrate.go @@ -5,14 +5,14 @@ import ( "database/sql" "errors" "fmt" - "io/fs" "strconv" ) // Options contains all settings type Options struct { - TableName string // Name used for version info table; defaults to DefaultTableName if not set - Schema string // Schema used for version info table; In PostgreSQL the current schema is changed to the one specified here + TableName string // Name used for version info table; defaults to DefaultTableName if not set + Schema string // Schema used for version info table; In PostgreSQL the current schema is changed to the one specified here + AssetPrefix string } // DefaultTableName is the name used when no TableName is specified in Options @@ -26,15 +26,16 @@ const fileFormat = `%04d.sql` // AssetFunc is a function that returns the data for the given name type AssetFunc func(string) ([]byte, error) -// Migrate executes all migrations -// Filenames need to have incrementing numbers +// Migrate executes migrations to get to the desired version // Downgrading is not supported as it could result in data loss -func Migrate(db *sql.DB, o Options, assets fs.FS) error { - entries, err := fs.ReadDir(assets, `.`) - if err != nil { - panic(`failed to read list of files`) +func Migrate(db *sql.DB, version int, o Options, asset AssetFunc) error { + dbType := getDbType(db) + if o.TableName == `` { + o.TableName = DefaultTableName } - version := len(entries) + + var err error + searchPath := `public` tx, err := db.Begin() if err != nil { @@ -42,12 +43,6 @@ func Migrate(db *sql.DB, o Options, assets fs.FS) error { } defer tx.Rollback() - dbType := getDbType(db) - searchPath := `public` - if o.TableName == `` { - o.TableName = DefaultTableName - } - if o.Schema != `` { err = createSchemaIfNotExists(tx, o.Schema) if err != nil { @@ -78,10 +73,10 @@ func Migrate(db *sql.DB, o Options, assets fs.FS) error { } for i := v + 1; i <= version; i++ { - fileName := fmt.Sprintf(fileFormat, i) + fileName := fmt.Sprintf(o.AssetPrefix+fileFormat, i) errorf := func(e error) error { return fmt.Errorf(`migration "%s" failed: %w`, fileName, e) } - script, err := fs.ReadFile(assets, fileName) + script, err := asset(fileName) if err != nil { return errorf(ErrUpdatesMissing) }