mirror of https://github.com/ory/hydra
337 lines
8.6 KiB
Go
337 lines
8.6 KiB
Go
// Copyright © 2022 Ory Corp
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
package cli
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"io"
|
|
"io/fs"
|
|
"os"
|
|
"path/filepath"
|
|
"regexp"
|
|
"strings"
|
|
|
|
"github.com/ory/x/popx"
|
|
"github.com/ory/x/servicelocatorx"
|
|
|
|
"github.com/pkg/errors"
|
|
|
|
"github.com/ory/x/configx"
|
|
|
|
"github.com/ory/x/cmdx"
|
|
|
|
"github.com/spf13/cobra"
|
|
|
|
"github.com/ory/hydra/v2/driver"
|
|
"github.com/ory/hydra/v2/driver/config"
|
|
"github.com/ory/hydra/v2/persistence"
|
|
"github.com/ory/x/flagx"
|
|
)
|
|
|
|
type MigrateHandler struct {
|
|
slOpts []servicelocatorx.Option
|
|
dOpts []driver.OptionsModifier
|
|
cOpts []configx.OptionModifier
|
|
}
|
|
|
|
func newMigrateHandler(slOpts []servicelocatorx.Option, dOpts []driver.OptionsModifier, cOpts []configx.OptionModifier) *MigrateHandler {
|
|
return &MigrateHandler{
|
|
slOpts: slOpts,
|
|
dOpts: dOpts,
|
|
cOpts: cOpts,
|
|
}
|
|
}
|
|
|
|
const (
|
|
genericDialectKey = "any"
|
|
)
|
|
|
|
func fragmentHeader() []byte {
|
|
return fmt.Appendf(nil, "-- Migration generated by the command below; DO NOT EDIT.\n-- %s\n", strings.Join(os.Args, " "))
|
|
}
|
|
|
|
func blankFragment() []byte {
|
|
return fmt.Appendf(nil, "-- This is a blank migration. It is generated to ensure that all dialects are represented in the migration files.\n-- %s\n", strings.Join(os.Args, " "))
|
|
}
|
|
|
|
var mrx = regexp.MustCompile(`^(\d{14})000000_([^.]+)(\.[a-z0-9]+)?\.(up|down)\.sql$`)
|
|
|
|
type migration struct {
|
|
Path string
|
|
ID string
|
|
Name string
|
|
Dialect string
|
|
Direction string
|
|
}
|
|
|
|
type migrationGroup struct {
|
|
ID string
|
|
Name string
|
|
Children []*migration
|
|
fallbackUpMigration *migration
|
|
fallbackDownMigration *migration
|
|
}
|
|
|
|
func (m *migration) ReadSource(fs fs.FS) ([]byte, error) {
|
|
f, err := fs.Open(m.Path)
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
defer f.Close()
|
|
return io.ReadAll(f)
|
|
}
|
|
|
|
func (m migration) generateMigrationFragments(source []byte) ([][]byte, error) {
|
|
chunks := bytes.Split(source, []byte("--split"))
|
|
if len(chunks) < 1 {
|
|
return nil, errors.New("no migration chunks found")
|
|
}
|
|
for i := range chunks {
|
|
chunks[i] = append(fragmentHeader(), chunks[i]...)
|
|
}
|
|
return chunks, nil
|
|
}
|
|
|
|
func (mg migrationGroup) fragmentName(m *migration, i int) string {
|
|
if m.Dialect == genericDialectKey {
|
|
return fmt.Sprintf("%s%06d_%s.%s.sql", mg.ID, i, mg.Name, m.Direction)
|
|
} else {
|
|
return fmt.Sprintf("%s%06d_%s.%s.%s.sql", mg.ID, i, mg.Name, m.Dialect, m.Direction)
|
|
}
|
|
}
|
|
|
|
// GenerateSQL splits the migration sources into chunks and writes them to the
|
|
// target directory.
|
|
func (mg migrationGroup) generateSQL(sourceFS fs.FS, target string) error {
|
|
ms := mg.Children
|
|
if mg.fallbackDownMigration != nil {
|
|
ms = append(ms, mg.fallbackDownMigration)
|
|
}
|
|
if mg.fallbackUpMigration != nil {
|
|
ms = append(ms, mg.fallbackUpMigration)
|
|
}
|
|
dialectFragmentCounts := map[string]int{}
|
|
maxFragmentCount := -1
|
|
for _, m := range ms {
|
|
source, err := m.ReadSource(sourceFS)
|
|
if err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
fragments, err := m.generateMigrationFragments(source)
|
|
dialectFragmentCounts[m.Dialect] = len(fragments)
|
|
if maxFragmentCount < len(fragments) {
|
|
maxFragmentCount = len(fragments)
|
|
}
|
|
if err != nil {
|
|
return errors.Errorf("failed to process %s: %s", m.Path, err.Error())
|
|
}
|
|
for i, fragment := range fragments {
|
|
dst := filepath.Join(target, mg.fragmentName(m, i))
|
|
if err = os.WriteFile(dst, fragment, 0600); err != nil {
|
|
return errors.WithStack(errors.Errorf("failed to write file %s", dst))
|
|
}
|
|
}
|
|
}
|
|
for _, m := range ms {
|
|
for i := dialectFragmentCounts[m.Dialect]; i < maxFragmentCount; i += 1 {
|
|
dst := filepath.Join(target, mg.fragmentName(m, i))
|
|
if err := os.WriteFile(dst, blankFragment(), 0600); err != nil {
|
|
return errors.WithStack(errors.Errorf("failed to write file %s", dst))
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func parseMigration(filename string) (*migration, error) {
|
|
matches := mrx.FindAllStringSubmatch(filename, -1)
|
|
if matches == nil {
|
|
return nil, errors.Errorf("failed to parse migration filename %s; %s does not match pattern ", filename, mrx.String())
|
|
}
|
|
if len(matches) != 1 && len(matches[0]) != 5 {
|
|
return nil, errors.Errorf("invalid migration %s; expected %s", filename, mrx.String())
|
|
}
|
|
dialect := matches[0][3]
|
|
if dialect == "" {
|
|
dialect = genericDialectKey
|
|
} else {
|
|
dialect = dialect[1:]
|
|
}
|
|
return &migration{
|
|
Path: filename,
|
|
ID: matches[0][1],
|
|
Name: matches[0][2],
|
|
Dialect: dialect,
|
|
Direction: matches[0][4],
|
|
}, nil
|
|
}
|
|
|
|
func readMigrations(migrationSourceFS fs.FS, expectedDialects []string) (map[string]*migrationGroup, error) {
|
|
mgs := make(map[string]*migrationGroup)
|
|
err := fs.WalkDir(migrationSourceFS, ".", func(p string, d fs.DirEntry, err2 error) error {
|
|
if err2 != nil {
|
|
fmt.Println("Warning: unexpected error " + err2.Error())
|
|
return nil
|
|
}
|
|
if d.IsDir() {
|
|
return nil
|
|
}
|
|
if p != filepath.Base(p) {
|
|
fmt.Println("Warning: ignoring nested file " + p)
|
|
return nil
|
|
}
|
|
|
|
m, err := parseMigration(p)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if _, ok := mgs[m.ID]; !ok {
|
|
mgs[m.ID] = &migrationGroup{
|
|
ID: m.ID,
|
|
Name: m.Name,
|
|
Children: nil,
|
|
}
|
|
}
|
|
|
|
if m.Dialect == genericDialectKey && m.Direction == "up" {
|
|
mgs[m.ID].fallbackUpMigration = m
|
|
} else if m.Dialect == genericDialectKey && m.Direction == "down" {
|
|
mgs[m.ID].fallbackDownMigration = m
|
|
} else {
|
|
mgs[m.ID].Children = append(mgs[m.ID].Children, m)
|
|
}
|
|
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(expectedDialects) == 0 {
|
|
return mgs, nil
|
|
}
|
|
|
|
eds := make(map[string]struct{})
|
|
for i := range expectedDialects {
|
|
eds[expectedDialects[i]] = struct{}{}
|
|
}
|
|
for _, mg := range mgs {
|
|
expect := make(map[string]struct{})
|
|
for _, m := range mg.Children {
|
|
if _, ok := eds[m.Dialect]; !ok {
|
|
return nil, errors.Errorf("unexpected dialect %s in filename %s", m.Dialect, m.Path)
|
|
}
|
|
|
|
expect[m.Dialect+"."+m.Direction] = struct{}{}
|
|
}
|
|
for _, d := range expectedDialects {
|
|
if _, ok := expect[d+".up"]; !ok && mg.fallbackUpMigration == nil {
|
|
return nil, errors.Errorf("dialect %s not found for up migration %s; use --dialects=\"\" to disable dialect validation", d, mg.ID)
|
|
}
|
|
if _, ok := expect[d+".down"]; !ok && mg.fallbackDownMigration == nil {
|
|
return nil, errors.Errorf("dialect %s not found for down migration %s; use --dialects=\"\" to disable dialect validation", d, mg.ID)
|
|
}
|
|
}
|
|
}
|
|
|
|
return mgs, nil
|
|
}
|
|
|
|
func (h *MigrateHandler) MigrateGen(cmd *cobra.Command, args []string) {
|
|
cmdx.ExactArgs(cmd, args, 2)
|
|
expectedDialects := flagx.MustGetStringSlice(cmd, "dialects")
|
|
|
|
sourceDir := args[0]
|
|
targetDir := args[1]
|
|
sourceFS := os.DirFS(sourceDir)
|
|
mgs, err := readMigrations(sourceFS, expectedDialects)
|
|
if err != nil {
|
|
fmt.Println(err.Error())
|
|
os.Exit(1)
|
|
}
|
|
for _, mg := range mgs {
|
|
err = mg.generateSQL(sourceFS, targetDir)
|
|
if err != nil {
|
|
fmt.Println(err.Error())
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
os.Exit(0)
|
|
}
|
|
|
|
func (h *MigrateHandler) makePersister(cmd *cobra.Command, args []string) (p persistence.Persister, err error) {
|
|
var d driver.Registry
|
|
|
|
if flagx.MustGetBool(cmd, "read-from-env") {
|
|
d, err = driver.New(
|
|
cmd.Context(),
|
|
servicelocatorx.NewOptions(),
|
|
append([]driver.OptionsModifier{
|
|
driver.WithOptions(
|
|
configx.SkipValidation(),
|
|
configx.WithFlags(cmd.Flags())),
|
|
driver.DisableValidation(),
|
|
driver.DisablePreloading(),
|
|
driver.SkipNetworkInit(),
|
|
}, h.dOpts...))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(d.Config().DSN()) == 0 {
|
|
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), "When using flag -e, environment variable DSN must be set.")
|
|
return nil, cmdx.FailSilently(cmd)
|
|
}
|
|
} else {
|
|
if len(args) != 1 {
|
|
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), "Please provide the database URL.")
|
|
return nil, cmdx.FailSilently(cmd)
|
|
}
|
|
d, err = driver.New(
|
|
cmd.Context(),
|
|
servicelocatorx.NewOptions(),
|
|
append([]driver.OptionsModifier{
|
|
driver.WithOptions(
|
|
configx.WithFlags(cmd.Flags()),
|
|
configx.SkipValidation(),
|
|
configx.WithValue(config.KeyDSN, args[0]),
|
|
),
|
|
driver.DisableValidation(),
|
|
driver.DisablePreloading(),
|
|
driver.SkipNetworkInit(),
|
|
}, h.dOpts...))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return d.Persister(), nil
|
|
}
|
|
|
|
func (h *MigrateHandler) MigrateSQLUp(cmd *cobra.Command, args []string) (err error) {
|
|
p, err := h.makePersister(cmd, args)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return popx.MigrateSQLUp(cmd, p)
|
|
}
|
|
|
|
func (h *MigrateHandler) MigrateSQLDown(cmd *cobra.Command, args []string) (err error) {
|
|
p, err := h.makePersister(cmd, args)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return popx.MigrateSQLDown(cmd, p)
|
|
}
|
|
|
|
func (h *MigrateHandler) MigrateStatus(cmd *cobra.Command, args []string) error {
|
|
p, err := h.makePersister(cmd, args)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return popx.MigrateStatus(cmd, p)
|
|
}
|