blob: 407dad9053b4d746721370e951fbd13acf5e85d7 [file] [log] [blame]
package migration
import (
"context"
"fmt"
"gorm.io/gorm"
"sort"
"sync"
)
var m = migrator{scripts: make(map[string]scriptWithComment)}
type scriptWithComment struct {
Script
comment string
}
type migrator struct {
sync.Mutex
db *gorm.DB
scripts map[string]scriptWithComment
}
func Init(db *gorm.DB) {
m.db = db
}
func (m *migrator) register(scripts []Script, comment string) {
m.Lock()
defer m.Unlock()
for _, script := range scripts {
m.scripts[fmt.Sprintf("%s:%d", script.Name(), script.Version())] = scriptWithComment{
Script: script,
comment: comment,
}
}
}
func (m *migrator) bookKeep(script scriptWithComment) error {
record := &MigrationHistory{
ScriptVersion: script.Version(),
ScriptName: script.Name(),
Comment: script.comment,
}
return m.db.Create(record).Error
}
func (m *migrator) execute(ctx context.Context) error {
versions, err := m.getExecuted()
if err != nil {
return err
}
for key := range versions {
delete(m.scripts, key)
}
var scriptSlice []scriptWithComment
for _, script := range m.scripts {
scriptSlice = append(scriptSlice, script)
}
sort.Slice(scriptSlice, func(i, j int) bool {
return scriptSlice[i].Version() < scriptSlice[j].Version()
})
for _, script := range scriptSlice {
err = script.Up(ctx, m.db)
if err != nil {
return err
}
err = m.bookKeep(script)
if err != nil {
return err
}
}
return nil
}
func (m *migrator) getExecuted() (map[string]struct{}, error) {
var err error
versions := make(map[string]struct{})
err = m.db.Migrator().AutoMigrate(&MigrationHistory{})
if err != nil {
return nil, err
}
var records []MigrationHistory
err = m.db.Find(&records).Error
if err != nil {
return nil, err
}
for _, record := range records {
versions[fmt.Sprintf("%s:%d", record.ScriptName, record.ScriptVersion)] = struct{}{}
}
return versions, nil
}
func Register(scripts []Script, comment string) {
m.register(scripts, comment)
}
func Execute(ctx context.Context) error {
return m.execute(ctx)
}