mirror of https://github.com/ory/hydra
153 lines
5.0 KiB
Go
153 lines
5.0 KiB
Go
// Copyright © 2023 Ory Corp
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
package sqlxx
|
|
|
|
import (
|
|
"fmt"
|
|
"net/url"
|
|
"reflect"
|
|
"slices"
|
|
"strings"
|
|
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
func keys(t any, exclude []string) []string {
|
|
tt := reflect.TypeOf(t)
|
|
if tt.Kind() == reflect.Pointer {
|
|
tt = tt.Elem()
|
|
}
|
|
ks := make([]string, 0, tt.NumField())
|
|
for i := range tt.NumField() {
|
|
f := tt.Field(i)
|
|
key, _, _ := strings.Cut(f.Tag.Get("db"), ",")
|
|
if key != "" && key != "-" && !slices.Contains(exclude, key) {
|
|
ks = append(ks, key)
|
|
}
|
|
}
|
|
return ks
|
|
}
|
|
|
|
// NamedInsertArguments returns columns and arguments for SQL INSERT statements based on a struct's tags. Does
|
|
// not work with nested structs or maps!
|
|
//
|
|
// type st struct {
|
|
// Foo string `db:"foo"`
|
|
// Bar string `db:"bar,omitempty"`
|
|
// Baz string `db:"-"`
|
|
// Zab string
|
|
// }
|
|
// columns, arguments := NamedInsertArguments(new(st))
|
|
// query := fmt.Sprintf("INSERT INTO foo (%s) VALUES (%s)", columns, arguments)
|
|
// // INSERT INTO foo (foo, bar) VALUES (:foo, :bar)
|
|
func NamedInsertArguments(t any, exclude ...string) (columns string, arguments string) {
|
|
keys := keys(t, exclude)
|
|
return strings.Join(keys, ", "),
|
|
":" + strings.Join(keys, ", :")
|
|
}
|
|
|
|
// NamedUpdateArguments returns columns and arguments for SQL UPDATE statements based on a struct's tags. Does
|
|
// not work with nested structs or maps!
|
|
//
|
|
// type st struct {
|
|
// Foo string `db:"foo"`
|
|
// Bar string `db:"bar,omitempty"`
|
|
// Baz string `db:"-"`
|
|
// Zab string
|
|
// }
|
|
// query := fmt.Sprintf("UPDATE foo SET %s", NamedUpdateArguments(new(st)))
|
|
// // UPDATE foo SET foo=:foo, bar=:bar
|
|
func NamedUpdateArguments(t any, exclude ...string) string {
|
|
keys := keys(t, exclude)
|
|
statements := make([]string, len(keys))
|
|
|
|
for k, key := range keys {
|
|
statements[k] = fmt.Sprintf("%s=:%s", key, key)
|
|
}
|
|
|
|
return strings.Join(statements, ", ")
|
|
}
|
|
|
|
func OnConflictDoNothing(dialect string, columnNoop string) string {
|
|
if dialect == "mysql" {
|
|
return fmt.Sprintf(" ON DUPLICATE KEY UPDATE `%s` = `%s` ", columnNoop, columnNoop)
|
|
} else {
|
|
return ` ON CONFLICT DO NOTHING `
|
|
}
|
|
}
|
|
|
|
// ExtractSchemeFromDSN returns the scheme (e.g. `mysql`, `postgres`, etc) component in a DSN string,
|
|
// as well as the remaining part of the DSN after the scheme separator.
|
|
// It is an error to not have a scheme present.
|
|
// This makes sense in the context of a DSN to be able to identify which database is in use.
|
|
func ExtractSchemeFromDSN(dsn string) (string, string, error) {
|
|
scheme, afterSchemeSeparator, schemeSeparatorFound := strings.Cut(dsn, "://")
|
|
if !schemeSeparatorFound {
|
|
return "", "", errors.New("invalid DSN: missing scheme separator")
|
|
}
|
|
if scheme == "" {
|
|
return "", "", errors.New("invalid DSN: empty scheme")
|
|
}
|
|
|
|
return scheme, afterSchemeSeparator, nil
|
|
}
|
|
|
|
// ReplaceSchemeInDSN replaces the scheme (e.g. `mysql`, `postgres`, etc) in a DSN string with another one.
|
|
// This is necessary for example when using `cockroach` as a scheme, but using the postgres driver to connect to the database,
|
|
// and this driver only accepts `postgres` as a scheme.
|
|
func ReplaceSchemeInDSN(dsn string, newScheme string) (string, error) {
|
|
_, afterSchemeSeparator, err := ExtractSchemeFromDSN(dsn)
|
|
if err != nil {
|
|
return "", errors.WithStack(err)
|
|
}
|
|
|
|
return newScheme + "://" + afterSchemeSeparator, nil
|
|
}
|
|
|
|
// DSNRedacted parses a database DSN and returns a redacted form as a string.
|
|
// It replaces any password with "xxxxx" just like `url.Redacted()`.
|
|
// Only the password is redacted, not the username.
|
|
// This function is necessary because MySQL uses a DSN format not compatible with `url.Parse`.
|
|
// Additionally and as a consequence of the point above, the scheme is expected to be present and non-empty.
|
|
// This function is less strict that `url.Parse` in the case of MySQL.
|
|
// It also does not escape any characters in the username, whereas `url.String()`/`url.Redacted` does.
|
|
func DSNRedacted(dsn string) (string, error) {
|
|
scheme, afterSchemeSeparator, err := ExtractSchemeFromDSN(dsn)
|
|
if err != nil {
|
|
return "", errors.WithStack(err)
|
|
}
|
|
|
|
// If this is not MySQL, we simply delegate the work to `url.Parse`.
|
|
if scheme != "mysql" {
|
|
u, err := url.Parse(dsn)
|
|
if err != nil {
|
|
return "", errors.WithStack(err)
|
|
}
|
|
return u.Redacted(), nil
|
|
}
|
|
|
|
// MySQL has a weird DSN syntax not conforming to a standard URL, of the form:
|
|
// `[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN]`
|
|
// We only need to parse up to `@` in order to redact the password. The rest is left as-is.
|
|
|
|
usernamePassword, afterUsernamePassword, usernamePasswordSeparatorFound := strings.Cut(afterSchemeSeparator, "@")
|
|
if !usernamePasswordSeparatorFound {
|
|
afterUsernamePassword = afterSchemeSeparator
|
|
}
|
|
|
|
username, password, hasPassword := strings.Cut(usernamePassword, ":")
|
|
// We only insert a redacted password in the final result if a password was provided in the input.
|
|
// This behavior matches the one of `url.Redacted()`.
|
|
if hasPassword {
|
|
password = ":xxxxx"
|
|
}
|
|
|
|
res := scheme + "://"
|
|
if usernamePasswordSeparatorFound {
|
|
res += username + password + "@"
|
|
}
|
|
res += afterUsernamePassword
|
|
return res, nil
|
|
}
|