feat: add sms verification for phone numbers (#3649)

This commit is contained in:
Jonas Hungershausen 2023-12-28 16:35:57 +01:00 committed by GitHub
parent ae8cbdc27f
commit e3a3c4fe0d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
65 changed files with 1164 additions and 789 deletions

View File

@ -0,0 +1,31 @@
{
"$id": "https://schemas.ory.sh/presets/kratos/quickstart/email-password/identity.schema.json",
"$schema": "http://json-schema.org/draft-07/schema#",
"title": "Person",
"type": "object",
"properties": {
"traits": {
"type": "object",
"properties": {
"phone": {
"type": "string",
"format": "tel",
"title": "Phone number",
"minLength": 3,
"ory.sh/kratos": {
"credentials": {
"password": {
"identifier": true
}
},
"verification": {
"via": "phone"
}
}
}
},
"required": ["phone"],
"additionalProperties": false
}
}
}

View File

@ -0,0 +1,113 @@
version: v0.13.0
dsn: memory
serve:
public:
base_url: http://127.0.0.1:4433/
cors:
enabled: true
admin:
base_url: http://kratos:4434/
selfservice:
default_browser_return_url: http://127.0.0.1:4455/
allowed_return_urls:
- http://127.0.0.1:4455
- http://localhost:19006/Callback
- exp://localhost:8081/--/Callback
methods:
password:
enabled: true
totp:
config:
issuer: Kratos
enabled: true
lookup_secret:
enabled: true
link:
enabled: true
code:
enabled: true
flows:
error:
ui_url: http://127.0.0.1:4455/error
settings:
ui_url: http://127.0.0.1:4455/settings
privileged_session_max_age: 15m
required_aal: highest_available
recovery:
enabled: true
ui_url: http://127.0.0.1:4455/recovery
use: code
verification:
enabled: true
ui_url: http://127.0.0.1:4455/verification
use: code
after:
default_browser_return_url: http://127.0.0.1:4455/
logout:
after:
default_browser_return_url: http://127.0.0.1:4455/login
login:
ui_url: http://127.0.0.1:4455/login
lifespan: 10m
registration:
lifespan: 10m
ui_url: http://127.0.0.1:4455/registration
after:
password:
hooks:
- hook: session
- hook: show_verification_ui
log:
level: debug
format: text
leak_sensitive_values: true
secrets:
cookie:
- PLEASE-CHANGE-ME-I-AM-VERY-INSECURE
cipher:
- 32-LONG-SECRET-NOT-SECURE-AT-ALL
ciphers:
algorithm: xchacha20-poly1305
hashers:
algorithm: bcrypt
bcrypt:
cost: 8
identity:
default_schema_id: default
schemas:
- id: default
url: file:///etc/config/kratos/identity.schema.json
courier:
channels:
- id: phone
request_config:
url: https://api.twilio.com/2010-04-01/Accounts/AXXXXXXXXXXXXXX/Messages.json
method: POST
body: base64://ZnVuY3Rpb24oY3R4KSB7CkJvZHk6IGN0eC5ib2R5LApUbzogY3R4LnRvLEZyb206IGN0eC5mcm9tCn0=
headers:
Content-Type: application/x-www-form-urlencoded
auth:
type: basic_auth
config:
user: AXXXXXXX
password: XXXX
feature_flags:
use_continue_with_transitions: true

13
courier/channel.go Normal file
View File

@ -0,0 +1,13 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package courier
import (
"context"
)
type Channel interface {
ID() string
Dispatch(ctx context.Context, msg Message) error
}

View File

@ -7,16 +7,15 @@ import (
"context"
"time"
"github.com/ory/kratos/courier/template"
"github.com/ory/x/jsonnetsecure"
"github.com/cenkalti/backoff"
"github.com/gofrs/uuid"
"github.com/pkg/errors"
"github.com/ory/kratos/courier/template"
"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/x"
gomail "github.com/ory/mail/v3"
)
type (
@ -33,11 +32,8 @@ type (
Work(ctx context.Context) error
QueueEmail(ctx context.Context, t EmailTemplate) (uuid.UUID, error)
QueueSMS(ctx context.Context, t SMSTemplate) (uuid.UUID, error)
SmtpDialer() *gomail.Dialer
DispatchQueue(ctx context.Context) error
DispatchMessage(ctx context.Context, msg Message) error
SetGetEmailTemplateType(f func(t EmailTemplate) (TemplateType, error))
SetNewEmailTemplateFromMessage(f func(d template.Dependencies, msg Message) (EmailTemplate, error))
UseBackoff(b backoff.BackOff)
FailOnDispatchError()
}
@ -51,9 +47,7 @@ type (
}
courier struct {
smsClient *smsClient
smtpClient *smtpClient
httpClient *httpClient
courierChannels map[string]Channel
deps Dependencies
failOnDispatchError bool
backoff backoff.BackOff
@ -61,16 +55,34 @@ type (
)
func NewCourier(ctx context.Context, deps Dependencies) (Courier, error) {
smtp, err := newSMTP(ctx, deps)
return NewCourierWithCustomTemplates(ctx, deps, NewEmailTemplateFromMessage)
}
func NewCourierWithCustomTemplates(ctx context.Context, deps Dependencies, newEmailTemplateFromMessage func(d template.Dependencies, msg Message) (EmailTemplate, error)) (Courier, error) {
cs, err := deps.CourierConfig().CourierChannels(ctx)
if err != nil {
return nil, err
}
channels := make(map[string]Channel, len(cs))
for _, c := range cs {
switch c.Type {
case "smtp":
ch, err := NewSMTPChannelWithCustomTemplates(deps, c.SMTPConfig, newEmailTemplateFromMessage)
if err != nil {
return nil, err
}
channels[ch.ID()] = ch
case "http":
channels[c.ID] = newHttpChannel(c.ID, c.RequestConfig, deps)
default:
return nil, errors.Errorf("unknown courier channel type: %s", c.Type)
}
}
return &courier{
smsClient: newSMS(ctx, deps),
smtpClient: smtp,
httpClient: newHTTP(ctx, deps),
deps: deps,
backoff: backoff.NewExponentialBackOff(),
deps: deps,
backoff: backoff.NewExponentialBackOff(),
courierChannels: channels,
}, nil
}

View File

@ -19,17 +19,13 @@ func (c *courier) DispatchMessage(ctx context.Context, msg Message) error {
return err
}
switch msg.Type {
case MessageTypeEmail:
if err := c.dispatchEmail(ctx, msg); err != nil {
return err
}
case MessageTypePhone:
if err := c.dispatchSMS(ctx, msg); err != nil {
return err
}
default:
return errors.Errorf("received unexpected message type: %d", msg.Type)
channel, ok := c.courierChannels[msg.Channel.String()]
if !ok {
return errors.Errorf("message %s has unknown channel %q", msg.ID.String(), msg.Channel)
}
if err := channel.Dispatch(ctx, msg); err != nil {
return err
}
if err := c.deps.CourierPersister().SetMessageStatus(ctx, msg.ID, MessageStatusSent); err != nil {
@ -37,6 +33,7 @@ func (c *courier) DispatchMessage(ctx context.Context, msg Message) error {
WithError(err).
WithField("message_id", msg.ID).
WithField("message_nid", msg.NID).
WithField("channel", channel.ID()).
Error(`Unable to set the message status to "sent".`)
return err
}
@ -47,6 +44,7 @@ func (c *courier) DispatchMessage(ctx context.Context, msg Message) error {
WithField("message_type", msg.Type).
WithField("message_template_type", msg.TemplateType).
WithField("message_subject", msg.Subject).
WithField("channel", channel.ID()).
Debug("Courier sent out message.")
return nil

View File

@ -16,6 +16,7 @@ import (
templates "github.com/ory/kratos/courier/template/email"
"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/internal"
"github.com/ory/kratos/internal/testhelpers"
)
func queueNewMessage(t *testing.T, ctx context.Context, c courier.Courier, d template.Dependencies) uuid.UUID {
@ -58,6 +59,33 @@ func TestDispatchMessageWithInvalidSMTP(t *testing.T) {
})
}
func TestDispatchMessage(t *testing.T) {
ctx := context.Background()
conf, reg := internal.NewRegistryDefaultWithDSN(t, "")
conf.MustSet(ctx, config.ViperKeyCourierMessageRetries, 5)
conf.MustSet(ctx, config.ViperKeyCourierSMTPURL, "http://foo.url")
ctx, cancel := context.WithCancel(ctx)
t.Cleanup(cancel)
c, err := reg.Courier(ctx)
require.NoError(t, err)
t.Run("case=invalid channel", func(t *testing.T) {
message := courier.Message{
Channel: "invalid-channel",
Status: courier.MessageStatusQueued,
Type: courier.MessageTypeEmail,
Recipient: testhelpers.RandomEmail(),
Subject: "test-subject-1",
Body: "test-body-1",
TemplateType: "stub",
}
require.NoError(t, reg.CourierPersister().AddMessage(ctx, &message))
require.Error(t, c.DispatchMessage(ctx, message))
})
}
func TestDispatchQueue(t *testing.T) {
ctx := context.Background()

View File

@ -1,93 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package courier_test
import (
"context"
"encoding/json"
"fmt"
"testing"
"github.com/stretchr/testify/require"
"github.com/ory/kratos/courier"
"github.com/ory/kratos/courier/template/email"
"github.com/ory/kratos/internal"
)
func TestGetTemplateType(t *testing.T) {
for expectedType, tmpl := range map[courier.TemplateType]courier.EmailTemplate{
courier.TypeRecoveryInvalid: &email.RecoveryInvalid{},
courier.TypeRecoveryValid: &email.RecoveryValid{},
courier.TypeRecoveryCodeInvalid: &email.RecoveryCodeInvalid{},
courier.TypeRecoveryCodeValid: &email.RecoveryCodeValid{},
courier.TypeVerificationInvalid: &email.VerificationInvalid{},
courier.TypeVerificationValid: &email.VerificationValid{},
courier.TypeVerificationCodeInvalid: &email.VerificationCodeInvalid{},
courier.TypeVerificationCodeValid: &email.VerificationCodeValid{},
courier.TypeTestStub: &email.TestStub{},
courier.TypeLoginCodeValid: &email.LoginCodeValid{},
courier.TypeRegistrationCodeValid: &email.RegistrationCodeValid{},
} {
t.Run(fmt.Sprintf("case=%s", expectedType), func(t *testing.T) {
actualType, err := courier.GetEmailTemplateType(tmpl)
require.NoError(t, err)
require.Equal(t, expectedType, actualType)
})
}
}
func TestNewEmailTemplateFromMessage(t *testing.T) {
_, reg := internal.NewFastRegistryWithMocks(t)
ctx := context.Background()
for tmplType, expectedTmpl := range map[courier.TemplateType]courier.EmailTemplate{
courier.TypeRecoveryInvalid: email.NewRecoveryInvalid(reg, &email.RecoveryInvalidModel{To: "foo"}),
courier.TypeRecoveryValid: email.NewRecoveryValid(reg, &email.RecoveryValidModel{To: "bar", RecoveryURL: "http://foo.bar"}),
courier.TypeRecoveryCodeValid: email.NewRecoveryCodeValid(reg, &email.RecoveryCodeValidModel{To: "bar", RecoveryCode: "12345678"}),
courier.TypeRecoveryCodeInvalid: email.NewRecoveryCodeInvalid(reg, &email.RecoveryCodeInvalidModel{To: "bar"}),
courier.TypeVerificationInvalid: email.NewVerificationInvalid(reg, &email.VerificationInvalidModel{To: "baz"}),
courier.TypeVerificationValid: email.NewVerificationValid(reg, &email.VerificationValidModel{To: "faz", VerificationURL: "http://bar.foo"}),
courier.TypeVerificationCodeInvalid: email.NewVerificationCodeInvalid(reg, &email.VerificationCodeInvalidModel{To: "baz"}),
courier.TypeVerificationCodeValid: email.NewVerificationCodeValid(reg, &email.VerificationCodeValidModel{To: "faz", VerificationURL: "http://bar.foo", VerificationCode: "123456678"}),
courier.TypeTestStub: email.NewTestStub(reg, &email.TestStubModel{To: "far", Subject: "test subject", Body: "test body"}),
courier.TypeLoginCodeValid: email.NewLoginCodeValid(reg, &email.LoginCodeValidModel{To: "far", LoginCode: "123456"}),
courier.TypeRegistrationCodeValid: email.NewRegistrationCodeValid(reg, &email.RegistrationCodeValidModel{To: "far", RegistrationCode: "123456"}),
} {
t.Run(fmt.Sprintf("case=%s", tmplType), func(t *testing.T) {
tmplData, err := json.Marshal(expectedTmpl)
require.NoError(t, err)
m := courier.Message{TemplateType: tmplType, TemplateData: tmplData}
actualTmpl, err := courier.NewEmailTemplateFromMessage(reg, m)
require.NoError(t, err)
require.IsType(t, expectedTmpl, actualTmpl)
expectedRecipient, err := expectedTmpl.EmailRecipient()
require.NoError(t, err)
actualRecipient, err := actualTmpl.EmailRecipient()
require.NoError(t, err)
require.Equal(t, expectedRecipient, actualRecipient)
expectedSubject, err := expectedTmpl.EmailSubject(ctx)
require.NoError(t, err)
actualSubject, err := actualTmpl.EmailSubject(ctx)
require.NoError(t, err)
require.Equal(t, expectedSubject, actualSubject)
expectedBody, err := expectedTmpl.EmailBody(ctx)
require.NoError(t, err)
actualBody, err := actualTmpl.EmailBody(ctx)
require.NoError(t, err)
require.Equal(t, expectedBody, actualBody)
expectedBodyPlaintext, err := expectedTmpl.EmailBodyPlaintext(ctx)
require.NoError(t, err)
actualBodyPlaintext, err := actualTmpl.EmailBodyPlaintext(ctx)
require.NoError(t, err)
require.Equal(t, expectedBodyPlaintext, actualBodyPlaintext)
})
}
}

View File

@ -1,88 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package courier
import (
"context"
"encoding/json"
"fmt"
"github.com/ory/kratos/request"
"github.com/ory/x/otelx"
)
type httpDataModel struct {
Recipient string
Subject string
Body string
TemplateType TemplateType
TemplateData EmailTemplate
}
type httpClient struct {
RequestConfig json.RawMessage
}
func newHTTP(ctx context.Context, deps Dependencies) *httpClient {
return &httpClient{
RequestConfig: deps.CourierConfig().CourierEmailRequestConfig(ctx),
}
}
func (c *courier) dispatchMailerEmail(ctx context.Context, msg Message) (err error) {
ctx, span := c.deps.Tracer(ctx).Tracer().Start(ctx, "courier.http.dispatchMailerEmail")
defer otelx.End(span, &err)
builder, err := request.NewBuilder(ctx, c.httpClient.RequestConfig, c.deps)
if err != nil {
return err
}
tmpl, err := c.smtpClient.NewTemplateFromMessage(c.deps, msg)
if err != nil {
return err
}
td := httpDataModel{
Recipient: msg.Recipient,
Subject: msg.Subject,
Body: msg.Body,
TemplateType: msg.TemplateType,
TemplateData: tmpl,
}
req, err := builder.BuildRequest(ctx, td)
if err != nil {
return err
}
res, err := c.deps.HTTPClient(ctx).Do(req)
if err != nil {
return err
}
defer res.Body.Close()
if res.StatusCode >= 200 && res.StatusCode < 300 {
c.deps.Logger().
WithField("message_id", msg.ID).
WithField("message_type", msg.Type).
WithField("message_template_type", msg.TemplateType).
WithField("message_subject", msg.Subject).
Debug("Courier sent out mailer.")
return nil
}
err = fmt.Errorf(
"unable to dispatch mail delivery because upstream server replied with status code %d",
res.StatusCode,
)
c.deps.Logger().
WithField("message_id", msg.ID).
WithField("message_type", msg.Type).
WithField("message_template_type", msg.TemplateType).
WithField("message_subject", msg.Subject).
WithError(err).
Error("sending mail via HTTP failed.")
return err
}

124
courier/http_channel.go Normal file
View File

@ -0,0 +1,124 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package courier
import (
"context"
"encoding/json"
"fmt"
"github.com/pkg/errors"
"github.com/ory/kratos/courier/template"
"github.com/ory/kratos/request"
"github.com/ory/kratos/x"
"github.com/ory/x/jsonnetsecure"
"github.com/ory/x/otelx"
)
type (
httpChannel struct {
id string
requestConfig json.RawMessage
d channelDependencies
}
channelDependencies interface {
x.TracingProvider
x.LoggingProvider
x.HTTPClientProvider
jsonnetsecure.VMProvider
ConfigProvider
}
)
var _ Channel = new(httpChannel)
func newHttpChannel(id string, requestConfig json.RawMessage, d channelDependencies) *httpChannel {
return &httpChannel{
id: id,
requestConfig: requestConfig,
d: d,
}
}
func (c *httpChannel) ID() string {
return c.id
}
type httpDataModel struct {
Recipient string
Subject string
Body string
TemplateType template.TemplateType
TemplateData Template
MessageType string
}
func (c *httpChannel) Dispatch(ctx context.Context, msg Message) (err error) {
ctx, span := c.d.Tracer(ctx).Tracer().Start(ctx, "courier.httpChannel.Dispatch")
defer otelx.End(span, &err)
builder, err := request.NewBuilder(ctx, c.requestConfig, c.d)
if err != nil {
return err
}
tmpl, err := newTemplate(c.d, msg)
if err != nil {
return err
}
td := httpDataModel{
Recipient: msg.Recipient,
Subject: msg.Subject,
Body: msg.Body,
TemplateType: msg.TemplateType,
TemplateData: tmpl,
MessageType: msg.Type.String(),
}
req, err := builder.BuildRequest(ctx, td)
if err != nil {
return err
}
res, err := c.d.HTTPClient(ctx).Do(req)
if err != nil {
return err
}
if res.StatusCode >= 200 && res.StatusCode < 300 {
c.d.Logger().
WithField("message_id", msg.ID).
WithField("message_type", msg.Type).
WithField("message_template_type", msg.TemplateType).
WithField("message_subject", msg.Subject).
Debug("Courier sent out mailer.")
return nil
}
err = errors.Errorf(
"unable to dispatch mail delivery because upstream server replied with status code %d",
res.StatusCode,
)
c.d.Logger().
WithField("message_id", msg.ID).
WithField("message_type", msg.Type).
WithField("message_template_type", msg.TemplateType).
WithField("message_subject", msg.Subject).
WithError(err).
Error("sending mail via HTTP failed.")
return err
}
func newTemplate(d template.Dependencies, msg Message) (Template, error) {
switch msg.Type {
case MessageTypeEmail:
return NewEmailTemplateFromMessage(d, msg)
case MessageTypeSMS:
return NewSMSTemplateFromMessage(d, msg)
default:
return nil, fmt.Errorf("received unexpected message type: %s", msg.Type)
}
}

View File

@ -12,7 +12,9 @@ import (
"github.com/pkg/errors"
"github.com/ory/herodot"
"github.com/ory/kratos/courier/template"
"github.com/ory/x/pagination/keysetpagination"
"github.com/ory/x/sqlxx"
"github.com/ory/x/stringsx"
)
@ -88,7 +90,6 @@ func (ms *MessageStatus) UnmarshalJSON(data []byte) error {
}
s, err := ToMessageStatus(str)
if err != nil {
return err
}
@ -106,12 +107,12 @@ type MessageType int
const (
MessageTypeEmail MessageType = iota + 1
MessageTypePhone
MessageTypeSMS
)
const (
messageTypeEmailText = "email"
messageTypePhoneText = "phone"
messageTypeSMSText = "sms"
)
// The format we need to use in the Page tokens, as it's the only format that is understood by all DBs
@ -121,8 +122,8 @@ func ToMessageType(str string) (MessageType, error) {
switch s := stringsx.SwitchExact(str); {
case s.AddCase(messageTypeEmailText):
return MessageTypeEmail, nil
case s.AddCase(messageTypePhoneText):
return MessageTypePhone, nil
case s.AddCase(messageTypeSMSText):
return MessageTypeSMS, nil
default:
return 0, errors.WithStack(herodot.ErrBadRequest.WithWrap(s.ToUnknownCaseErr()).WithReason("Message type is not valid"))
}
@ -132,8 +133,8 @@ func (mt MessageType) String() string {
switch mt {
case MessageTypeEmail:
return messageTypeEmailText
case MessageTypePhone:
return messageTypePhoneText
case MessageTypeSMS:
return messageTypeSMSText
default:
return ""
}
@ -141,7 +142,7 @@ func (mt MessageType) String() string {
func (mt MessageType) IsValid() error {
switch mt {
case MessageTypeEmail, MessageTypePhone:
case MessageTypeEmail, MessageTypeSMS:
return nil
default:
return errors.WithStack(herodot.ErrBadRequest.WithReason("Message type is not valid"))
@ -187,7 +188,9 @@ type Message struct {
// required: true
Subject string `json:"subject" db:"subject"`
// required: true
TemplateType TemplateType `json:"template_type" db:"template_type"`
TemplateType template.TemplateType `json:"template_type" db:"template_type"`
Channel sqlxx.NullString `json:"channel" db:"channel"`
TemplateData []byte `json:"-" db:"template_data"`
// required: true

View File

@ -46,7 +46,7 @@ func TestToMessageType(t *testing.T) {
t.Run("case=should return corresponding MessageType for given str", func(t *testing.T) {
for str, exp := range map[string]courier.MessageType{
"email": courier.MessageTypeEmail,
"phone": courier.MessageTypePhone,
"sms": courier.MessageTypeSMS,
} {
result, err := courier.ToMessageType(str)
require.NoError(t, err)

View File

@ -6,60 +6,34 @@ package courier
import (
"context"
"encoding/json"
"net/http"
"github.com/pkg/errors"
"github.com/ory/herodot"
"github.com/gofrs/uuid"
"github.com/ory/kratos/request"
)
type sendSMSRequestBody struct {
From string `json:"from"`
To string `json:"to"`
Body string `json:"body"`
}
type smsClient struct {
RequestConfig json.RawMessage
GetTemplateType func(t SMSTemplate) (TemplateType, error)
NewTemplateFromMessage func(d Dependencies, msg Message) (SMSTemplate, error)
}
func newSMS(ctx context.Context, deps Dependencies) *smsClient {
return &smsClient{
RequestConfig: deps.CourierConfig().CourierSMSRequestConfig(ctx),
GetTemplateType: SMSTemplateType,
NewTemplateFromMessage: NewSMSTemplateFromMessage,
}
}
func (c *courier) QueueSMS(ctx context.Context, t SMSTemplate) (uuid.UUID, error) {
recipient, err := t.PhoneNumber()
if err != nil {
return uuid.Nil, err
}
templateType, err := c.smsClient.GetTemplateType(t)
if err != nil {
return uuid.Nil, err
}
templateData, err := json.Marshal(t)
if err != nil {
return uuid.Nil, err
}
body, err := t.SMSBody(ctx)
if err != nil {
return uuid.Nil, err
}
message := &Message{
Status: MessageStatusQueued,
Type: MessageTypePhone,
Type: MessageTypeSMS,
Channel: "sms",
Recipient: recipient,
TemplateType: templateType,
TemplateType: t.TemplateType(),
TemplateData: templateData,
Body: body,
}
if err := c.deps.CourierPersister().AddMessage(ctx, message); err != nil {
return uuid.Nil, err
@ -67,49 +41,3 @@ func (c *courier) QueueSMS(ctx context.Context, t SMSTemplate) (uuid.UUID, error
return message.ID, nil
}
func (c *courier) dispatchSMS(ctx context.Context, msg Message) error {
if !c.deps.CourierConfig().CourierSMSEnabled(ctx) {
return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Courier tried to deliver an sms but courier.sms.enabled is set to false!"))
}
tmpl, err := c.smsClient.NewTemplateFromMessage(c.deps, msg)
if err != nil {
return err
}
body, err := tmpl.SMSBody(ctx)
if err != nil {
return err
}
builder, err := request.NewBuilder(ctx, c.smsClient.RequestConfig, c.deps)
if err != nil {
return err
}
req, err := builder.BuildRequest(ctx, &sendSMSRequestBody{
To: msg.Recipient,
From: c.deps.CourierConfig().CourierSMSFrom(ctx),
Body: body,
})
if err != nil {
return err
}
res, err := c.deps.HTTPClient(ctx).Do(req)
if err != nil {
return err
}
defer res.Body.Close()
switch res.StatusCode {
case http.StatusOK:
case http.StatusCreated:
default:
return errors.New(http.StatusText(res.StatusCode))
}
return nil
}

View File

@ -9,6 +9,7 @@ import (
"github.com/pkg/errors"
"github.com/ory/kratos/courier/template"
"github.com/ory/kratos/courier/template/sms"
)
@ -16,33 +17,24 @@ type SMSTemplate interface {
json.Marshaler
SMSBody(context.Context) (string, error)
PhoneNumber() (string, error)
TemplateType() template.TemplateType
}
func SMSTemplateType(t SMSTemplate) (TemplateType, error) {
switch t.(type) {
case *sms.OTPMessage:
return TypeOTP, nil
case *sms.TestStub:
return TypeTestStub, nil
default:
return "", errors.Errorf("unexpected template type")
}
}
func NewSMSTemplateFromMessage(d Dependencies, m Message) (SMSTemplate, error) {
func NewSMSTemplateFromMessage(d template.Dependencies, m Message) (SMSTemplate, error) {
switch m.TemplateType {
case TypeOTP:
var t sms.OTPMessageModel
case template.TypeVerificationCodeValid:
var t sms.VerificationCodeValidModel
if err := json.Unmarshal(m.TemplateData, &t); err != nil {
return nil, err
}
return sms.NewOTPMessage(d, &t), nil
case TypeTestStub:
return sms.NewVerificationCodeValid(d, &t), nil
case template.TypeTestStub:
var t sms.TestStubModel
if err := json.Unmarshal(m.TemplateData, &t); err != nil {
return nil, err
}
return sms.NewTestStub(d, &t), nil
default:
return nil, errors.Errorf("received unexpected message template type: %s", m.TemplateType)
}

View File

@ -12,19 +12,18 @@ import (
"github.com/stretchr/testify/require"
"github.com/ory/kratos/courier"
"github.com/ory/kratos/courier/template"
"github.com/ory/kratos/courier/template/sms"
"github.com/ory/kratos/internal"
)
func TestSMSTemplateType(t *testing.T) {
for expectedType, tmpl := range map[courier.TemplateType]courier.SMSTemplate{
courier.TypeOTP: &sms.OTPMessage{},
courier.TypeTestStub: &sms.TestStub{},
for expectedType, tmpl := range map[template.TemplateType]courier.SMSTemplate{
template.TypeVerificationCodeValid: &sms.VerificationCodeValid{},
template.TypeTestStub: &sms.TestStub{},
} {
t.Run(fmt.Sprintf("case=%s", expectedType), func(t *testing.T) {
actualType, err := courier.SMSTemplateType(tmpl)
require.NoError(t, err)
require.Equal(t, expectedType, actualType)
require.Equal(t, expectedType, tmpl.TemplateType())
})
}
}
@ -33,9 +32,9 @@ func TestNewSMSTemplateFromMessage(t *testing.T) {
_, reg := internal.NewFastRegistryWithMocks(t)
ctx := context.Background()
for tmplType, expectedTmpl := range map[courier.TemplateType]courier.SMSTemplate{
courier.TypeOTP: sms.NewOTPMessage(reg, &sms.OTPMessageModel{To: "+12345678901"}),
courier.TypeTestStub: sms.NewTestStub(reg, &sms.TestStubModel{To: "+12345678901", Body: "test body"}),
for tmplType, expectedTmpl := range map[template.TemplateType]courier.SMSTemplate{
template.TypeVerificationCodeValid: sms.NewVerificationCodeValid(reg, &sms.VerificationCodeValidModel{To: "+12345678901"}),
template.TypeTestStub: sms.NewTestStub(reg, &sms.TestStubModel{To: "+12345678901", Body: "test body"}),
} {
t.Run(fmt.Sprintf("case=%s", tmplType), func(t *testing.T) {
tmplData, err := json.Marshal(expectedTmpl)

View File

@ -14,7 +14,6 @@ import (
"time"
"github.com/gofrs/uuid"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -22,7 +21,6 @@ import (
"github.com/ory/kratos/courier/template/sms"
"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/internal"
"github.com/ory/x/resilience"
)
func TestQueueSMS(t *testing.T) {
@ -80,9 +78,13 @@ func TestQueueSMS(t *testing.T) {
}`, srv.URL)
conf, reg := internal.NewFastRegistryWithMocks(t)
conf.MustSet(ctx, config.ViperKeyCourierSMSRequestConfig, requestConfig)
conf.MustSet(ctx, config.ViperKeyCourierSMSFrom, expectedSender)
conf.MustSet(ctx, config.ViperKeyCourierSMSEnabled, true)
conf.MustSet(ctx, config.ViperKeyCourierChannels, fmt.Sprintf(`[
{
"id": "sms",
"type": "http",
"request_config": %s
}
]`, requestConfig))
conf.MustSet(ctx, config.ViperKeyCourierSMTPURL, "http://foo.url")
reg.Logger().Level = logrus.TraceLevel
@ -98,16 +100,11 @@ func TestQueueSMS(t *testing.T) {
require.NotEqual(t, uuid.Nil, id)
}
go func() {
require.NoError(t, c.Work(ctx))
}()
require.NoError(t, c.DispatchQueue(ctx))
require.NoError(t, resilience.Retry(reg.Logger(), time.Millisecond*250, time.Second*10, func() error {
if len(actual) == len(expectedSMS) {
return nil
}
return errors.New("capacity not reached")
}))
require.Eventually(t, func() bool {
return len(actual) == len(expectedSMS)
}, 10*time.Second, 250*time.Millisecond)
for i, message := range actual {
expected := expectedSMS[i]
@ -123,15 +120,19 @@ func TestDisallowedInternalNetwork(t *testing.T) {
ctx := context.Background()
conf, reg := internal.NewFastRegistryWithMocks(t)
conf.MustSet(ctx, config.ViperKeyCourierSMSRequestConfig, `{
"url": "http://127.0.0.1/",
"method": "GET",
"body": "file://./stub/request.config.twilio.jsonnet"
}`)
conf.MustSet(ctx, config.ViperKeyCourierSMSEnabled, true)
conf.MustSet(ctx, config.ViperKeyCourierChannels, `[
{
"id": "sms",
"type": "http",
"request_config": {
"url": "http://127.0.0.1/",
"method": "GET",
"body": "file://./stub/request.config.twilio.jsonnet"
}
}
]`)
conf.MustSet(ctx, config.ViperKeyCourierSMTPURL, "http://foo.url")
conf.MustSet(ctx, config.ViperKeyClientHTTPNoPrivateIPRanges, true)
reg.Logger().Level = logrus.TraceLevel
c, err := reg.Courier(ctx)
require.NoError(t, err)

View File

@ -7,42 +7,33 @@ import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"net/mail"
"net/textproto"
"net/url"
"strconv"
"time"
"github.com/ory/kratos/courier/template"
"github.com/ory/herodot"
"github.com/ory/kratos/driver/config"
"github.com/gofrs/uuid"
"github.com/pkg/errors"
"github.com/ory/herodot"
gomail "github.com/ory/mail/v3"
)
type smtpClient struct {
type SMTPClient struct {
*gomail.Dialer
GetTemplateType func(t EmailTemplate) (TemplateType, error)
NewTemplateFromMessage func(d template.Dependencies, msg Message) (EmailTemplate, error)
}
func newSMTP(ctx context.Context, deps Dependencies) (*smtpClient, error) {
uri, err := deps.CourierConfig().CourierSMTPURL(ctx)
func NewSMTPClient(deps Dependencies, cfg *config.SMTPConfig) (*SMTPClient, error) {
uri, err := url.Parse(cfg.ConnectionURI)
if err != nil {
return nil, err
return nil, herodot.ErrInternalServerError.WithError(err.Error())
}
var tlsCertificates []tls.Certificate
clientCertPath := deps.CourierConfig().CourierSMTPClientCertPath(ctx)
clientKeyPath := deps.CourierConfig().CourierSMTPClientKeyPath(ctx)
if clientCertPath != "" && clientKeyPath != "" {
clientCert, err := tls.LoadX509KeyPair(clientCertPath, clientKeyPath)
if cfg.ClientCertPath != "" && cfg.ClientKeyPath != "" {
clientCert, err := tls.LoadX509KeyPair(cfg.ClientCertPath, cfg.ClientKeyPath)
if err == nil {
tlsCertificates = append(tlsCertificates, clientCert)
} else {
@ -52,7 +43,6 @@ func newSMTP(ctx context.Context, deps Dependencies) (*smtpClient, error) {
}
}
localName := deps.CourierConfig().CourierSMTPLocalName(ctx)
password, _ := uri.User.Password()
port, _ := strconv.ParseInt(uri.Port(), 10, 0)
@ -61,7 +51,7 @@ func newSMTP(ctx context.Context, deps Dependencies) (*smtpClient, error) {
Port: int(port),
Username: uri.User.Username(),
Password: password,
LocalName: localName,
LocalName: cfg.LocalName,
Timeout: time.Second * 10,
RetryFailure: true,
@ -94,26 +84,11 @@ func newSMTP(ctx context.Context, deps Dependencies) (*smtpClient, error) {
dialer.SSL = true
}
return &smtpClient{
return &SMTPClient{
Dialer: dialer,
GetTemplateType: GetEmailTemplateType,
NewTemplateFromMessage: NewEmailTemplateFromMessage,
}, nil
}
func (c *courier) SetGetEmailTemplateType(f func(t EmailTemplate) (TemplateType, error)) {
c.smtpClient.GetTemplateType = f
}
func (c *courier) SetNewEmailTemplateFromMessage(f func(d template.Dependencies, msg Message) (EmailTemplate, error)) {
c.smtpClient.NewTemplateFromMessage = f
}
func (c *courier) SmtpDialer() *gomail.Dialer {
return c.smtpClient.Dialer
}
func (c *courier) QueueEmail(ctx context.Context, t EmailTemplate) (uuid.UUID, error) {
recipient, err := t.EmailRecipient()
if err != nil {
@ -133,11 +108,6 @@ func (c *courier) QueueEmail(ctx context.Context, t EmailTemplate) (uuid.UUID, e
return uuid.Nil, err
}
templateType, err := c.smtpClient.GetTemplateType(t)
if err != nil {
return uuid.Nil, err
}
templateData, err := json.Marshal(t)
if err != nil {
return uuid.Nil, err
@ -146,10 +116,11 @@ func (c *courier) QueueEmail(ctx context.Context, t EmailTemplate) (uuid.UUID, e
message := &Message{
Status: MessageStatusQueued,
Type: MessageTypeEmail,
Channel: "email",
Recipient: recipient,
Body: bodyPlaintext,
Subject: subject,
TemplateType: templateType,
TemplateType: t.TemplateType(),
TemplateData: templateData,
}
@ -159,95 +130,3 @@ func (c *courier) QueueEmail(ctx context.Context, t EmailTemplate) (uuid.UUID, e
return message.ID, nil
}
func (c *courier) dispatchEmail(ctx context.Context, msg Message) error {
if c.deps.CourierConfig().CourierEmailStrategy(ctx) == "http" {
return c.dispatchMailerEmail(ctx, msg)
}
if c.smtpClient.Host == "" {
return errors.WithStack(herodot.ErrInternalServerError.WithErrorf("Courier tried to deliver an email but %s is not set!", config.ViperKeyCourierSMTPURL))
}
from := c.deps.CourierConfig().CourierSMTPFrom(ctx)
fromName := c.deps.CourierConfig().CourierSMTPFromName(ctx)
gm := gomail.NewMessage()
if fromName == "" {
gm.SetHeader("From", from)
} else {
gm.SetAddressHeader("From", from, fromName)
}
gm.SetHeader("To", msg.Recipient)
gm.SetHeader("Subject", msg.Subject)
headers := c.deps.CourierConfig().CourierSMTPHeaders(ctx)
for k, v := range headers {
gm.SetHeader(k, v)
}
gm.SetBody("text/plain", msg.Body)
tmpl, err := c.smtpClient.NewTemplateFromMessage(c.deps, msg)
if err != nil {
c.deps.Logger().
WithError(err).
WithField("message_id", msg.ID).
WithField("message_nid", msg.NID).
Error(`Unable to get email template from message.`)
} else {
htmlBody, err := tmpl.EmailBody(ctx)
if err != nil {
c.deps.Logger().
WithError(err).
WithField("message_id", msg.ID).
WithField("message_nid", msg.NID).
Error(`Unable to get email body from template.`)
} else {
gm.AddAlternative("text/html", htmlBody)
}
}
if err := c.smtpClient.DialAndSend(ctx, gm); err != nil {
c.deps.Logger().
WithError(err).
WithField("smtp_server", fmt.Sprintf("%s:%d", c.smtpClient.Host, c.smtpClient.Port)).
WithField("smtp_ssl_enabled", c.smtpClient.SSL).
WithField("message_from", from).
WithField("message_id", msg.ID).
WithField("message_nid", msg.NID).
Error("Unable to send email using SMTP connection.")
var protoErr *textproto.Error
var mailErr *gomail.SendError
switch {
case errors.As(err, &mailErr) && errors.As(mailErr.Cause, &protoErr) && protoErr.Code >= 500:
fallthrough
case errors.As(err, &protoErr) && protoErr.Code >= 500:
// See https://en.wikipedia.org/wiki/List_of_SMTP_server_return_codes
// If the SMTP server responds with 5xx, sending the message should not be retried (without changing something about the request)
if err := c.deps.CourierPersister().SetMessageStatus(ctx, msg.ID, MessageStatusAbandoned); err != nil {
c.deps.Logger().
WithError(err).
WithField("message_id", msg.ID).
WithField("message_nid", msg.NID).
Error(`Unable to reset the retried message's status to "abandoned".`)
return err
}
}
return errors.WithStack(herodot.ErrInternalServerError.
WithError(err.Error()).WithReason("failed to send email via smtp"))
}
c.deps.Logger().
WithField("message_id", msg.ID).
WithField("message_nid", msg.NID).
WithField("message_type", msg.Type).
WithField("message_template_type", msg.TemplateType).
WithField("message_subject", msg.Subject).
Debug("Courier sent out message.")
return nil
}

144
courier/smtp_channel.go Normal file
View File

@ -0,0 +1,144 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package courier
import (
"context"
"fmt"
"net/textproto"
"github.com/pkg/errors"
"github.com/ory/herodot"
"github.com/ory/kratos/courier/template"
"github.com/ory/kratos/driver/config"
"github.com/ory/mail/v3"
)
type (
SMTPChannel struct {
smtpClient *SMTPClient
d Dependencies
newEmailTemplateFromMessage func(d template.Dependencies, msg Message) (EmailTemplate, error)
}
)
var _ Channel = new(SMTPChannel)
func NewSMTPChannel(deps Dependencies, cfg *config.SMTPConfig) (*SMTPChannel, error) {
return NewSMTPChannelWithCustomTemplates(deps, cfg, NewEmailTemplateFromMessage)
}
func NewSMTPChannelWithCustomTemplates(deps Dependencies, cfg *config.SMTPConfig, newEmailTemplateFromMessage func(d template.Dependencies, msg Message) (EmailTemplate, error)) (*SMTPChannel, error) {
smtpClient, err := NewSMTPClient(deps, cfg)
if err != nil {
return nil, err
}
return &SMTPChannel{
smtpClient: smtpClient,
d: deps,
newEmailTemplateFromMessage: newEmailTemplateFromMessage,
}, nil
}
func (c *SMTPChannel) ID() string {
return "email"
}
func (c *SMTPChannel) Dispatch(ctx context.Context, msg Message) error {
if c.smtpClient.Host == "" {
return errors.WithStack(herodot.ErrInternalServerError.WithErrorf("Courier tried to deliver an email but %s is not set!", config.ViperKeyCourierSMTPURL))
}
channels, err := c.d.CourierConfig().CourierChannels(ctx)
if err != nil {
return err
}
var cfg *config.SMTPConfig
for _, channel := range channels {
if channel.ID == "email" && channel.SMTPConfig != nil {
cfg = channel.SMTPConfig
break
}
}
gm := mail.NewMessage()
if cfg.FromName == "" {
gm.SetHeader("From", cfg.FromAddress)
} else {
gm.SetAddressHeader("From", cfg.FromAddress, cfg.FromName)
}
gm.SetHeader("To", msg.Recipient)
gm.SetHeader("Subject", msg.Subject)
headers := cfg.Headers
for k, v := range headers {
gm.SetHeader(k, v)
}
gm.SetBody("text/plain", msg.Body)
tmpl, err := c.newEmailTemplateFromMessage(c.d, msg)
if err != nil {
c.d.Logger().
WithError(err).
WithField("message_id", msg.ID).
WithField("message_nid", msg.NID).
Error(`Unable to get email template from message.`)
} else if htmlBody, err := tmpl.EmailBody(ctx); err != nil {
c.d.Logger().
WithError(err).
WithField("message_id", msg.ID).
WithField("message_nid", msg.NID).
Error(`Unable to get email body from template.`)
} else {
gm.AddAlternative("text/html", htmlBody)
}
if err := c.smtpClient.DialAndSend(ctx, gm); err != nil {
c.d.Logger().
WithError(err).
WithField("smtp_server", fmt.Sprintf("%s:%d", c.smtpClient.Host, c.smtpClient.Port)).
WithField("smtp_ssl_enabled", c.smtpClient.SSL).
WithField("message_from", cfg.FromAddress).
WithField("message_id", msg.ID).
WithField("message_nid", msg.NID).
Error("Unable to send email using SMTP connection.")
var protoErr *textproto.Error
var mailErr *mail.SendError
switch {
case errors.As(err, &mailErr) && errors.As(mailErr.Cause, &protoErr) && protoErr.Code >= 500:
fallthrough
case errors.As(err, &protoErr) && protoErr.Code >= 500:
// See https://en.wikipedia.org/wiki/List_of_SMTP_server_return_codes
// If the SMTP server responds with 5xx, sending the message should not be retried (without changing something about the request)
if err := c.d.CourierPersister().SetMessageStatus(ctx, msg.ID, MessageStatusAbandoned); err != nil {
c.d.Logger().
WithError(err).
WithField("message_id", msg.ID).
WithField("message_nid", msg.NID).
Error(`Unable to reset the retried message's status to "abandoned".`)
return err
}
}
return errors.WithStack(herodot.ErrInternalServerError.
WithError(err.Error()).WithReason("failed to send email via smtp"))
}
c.d.Logger().
WithField("message_id", msg.ID).
WithField("message_nid", msg.NID).
WithField("message_type", msg.Type).
WithField("message_template_type", msg.TemplateType).
WithField("message_subject", msg.Subject).
Debug("Courier sent out message.")
return nil
}

View File

@ -39,13 +39,13 @@ func TestNewSMTP(t *testing.T) {
ctx := context.Background()
conf, reg := internal.NewFastRegistryWithMocks(t)
setupCourier := func(stringURL string) courier.Courier {
setupSMTPClient := func(stringURL string) *courier.SMTPClient {
conf.MustSet(ctx, config.ViperKeyCourierSMTPURL, stringURL)
u, err := conf.CourierSMTPURL(ctx)
require.NoError(t, err)
t.Logf("SMTP URL: %s", u.String())
c, err := courier.NewCourier(ctx, reg)
channels, err := conf.CourierChannels(ctx)
require.NoError(t, err)
require.Len(t, channels, 1)
c, err := courier.NewSMTPClient(reg, channels[0].SMTPConfig)
require.NoError(t, err)
return c
}
@ -55,18 +55,18 @@ func TestNewSMTP(t *testing.T) {
}
// Should enforce StartTLS => dialer.StartTLSPolicy = gomail.MandatoryStartTLS and dialer.SSL = false
smtp := setupCourier("smtp://foo:bar@my-server:1234/")
assert.Equal(t, smtp.SmtpDialer().StartTLSPolicy, gomail.MandatoryStartTLS, "StartTLS not enforced")
assert.Equal(t, smtp.SmtpDialer().SSL, false, "Implicit TLS should not be enabled")
smtp := setupSMTPClient("smtp://foo:bar@my-server:1234/")
assert.Equal(t, smtp.StartTLSPolicy, gomail.MandatoryStartTLS, "StartTLS not enforced")
assert.Equal(t, smtp.SSL, false, "Implicit TLS should not be enabled")
// Should enforce TLS => dialer.SSL = true
smtp = setupCourier("smtps://foo:bar@my-server:1234/")
assert.Equal(t, smtp.SmtpDialer().SSL, true, "Implicit TLS should be enabled")
smtp = setupSMTPClient("smtps://foo:bar@my-server:1234/")
assert.Equal(t, smtp.SSL, true, "Implicit TLS should be enabled")
// Should allow cleartext => dialer.StartTLSPolicy = gomail.OpportunisticStartTLS and dialer.SSL = false
smtp = setupCourier("smtp://foo:bar@my-server:1234/?disable_starttls=true")
assert.Equal(t, smtp.SmtpDialer().StartTLSPolicy, gomail.OpportunisticStartTLS, "StartTLS is enforced")
assert.Equal(t, smtp.SmtpDialer().SSL, false, "Implicit TLS should not be enabled")
smtp = setupSMTPClient("smtp://foo:bar@my-server:1234/?disable_starttls=true")
assert.Equal(t, smtp.StartTLSPolicy, gomail.OpportunisticStartTLS, "StartTLS is enforced")
assert.Equal(t, smtp.SSL, false, "Implicit TLS should not be enabled")
// Test cert based SMTP client auth
clientCert, clientKey, err := generateTestClientCert()
@ -80,17 +80,17 @@ func TestNewSMTP(t *testing.T) {
clientPEM, err := tls.LoadX509KeyPair(clientCert.Name(), clientKey.Name())
require.NoError(t, err)
smtpWithCert := setupCourier("smtps://subdomain.my-server:1234/?server_name=my-server")
assert.Equal(t, smtpWithCert.SmtpDialer().SSL, true, "Implicit TLS should be enabled")
assert.Equal(t, smtpWithCert.SmtpDialer().Host, "subdomain.my-server", "SMTP Dialer host should match")
assert.Equal(t, smtpWithCert.SmtpDialer().TLSConfig.ServerName, "my-server", "TLS config server name should match")
assert.Equal(t, smtpWithCert.SmtpDialer().TLSConfig.ServerName, "my-server", "TLS config server name should match")
assert.Contains(t, smtpWithCert.SmtpDialer().TLSConfig.Certificates, clientPEM, "TLS config should contain client pem")
smtpWithCert := setupSMTPClient("smtps://subdomain.my-server:1234/?server_name=my-server")
assert.Equal(t, smtpWithCert.SSL, true, "Implicit TLS should be enabled")
assert.Equal(t, smtpWithCert.Host, "subdomain.my-server", "SMTP Dialer host should match")
assert.Equal(t, smtpWithCert.TLSConfig.ServerName, "my-server", "TLS config server name should match")
assert.Equal(t, smtpWithCert.TLSConfig.ServerName, "my-server", "TLS config server name should match")
assert.Contains(t, smtpWithCert.TLSConfig.Certificates, clientPEM, "TLS config should contain client pem")
// error case: invalid client key
conf.Set(ctx, config.ViperKeyCourierSMTPClientKeyPath, clientCert.Name()) // mixup client key and client cert
smtpWithCert = setupCourier("smtps://subdomain.my-server:1234/?server_name=my-server")
assert.Equal(t, len(smtpWithCert.SmtpDialer().TLSConfig.Certificates), 0, "TLS config certificates should be empty")
smtpWithCert = setupSMTPClient("smtps://subdomain.my-server:1234/?server_name=my-server")
assert.Equal(t, len(smtpWithCert.TLSConfig.Certificates), 0, "TLS config certificates should be empty")
}
func TestQueueEmail(t *testing.T) {

View File

@ -1,5 +1,5 @@
function(ctx) {
from: ctx.from,
to: ctx.to,
body: ctx.body
from: "Kratos Test",
to: ctx.Recipient,
body: ctx.Body
}

View File

@ -0,0 +1 @@
Your verification code is: {{ .VerificationCode }}

View File

@ -49,3 +49,7 @@ func (t *LoginCodeValid) EmailBodyPlaintext(ctx context.Context) (string, error)
func (t *LoginCodeValid) MarshalJSON() ([]byte, error) {
return json.Marshal(t.model)
}
func (t *LoginCodeValid) TemplateType() template.TemplateType {
return template.TypeLoginCodeValid
}

View File

@ -7,7 +7,7 @@ import (
"context"
"testing"
"github.com/ory/kratos/courier"
"github.com/ory/kratos/courier/template"
"github.com/ory/kratos/courier/template/email"
"github.com/ory/kratos/courier/template/testhelpers"
"github.com/ory/kratos/internal"
@ -25,6 +25,6 @@ func TestLoginCodeValid(t *testing.T) {
})
t.Run("test=with remote resources", func(t *testing.T) {
testhelpers.TestRemoteTemplates(t, "../courier/builtin/templates/login_code/valid", courier.TypeLoginCodeValid)
testhelpers.TestRemoteTemplates(t, "../courier/builtin/templates/login_code/valid", template.TypeLoginCodeValid)
})
}

View File

@ -50,3 +50,7 @@ func (t *RecoveryCodeInvalid) EmailBodyPlaintext(ctx context.Context) (string, e
func (t *RecoveryCodeInvalid) MarshalJSON() ([]byte, error) {
return json.Marshal(t.model)
}
func (t *RecoveryCodeInvalid) TemplateType() template.TemplateType {
return template.TypeRecoveryCodeInvalid
}

View File

@ -7,7 +7,7 @@ import (
"context"
"testing"
"github.com/ory/kratos/courier"
"github.com/ory/kratos/courier/template"
"github.com/ory/kratos/courier/template/email"
"github.com/ory/kratos/courier/template/testhelpers"
"github.com/ory/kratos/internal"
@ -25,6 +25,6 @@ func TestRecoveryCodeInvalid(t *testing.T) {
})
t.Run("case=test remote resources", func(t *testing.T) {
testhelpers.TestRemoteTemplates(t, "../courier/builtin/templates/recovery_code/invalid", courier.TypeRecoveryCodeInvalid)
testhelpers.TestRemoteTemplates(t, "../courier/builtin/templates/recovery_code/invalid", template.TypeRecoveryCodeInvalid)
})
}

View File

@ -49,3 +49,7 @@ func (t *RecoveryCodeValid) EmailBodyPlaintext(ctx context.Context) (string, err
func (t *RecoveryCodeValid) MarshalJSON() ([]byte, error) {
return json.Marshal(t.model)
}
func (t *RecoveryCodeValid) TemplateType() template.TemplateType {
return template.TypeRecoveryCodeValid
}

View File

@ -7,7 +7,7 @@ import (
"context"
"testing"
"github.com/ory/kratos/courier"
"github.com/ory/kratos/courier/template"
"github.com/ory/kratos/courier/template/email"
"github.com/ory/kratos/courier/template/testhelpers"
"github.com/ory/kratos/internal"
@ -25,6 +25,6 @@ func TestRecoveryCodeValid(t *testing.T) {
})
t.Run("test=with remote resources", func(t *testing.T) {
testhelpers.TestRemoteTemplates(t, "../courier/builtin/templates/recovery_code/valid", courier.TypeRecoveryCodeValid)
testhelpers.TestRemoteTemplates(t, "../courier/builtin/templates/recovery_code/valid", template.TypeRecoveryCodeValid)
})
}

View File

@ -47,3 +47,7 @@ func (t *RecoveryInvalid) EmailBodyPlaintext(ctx context.Context) (string, error
func (t *RecoveryInvalid) MarshalJSON() ([]byte, error) {
return json.Marshal(t.m)
}
func (t *RecoveryInvalid) TemplateType() template.TemplateType {
return template.TypeRecoveryInvalid
}

View File

@ -7,7 +7,7 @@ import (
"context"
"testing"
"github.com/ory/kratos/courier"
"github.com/ory/kratos/courier/template"
"github.com/ory/kratos/courier/template/email"
"github.com/ory/kratos/courier/template/testhelpers"
"github.com/ory/kratos/internal"
@ -25,6 +25,6 @@ func TestRecoverInvalid(t *testing.T) {
})
t.Run("case=test remote resources", func(t *testing.T) {
testhelpers.TestRemoteTemplates(t, "../courier/builtin/templates/recovery/invalid", courier.TypeRecoveryInvalid)
testhelpers.TestRemoteTemplates(t, "../courier/builtin/templates/recovery/invalid", template.TypeRecoveryInvalid)
})
}

View File

@ -49,3 +49,7 @@ func (t *RecoveryValid) EmailBodyPlaintext(ctx context.Context) (string, error)
func (t *RecoveryValid) MarshalJSON() ([]byte, error) {
return json.Marshal(t.m)
}
func (t *RecoveryValid) TemplateType() template.TemplateType {
return template.TypeRecoveryValid
}

View File

@ -7,7 +7,7 @@ import (
"context"
"testing"
"github.com/ory/kratos/courier"
"github.com/ory/kratos/courier/template"
"github.com/ory/kratos/courier/template/email"
"github.com/ory/kratos/courier/template/testhelpers"
"github.com/ory/kratos/internal"
@ -25,6 +25,6 @@ func TestRecoverValid(t *testing.T) {
})
t.Run("test=with remote resources", func(t *testing.T) {
testhelpers.TestRemoteTemplates(t, "../courier/builtin/templates/recovery/valid", courier.TypeRecoveryValid)
testhelpers.TestRemoteTemplates(t, "../courier/builtin/templates/recovery/valid", template.TypeRecoveryValid)
})
}

View File

@ -49,3 +49,7 @@ func (t *RegistrationCodeValid) EmailBodyPlaintext(ctx context.Context) (string,
func (t *RegistrationCodeValid) MarshalJSON() ([]byte, error) {
return json.Marshal(t.model)
}
func (t *RegistrationCodeValid) TemplateType() template.TemplateType {
return template.TypeRegistrationCodeValid
}

View File

@ -7,7 +7,7 @@ import (
"context"
"testing"
"github.com/ory/kratos/courier"
"github.com/ory/kratos/courier/template"
"github.com/ory/kratos/courier/template/email"
"github.com/ory/kratos/courier/template/testhelpers"
"github.com/ory/kratos/internal"
@ -25,6 +25,6 @@ func TestRegistrationCodeValid(t *testing.T) {
})
t.Run("test=with remote resources", func(t *testing.T) {
testhelpers.TestRemoteTemplates(t, "../courier/builtin/templates/registration_code/valid", courier.TypeRegistrationCodeValid)
testhelpers.TestRemoteTemplates(t, "../courier/builtin/templates/registration_code/valid", template.TypeRegistrationCodeValid)
})
}

View File

@ -49,3 +49,7 @@ func (t *TestStub) EmailBodyPlaintext(ctx context.Context) (string, error) {
func (t *TestStub) MarshalJSON() ([]byte, error) {
return json.Marshal(t.m)
}
func (t *TestStub) TemplateType() template.TemplateType {
return template.TypeTestStub
}

View File

@ -71,3 +71,7 @@ func (t *VerificationCodeInvalid) EmailBodyPlaintext(ctx context.Context) (strin
func (t *VerificationCodeInvalid) MarshalJSON() ([]byte, error) {
return json.Marshal(t.m)
}
func (t *VerificationCodeInvalid) TemplateType() template.TemplateType {
return template.TypeVerificationCodeInvalid
}

View File

@ -7,7 +7,7 @@ import (
"context"
"testing"
"github.com/ory/kratos/courier"
"github.com/ory/kratos/courier/template"
"github.com/ory/kratos/courier/template/email"
"github.com/ory/kratos/courier/template/testhelpers"
"github.com/ory/kratos/internal"
@ -25,6 +25,6 @@ func TestVerifyCodeInvalid(t *testing.T) {
})
t.Run("test=with remote resources", func(t *testing.T) {
testhelpers.TestRemoteTemplates(t, "../courier/builtin/templates/verification_code/invalid", courier.TypeVerificationCodeInvalid)
testhelpers.TestRemoteTemplates(t, "../courier/builtin/templates/verification_code/invalid", template.TypeVerificationCodeInvalid)
})
}

View File

@ -72,3 +72,7 @@ func (t *VerificationCodeValid) EmailBodyPlaintext(ctx context.Context) (string,
func (t *VerificationCodeValid) MarshalJSON() ([]byte, error) {
return json.Marshal(t.m)
}
func (t *VerificationCodeValid) TemplateType() template.TemplateType {
return template.TypeVerificationCodeValid
}

View File

@ -7,7 +7,7 @@ import (
"context"
"testing"
"github.com/ory/kratos/courier"
"github.com/ory/kratos/courier/template"
"github.com/ory/kratos/courier/template/email"
"github.com/ory/kratos/courier/template/testhelpers"
"github.com/ory/kratos/internal"
@ -25,6 +25,6 @@ func TestVerifyCodeValid(t *testing.T) {
})
t.Run("test=with remote resources", func(t *testing.T) {
testhelpers.TestRemoteTemplates(t, "../courier/builtin/templates/verification_code/valid", courier.TypeVerificationCodeValid)
testhelpers.TestRemoteTemplates(t, "../courier/builtin/templates/verification_code/valid", template.TypeVerificationCodeValid)
})
}

View File

@ -47,3 +47,7 @@ func (t *VerificationInvalid) EmailBodyPlaintext(ctx context.Context) (string, e
func (t *VerificationInvalid) MarshalJSON() ([]byte, error) {
return json.Marshal(t.m)
}
func (t *VerificationInvalid) TemplateType() template.TemplateType {
return template.TypeVerificationInvalid
}

View File

@ -7,7 +7,7 @@ import (
"context"
"testing"
"github.com/ory/kratos/courier"
"github.com/ory/kratos/courier/template"
"github.com/ory/kratos/courier/template/email"
"github.com/ory/kratos/courier/template/testhelpers"
"github.com/ory/kratos/internal"
@ -26,7 +26,7 @@ func TestVerifyInvalid(t *testing.T) {
t.Run("test=with remote resources", func(t *testing.T) {
t.Run("test=with remote resources", func(t *testing.T) {
testhelpers.TestRemoteTemplates(t, "../courier/builtin/templates/verification/invalid", courier.TypeVerificationInvalid)
testhelpers.TestRemoteTemplates(t, "../courier/builtin/templates/verification/invalid", template.TypeVerificationInvalid)
})
})
}

View File

@ -49,3 +49,7 @@ func (t *VerificationValid) EmailBodyPlaintext(ctx context.Context) (string, err
func (t *VerificationValid) MarshalJSON() ([]byte, error) {
return json.Marshal(t.m)
}
func (t *VerificationValid) TemplateType() template.TemplateType {
return template.TypeVerificationValid
}

View File

@ -7,7 +7,7 @@ import (
"context"
"testing"
"github.com/ory/kratos/courier"
"github.com/ory/kratos/courier/template"
"github.com/ory/kratos/courier/template/email"
"github.com/ory/kratos/courier/template/testhelpers"
"github.com/ory/kratos/internal"
@ -25,6 +25,6 @@ func TestVerifyValid(t *testing.T) {
})
t.Run("test=with remote resources", func(t *testing.T) {
testhelpers.TestRemoteTemplates(t, "../courier/builtin/templates/verification/valid", courier.TypeVerificationValid)
testhelpers.TestRemoteTemplates(t, "../courier/builtin/templates/verification/valid", template.TypeVerificationValid)
})
}

View File

@ -1,41 +0,0 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package sms
import (
"context"
"encoding/json"
"os"
"github.com/ory/kratos/courier/template"
)
type (
OTPMessage struct {
d template.Dependencies
m *OTPMessageModel
}
OTPMessageModel struct {
To string
Code string
Identity map[string]interface{}
}
)
func NewOTPMessage(d template.Dependencies, m *OTPMessageModel) *OTPMessage {
return &OTPMessage{d: d, m: m}
}
func (t *OTPMessage) PhoneNumber() (string, error) {
return t.m.To, nil
}
func (t *OTPMessage) SMSBody(ctx context.Context) (string, error) {
return template.LoadText(ctx, t.d, os.DirFS(t.d.CourierConfig().CourierTemplatesRoot(ctx)), "otp/sms.body.gotmpl", "otp/sms.body*", t.m, "")
}
func (t *OTPMessage) MarshalJSON() ([]byte, error) {
return json.Marshal(t.m)
}

View File

@ -39,3 +39,7 @@ func (t *TestStub) SMSBody(ctx context.Context) (string, error) {
func (t *TestStub) MarshalJSON() ([]byte, error) {
return json.Marshal(t.m)
}
func (t *TestStub) TemplateType() template.TemplateType {
return template.TypeTestStub
}

View File

@ -0,0 +1,53 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package sms
import (
"context"
"encoding/json"
"os"
"github.com/ory/kratos/courier/template"
)
type (
VerificationCodeValid struct {
deps template.Dependencies
model *VerificationCodeValidModel
}
VerificationCodeValidModel struct {
To string
VerificationCode string
Identity map[string]interface{}
}
)
func NewVerificationCodeValid(d template.Dependencies, m *VerificationCodeValidModel) *VerificationCodeValid {
return &VerificationCodeValid{deps: d, model: m}
}
func (t *VerificationCodeValid) PhoneNumber() (string, error) {
return t.model.To, nil
}
func (t *VerificationCodeValid) SMSBody(ctx context.Context) (string, error) {
return template.LoadText(
ctx,
t.deps,
os.DirFS(t.deps.CourierConfig().CourierTemplatesRoot(ctx)),
"verification_code/valid/sms.body.gotmpl",
"verification_code/valid/sms.body*",
t.model,
t.deps.CourierConfig().CourierSMSTemplatesVerificationCodeValid(ctx).Body.PlainText,
)
}
func (t *VerificationCodeValid) MarshalJSON() ([]byte, error) {
return json.Marshal(t.model)
}
func (t *VerificationCodeValid) TemplateType() template.TemplateType {
return template.TypeVerificationCodeValid
}

View File

@ -23,7 +23,7 @@ func TestNewOTPMessage(t *testing.T) {
otp = "012345"
)
tpl := sms.NewOTPMessage(reg, &sms.OTPMessageModel{To: expectedPhone, Code: otp})
tpl := sms.NewVerificationCodeValid(reg, &sms.VerificationCodeValidModel{To: expectedPhone, VerificationCode: otp})
expectedBody := fmt.Sprintf("Your verification code is: %s\n", otp)

View File

@ -12,19 +12,7 @@ import (
"github.com/ory/x/httpx"
)
type (
Config interface {
CourierTemplatesRoot() string
CourierTemplatesVerificationInvalid() *config.CourierEmailTemplate
CourierTemplatesVerificationValid() *config.CourierEmailTemplate
CourierTemplatesRecoveryInvalid() *config.CourierEmailTemplate
CourierTemplatesRecoveryValid() *config.CourierEmailTemplate
CourierTemplatesLoginValid() *config.CourierEmailTemplate
CourierTemplatesRegistrationValid() *config.CourierEmailTemplate
}
Dependencies interface {
CourierConfig() config.CourierConfigs
HTTPClient(ctx context.Context, opts ...httpx.ResilientOptions) *retryablehttp.Client
}
)
type Dependencies interface {
CourierConfig() config.CourierConfigs
HTTPClient(ctx context.Context, opts ...httpx.ResilientOptions) *retryablehttp.Client
}

View File

@ -18,7 +18,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ory/kratos/courier"
"github.com/ory/kratos/courier/template"
"github.com/ory/kratos/driver"
"github.com/ory/kratos/driver/config"
@ -51,7 +50,7 @@ func TestRendered(t *testing.T, ctx context.Context, tpl interface {
assert.NotEmpty(t, rendered)
}
func TestRemoteTemplates(t *testing.T, basePath string, tmplType courier.TemplateType) {
func TestRemoteTemplates(t *testing.T, basePath string, tmplType template.TemplateType) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
@ -61,32 +60,32 @@ func TestRemoteTemplates(t *testing.T, basePath string, tmplType courier.Templat
return base64.StdEncoding.EncodeToString(f)
}
getTemplate := func(tmpl courier.TemplateType, d template.Dependencies) interface {
getTemplate := func(tmpl template.TemplateType, d template.Dependencies) interface {
EmailBody(context.Context) (string, error)
EmailSubject(context.Context) (string, error)
} {
switch tmpl {
case courier.TypeRecoveryInvalid:
case template.TypeRecoveryInvalid:
return email.NewRecoveryInvalid(d, &email.RecoveryInvalidModel{})
case courier.TypeRecoveryValid:
case template.TypeRecoveryValid:
return email.NewRecoveryValid(d, &email.RecoveryValidModel{})
case courier.TypeRecoveryCodeValid:
case template.TypeRecoveryCodeValid:
return email.NewRecoveryCodeValid(d, &email.RecoveryCodeValidModel{})
case courier.TypeRecoveryCodeInvalid:
case template.TypeRecoveryCodeInvalid:
return email.NewRecoveryCodeInvalid(d, &email.RecoveryCodeInvalidModel{})
case courier.TypeTestStub:
case template.TypeTestStub:
return email.NewTestStub(d, &email.TestStubModel{})
case courier.TypeVerificationInvalid:
case template.TypeVerificationInvalid:
return email.NewVerificationInvalid(d, &email.VerificationInvalidModel{})
case courier.TypeVerificationValid:
case template.TypeVerificationValid:
return email.NewVerificationValid(d, &email.VerificationValidModel{})
case courier.TypeVerificationCodeInvalid:
case template.TypeVerificationCodeInvalid:
return email.NewVerificationCodeInvalid(d, &email.VerificationCodeInvalidModel{})
case courier.TypeVerificationCodeValid:
case template.TypeVerificationCodeValid:
return email.NewVerificationCodeValid(d, &email.VerificationCodeValidModel{})
case courier.TypeLoginCodeValid:
case template.TypeLoginCodeValid:
return email.NewLoginCodeValid(d, &email.LoginCodeValidModel{})
case courier.TypeRegistrationCodeValid:
case template.TypeRegistrationCodeValid:
return email.NewRegistrationCodeValid(d, &email.RegistrationCodeValidModel{})
default:
return nil

23
courier/template/type.go Normal file
View File

@ -0,0 +1,23 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package template
// A Template's type
//
// swagger:enum TemplateType
type TemplateType string
const (
TypeRecoveryInvalid TemplateType = "recovery_invalid"
TypeRecoveryValid TemplateType = "recovery_valid"
TypeRecoveryCodeInvalid TemplateType = "recovery_code_invalid"
TypeRecoveryCodeValid TemplateType = "recovery_code_valid"
TypeVerificationInvalid TemplateType = "verification_invalid"
TypeVerificationValid TemplateType = "verification_valid"
TypeVerificationCodeInvalid TemplateType = "verification_code_invalid"
TypeVerificationCodeValid TemplateType = "verification_code_valid"
TypeTestStub TemplateType = "stub"
TypeLoginCodeValid TemplateType = "login_code_valid"
TypeRegistrationCodeValid TemplateType = "registration_code_valid"
)

View File

@ -15,8 +15,13 @@ import (
)
type (
EmailTemplate interface {
Template interface {
json.Marshaler
TemplateType() template.TemplateType
}
EmailTemplate interface {
Template
EmailSubject(context.Context) (string, error)
EmailBody(context.Context) (string, error)
EmailBodyPlaintext(context.Context) (string, error)
@ -24,118 +29,69 @@ type (
}
)
// A Template's type
//
// swagger:enum TemplateType
type TemplateType string
const (
TypeRecoveryInvalid TemplateType = "recovery_invalid"
TypeRecoveryValid TemplateType = "recovery_valid"
TypeRecoveryCodeInvalid TemplateType = "recovery_code_invalid"
TypeRecoveryCodeValid TemplateType = "recovery_code_valid"
TypeVerificationInvalid TemplateType = "verification_invalid"
TypeVerificationValid TemplateType = "verification_valid"
TypeVerificationCodeInvalid TemplateType = "verification_code_invalid"
TypeVerificationCodeValid TemplateType = "verification_code_valid"
TypeOTP TemplateType = "otp"
TypeTestStub TemplateType = "stub"
TypeLoginCodeValid TemplateType = "login_code_valid"
TypeRegistrationCodeValid TemplateType = "registration_code_valid"
)
func GetEmailTemplateType(t EmailTemplate) (TemplateType, error) {
switch t.(type) {
case *email.RecoveryInvalid:
return TypeRecoveryInvalid, nil
case *email.RecoveryValid:
return TypeRecoveryValid, nil
case *email.RecoveryCodeInvalid:
return TypeRecoveryCodeInvalid, nil
case *email.RecoveryCodeValid:
return TypeRecoveryCodeValid, nil
case *email.VerificationInvalid:
return TypeVerificationInvalid, nil
case *email.VerificationValid:
return TypeVerificationValid, nil
case *email.VerificationCodeInvalid:
return TypeVerificationCodeInvalid, nil
case *email.VerificationCodeValid:
return TypeVerificationCodeValid, nil
case *email.LoginCodeValid:
return TypeLoginCodeValid, nil
case *email.RegistrationCodeValid:
return TypeRegistrationCodeValid, nil
case *email.TestStub:
return TypeTestStub, nil
default:
return "", errors.Errorf("unexpected template type")
}
}
func NewEmailTemplateFromMessage(d template.Dependencies, msg Message) (EmailTemplate, error) {
switch msg.TemplateType {
case TypeRecoveryInvalid:
case template.TypeRecoveryInvalid:
var t email.RecoveryInvalidModel
if err := json.Unmarshal(msg.TemplateData, &t); err != nil {
return nil, err
}
return email.NewRecoveryInvalid(d, &t), nil
case TypeRecoveryValid:
case template.TypeRecoveryValid:
var t email.RecoveryValidModel
if err := json.Unmarshal(msg.TemplateData, &t); err != nil {
return nil, err
}
return email.NewRecoveryValid(d, &t), nil
case TypeRecoveryCodeInvalid:
case template.TypeRecoveryCodeInvalid:
var t email.RecoveryCodeInvalidModel
if err := json.Unmarshal(msg.TemplateData, &t); err != nil {
return nil, err
}
return email.NewRecoveryCodeInvalid(d, &t), nil
case TypeRecoveryCodeValid:
case template.TypeRecoveryCodeValid:
var t email.RecoveryCodeValidModel
if err := json.Unmarshal(msg.TemplateData, &t); err != nil {
return nil, err
}
return email.NewRecoveryCodeValid(d, &t), nil
case TypeVerificationInvalid:
case template.TypeVerificationInvalid:
var t email.VerificationInvalidModel
if err := json.Unmarshal(msg.TemplateData, &t); err != nil {
return nil, err
}
return email.NewVerificationInvalid(d, &t), nil
case TypeVerificationValid:
case template.TypeVerificationValid:
var t email.VerificationValidModel
if err := json.Unmarshal(msg.TemplateData, &t); err != nil {
return nil, err
}
return email.NewVerificationValid(d, &t), nil
case TypeVerificationCodeInvalid:
case template.TypeVerificationCodeInvalid:
var t email.VerificationCodeInvalidModel
if err := json.Unmarshal(msg.TemplateData, &t); err != nil {
return nil, err
}
return email.NewVerificationCodeInvalid(d, &t), nil
case TypeVerificationCodeValid:
case template.TypeVerificationCodeValid:
var t email.VerificationCodeValidModel
if err := json.Unmarshal(msg.TemplateData, &t); err != nil {
return nil, err
}
return email.NewVerificationCodeValid(d, &t), nil
case TypeTestStub:
case template.TypeTestStub:
var t email.TestStubModel
if err := json.Unmarshal(msg.TemplateData, &t); err != nil {
return nil, err
}
return email.NewTestStub(d, &t), nil
case TypeLoginCodeValid:
case template.TypeLoginCodeValid:
var t email.LoginCodeValidModel
if err := json.Unmarshal(msg.TemplateData, &t); err != nil {
return nil, err
}
return email.NewLoginCodeValid(d, &t), nil
case TypeRegistrationCodeValid:
case template.TypeRegistrationCodeValid:
var t email.RegistrationCodeValidModel
if err := json.Unmarshal(msg.TemplateData, &t); err != nil {
return nil, err

72
courier/templates_test.go Normal file
View File

@ -0,0 +1,72 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package courier_test
import (
"context"
"encoding/json"
"fmt"
"testing"
"github.com/stretchr/testify/require"
"github.com/ory/kratos/courier"
"github.com/ory/kratos/courier/template"
"github.com/ory/kratos/courier/template/email"
"github.com/ory/kratos/internal"
)
func TestNewEmailTemplateFromMessage(t *testing.T) {
_, reg := internal.NewFastRegistryWithMocks(t)
ctx := context.Background()
for tmplType, expectedTmpl := range map[template.TemplateType]courier.EmailTemplate{
template.TypeRecoveryInvalid: email.NewRecoveryInvalid(reg, &email.RecoveryInvalidModel{To: "foo"}),
template.TypeRecoveryValid: email.NewRecoveryValid(reg, &email.RecoveryValidModel{To: "bar", RecoveryURL: "http://foo.bar"}),
template.TypeRecoveryCodeValid: email.NewRecoveryCodeValid(reg, &email.RecoveryCodeValidModel{To: "bar", RecoveryCode: "12345678"}),
template.TypeRecoveryCodeInvalid: email.NewRecoveryCodeInvalid(reg, &email.RecoveryCodeInvalidModel{To: "bar"}),
template.TypeVerificationInvalid: email.NewVerificationInvalid(reg, &email.VerificationInvalidModel{To: "baz"}),
template.TypeVerificationValid: email.NewVerificationValid(reg, &email.VerificationValidModel{To: "faz", VerificationURL: "http://bar.foo"}),
template.TypeVerificationCodeInvalid: email.NewVerificationCodeInvalid(reg, &email.VerificationCodeInvalidModel{To: "baz"}),
template.TypeVerificationCodeValid: email.NewVerificationCodeValid(reg, &email.VerificationCodeValidModel{To: "faz", VerificationURL: "http://bar.foo", VerificationCode: "123456678"}),
template.TypeTestStub: email.NewTestStub(reg, &email.TestStubModel{To: "far", Subject: "test subject", Body: "test body"}),
template.TypeLoginCodeValid: email.NewLoginCodeValid(reg, &email.LoginCodeValidModel{To: "far", LoginCode: "123456"}),
template.TypeRegistrationCodeValid: email.NewRegistrationCodeValid(reg, &email.RegistrationCodeValidModel{To: "far", RegistrationCode: "123456"}),
} {
t.Run(fmt.Sprintf("case=%s", tmplType), func(t *testing.T) {
tmplData, err := json.Marshal(expectedTmpl)
require.NoError(t, err)
m := courier.Message{TemplateType: tmplType, TemplateData: tmplData}
actualTmpl, err := courier.NewEmailTemplateFromMessage(reg, m)
require.NoError(t, err)
require.IsType(t, expectedTmpl, actualTmpl)
expectedRecipient, err := expectedTmpl.EmailRecipient()
require.NoError(t, err)
actualRecipient, err := actualTmpl.EmailRecipient()
require.NoError(t, err)
require.Equal(t, expectedRecipient, actualRecipient)
expectedSubject, err := expectedTmpl.EmailSubject(ctx)
require.NoError(t, err)
actualSubject, err := actualTmpl.EmailSubject(ctx)
require.NoError(t, err)
require.Equal(t, expectedSubject, actualSubject)
expectedBody, err := expectedTmpl.EmailBody(ctx)
require.NoError(t, err)
actualBody, err := actualTmpl.EmailBody(ctx)
require.NoError(t, err)
require.Equal(t, expectedBody, actualBody)
expectedBodyPlaintext, err := expectedTmpl.EmailBodyPlaintext(ctx)
require.NoError(t, err)
actualBodyPlaintext, err := actualTmpl.EmailBodyPlaintext(ctx)
require.NoError(t, err)
require.Equal(t, expectedBodyPlaintext, actualBodyPlaintext)
})
}
}

View File

@ -66,20 +66,20 @@ const (
ViperKeyCourierTemplatesVerificationValidEmail = "courier.templates.verification.valid.email"
ViperKeyCourierTemplatesVerificationCodeInvalidEmail = "courier.templates.verification_code.invalid.email"
ViperKeyCourierTemplatesVerificationCodeValidEmail = "courier.templates.verification_code.valid.email"
ViperKeyCourierTemplatesVerificationCodeValidSMS = "courier.templates.verification_code.valid.sms"
ViperKeyCourierDeliveryStrategy = "courier.delivery_strategy"
ViperKeyCourierHTTPRequestConfig = "courier.http.request_config"
ViperKeyCourierTemplatesLoginCodeValidEmail = "courier.templates.login_code.valid.email"
ViperKeyCourierTemplatesRegistrationCodeValidEmail = "courier.templates.registration_code.valid.email"
ViperKeyCourierSMTP = "courier.smtp"
ViperKeyCourierSMTPFrom = "courier.smtp.from_address"
ViperKeyCourierSMTPFromName = "courier.smtp.from_name"
ViperKeyCourierSMTPHeaders = "courier.smtp.headers"
ViperKeyCourierSMTPLocalName = "courier.smtp.local_name"
ViperKeyCourierSMSRequestConfig = "courier.sms.request_config"
ViperKeyCourierSMSEnabled = "courier.sms.enabled"
ViperKeyCourierSMSFrom = "courier.sms.from"
ViperKeyCourierMessageRetries = "courier.message_retries"
ViperKeyCourierWorkerPullCount = "courier.worker.pull_count"
ViperKeyCourierWorkerPullWait = "courier.worker.pull_wait"
ViperKeyCourierChannels = "courier.channels"
ViperKeySecretsDefault = "secrets.default"
ViperKeySecretsCookie = "secrets.cookie"
ViperKeySecretsCipher = "secrets.cipher"
@ -258,6 +258,28 @@ type (
Body *CourierEmailBodyTemplate `json:"body"`
Subject string `json:"subject"`
}
CourierSMSTemplate struct {
Body *CourierSMSTemplateBody `json:"body"`
}
CourierSMSTemplateBody struct {
PlainText string `json:"plaintext"`
}
CourierChannel struct {
ID string `json:"id" koanf:"id"`
Type string `json:"type" koanf:"type"`
SMTPConfig *SMTPConfig `json:"smtp_config" koanf:"smtp_config"`
RequestConfig json.RawMessage `json:"request_config" koanf:"-"`
RequestConfigRaw map[string]any `json:"-" koanf:"request_config"`
}
SMTPConfig struct {
ConnectionURI string `json:"connection_uri" koanf:"connection_uri"`
ClientCertPath string `json:"client_cert_path" koanf:"client_cert_path"`
ClientKeyPath string `json:"client_key_path" koanf:"client_key_path"`
FromAddress string `json:"from_address" koanf:"from_address"`
FromName string `json:"from_name" koanf:"from_name"`
Headers map[string]string `json:"headers" koanf:"headers"`
LocalName string `json:"local_name" koanf:"local_name"`
}
Config struct {
l *logrusx.Logger
p *configx.Provider
@ -269,18 +291,6 @@ type (
Config() *Config
}
CourierConfigs interface {
CourierEmailStrategy(ctx context.Context) string
CourierEmailRequestConfig(ctx context.Context) json.RawMessage
CourierSMTPURL(ctx context.Context) (*url.URL, error)
CourierSMTPClientCertPath(ctx context.Context) string
CourierSMTPClientKeyPath(ctx context.Context) string
CourierSMTPFrom(ctx context.Context) string
CourierSMTPFromName(ctx context.Context) string
CourierSMTPHeaders(ctx context.Context) map[string]string
CourierSMTPLocalName(ctx context.Context) string
CourierSMSEnabled(ctx context.Context) bool
CourierSMSFrom(ctx context.Context) string
CourierSMSRequestConfig(ctx context.Context) json.RawMessage
CourierTemplatesRoot(ctx context.Context) string
CourierTemplatesVerificationInvalid(ctx context.Context) *CourierEmailTemplate
CourierTemplatesVerificationValid(ctx context.Context) *CourierEmailTemplate
@ -292,9 +302,11 @@ type (
CourierTemplatesVerificationCodeValid(ctx context.Context) *CourierEmailTemplate
CourierTemplatesLoginCodeValid(ctx context.Context) *CourierEmailTemplate
CourierTemplatesRegistrationCodeValid(ctx context.Context) *CourierEmailTemplate
CourierSMSTemplatesVerificationCodeValid(ctx context.Context) *CourierSMSTemplate
CourierMessageRetries(ctx context.Context) int
CourierWorkerPullCount(ctx context.Context) int
CourierWorkerPullWait(ctx context.Context) time.Duration
CourierChannels(context.Context) ([]*CourierChannel, error)
}
)
@ -884,15 +896,6 @@ func (p *Config) SelfAdminURL(ctx context.Context) *url.URL {
return p.baseURL(ctx, ViperKeyAdminBaseURL, ViperKeyAdminHost, ViperKeyAdminPort, 4434)
}
func (p *Config) CourierSMTPURL(ctx context.Context) (*url.URL, error) {
source := p.GetProvider(ctx).String(ViperKeyCourierSMTPURL)
parsed, err := url.Parse(source)
if err != nil {
return nil, errors.WithStack(herodot.ErrInternalServerError.WithReason("Unable to parse the project's SMTP URL. Please ensure that it is properly escaped: https://www.ory.sh/dr/3").WithDebugf("%s", err))
}
return parsed, nil
}
func (p *Config) OAuth2ProviderHeader(ctx context.Context) http.Header {
hh := map[string]string{}
if err := p.GetProvider(ctx).Unmarshal(ViperKeyOAuth2ProviderHeader, &hh); err != nil {
@ -1022,31 +1025,11 @@ func (p *Config) CourierEmailRequestConfig(ctx context.Context) json.RawMessage
return config
}
func (p *Config) CourierSMTPClientCertPath(ctx context.Context) string {
return p.GetProvider(ctx).StringF(ViperKeyCourierSMTPClientCertPath, "")
}
func (p *Config) CourierSMTPClientKeyPath(ctx context.Context) string {
return p.GetProvider(ctx).StringF(ViperKeyCourierSMTPClientKeyPath, "")
}
func (p *Config) CourierSMTPFrom(ctx context.Context) string {
return p.GetProvider(ctx).StringF(ViperKeyCourierSMTPFrom, "noreply@kratos.ory.sh")
}
func (p *Config) CourierSMTPFromName(ctx context.Context) string {
return p.GetProvider(ctx).StringF(ViperKeyCourierSMTPFromName, "")
}
func (p *Config) CourierSMTPLocalName(ctx context.Context) string {
return p.GetProvider(ctx).StringF(ViperKeyCourierSMTPLocalName, "localhost")
}
func (p *Config) CourierTemplatesRoot(ctx context.Context) string {
return p.GetProvider(ctx).StringF(ViperKeyCourierTemplatesPath, "courier/builtin/templates")
}
func (p *Config) CourierTemplatesHelper(ctx context.Context, key string) *CourierEmailTemplate {
func (p *Config) CourierEmailTemplatesHelper(ctx context.Context, key string) *CourierEmailTemplate {
courierTemplate := &CourierEmailTemplate{
Body: &CourierEmailBodyTemplate{
PlainText: "",
@ -1072,44 +1055,72 @@ func (p *Config) CourierTemplatesHelper(ctx context.Context, key string) *Courie
return courierTemplate
}
func (p *Config) CourierSMSTemplatesHelper(ctx context.Context, key string) *CourierSMSTemplate {
courierTemplate := &CourierSMSTemplate{
Body: &CourierSMSTemplateBody{
PlainText: "",
},
}
if !p.GetProvider(ctx).Exists(key) {
return courierTemplate
}
config, err := json.Marshal(p.GetProvider(ctx).Get(key))
if err != nil {
p.l.WithError(err).Fatalf("Unable to decode values from %s.", key)
return courierTemplate
}
if err := json.Unmarshal(config, courierTemplate); err != nil {
p.l.WithError(err).Fatalf("Unable to encode values from %s.", key)
return courierTemplate
}
return courierTemplate
}
func (p *Config) CourierTemplatesVerificationInvalid(ctx context.Context) *CourierEmailTemplate {
return p.CourierTemplatesHelper(ctx, ViperKeyCourierTemplatesVerificationInvalidEmail)
return p.CourierEmailTemplatesHelper(ctx, ViperKeyCourierTemplatesVerificationInvalidEmail)
}
func (p *Config) CourierTemplatesVerificationValid(ctx context.Context) *CourierEmailTemplate {
return p.CourierTemplatesHelper(ctx, ViperKeyCourierTemplatesVerificationValidEmail)
return p.CourierEmailTemplatesHelper(ctx, ViperKeyCourierTemplatesVerificationValidEmail)
}
func (p *Config) CourierTemplatesRecoveryInvalid(ctx context.Context) *CourierEmailTemplate {
return p.CourierTemplatesHelper(ctx, ViperKeyCourierTemplatesRecoveryInvalidEmail)
return p.CourierEmailTemplatesHelper(ctx, ViperKeyCourierTemplatesRecoveryInvalidEmail)
}
func (p *Config) CourierTemplatesRecoveryValid(ctx context.Context) *CourierEmailTemplate {
return p.CourierTemplatesHelper(ctx, ViperKeyCourierTemplatesRecoveryValidEmail)
return p.CourierEmailTemplatesHelper(ctx, ViperKeyCourierTemplatesRecoveryValidEmail)
}
func (p *Config) CourierTemplatesRecoveryCodeInvalid(ctx context.Context) *CourierEmailTemplate {
return p.CourierTemplatesHelper(ctx, ViperKeyCourierTemplatesRecoveryCodeInvalidEmail)
return p.CourierEmailTemplatesHelper(ctx, ViperKeyCourierTemplatesRecoveryCodeInvalidEmail)
}
func (p *Config) CourierTemplatesRecoveryCodeValid(ctx context.Context) *CourierEmailTemplate {
return p.CourierTemplatesHelper(ctx, ViperKeyCourierTemplatesRecoveryCodeValidEmail)
return p.CourierEmailTemplatesHelper(ctx, ViperKeyCourierTemplatesRecoveryCodeValidEmail)
}
func (p *Config) CourierTemplatesVerificationCodeInvalid(ctx context.Context) *CourierEmailTemplate {
return p.CourierTemplatesHelper(ctx, ViperKeyCourierTemplatesVerificationCodeInvalidEmail)
return p.CourierEmailTemplatesHelper(ctx, ViperKeyCourierTemplatesVerificationCodeInvalidEmail)
}
func (p *Config) CourierTemplatesVerificationCodeValid(ctx context.Context) *CourierEmailTemplate {
return p.CourierTemplatesHelper(ctx, ViperKeyCourierTemplatesVerificationCodeValidEmail)
return p.CourierEmailTemplatesHelper(ctx, ViperKeyCourierTemplatesVerificationCodeValidEmail)
}
func (p *Config) CourierSMSTemplatesVerificationCodeValid(ctx context.Context) *CourierSMSTemplate {
return p.CourierSMSTemplatesHelper(ctx, ViperKeyCourierTemplatesVerificationCodeValidEmail)
}
func (p *Config) CourierTemplatesLoginCodeValid(ctx context.Context) *CourierEmailTemplate {
return p.CourierTemplatesHelper(ctx, ViperKeyCourierTemplatesLoginCodeValidEmail)
return p.CourierEmailTemplatesHelper(ctx, ViperKeyCourierTemplatesLoginCodeValidEmail)
}
func (p *Config) CourierTemplatesRegistrationCodeValid(ctx context.Context) *CourierEmailTemplate {
return p.CourierTemplatesHelper(ctx, ViperKeyCourierTemplatesRegistrationCodeValidEmail)
return p.CourierEmailTemplatesHelper(ctx, ViperKeyCourierTemplatesRegistrationCodeValidEmail)
}
func (p *Config) CourierMessageRetries(ctx context.Context) int {
@ -1128,25 +1139,40 @@ func (p *Config) CourierSMTPHeaders(ctx context.Context) map[string]string {
return p.GetProvider(ctx).StringMap(ViperKeyCourierSMTPHeaders)
}
func (p *Config) CourierSMSRequestConfig(ctx context.Context) json.RawMessage {
if !p.GetProvider(ctx).Bool(ViperKeyCourierSMSEnabled) {
return nil
func (p *Config) CourierChannels(ctx context.Context) (ccs []*CourierChannel, _ error) {
if err := p.GetProvider(ctx).Koanf.Unmarshal(ViperKeyCourierChannels, &ccs); err != nil {
return nil, errors.WithStack(err)
}
if len(ccs) != 0 {
for _, c := range ccs {
if c.RequestConfigRaw != nil {
var err error
c.RequestConfig, err = json.Marshal(c.RequestConfigRaw)
if err != nil {
return nil, errors.WithStack(err)
}
}
}
return ccs, nil
}
config, err := json.Marshal(p.GetProvider(ctx).Get(ViperKeyCourierSMSRequestConfig))
if err != nil {
p.l.WithError(err).Warn("Unable to marshal SMS request configuration.")
return json.RawMessage("{}")
// load legacy configs
channel := CourierChannel{
ID: "email",
Type: p.CourierEmailStrategy(ctx),
}
return config
}
func (p *Config) CourierSMSFrom(ctx context.Context) string {
return p.GetProvider(ctx).StringF(ViperKeyCourierSMSFrom, "Ory Kratos")
}
func (p *Config) CourierSMSEnabled(ctx context.Context) bool {
return p.GetProvider(ctx).Bool(ViperKeyCourierSMSEnabled)
if channel.Type == "smtp" {
if err := p.GetProvider(ctx).Koanf.Unmarshal(ViperKeyCourierSMTP, &channel.SMTPConfig); err != nil {
return nil, errors.WithStack(err)
}
} else {
var err error
channel.RequestConfig, err = json.Marshal(p.GetProvider(ctx).Get(ViperKeyCourierHTTPRequestConfig))
if err != nil {
return nil, errors.WithStack(err)
}
}
return []*CourierChannel{&channel}, nil
}
func splitUrlAndFragment(s string) (string, string) {

View File

@ -1149,53 +1149,47 @@ func TestCourierEmailHTTP(t *testing.T) {
})
}
func TestCourierSMS(t *testing.T) {
func TestCourierChannels(t *testing.T) {
t.Parallel()
ctx := context.Background()
t.Run("case=configs set", func(t *testing.T) {
conf, _ := config.New(ctx, logrusx.New("", ""), os.Stderr,
configx.WithConfigFiles("stub/.kratos.courier.sms.yaml"), configx.SkipValidation())
assert.True(t, conf.CourierSMSEnabled(ctx))
snapshotx.SnapshotTExcept(t, conf.CourierSMSRequestConfig(ctx), nil)
assert.Equal(t, "+49123456789", conf.CourierSMSFrom(ctx))
conf, _ := config.New(ctx, logrusx.New("", ""), os.Stderr, configx.WithConfigFiles("stub/.kratos.courier.channels.yaml"), configx.SkipValidation())
channelConfig, err := conf.CourierChannels(ctx)
require.NoError(t, err)
require.Len(t, channelConfig, 1)
assert.Equal(t, channelConfig[0].ID, "phone")
assert.NotEmpty(t, channelConfig[0].RequestConfig)
})
t.Run("case=defaults", func(t *testing.T) {
conf, _ := config.New(ctx, logrusx.New("", ""), os.Stderr, configx.SkipValidation())
assert.False(t, conf.CourierSMSEnabled(ctx))
snapshotx.SnapshotTExcept(t, conf.CourierSMSRequestConfig(ctx), nil)
assert.Equal(t, "Ory Kratos", conf.CourierSMSFrom(ctx))
})
}
func TestCourierSMTPUrl(t *testing.T) {
t.Parallel()
ctx := context.Background()
for _, tc := range []string{
"smtp://a:basdasdasda%2Fc@email-smtp.eu-west-3.amazonaws.com:587/",
"smtp://a:b$c@email-smtp.eu-west-3.amazonaws.com:587/",
"smtp://a/a:bc@email-smtp.eu-west-3.amazonaws.com:587",
"smtp://aa:b+c@email-smtp.eu-west-3.amazonaws.com:587/",
"smtp://user?name:password@email-smtp.eu-west-3.amazonaws.com:587/",
"smtp://username:pass%2Fword@email-smtp.eu-west-3.amazonaws.com:587/",
} {
t.Run("case="+tc, func(t *testing.T) {
conf, err := config.New(ctx, logrusx.New("", ""), os.Stderr, configx.WithValue(config.ViperKeyCourierSMTPURL, tc), configx.SkipValidation())
require.NoError(t, err)
parsed, err := conf.CourierSMTPURL(ctx)
require.NoError(t, err)
assert.Equal(t, tc, parsed.String())
})
}
t.Run("invalid", func(t *testing.T) {
conf, err := config.New(ctx, logrusx.New("", ""), os.Stderr, configx.WithValue(config.ViperKeyCourierSMTPURL, "smtp://a:b/c@email-smtp.eu-west-3.amazonaws.com:587/"), configx.SkipValidation())
channelConfig, err := conf.CourierChannels(ctx)
require.NoError(t, err)
_, err = conf.CourierSMTPURL(ctx)
require.Error(t, err)
assert.Len(t, channelConfig, 1)
assert.Equal(t, channelConfig[0].ID, "email")
assert.Equal(t, channelConfig[0].Type, "smtp")
})
t.Run("smtp urls", func(t *testing.T) {
for _, tc := range []string{
"smtp://a:basdasdasda%2Fc@email-smtp.eu-west-3.amazonaws.com:587/",
"smtp://a:b$c@email-smtp.eu-west-3.amazonaws.com:587/",
"smtp://a/a:bc@email-smtp.eu-west-3.amazonaws.com:587",
"smtp://aa:b+c@email-smtp.eu-west-3.amazonaws.com:587/",
"smtp://user?name:password@email-smtp.eu-west-3.amazonaws.com:587/",
"smtp://username:pass%2Fword@email-smtp.eu-west-3.amazonaws.com:587/",
} {
t.Run("case="+tc, func(t *testing.T) {
conf, err := config.New(ctx, logrusx.New("", ""), os.Stderr, configx.WithValue(config.ViperKeyCourierSMTPURL, tc), configx.SkipValidation())
require.NoError(t, err)
cs, err := conf.CourierChannels(ctx)
require.NoError(t, err)
require.Len(t, cs, 1)
assert.Equal(t, tc, cs[0].SMTPConfig.ConnectionURI)
})
}
})
}
@ -1311,10 +1305,10 @@ func TestCourierTemplatesConfig(t *testing.T) {
Subject: "",
}
assert.Equal(t, courierTemplateConfig, c.CourierTemplatesHelper(ctx, config.ViperKeyCourierTemplatesVerificationInvalidEmail))
assert.Equal(t, courierTemplateConfig, c.CourierTemplatesHelper(ctx, config.ViperKeyCourierTemplatesVerificationValidEmail))
assert.Equal(t, courierTemplateConfig, c.CourierEmailTemplatesHelper(ctx, config.ViperKeyCourierTemplatesVerificationInvalidEmail))
assert.Equal(t, courierTemplateConfig, c.CourierEmailTemplatesHelper(ctx, config.ViperKeyCourierTemplatesVerificationValidEmail))
// this should return an empty courierEmailTemplate as the key does not exist
assert.Equal(t, courierTemplateConfig, c.CourierTemplatesHelper(ctx, "a_random_key"))
assert.Equal(t, courierTemplateConfig, c.CourierEmailTemplatesHelper(ctx, "a_random_key"))
courierTemplateConfig = &config.CourierEmailTemplate{
Body: &config.CourierEmailBodyTemplate{
@ -1323,7 +1317,7 @@ func TestCourierTemplatesConfig(t *testing.T) {
},
Subject: "base64://QWNjb3VudCBBY2Nlc3MgQXR0ZW1wdGVk",
}
assert.Equal(t, courierTemplateConfig, c.CourierTemplatesHelper(ctx, config.ViperKeyCourierTemplatesRecoveryInvalidEmail))
assert.Equal(t, courierTemplateConfig, c.CourierEmailTemplatesHelper(ctx, config.ViperKeyCourierTemplatesRecoveryInvalidEmail))
courierTemplateConfig = &config.CourierEmailTemplate{
Body: &config.CourierEmailBodyTemplate{
@ -1332,7 +1326,7 @@ func TestCourierTemplatesConfig(t *testing.T) {
},
Subject: "base64://UmVjb3ZlciBhY2Nlc3MgdG8geW91ciBhY2NvdW50",
}
assert.Equal(t, courierTemplateConfig, c.CourierTemplatesHelper(ctx, config.ViperKeyCourierTemplatesRecoveryValidEmail))
assert.Equal(t, courierTemplateConfig, c.CourierEmailTemplatesHelper(ctx, config.ViperKeyCourierTemplatesRecoveryValidEmail))
})
}

View File

@ -0,0 +1,14 @@
courier:
channels:
- id: phone
request_config:
url: https://ory.sh
method: GET
body: base64://ZnVuY3Rpb24oY3R4KSB7CkJvZHk6IGN0eC5ib2R5LApUbzogY3R4LnRvLEZyb206IGN0eC5mcm9tCn0=
headers:
Content-Type: application/x-www-form-urlencoded
auth:
type: basic_auth
config:
user: ABC
password: DEF

View File

@ -1034,12 +1034,36 @@
"properties": {
"email": {
"$ref": "#/definitions/emailCourierTemplate"
},
"sms": {
"$ref": "#/definitions/smsCourierTemplate"
}
},
"required": ["email"]
}
}
},
"smsCourierTemplate": {
"additionalProperties": false,
"type": "object",
"properties": {
"body": {
"additionalProperties": false,
"type": "object",
"properties": {
"plaintext": {
"type": "string",
"description": "A template send to the SMS provider.",
"format": "uri",
"examples": [
"file://path/to/body.plaintext.gotmpl",
"https://foo.bar.com/path/to/body.plaintext.gotmpl"
]
}
}
}
}
},
"emailCourierTemplate": {
"additionalProperties": false,
"type": "object",
@ -1963,9 +1987,29 @@
}
},
"additionalProperties": false
},
"channels": {
"type": "array",
"items": {
"title": "Courier channel configuration",
"type": "object",
"properties": {
"id": {
"type": "string",
"title": "Channel id",
"description": "The channel id. Corresponds to the .via property of the identity schema for recovery, verification, etc. Currently only phone is supported.",
"maxLength": 32,
"enum": ["phone"]
},
"request_config": {
"$ref": "#/definitions/httpRequestConfig"
}
},
"required": ["id", "request_config"],
"additionalProperties": false
}
}
},
"required": ["smtp"],
"additionalProperties": false
},
"oauth2_provider": {

View File

@ -63,13 +63,13 @@ func (r *SchemaExtensionCredentials) Run(ctx jsonschema.ValidationContext, s sch
return ctx.Error("format", "%q is not a valid %q", value, s.Credentials.Code.Via)
}
r.setIdentifier(CredentialsTypeCodeAuth, value, AddressTypeEmail)
r.setIdentifier(CredentialsTypeCodeAuth, value, CredentialsIdentifierAddressTypeEmail)
// case f.AddCase(AddressTypePhone):
// if !jsonschema.Formats["tel"](value) {
// return ctx.Error("format", "%q is not a valid %q", value, s.Credentials.Code.Via)
// }
//
// r.setIdentifier(CredentialsTypeCodeAuth, value, CredentialsIdentifierAddressType(AddressTypePhone))
// r.setIdentifier(CredentialsTypeCodeAuth, value, CredentialsIdentifierAddressTypePhone)
default:
return ctx.Error("", "credentials.code.via has unknown value %q", s.Credentials.Code.Via)
}

View File

@ -18,7 +18,8 @@ import (
// Message struct for Message
type Message struct {
Body string `json:"body"`
Body string `json:"body"`
Channel *string `json:"channel,omitempty"`
// CreatedAt is a helper struct field for gobuffalo.pop.
CreatedAt time.Time `json:"created_at"`
// Dispatches store information about the attempts of delivering a message May contain an error if any happened, or just the `success` state.
@ -28,7 +29,7 @@ type Message struct {
SendCount int64 `json:"send_count"`
Status CourierMessageStatus `json:"status"`
Subject string `json:"subject"`
// recovery_invalid TypeRecoveryInvalid recovery_valid TypeRecoveryValid recovery_code_invalid TypeRecoveryCodeInvalid recovery_code_valid TypeRecoveryCodeValid verification_invalid TypeVerificationInvalid verification_valid TypeVerificationValid verification_code_invalid TypeVerificationCodeInvalid verification_code_valid TypeVerificationCodeValid otp TypeOTP stub TypeTestStub login_code_valid TypeLoginCodeValid registration_code_valid TypeRegistrationCodeValid
// recovery_invalid TypeRecoveryInvalid recovery_valid TypeRecoveryValid recovery_code_invalid TypeRecoveryCodeInvalid recovery_code_valid TypeRecoveryCodeValid verification_invalid TypeVerificationInvalid verification_valid TypeVerificationValid verification_code_invalid TypeVerificationCodeInvalid verification_code_valid TypeVerificationCodeValid stub TypeTestStub login_code_valid TypeLoginCodeValid registration_code_valid TypeRegistrationCodeValid
TemplateType string `json:"template_type"`
Type CourierMessageType `json:"type"`
// UpdatedAt is a helper struct field for gobuffalo.pop.
@ -86,6 +87,38 @@ func (o *Message) SetBody(v string) {
o.Body = v
}
// GetChannel returns the Channel field value if set, zero value otherwise.
func (o *Message) GetChannel() string {
if o == nil || o.Channel == nil {
var ret string
return ret
}
return *o.Channel
}
// GetChannelOk returns a tuple with the Channel field value if set, nil otherwise
// and a boolean to check if the value has been set.
func (o *Message) GetChannelOk() (*string, bool) {
if o == nil || o.Channel == nil {
return nil, false
}
return o.Channel, true
}
// HasChannel returns a boolean if a field has been set.
func (o *Message) HasChannel() bool {
if o != nil && o.Channel != nil {
return true
}
return false
}
// SetChannel gets a reference to the given string and assigns it to the Channel field.
func (o *Message) SetChannel(v string) {
o.Channel = &v
}
// GetCreatedAt returns the CreatedAt field value
func (o *Message) GetCreatedAt() time.Time {
if o == nil {
@ -339,6 +372,9 @@ func (o Message) MarshalJSON() ([]byte, error) {
if true {
toSerialize["body"] = o.Body
}
if o.Channel != nil {
toSerialize["channel"] = o.Channel
}
if true {
toSerialize["created_at"] = o.CreatedAt
}

View File

@ -18,7 +18,8 @@ import (
// Message struct for Message
type Message struct {
Body string `json:"body"`
Body string `json:"body"`
Channel *string `json:"channel,omitempty"`
// CreatedAt is a helper struct field for gobuffalo.pop.
CreatedAt time.Time `json:"created_at"`
// Dispatches store information about the attempts of delivering a message May contain an error if any happened, or just the `success` state.
@ -28,7 +29,7 @@ type Message struct {
SendCount int64 `json:"send_count"`
Status CourierMessageStatus `json:"status"`
Subject string `json:"subject"`
// recovery_invalid TypeRecoveryInvalid recovery_valid TypeRecoveryValid recovery_code_invalid TypeRecoveryCodeInvalid recovery_code_valid TypeRecoveryCodeValid verification_invalid TypeVerificationInvalid verification_valid TypeVerificationValid verification_code_invalid TypeVerificationCodeInvalid verification_code_valid TypeVerificationCodeValid otp TypeOTP stub TypeTestStub login_code_valid TypeLoginCodeValid registration_code_valid TypeRegistrationCodeValid
// recovery_invalid TypeRecoveryInvalid recovery_valid TypeRecoveryValid recovery_code_invalid TypeRecoveryCodeInvalid recovery_code_valid TypeRecoveryCodeValid verification_invalid TypeVerificationInvalid verification_valid TypeVerificationValid verification_code_invalid TypeVerificationCodeInvalid verification_code_valid TypeVerificationCodeValid stub TypeTestStub login_code_valid TypeLoginCodeValid registration_code_valid TypeRegistrationCodeValid
TemplateType string `json:"template_type"`
Type CourierMessageType `json:"type"`
// UpdatedAt is a helper struct field for gobuffalo.pop.
@ -86,6 +87,38 @@ func (o *Message) SetBody(v string) {
o.Body = v
}
// GetChannel returns the Channel field value if set, zero value otherwise.
func (o *Message) GetChannel() string {
if o == nil || o.Channel == nil {
var ret string
return ret
}
return *o.Channel
}
// GetChannelOk returns a tuple with the Channel field value if set, nil otherwise
// and a boolean to check if the value has been set.
func (o *Message) GetChannelOk() (*string, bool) {
if o == nil || o.Channel == nil {
return nil, false
}
return o.Channel, true
}
// HasChannel returns a boolean if a field has been set.
func (o *Message) HasChannel() bool {
if o != nil && o.Channel != nil {
return true
}
return false
}
// SetChannel gets a reference to the given string and assigns it to the Channel field.
func (o *Message) SetChannel(v string) {
o.Channel = &v
}
// GetCreatedAt returns the CreatedAt field value
func (o *Message) GetCreatedAt() time.Time {
if o == nil {
@ -339,6 +372,9 @@ func (o Message) MarshalJSON() ([]byte, error) {
if true {
toSerialize["body"] = o.Body
}
if o.Channel != nil {
toSerialize["channel"] = o.Channel
}
if true {
toSerialize["created_at"] = o.CreatedAt
}

View File

@ -0,0 +1,2 @@
ALTER TABLE
courier_messages DROP column channel;

View File

@ -0,0 +1,4 @@
ALTER TABLE
courier_messages
ADD
column channel VARCHAR(32) NULL;

View File

@ -13,7 +13,7 @@ import (
"github.com/gofrs/uuid"
"github.com/laher/mergefs"
"github.com/pkg/errors"
"github.com/sirupsen/logrus/hooks/test"
"github.com/sirupsen/logrus"
"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/identity"
@ -24,7 +24,6 @@ import (
"github.com/ory/kratos/session"
"github.com/ory/kratos/x"
"github.com/ory/x/contextx"
"github.com/ory/x/logrusx"
"github.com/ory/x/networkx"
"github.com/ory/x/popx"
)
@ -82,8 +81,7 @@ func NewPersister(ctx context.Context, r persisterDependencies, c *pop.Connectio
}
logger := r.Logger()
if o.disableLogging {
inner, _ := test.NewNullLogger()
logger = logrusx.New("kratos", "", logrusx.UseLogger(inner))
logger.Logrus().SetLevel(logrus.WarnLevel)
}
m, err := popx.NewMigrationBox(
mergefs.Merge(

View File

@ -13,6 +13,7 @@ import (
"github.com/ory/herodot"
"github.com/ory/kratos/courier/template/email"
"github.com/ory/kratos/courier/template/sms"
"github.com/ory/x/httpx"
"github.com/ory/x/sqlcon"
@ -312,20 +313,35 @@ func (s *Sender) SendVerificationCodeTo(ctx context.Context, f *verification.Flo
return err
}
if err := s.send(ctx, string(code.VerifiableAddress.Via), email.NewVerificationCodeValid(s.deps,
&email.VerificationCodeValidModel{
var t courier.Template
// TODO: this can likely be abstracted by making templates not specific to the channel they're using
switch code.VerifiableAddress.Via {
case identity.AddressTypeEmail:
t = email.NewVerificationCodeValid(s.deps, &email.VerificationCodeValidModel{
To: code.VerifiableAddress.Value,
VerificationURL: s.constructVerificationLink(ctx, f.ID, codeString),
Identity: model,
VerificationCode: codeString,
})); err != nil {
})
case identity.AddressTypePhone:
t = sms.NewVerificationCodeValid(s.deps, &sms.VerificationCodeValidModel{
To: code.VerifiableAddress.Value,
VerificationCode: codeString,
Identity: model,
})
default:
return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Expected email or phone but got %s", code.VerifiableAddress.Via))
}
if err := s.send(ctx, string(code.VerifiableAddress.Via), t); err != nil {
return err
}
code.VerifiableAddress.Status = identity.VerifiableAddressStatusSent
return s.deps.PrivilegedIdentityPool().UpdateVerifiableAddress(ctx, code.VerifiableAddress)
}
func (s *Sender) send(ctx context.Context, via string, t courier.EmailTemplate) error {
func (s *Sender) send(ctx context.Context, via string, t courier.Template) error {
switch f := stringsx.SwitchExact(via); {
case f.AddCase(identity.AddressTypeEmail):
c, err := s.deps.Courier(ctx)
@ -333,8 +349,26 @@ func (s *Sender) send(ctx context.Context, via string, t courier.EmailTemplate)
return err
}
t, ok := t.(courier.EmailTemplate)
if !ok {
return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Expected email template but got %T", t))
}
_, err = c.QueueEmail(ctx, t)
return err
case f.AddCase(identity.AddressTypePhone):
c, err := s.deps.Courier(ctx)
if err != nil {
return err
}
t, ok := t.(courier.SMSTemplate)
if !ok {
return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Expected sms template but got %T", t))
}
_, err = c.QueueSMS(ctx, t)
return err
default:
return f.ToUnknownCaseErr()
}

View File

@ -1387,6 +1387,9 @@
"body": {
"type": "string"
},
"channel": {
"type": "string"
},
"created_at": {
"description": "CreatedAt is a helper struct field for gobuffalo.pop.",
"format": "date-time",
@ -1417,7 +1420,7 @@
"type": "string"
},
"template_type": {
"description": "\nrecovery_invalid TypeRecoveryInvalid\nrecovery_valid TypeRecoveryValid\nrecovery_code_invalid TypeRecoveryCodeInvalid\nrecovery_code_valid TypeRecoveryCodeValid\nverification_invalid TypeVerificationInvalid\nverification_valid TypeVerificationValid\nverification_code_invalid TypeVerificationCodeInvalid\nverification_code_valid TypeVerificationCodeValid\notp TypeOTP\nstub TypeTestStub\nlogin_code_valid TypeLoginCodeValid\nregistration_code_valid TypeRegistrationCodeValid",
"description": "\nrecovery_invalid TypeRecoveryInvalid\nrecovery_valid TypeRecoveryValid\nrecovery_code_invalid TypeRecoveryCodeInvalid\nrecovery_code_valid TypeRecoveryCodeValid\nverification_invalid TypeVerificationInvalid\nverification_valid TypeVerificationValid\nverification_code_invalid TypeVerificationCodeInvalid\nverification_code_valid TypeVerificationCodeValid\nstub TypeTestStub\nlogin_code_valid TypeLoginCodeValid\nregistration_code_valid TypeRegistrationCodeValid",
"enum": [
"recovery_invalid",
"recovery_valid",
@ -1427,13 +1430,12 @@
"verification_valid",
"verification_code_invalid",
"verification_code_valid",
"otp",
"stub",
"login_code_valid",
"registration_code_valid"
],
"type": "string",
"x-go-enum-desc": "recovery_invalid TypeRecoveryInvalid\nrecovery_valid TypeRecoveryValid\nrecovery_code_invalid TypeRecoveryCodeInvalid\nrecovery_code_valid TypeRecoveryCodeValid\nverification_invalid TypeVerificationInvalid\nverification_valid TypeVerificationValid\nverification_code_invalid TypeVerificationCodeInvalid\nverification_code_valid TypeVerificationCodeValid\notp TypeOTP\nstub TypeTestStub\nlogin_code_valid TypeLoginCodeValid\nregistration_code_valid TypeRegistrationCodeValid"
"x-go-enum-desc": "recovery_invalid TypeRecoveryInvalid\nrecovery_valid TypeRecoveryValid\nrecovery_code_invalid TypeRecoveryCodeInvalid\nrecovery_code_valid TypeRecoveryCodeValid\nverification_invalid TypeVerificationInvalid\nverification_valid TypeVerificationValid\nverification_code_invalid TypeVerificationCodeInvalid\nverification_code_valid TypeVerificationCodeValid\nstub TypeTestStub\nlogin_code_valid TypeLoginCodeValid\nregistration_code_valid TypeRegistrationCodeValid"
},
"type": {
"$ref": "#/components/schemas/courierMessageType"

View File

@ -4458,6 +4458,9 @@
"body": {
"type": "string"
},
"channel": {
"type": "string"
},
"created_at": {
"description": "CreatedAt is a helper struct field for gobuffalo.pop.",
"type": "string",
@ -4488,7 +4491,7 @@
"type": "string"
},
"template_type": {
"description": "\nrecovery_invalid TypeRecoveryInvalid\nrecovery_valid TypeRecoveryValid\nrecovery_code_invalid TypeRecoveryCodeInvalid\nrecovery_code_valid TypeRecoveryCodeValid\nverification_invalid TypeVerificationInvalid\nverification_valid TypeVerificationValid\nverification_code_invalid TypeVerificationCodeInvalid\nverification_code_valid TypeVerificationCodeValid\notp TypeOTP\nstub TypeTestStub\nlogin_code_valid TypeLoginCodeValid\nregistration_code_valid TypeRegistrationCodeValid",
"description": "\nrecovery_invalid TypeRecoveryInvalid\nrecovery_valid TypeRecoveryValid\nrecovery_code_invalid TypeRecoveryCodeInvalid\nrecovery_code_valid TypeRecoveryCodeValid\nverification_invalid TypeVerificationInvalid\nverification_valid TypeVerificationValid\nverification_code_invalid TypeVerificationCodeInvalid\nverification_code_valid TypeVerificationCodeValid\nstub TypeTestStub\nlogin_code_valid TypeLoginCodeValid\nregistration_code_valid TypeRegistrationCodeValid",
"type": "string",
"enum": [
"recovery_invalid",
@ -4499,12 +4502,11 @@
"verification_valid",
"verification_code_invalid",
"verification_code_valid",
"otp",
"stub",
"login_code_valid",
"registration_code_valid"
],
"x-go-enum-desc": "recovery_invalid TypeRecoveryInvalid\nrecovery_valid TypeRecoveryValid\nrecovery_code_invalid TypeRecoveryCodeInvalid\nrecovery_code_valid TypeRecoveryCodeValid\nverification_invalid TypeVerificationInvalid\nverification_valid TypeVerificationValid\nverification_code_invalid TypeVerificationCodeInvalid\nverification_code_valid TypeVerificationCodeValid\notp TypeOTP\nstub TypeTestStub\nlogin_code_valid TypeLoginCodeValid\nregistration_code_valid TypeRegistrationCodeValid"
"x-go-enum-desc": "recovery_invalid TypeRecoveryInvalid\nrecovery_valid TypeRecoveryValid\nrecovery_code_invalid TypeRecoveryCodeInvalid\nrecovery_code_valid TypeRecoveryCodeValid\nverification_invalid TypeVerificationInvalid\nverification_valid TypeVerificationValid\nverification_code_invalid TypeVerificationCodeInvalid\nverification_code_valid TypeVerificationCodeValid\nstub TypeTestStub\nlogin_code_valid TypeLoginCodeValid\nregistration_code_valid TypeRegistrationCodeValid"
},
"type": {
"$ref": "#/definitions/courierMessageType"