mirror of https://github.com/ory/kratos
feat: speed up OIDC login+registration handling
GitOrigin-RevId: bd516d8f488b385e2968d56509123cde78d0f642
This commit is contained in:
parent
e7045c55a3
commit
6bfbaf5791
|
|
@ -1,6 +1,6 @@
|
|||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.34.2
|
||||
// protoc-gen-go v1.36.10
|
||||
// protoc (unknown)
|
||||
// source: oidc/v1/state.proto
|
||||
|
||||
|
|
@ -9,6 +9,7 @@ package oidcv1
|
|||
import (
|
||||
reflect "reflect"
|
||||
sync "sync"
|
||||
unsafe "unsafe"
|
||||
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
|
|
@ -21,24 +22,74 @@ const (
|
|||
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||
)
|
||||
|
||||
type State struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
type FlowKind int32
|
||||
|
||||
FlowId []byte `protobuf:"bytes,1,opt,name=flow_id,json=flowId,proto3" json:"flow_id,omitempty"`
|
||||
SessionTokenExchangeCodeSha512 []byte `protobuf:"bytes,2,opt,name=session_token_exchange_code_sha512,json=sessionTokenExchangeCodeSha512,proto3" json:"session_token_exchange_code_sha512,omitempty"`
|
||||
ProviderId string `protobuf:"bytes,3,opt,name=provider_id,json=providerId,proto3" json:"provider_id,omitempty"`
|
||||
PkceVerifier string `protobuf:"bytes,4,opt,name=pkce_verifier,json=pkceVerifier,proto3" json:"pkce_verifier,omitempty"`
|
||||
const (
|
||||
FlowKind_FLOW_KIND_UNSPECIFIED FlowKind = 0
|
||||
FlowKind_FLOW_KIND_LOGIN FlowKind = 1
|
||||
FlowKind_FLOW_KIND_REGISTRATION FlowKind = 2
|
||||
FlowKind_FLOW_KIND_SETTINGS FlowKind = 3
|
||||
)
|
||||
|
||||
// Enum value maps for FlowKind.
|
||||
var (
|
||||
FlowKind_name = map[int32]string{
|
||||
0: "FLOW_KIND_UNSPECIFIED",
|
||||
1: "FLOW_KIND_LOGIN",
|
||||
2: "FLOW_KIND_REGISTRATION",
|
||||
3: "FLOW_KIND_SETTINGS",
|
||||
}
|
||||
FlowKind_value = map[string]int32{
|
||||
"FLOW_KIND_UNSPECIFIED": 0,
|
||||
"FLOW_KIND_LOGIN": 1,
|
||||
"FLOW_KIND_REGISTRATION": 2,
|
||||
"FLOW_KIND_SETTINGS": 3,
|
||||
}
|
||||
)
|
||||
|
||||
func (x FlowKind) Enum() *FlowKind {
|
||||
p := new(FlowKind)
|
||||
*p = x
|
||||
return p
|
||||
}
|
||||
|
||||
func (x FlowKind) String() string {
|
||||
return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x))
|
||||
}
|
||||
|
||||
func (FlowKind) Descriptor() protoreflect.EnumDescriptor {
|
||||
return file_oidc_v1_state_proto_enumTypes[0].Descriptor()
|
||||
}
|
||||
|
||||
func (FlowKind) Type() protoreflect.EnumType {
|
||||
return &file_oidc_v1_state_proto_enumTypes[0]
|
||||
}
|
||||
|
||||
func (x FlowKind) Number() protoreflect.EnumNumber {
|
||||
return protoreflect.EnumNumber(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use FlowKind.Descriptor instead.
|
||||
func (FlowKind) EnumDescriptor() ([]byte, []int) {
|
||||
return file_oidc_v1_state_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
type State struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
FlowId []byte `protobuf:"bytes,1,opt,name=flow_id,json=flowId,proto3" json:"flow_id,omitempty"`
|
||||
SessionTokenExchangeCodeSha512 []byte `protobuf:"bytes,2,opt,name=session_token_exchange_code_sha512,json=sessionTokenExchangeCodeSha512,proto3" json:"session_token_exchange_code_sha512,omitempty"`
|
||||
ProviderId string `protobuf:"bytes,3,opt,name=provider_id,json=providerId,proto3" json:"provider_id,omitempty"`
|
||||
PkceVerifier string `protobuf:"bytes,4,opt,name=pkce_verifier,json=pkceVerifier,proto3" json:"pkce_verifier,omitempty"`
|
||||
FlowKind FlowKind `protobuf:"varint,5,opt,name=flow_kind,json=flowKind,proto3,enum=oidc.v1.FlowKind" json:"flow_kind,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *State) Reset() {
|
||||
*x = State{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_oidc_v1_state_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
mi := &file_oidc_v1_state_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *State) String() string {
|
||||
|
|
@ -49,7 +100,7 @@ func (*State) ProtoMessage() {}
|
|||
|
||||
func (x *State) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_oidc_v1_state_proto_msgTypes[0]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
|
|
@ -92,55 +143,58 @@ func (x *State) GetPkceVerifier() string {
|
|||
return ""
|
||||
}
|
||||
|
||||
func (x *State) GetFlowKind() FlowKind {
|
||||
if x != nil {
|
||||
return x.FlowKind
|
||||
}
|
||||
return FlowKind_FLOW_KIND_UNSPECIFIED
|
||||
}
|
||||
|
||||
var File_oidc_v1_state_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_oidc_v1_state_proto_rawDesc = []byte{
|
||||
0x0a, 0x13, 0x6f, 0x69, 0x64, 0x63, 0x2f, 0x76, 0x31, 0x2f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x2e,
|
||||
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x07, 0x6f, 0x69, 0x64, 0x63, 0x2e, 0x76, 0x31, 0x22, 0xb2,
|
||||
0x01, 0x0a, 0x05, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x17, 0x0a, 0x07, 0x66, 0x6c, 0x6f, 0x77,
|
||||
0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x66, 0x6c, 0x6f, 0x77, 0x49,
|
||||
0x64, 0x12, 0x4a, 0x0a, 0x22, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x74, 0x6f, 0x6b,
|
||||
0x65, 0x6e, 0x5f, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x5f, 0x63, 0x6f, 0x64, 0x65,
|
||||
0x5f, 0x73, 0x68, 0x61, 0x35, 0x31, 0x32, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x1e, 0x73,
|
||||
0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x78, 0x63, 0x68, 0x61,
|
||||
0x6e, 0x67, 0x65, 0x43, 0x6f, 0x64, 0x65, 0x53, 0x68, 0x61, 0x35, 0x31, 0x32, 0x12, 0x1f, 0x0a,
|
||||
0x0b, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01,
|
||||
0x28, 0x09, 0x52, 0x0a, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x49, 0x64, 0x12, 0x23,
|
||||
0x0a, 0x0d, 0x70, 0x6b, 0x63, 0x65, 0x5f, 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, 0x65, 0x72, 0x18,
|
||||
0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x70, 0x6b, 0x63, 0x65, 0x56, 0x65, 0x72, 0x69, 0x66,
|
||||
0x69, 0x65, 0x72, 0x42, 0x7c, 0x0a, 0x0b, 0x63, 0x6f, 0x6d, 0x2e, 0x6f, 0x69, 0x64, 0x63, 0x2e,
|
||||
0x76, 0x31, 0x42, 0x0a, 0x53, 0x74, 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01,
|
||||
0x5a, 0x24, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6f, 0x72, 0x79,
|
||||
0x2f, 0x6b, 0x72, 0x61, 0x74, 0x6f, 0x73, 0x2f, 0x6f, 0x69, 0x64, 0x63, 0x2f, 0x76, 0x31, 0x3b,
|
||||
0x6f, 0x69, 0x64, 0x63, 0x76, 0x31, 0xa2, 0x02, 0x03, 0x4f, 0x58, 0x58, 0xaa, 0x02, 0x07, 0x4f,
|
||||
0x69, 0x64, 0x63, 0x2e, 0x56, 0x31, 0xca, 0x02, 0x07, 0x4f, 0x69, 0x64, 0x63, 0x5c, 0x56, 0x31,
|
||||
0xe2, 0x02, 0x13, 0x4f, 0x69, 0x64, 0x63, 0x5c, 0x56, 0x31, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65,
|
||||
0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0xea, 0x02, 0x08, 0x4f, 0x69, 0x64, 0x63, 0x3a, 0x3a, 0x56,
|
||||
0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
}
|
||||
const file_oidc_v1_state_proto_rawDesc = "" +
|
||||
"\n" +
|
||||
"\x13oidc/v1/state.proto\x12\aoidc.v1\"\xe2\x01\n" +
|
||||
"\x05State\x12\x17\n" +
|
||||
"\aflow_id\x18\x01 \x01(\fR\x06flowId\x12J\n" +
|
||||
"\"session_token_exchange_code_sha512\x18\x02 \x01(\fR\x1esessionTokenExchangeCodeSha512\x12\x1f\n" +
|
||||
"\vprovider_id\x18\x03 \x01(\tR\n" +
|
||||
"providerId\x12#\n" +
|
||||
"\rpkce_verifier\x18\x04 \x01(\tR\fpkceVerifier\x12.\n" +
|
||||
"\tflow_kind\x18\x05 \x01(\x0e2\x11.oidc.v1.FlowKindR\bflowKind*n\n" +
|
||||
"\bFlowKind\x12\x19\n" +
|
||||
"\x15FLOW_KIND_UNSPECIFIED\x10\x00\x12\x13\n" +
|
||||
"\x0fFLOW_KIND_LOGIN\x10\x01\x12\x1a\n" +
|
||||
"\x16FLOW_KIND_REGISTRATION\x10\x02\x12\x16\n" +
|
||||
"\x12FLOW_KIND_SETTINGS\x10\x03B|\n" +
|
||||
"\vcom.oidc.v1B\n" +
|
||||
"StateProtoP\x01Z$github.com/ory/kratos/oidc/v1;oidcv1\xa2\x02\x03OXX\xaa\x02\aOidc.V1\xca\x02\aOidc\\V1\xe2\x02\x13Oidc\\V1\\GPBMetadata\xea\x02\bOidc::V1b\x06proto3"
|
||||
|
||||
var (
|
||||
file_oidc_v1_state_proto_rawDescOnce sync.Once
|
||||
file_oidc_v1_state_proto_rawDescData = file_oidc_v1_state_proto_rawDesc
|
||||
file_oidc_v1_state_proto_rawDescData []byte
|
||||
)
|
||||
|
||||
func file_oidc_v1_state_proto_rawDescGZIP() []byte {
|
||||
file_oidc_v1_state_proto_rawDescOnce.Do(func() {
|
||||
file_oidc_v1_state_proto_rawDescData = protoimpl.X.CompressGZIP(file_oidc_v1_state_proto_rawDescData)
|
||||
file_oidc_v1_state_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_oidc_v1_state_proto_rawDesc), len(file_oidc_v1_state_proto_rawDesc)))
|
||||
})
|
||||
return file_oidc_v1_state_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_oidc_v1_state_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
|
||||
var file_oidc_v1_state_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
|
||||
var file_oidc_v1_state_proto_goTypes = []any{
|
||||
(*State)(nil), // 0: oidc.v1.State
|
||||
(FlowKind)(0), // 0: oidc.v1.FlowKind
|
||||
(*State)(nil), // 1: oidc.v1.State
|
||||
}
|
||||
var file_oidc_v1_state_proto_depIdxs = []int32{
|
||||
0, // [0:0] is the sub-list for method output_type
|
||||
0, // [0:0] is the sub-list for method input_type
|
||||
0, // [0:0] is the sub-list for extension type_name
|
||||
0, // [0:0] is the sub-list for extension extendee
|
||||
0, // [0:0] is the sub-list for field type_name
|
||||
0, // 0: oidc.v1.State.flow_kind:type_name -> oidc.v1.FlowKind
|
||||
1, // [1:1] is the sub-list for method output_type
|
||||
1, // [1:1] is the sub-list for method input_type
|
||||
1, // [1:1] is the sub-list for extension type_name
|
||||
1, // [1:1] is the sub-list for extension extendee
|
||||
0, // [0:1] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_oidc_v1_state_proto_init() }
|
||||
|
|
@ -148,36 +202,22 @@ func file_oidc_v1_state_proto_init() {
|
|||
if File_oidc_v1_state_proto != nil {
|
||||
return
|
||||
}
|
||||
if !protoimpl.UnsafeEnabled {
|
||||
file_oidc_v1_state_proto_msgTypes[0].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*State); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
File: protoimpl.DescBuilder{
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: file_oidc_v1_state_proto_rawDesc,
|
||||
NumEnums: 0,
|
||||
RawDescriptor: unsafe.Slice(unsafe.StringData(file_oidc_v1_state_proto_rawDesc), len(file_oidc_v1_state_proto_rawDesc)),
|
||||
NumEnums: 1,
|
||||
NumMessages: 1,
|
||||
NumExtensions: 0,
|
||||
NumServices: 0,
|
||||
},
|
||||
GoTypes: file_oidc_v1_state_proto_goTypes,
|
||||
DependencyIndexes: file_oidc_v1_state_proto_depIdxs,
|
||||
EnumInfos: file_oidc_v1_state_proto_enumTypes,
|
||||
MessageInfos: file_oidc_v1_state_proto_msgTypes,
|
||||
}.Build()
|
||||
File_oidc_v1_state_proto = out.File
|
||||
file_oidc_v1_state_proto_rawDesc = nil
|
||||
file_oidc_v1_state_proto_goTypes = nil
|
||||
file_oidc_v1_state_proto_depIdxs = nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,9 +2,17 @@ syntax = "proto3";
|
|||
|
||||
package oidc.v1;
|
||||
|
||||
enum FlowKind {
|
||||
FLOW_KIND_UNSPECIFIED = 0;
|
||||
FLOW_KIND_LOGIN = 1;
|
||||
FLOW_KIND_REGISTRATION = 2;
|
||||
FLOW_KIND_SETTINGS = 3;
|
||||
}
|
||||
|
||||
message State {
|
||||
bytes flow_id = 1;
|
||||
bytes session_token_exchange_code_sha512 = 2;
|
||||
string provider_id = 3;
|
||||
string pkce_verifier = 4;
|
||||
FlowKind flow_kind = 5;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,7 +13,10 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
oidcv1 "github.com/ory/kratos/gen/oidc/v1"
|
||||
"github.com/ory/kratos/internal"
|
||||
"github.com/ory/kratos/selfservice/flow/login"
|
||||
"github.com/ory/kratos/selfservice/flow/registration"
|
||||
"github.com/ory/kratos/selfservice/strategy/oidc"
|
||||
"github.com/ory/kratos/x"
|
||||
)
|
||||
|
|
@ -54,12 +57,17 @@ func TestPKCESupport(t *testing.T) {
|
|||
} {
|
||||
provider := oidc.NewProviderGenericOIDC(tc.c, reg)
|
||||
|
||||
stateParam, pkce, err := strat.GenerateState(context.Background(), provider, x.NewUUID())
|
||||
flow := &login.Flow{
|
||||
ID: x.NewUUID(),
|
||||
}
|
||||
|
||||
stateParam, pkce, err := strat.GenerateState(context.Background(), provider, flow)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, stateParam)
|
||||
|
||||
state, err := oidc.DecryptState(context.Background(), reg.Cipher(context.Background()), stateParam)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, oidcv1.FlowKind_FLOW_KIND_LOGIN, state.FlowKind)
|
||||
|
||||
if tc.pkce {
|
||||
require.NotEmpty(t, pkce)
|
||||
|
|
@ -76,7 +84,7 @@ func TestPKCESupport(t *testing.T) {
|
|||
oidc.NewProviderX(&oidc.Configuration{IssuerURL: supported.URL, PKCE: "never"}, reg),
|
||||
oidc.NewProviderX(&oidc.Configuration{IssuerURL: supported.URL, PKCE: "auto"}, reg),
|
||||
} {
|
||||
stateParam, pkce, err := strat.GenerateState(context.Background(), provider, x.NewUUID())
|
||||
stateParam, pkce, err := strat.GenerateState(context.Background(), provider, ®istration.Flow{ID: x.NewUUID()})
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, stateParam)
|
||||
assert.Empty(t, pkce)
|
||||
|
|
@ -84,6 +92,7 @@ func TestPKCESupport(t *testing.T) {
|
|||
state, err := oidc.DecryptState(context.Background(), reg.Cipher(context.Background()), stateParam)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, oidc.PKCEVerifier(state))
|
||||
assert.Equal(t, oidcv1.FlowKind_FLOW_KIND_REGISTRATION, state.FlowKind)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,13 +8,16 @@ import (
|
|||
"crypto/sha512"
|
||||
"crypto/subtle"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
"golang.org/x/oauth2"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/ory/herodot"
|
||||
"github.com/ory/kratos/cipher"
|
||||
oidcv1 "github.com/ory/kratos/gen/oidc/v1"
|
||||
"github.com/ory/kratos/selfservice/flow"
|
||||
"github.com/ory/kratos/selfservice/flow/login"
|
||||
"github.com/ory/kratos/selfservice/flow/registration"
|
||||
"github.com/ory/kratos/selfservice/flow/settings"
|
||||
"github.com/ory/kratos/x"
|
||||
)
|
||||
|
||||
|
|
@ -38,14 +41,26 @@ func DecryptState(ctx context.Context, c cipher.Cipher, ciphertext string) (*oid
|
|||
return &state, nil
|
||||
}
|
||||
|
||||
func (s *Strategy) GenerateState(ctx context.Context, p Provider, flowID uuid.UUID) (stateParam string, pkce []oauth2.AuthCodeOption, err error) {
|
||||
func (s *Strategy) GenerateState(ctx context.Context, p Provider, flow flow.Flow) (stateParam string, pkce []oauth2.AuthCodeOption, err error) {
|
||||
state := oidcv1.State{
|
||||
FlowId: flowID.Bytes(),
|
||||
FlowId: flow.GetID().Bytes(),
|
||||
SessionTokenExchangeCodeSha512: x.NewUUID().Bytes(),
|
||||
ProviderId: p.Config().ID,
|
||||
PkceVerifier: maybePKCE(ctx, s.d, p),
|
||||
}
|
||||
if code, hasCode, _ := s.d.SessionTokenExchangePersister().CodeForFlow(ctx, flowID); hasCode {
|
||||
|
||||
switch flow.(type) {
|
||||
case *login.Flow:
|
||||
state.FlowKind = oidcv1.FlowKind_FLOW_KIND_LOGIN
|
||||
case *registration.Flow:
|
||||
state.FlowKind = oidcv1.FlowKind_FLOW_KIND_REGISTRATION
|
||||
case *settings.Flow:
|
||||
state.FlowKind = oidcv1.FlowKind_FLOW_KIND_SETTINGS
|
||||
default:
|
||||
state.FlowKind = oidcv1.FlowKind_FLOW_KIND_UNSPECIFIED
|
||||
}
|
||||
|
||||
if code, hasCode, _ := s.d.SessionTokenExchangePersister().CodeForFlow(ctx, flow.GetID()); hasCode {
|
||||
sum := sha512.Sum512([]byte(code.InitCode))
|
||||
state.SessionTokenExchangeCodeSha512 = sum[:]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,7 +11,9 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/ory/kratos/cipher"
|
||||
oidcv1 "github.com/ory/kratos/gen/oidc/v1"
|
||||
"github.com/ory/kratos/internal"
|
||||
"github.com/ory/kratos/selfservice/flow/registration"
|
||||
"github.com/ory/kratos/selfservice/strategy/oidc"
|
||||
"github.com/ory/kratos/x"
|
||||
)
|
||||
|
|
@ -25,18 +27,21 @@ func TestGenerateState(t *testing.T) {
|
|||
_, ok := ciph.(*cipher.Noop)
|
||||
require.False(t, ok)
|
||||
|
||||
flowID := x.NewUUID()
|
||||
flow := ®istration.Flow{
|
||||
ID: x.NewUUID(),
|
||||
}
|
||||
|
||||
stateParam, pkce, err := strat.GenerateState(ctx, &testProvider{}, flowID)
|
||||
stateParam, pkce, err := strat.GenerateState(ctx, &testProvider{}, flow)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, stateParam)
|
||||
assert.Empty(t, pkce)
|
||||
|
||||
state, err := oidc.DecryptState(ctx, ciph, stateParam)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, flowID.Bytes(), state.FlowId)
|
||||
assert.Equal(t, flow.GetID().Bytes(), state.FlowId)
|
||||
assert.Empty(t, oidc.PKCEVerifier(state))
|
||||
assert.Equal(t, "test-provider", state.ProviderId)
|
||||
assert.Equal(t, oidcv1.FlowKind_FLOW_KIND_REGISTRATION, state.FlowKind)
|
||||
}
|
||||
|
||||
type testProvider struct{}
|
||||
|
|
|
|||
|
|
@ -297,39 +297,47 @@ func (s *Strategy) ID() identity.CredentialsType {
|
|||
return s.credType
|
||||
}
|
||||
|
||||
func (s *Strategy) validateFlow(ctx context.Context, r *http.Request, rid uuid.UUID) (flow.Flow, error) {
|
||||
func (s *Strategy) validateFlow(ctx context.Context, r *http.Request, rid uuid.UUID, kind oidcv1.FlowKind) (f flow.Flow, err error) {
|
||||
if rid.IsNil() {
|
||||
return nil, errors.WithStack(herodot.ErrBadRequest.WithReason("The session cookie contains invalid values and the flow could not be executed. Please try again."))
|
||||
}
|
||||
|
||||
if ar, err := s.d.RegistrationFlowPersister().GetRegistrationFlow(ctx, rid); err == nil {
|
||||
if err := ar.Valid(); err != nil {
|
||||
return ar, err
|
||||
}
|
||||
return ar, nil
|
||||
}
|
||||
switch kind {
|
||||
|
||||
if ar, err := s.d.LoginFlowPersister().GetLoginFlow(ctx, rid); err == nil {
|
||||
if err := ar.Valid(); err != nil {
|
||||
return ar, err
|
||||
case oidcv1.FlowKind_FLOW_KIND_LOGIN:
|
||||
lf, err := s.d.LoginFlowPersister().GetLoginFlow(ctx, rid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ar, nil
|
||||
}
|
||||
return lf, lf.Valid()
|
||||
|
||||
ar, err := s.d.SettingsFlowPersister().GetSettingsFlow(ctx, rid)
|
||||
if err == nil {
|
||||
case oidcv1.FlowKind_FLOW_KIND_REGISTRATION:
|
||||
rf, err := s.d.RegistrationFlowPersister().GetRegistrationFlow(ctx, rid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rf, rf.Valid()
|
||||
|
||||
case oidcv1.FlowKind_FLOW_KIND_SETTINGS:
|
||||
sf, err := s.d.SettingsFlowPersister().GetSettingsFlow(ctx, rid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sess, err := s.d.SessionManager().FetchFromRequest(ctx, r)
|
||||
if err != nil {
|
||||
return ar, err
|
||||
return sf, err
|
||||
}
|
||||
return sf, sf.Valid(sess)
|
||||
|
||||
if err := ar.Valid(sess); err != nil {
|
||||
return ar, err
|
||||
}
|
||||
return ar, nil
|
||||
}
|
||||
|
||||
return ar, err // this must return the error
|
||||
// fallback to the old behavior for backwards compatibility
|
||||
for _, kind := range []oidcv1.FlowKind{oidcv1.FlowKind_FLOW_KIND_LOGIN, oidcv1.FlowKind_FLOW_KIND_REGISTRATION, oidcv1.FlowKind_FLOW_KIND_SETTINGS} {
|
||||
if f, err = s.validateFlow(ctx, r, rid, kind); f != nil {
|
||||
return f, err
|
||||
}
|
||||
}
|
||||
return f, err
|
||||
}
|
||||
|
||||
func (s *Strategy) ValidateCallback(w http.ResponseWriter, r *http.Request) (flow.Flow, *oidcv1.State, *AuthCodeContainer, error) {
|
||||
|
|
@ -362,7 +370,7 @@ func (s *Strategy) ValidateCallback(w http.ResponseWriter, r *http.Request) (flo
|
|||
return nil, nil, nil, errors.WithStack(herodot.ErrBadRequest.WithReasonf(`Unable to complete OpenID Connect flow: provider could not be retrieved from state nor URL.`))
|
||||
}
|
||||
|
||||
f, err := s.validateFlow(r.Context(), r, uuid.FromBytesOrNil(state.FlowId))
|
||||
f, err := s.validateFlow(r.Context(), r, uuid.FromBytesOrNil(state.FlowId), state.FlowKind)
|
||||
if err != nil {
|
||||
return nil, state, nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ import (
|
|||
|
||||
"github.com/ory/herodot"
|
||||
"github.com/ory/kratos/continuity"
|
||||
oidcv1 "github.com/ory/kratos/gen/oidc/v1"
|
||||
"github.com/ory/kratos/identity"
|
||||
"github.com/ory/kratos/selfservice/flow"
|
||||
"github.com/ory/kratos/selfservice/flow/login"
|
||||
|
|
@ -302,7 +303,7 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow,
|
|||
return nil, s.HandleError(ctx, w, r, f, pid, nil, err)
|
||||
}
|
||||
|
||||
req, err := s.validateFlow(ctx, r, f.ID)
|
||||
req, err := s.validateFlow(ctx, r, f.ID, oidcv1.FlowKind_FLOW_KIND_LOGIN)
|
||||
if err != nil {
|
||||
return nil, s.HandleError(ctx, w, r, f, pid, nil, err)
|
||||
}
|
||||
|
|
@ -330,7 +331,7 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow,
|
|||
return nil, errors.WithStack(flow.ErrCompletedByStrategy)
|
||||
}
|
||||
|
||||
state, pkce, err := s.GenerateState(ctx, provider, f.ID)
|
||||
state, pkce, err := s.GenerateState(ctx, provider, f)
|
||||
if err != nil {
|
||||
return nil, s.HandleError(ctx, w, r, f, pid, nil, err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ import (
|
|||
|
||||
"github.com/ory/herodot"
|
||||
"github.com/ory/kratos/continuity"
|
||||
oidcv1 "github.com/ory/kratos/gen/oidc/v1"
|
||||
"github.com/ory/kratos/identity"
|
||||
"github.com/ory/kratos/selfservice/flow"
|
||||
"github.com/ory/kratos/selfservice/flow/login"
|
||||
|
|
@ -201,7 +202,7 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registrat
|
|||
return s.HandleError(ctx, w, r, f, pid, nil, err)
|
||||
}
|
||||
|
||||
req, err := s.validateFlow(ctx, r, f.ID)
|
||||
req, err := s.validateFlow(ctx, r, f.ID, oidcv1.FlowKind_FLOW_KIND_REGISTRATION)
|
||||
if err != nil {
|
||||
return s.HandleError(ctx, w, r, f, pid, nil, err)
|
||||
}
|
||||
|
|
@ -231,7 +232,7 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registrat
|
|||
return errors.WithStack(flow.ErrCompletedByStrategy)
|
||||
}
|
||||
|
||||
state, pkce, err := s.GenerateState(ctx, provider, f.ID)
|
||||
state, pkce, err := s.GenerateState(ctx, provider, f)
|
||||
if err != nil {
|
||||
return s.HandleError(ctx, w, r, f, pid, nil, err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ import (
|
|||
"github.com/ory/x/stringsx"
|
||||
|
||||
"github.com/ory/kratos/continuity"
|
||||
oidcv1 "github.com/ory/kratos/gen/oidc/v1"
|
||||
"github.com/ory/kratos/identity"
|
||||
"github.com/ory/kratos/selfservice/flow"
|
||||
"github.com/ory/kratos/selfservice/flow/settings"
|
||||
|
|
@ -365,12 +366,12 @@ func (s *Strategy) initLinkProvider(ctx context.Context, w http.ResponseWriter,
|
|||
return s.handleSettingsError(ctx, w, r, ctxUpdate, p, err)
|
||||
}
|
||||
|
||||
req, err := s.validateFlow(ctx, r, ctxUpdate.Flow.ID)
|
||||
req, err := s.validateFlow(ctx, r, ctxUpdate.Flow.ID, oidcv1.FlowKind_FLOW_KIND_SETTINGS)
|
||||
if err != nil {
|
||||
return s.handleSettingsError(ctx, w, r, ctxUpdate, p, err)
|
||||
}
|
||||
|
||||
state, pkce, err := s.GenerateState(ctx, provider, ctxUpdate.Flow.ID)
|
||||
state, pkce, err := s.GenerateState(ctx, provider, ctxUpdate.Flow)
|
||||
if err != nil {
|
||||
return s.handleSettingsError(ctx, w, r, ctxUpdate, p, err)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue