mirror of https://github.com/ory/kratos
253 lines
8.3 KiB
Go
253 lines
8.3 KiB
Go
// Copyright © 2023 Ory Corp
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
package courier_test
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"encoding/pem"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"math/big"
|
|
"net/http"
|
|
"os"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gofrs/uuid"
|
|
"github.com/pkg/errors"
|
|
"github.com/sirupsen/logrus"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"github.com/tidwall/gjson"
|
|
|
|
"github.com/ory/kratos/courier"
|
|
templates "github.com/ory/kratos/courier/template/email"
|
|
"github.com/ory/kratos/driver/config"
|
|
"github.com/ory/kratos/internal"
|
|
"github.com/ory/kratos/x"
|
|
gomail "github.com/ory/mail/v3"
|
|
)
|
|
|
|
func TestNewSMTPClientPreventLeak(t *testing.T) {
|
|
// Test for https://hackerone.com/reports/2384028
|
|
|
|
ctx := context.Background()
|
|
conf, reg := internal.NewFastRegistryWithMocks(t)
|
|
|
|
invalidURL := "sm<>t>p://f%oo::bar:baz@my-server:1234:122/"
|
|
conf.MustSet(ctx, config.ViperKeyCourierSMTPURL, invalidURL)
|
|
channels, err := conf.CourierChannels(ctx)
|
|
require.NoError(t, err)
|
|
require.Len(t, channels, 1)
|
|
|
|
_, err = courier.NewSMTPClient(reg, channels[0].SMTPConfig)
|
|
require.Error(t, err)
|
|
assert.NotContains(t, err.Error(), invalidURL)
|
|
}
|
|
|
|
func TestNewSMTP(t *testing.T) {
|
|
ctx := context.Background()
|
|
conf, reg := internal.NewFastRegistryWithMocks(t)
|
|
|
|
setupSMTPClient := func(stringURL string) *courier.SMTPClient {
|
|
conf.MustSet(ctx, config.ViperKeyCourierSMTPURL, stringURL)
|
|
|
|
channels, err := conf.CourierChannels(ctx)
|
|
require.NoError(t, err)
|
|
require.Len(t, channels, 1)
|
|
c, err := courier.NewSMTPClient(reg, channels[0].SMTPConfig)
|
|
require.NoError(t, err)
|
|
return c
|
|
}
|
|
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
|
|
// Should enforce StartTLS => dialer.StartTLSPolicy = gomail.MandatoryStartTLS and dialer.SSL = false
|
|
smtp := setupSMTPClient("smtp://foo:bar@my-server:1234/")
|
|
assert.Equal(t, smtp.StartTLSPolicy, gomail.MandatoryStartTLS, "StartTLS not enforced")
|
|
assert.Equal(t, smtp.SSL, false, "Implicit TLS should not be enabled")
|
|
|
|
// Should enforce TLS => dialer.SSL = true
|
|
smtp = setupSMTPClient("smtps://foo:bar@my-server:1234/")
|
|
assert.Equal(t, smtp.SSL, true, "Implicit TLS should be enabled")
|
|
|
|
// Should allow cleartext => dialer.StartTLSPolicy = gomail.OpportunisticStartTLS and dialer.SSL = false
|
|
smtp = setupSMTPClient("smtp://foo:bar@my-server:1234/?disable_starttls=true")
|
|
assert.Equal(t, smtp.StartTLSPolicy, gomail.OpportunisticStartTLS, "StartTLS is enforced")
|
|
assert.Equal(t, smtp.SSL, false, "Implicit TLS should not be enabled")
|
|
|
|
// Test cert based SMTP client auth
|
|
clientCert, clientKey, err := generateTestClientCert(t)
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() { _ = os.Remove(clientCert.Name()) })
|
|
t.Cleanup(func() { _ = os.Remove(clientKey.Name()) })
|
|
|
|
conf.MustSet(ctx, config.ViperKeyCourierSMTPClientCertPath, clientCert.Name())
|
|
conf.MustSet(ctx, config.ViperKeyCourierSMTPClientKeyPath, clientKey.Name())
|
|
|
|
clientPEM, err := tls.LoadX509KeyPair(clientCert.Name(), clientKey.Name())
|
|
require.NoError(t, err)
|
|
|
|
smtpWithCert := setupSMTPClient("smtps://subdomain.my-server:1234/?server_name=my-server")
|
|
assert.Equal(t, smtpWithCert.SSL, true, "Implicit TLS should be enabled")
|
|
assert.Equal(t, smtpWithCert.Host, "subdomain.my-server", "SMTP Dialer host should match")
|
|
assert.Equal(t, smtpWithCert.TLSConfig.ServerName, "my-server", "TLS config server name should match")
|
|
assert.Equal(t, smtpWithCert.TLSConfig.ServerName, "my-server", "TLS config server name should match")
|
|
assert.Contains(t, smtpWithCert.TLSConfig.Certificates, clientPEM, "TLS config should contain client pem")
|
|
|
|
// error case: invalid client key
|
|
require.NoError(t, conf.Set(ctx, config.ViperKeyCourierSMTPClientKeyPath, clientCert.Name())) // mixup client key and client cert
|
|
smtpWithCert = setupSMTPClient("smtps://subdomain.my-server:1234/?server_name=my-server")
|
|
assert.Equal(t, len(smtpWithCert.TLSConfig.Certificates), 0, "TLS config certificates should be empty")
|
|
}
|
|
|
|
func TestQueueEmail(t *testing.T) {
|
|
if testing.Short() {
|
|
t.SkipNow()
|
|
}
|
|
|
|
smtp, api, err := x.RunTestSMTP()
|
|
require.NoError(t, err)
|
|
t.Logf("SMTP URL: %s", smtp)
|
|
t.Logf("API URL: %s", api)
|
|
|
|
ctx := context.Background()
|
|
|
|
conf, reg := internal.NewRegistryDefaultWithDSN(t, "")
|
|
conf.MustSet(ctx, config.ViperKeyCourierSMTPURL, smtp)
|
|
conf.MustSet(ctx, config.ViperKeyCourierSMTPFrom, "test-stub@ory.sh")
|
|
reg.Logger().Level = logrus.TraceLevel
|
|
|
|
c, err := reg.Courier(ctx)
|
|
require.NoError(t, err)
|
|
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
_, err = c.QueueEmail(ctx, templates.NewTestStub(reg, &templates.TestStubModel{
|
|
To: "invalid-email",
|
|
Subject: "test-subject-1",
|
|
Body: "test-body-1",
|
|
}))
|
|
require.Error(t, err)
|
|
|
|
id, err := c.QueueEmail(ctx, templates.NewTestStub(reg, &templates.TestStubModel{
|
|
To: "test-recipient-1@example.org",
|
|
Subject: "test-subject-1",
|
|
Body: "test-body-1",
|
|
}))
|
|
require.NoError(t, err)
|
|
require.NotEqual(t, uuid.Nil, id)
|
|
|
|
id, err = c.QueueEmail(ctx, templates.NewTestStub(reg, &templates.TestStubModel{
|
|
To: "test-recipient-2@example.org",
|
|
Subject: "test-subject-2",
|
|
Body: "test-body-2",
|
|
}))
|
|
require.NoError(t, err)
|
|
require.NotEqual(t, uuid.Nil, id)
|
|
|
|
// The third email contains a sender name and custom headers
|
|
conf.MustSet(ctx, config.ViperKeyCourierSMTPFromName, "Bob")
|
|
conf.MustSet(ctx, config.ViperKeyCourierSMTPHeaders+".test-stub-header1", "foo")
|
|
conf.MustSet(ctx, config.ViperKeyCourierSMTPHeaders+".test-stub-header2", "bar")
|
|
customerHeaders := conf.CourierSMTPHeaders(ctx)
|
|
require.Len(t, customerHeaders, 2)
|
|
|
|
id, err = c.QueueEmail(ctx, templates.NewTestStub(reg, &templates.TestStubModel{
|
|
To: "test-recipient-3@example.org",
|
|
Subject: "test-subject-3",
|
|
Body: "test-body-3",
|
|
}))
|
|
require.NoError(t, err)
|
|
require.NotEqual(t, uuid.Nil, id)
|
|
|
|
go func() {
|
|
require.NoError(t, c.Work(ctx))
|
|
}()
|
|
|
|
var body []byte
|
|
for k := 0; k < 30; k++ {
|
|
time.Sleep(time.Second)
|
|
err = func() error {
|
|
res, err := http.Get(api + "/api/v2/messages")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
defer func() { _ = res.Body.Close() }()
|
|
body, err = io.ReadAll(res.Body)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if http.StatusOK != res.StatusCode {
|
|
return errors.Errorf("expected status code 200 but got %d with body: %s", res.StatusCode, body)
|
|
}
|
|
|
|
if total := gjson.GetBytes(body, "total").Int(); total != 3 {
|
|
return errors.Errorf("expected to have delivered exactly 3 messages but got count %d with body: %s", total, body)
|
|
}
|
|
|
|
return nil
|
|
}()
|
|
if err == nil {
|
|
break
|
|
}
|
|
}
|
|
require.NoError(t, err)
|
|
|
|
for k := 1; k <= 3; k++ {
|
|
assert.Contains(t, string(body), fmt.Sprintf("test-subject-%d", k))
|
|
assert.Contains(t, string(body), fmt.Sprintf("test-body-%d", k))
|
|
assert.Contains(t, string(body), fmt.Sprintf("test-recipient-%d@example.org", k))
|
|
assert.Contains(t, string(body), "test-stub@ory.sh")
|
|
}
|
|
|
|
// Assertion for the third email with sender name and headers
|
|
assert.Contains(t, string(body), "Bob")
|
|
assert.Contains(t, string(body), `"test-stub-header1":["foo"]`)
|
|
assert.Contains(t, string(body), `"test-stub-header2":["bar"]`)
|
|
}
|
|
|
|
func generateTestClientCert(t *testing.T) (clientCert *os.File, clientKey *os.File, err error) {
|
|
hostName := flag.String("host", "127.0.0.1", "Hostname to certify")
|
|
priv, err := rsa.GenerateKey(rand.Reader, 1024) // #nosec G403 -- test code
|
|
require.NoError(t, err)
|
|
now := time.Now()
|
|
certTemplate := x509.Certificate{
|
|
SerialNumber: big.NewInt(1234),
|
|
Subject: pkix.Name{
|
|
CommonName: *hostName,
|
|
Organization: []string{"myorg"},
|
|
},
|
|
NotBefore: now.Add(-300 * time.Second),
|
|
NotAfter: now.Add(24 * time.Hour),
|
|
SubjectKeyId: []byte{1, 2, 3, 4},
|
|
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
|
}
|
|
cert, err := x509.CreateCertificate(rand.Reader, &certTemplate, &certTemplate, &priv.PublicKey, priv)
|
|
require.NoError(t, err)
|
|
clientCert, err = os.CreateTemp("./test", "testCert")
|
|
require.NoError(t, err)
|
|
defer func() { _ = clientCert.Close() }()
|
|
|
|
require.NoError(t, pem.Encode(clientCert, &pem.Block{Type: "CERTIFICATE", Bytes: cert}))
|
|
|
|
clientKey, err = os.CreateTemp("./test", "testKey")
|
|
require.NoError(t, err)
|
|
defer func() { _ = clientKey.Close() }()
|
|
require.NoError(t, pem.Encode(clientKey, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}))
|
|
|
|
return clientCert, clientKey, nil
|
|
}
|