feat: speed up OIDC login+registration handling

GitOrigin-RevId: bd516d8f488b385e2968d56509123cde78d0f642
This commit is contained in:
Arne Luenser 2025-11-26 17:45:51 +01:00 committed by ory-bot
parent e7045c55a3
commit 6bfbaf5791
9 changed files with 188 additions and 100 deletions

View File

@ -1,6 +1,6 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.34.2
// protoc-gen-go v1.36.10
// protoc (unknown)
// source: oidc/v1/state.proto
@ -9,6 +9,7 @@ package oidcv1
import (
reflect "reflect"
sync "sync"
unsafe "unsafe"
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
@ -21,24 +22,74 @@ const (
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type State struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
type FlowKind int32
FlowId []byte `protobuf:"bytes,1,opt,name=flow_id,json=flowId,proto3" json:"flow_id,omitempty"`
SessionTokenExchangeCodeSha512 []byte `protobuf:"bytes,2,opt,name=session_token_exchange_code_sha512,json=sessionTokenExchangeCodeSha512,proto3" json:"session_token_exchange_code_sha512,omitempty"`
ProviderId string `protobuf:"bytes,3,opt,name=provider_id,json=providerId,proto3" json:"provider_id,omitempty"`
PkceVerifier string `protobuf:"bytes,4,opt,name=pkce_verifier,json=pkceVerifier,proto3" json:"pkce_verifier,omitempty"`
const (
FlowKind_FLOW_KIND_UNSPECIFIED FlowKind = 0
FlowKind_FLOW_KIND_LOGIN FlowKind = 1
FlowKind_FLOW_KIND_REGISTRATION FlowKind = 2
FlowKind_FLOW_KIND_SETTINGS FlowKind = 3
)
// Enum value maps for FlowKind.
var (
FlowKind_name = map[int32]string{
0: "FLOW_KIND_UNSPECIFIED",
1: "FLOW_KIND_LOGIN",
2: "FLOW_KIND_REGISTRATION",
3: "FLOW_KIND_SETTINGS",
}
FlowKind_value = map[string]int32{
"FLOW_KIND_UNSPECIFIED": 0,
"FLOW_KIND_LOGIN": 1,
"FLOW_KIND_REGISTRATION": 2,
"FLOW_KIND_SETTINGS": 3,
}
)
func (x FlowKind) Enum() *FlowKind {
p := new(FlowKind)
*p = x
return p
}
func (x FlowKind) String() string {
return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x))
}
func (FlowKind) Descriptor() protoreflect.EnumDescriptor {
return file_oidc_v1_state_proto_enumTypes[0].Descriptor()
}
func (FlowKind) Type() protoreflect.EnumType {
return &file_oidc_v1_state_proto_enumTypes[0]
}
func (x FlowKind) Number() protoreflect.EnumNumber {
return protoreflect.EnumNumber(x)
}
// Deprecated: Use FlowKind.Descriptor instead.
func (FlowKind) EnumDescriptor() ([]byte, []int) {
return file_oidc_v1_state_proto_rawDescGZIP(), []int{0}
}
type State struct {
state protoimpl.MessageState `protogen:"open.v1"`
FlowId []byte `protobuf:"bytes,1,opt,name=flow_id,json=flowId,proto3" json:"flow_id,omitempty"`
SessionTokenExchangeCodeSha512 []byte `protobuf:"bytes,2,opt,name=session_token_exchange_code_sha512,json=sessionTokenExchangeCodeSha512,proto3" json:"session_token_exchange_code_sha512,omitempty"`
ProviderId string `protobuf:"bytes,3,opt,name=provider_id,json=providerId,proto3" json:"provider_id,omitempty"`
PkceVerifier string `protobuf:"bytes,4,opt,name=pkce_verifier,json=pkceVerifier,proto3" json:"pkce_verifier,omitempty"`
FlowKind FlowKind `protobuf:"varint,5,opt,name=flow_kind,json=flowKind,proto3,enum=oidc.v1.FlowKind" json:"flow_kind,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *State) Reset() {
*x = State{}
if protoimpl.UnsafeEnabled {
mi := &file_oidc_v1_state_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
mi := &file_oidc_v1_state_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *State) String() string {
@ -49,7 +100,7 @@ func (*State) ProtoMessage() {}
func (x *State) ProtoReflect() protoreflect.Message {
mi := &file_oidc_v1_state_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@ -92,55 +143,58 @@ func (x *State) GetPkceVerifier() string {
return ""
}
func (x *State) GetFlowKind() FlowKind {
if x != nil {
return x.FlowKind
}
return FlowKind_FLOW_KIND_UNSPECIFIED
}
var File_oidc_v1_state_proto protoreflect.FileDescriptor
var file_oidc_v1_state_proto_rawDesc = []byte{
0x0a, 0x13, 0x6f, 0x69, 0x64, 0x63, 0x2f, 0x76, 0x31, 0x2f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x2e,
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x07, 0x6f, 0x69, 0x64, 0x63, 0x2e, 0x76, 0x31, 0x22, 0xb2,
0x01, 0x0a, 0x05, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x17, 0x0a, 0x07, 0x66, 0x6c, 0x6f, 0x77,
0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x66, 0x6c, 0x6f, 0x77, 0x49,
0x64, 0x12, 0x4a, 0x0a, 0x22, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x74, 0x6f, 0x6b,
0x65, 0x6e, 0x5f, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x5f, 0x63, 0x6f, 0x64, 0x65,
0x5f, 0x73, 0x68, 0x61, 0x35, 0x31, 0x32, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x1e, 0x73,
0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x78, 0x63, 0x68, 0x61,
0x6e, 0x67, 0x65, 0x43, 0x6f, 0x64, 0x65, 0x53, 0x68, 0x61, 0x35, 0x31, 0x32, 0x12, 0x1f, 0x0a,
0x0b, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01,
0x28, 0x09, 0x52, 0x0a, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x49, 0x64, 0x12, 0x23,
0x0a, 0x0d, 0x70, 0x6b, 0x63, 0x65, 0x5f, 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, 0x65, 0x72, 0x18,
0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x70, 0x6b, 0x63, 0x65, 0x56, 0x65, 0x72, 0x69, 0x66,
0x69, 0x65, 0x72, 0x42, 0x7c, 0x0a, 0x0b, 0x63, 0x6f, 0x6d, 0x2e, 0x6f, 0x69, 0x64, 0x63, 0x2e,
0x76, 0x31, 0x42, 0x0a, 0x53, 0x74, 0x61, 0x74, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01,
0x5a, 0x24, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6f, 0x72, 0x79,
0x2f, 0x6b, 0x72, 0x61, 0x74, 0x6f, 0x73, 0x2f, 0x6f, 0x69, 0x64, 0x63, 0x2f, 0x76, 0x31, 0x3b,
0x6f, 0x69, 0x64, 0x63, 0x76, 0x31, 0xa2, 0x02, 0x03, 0x4f, 0x58, 0x58, 0xaa, 0x02, 0x07, 0x4f,
0x69, 0x64, 0x63, 0x2e, 0x56, 0x31, 0xca, 0x02, 0x07, 0x4f, 0x69, 0x64, 0x63, 0x5c, 0x56, 0x31,
0xe2, 0x02, 0x13, 0x4f, 0x69, 0x64, 0x63, 0x5c, 0x56, 0x31, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65,
0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0xea, 0x02, 0x08, 0x4f, 0x69, 0x64, 0x63, 0x3a, 0x3a, 0x56,
0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
const file_oidc_v1_state_proto_rawDesc = "" +
"\n" +
"\x13oidc/v1/state.proto\x12\aoidc.v1\"\xe2\x01\n" +
"\x05State\x12\x17\n" +
"\aflow_id\x18\x01 \x01(\fR\x06flowId\x12J\n" +
"\"session_token_exchange_code_sha512\x18\x02 \x01(\fR\x1esessionTokenExchangeCodeSha512\x12\x1f\n" +
"\vprovider_id\x18\x03 \x01(\tR\n" +
"providerId\x12#\n" +
"\rpkce_verifier\x18\x04 \x01(\tR\fpkceVerifier\x12.\n" +
"\tflow_kind\x18\x05 \x01(\x0e2\x11.oidc.v1.FlowKindR\bflowKind*n\n" +
"\bFlowKind\x12\x19\n" +
"\x15FLOW_KIND_UNSPECIFIED\x10\x00\x12\x13\n" +
"\x0fFLOW_KIND_LOGIN\x10\x01\x12\x1a\n" +
"\x16FLOW_KIND_REGISTRATION\x10\x02\x12\x16\n" +
"\x12FLOW_KIND_SETTINGS\x10\x03B|\n" +
"\vcom.oidc.v1B\n" +
"StateProtoP\x01Z$github.com/ory/kratos/oidc/v1;oidcv1\xa2\x02\x03OXX\xaa\x02\aOidc.V1\xca\x02\aOidc\\V1\xe2\x02\x13Oidc\\V1\\GPBMetadata\xea\x02\bOidc::V1b\x06proto3"
var (
file_oidc_v1_state_proto_rawDescOnce sync.Once
file_oidc_v1_state_proto_rawDescData = file_oidc_v1_state_proto_rawDesc
file_oidc_v1_state_proto_rawDescData []byte
)
func file_oidc_v1_state_proto_rawDescGZIP() []byte {
file_oidc_v1_state_proto_rawDescOnce.Do(func() {
file_oidc_v1_state_proto_rawDescData = protoimpl.X.CompressGZIP(file_oidc_v1_state_proto_rawDescData)
file_oidc_v1_state_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_oidc_v1_state_proto_rawDesc), len(file_oidc_v1_state_proto_rawDesc)))
})
return file_oidc_v1_state_proto_rawDescData
}
var file_oidc_v1_state_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
var file_oidc_v1_state_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
var file_oidc_v1_state_proto_goTypes = []any{
(*State)(nil), // 0: oidc.v1.State
(FlowKind)(0), // 0: oidc.v1.FlowKind
(*State)(nil), // 1: oidc.v1.State
}
var file_oidc_v1_state_proto_depIdxs = []int32{
0, // [0:0] is the sub-list for method output_type
0, // [0:0] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
0, // 0: oidc.v1.State.flow_kind:type_name -> oidc.v1.FlowKind
1, // [1:1] is the sub-list for method output_type
1, // [1:1] is the sub-list for method input_type
1, // [1:1] is the sub-list for extension type_name
1, // [1:1] is the sub-list for extension extendee
0, // [0:1] is the sub-list for field type_name
}
func init() { file_oidc_v1_state_proto_init() }
@ -148,36 +202,22 @@ func file_oidc_v1_state_proto_init() {
if File_oidc_v1_state_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_oidc_v1_state_proto_msgTypes[0].Exporter = func(v any, i int) any {
switch v := v.(*State); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_oidc_v1_state_proto_rawDesc,
NumEnums: 0,
RawDescriptor: unsafe.Slice(unsafe.StringData(file_oidc_v1_state_proto_rawDesc), len(file_oidc_v1_state_proto_rawDesc)),
NumEnums: 1,
NumMessages: 1,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_oidc_v1_state_proto_goTypes,
DependencyIndexes: file_oidc_v1_state_proto_depIdxs,
EnumInfos: file_oidc_v1_state_proto_enumTypes,
MessageInfos: file_oidc_v1_state_proto_msgTypes,
}.Build()
File_oidc_v1_state_proto = out.File
file_oidc_v1_state_proto_rawDesc = nil
file_oidc_v1_state_proto_goTypes = nil
file_oidc_v1_state_proto_depIdxs = nil
}

View File

@ -2,9 +2,17 @@ syntax = "proto3";
package oidc.v1;
enum FlowKind {
FLOW_KIND_UNSPECIFIED = 0;
FLOW_KIND_LOGIN = 1;
FLOW_KIND_REGISTRATION = 2;
FLOW_KIND_SETTINGS = 3;
}
message State {
bytes flow_id = 1;
bytes session_token_exchange_code_sha512 = 2;
string provider_id = 3;
string pkce_verifier = 4;
FlowKind flow_kind = 5;
}

View File

@ -13,7 +13,10 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
oidcv1 "github.com/ory/kratos/gen/oidc/v1"
"github.com/ory/kratos/internal"
"github.com/ory/kratos/selfservice/flow/login"
"github.com/ory/kratos/selfservice/flow/registration"
"github.com/ory/kratos/selfservice/strategy/oidc"
"github.com/ory/kratos/x"
)
@ -54,12 +57,17 @@ func TestPKCESupport(t *testing.T) {
} {
provider := oidc.NewProviderGenericOIDC(tc.c, reg)
stateParam, pkce, err := strat.GenerateState(context.Background(), provider, x.NewUUID())
flow := &login.Flow{
ID: x.NewUUID(),
}
stateParam, pkce, err := strat.GenerateState(context.Background(), provider, flow)
require.NoError(t, err)
require.NotEmpty(t, stateParam)
state, err := oidc.DecryptState(context.Background(), reg.Cipher(context.Background()), stateParam)
require.NoError(t, err)
assert.Equal(t, oidcv1.FlowKind_FLOW_KIND_LOGIN, state.FlowKind)
if tc.pkce {
require.NotEmpty(t, pkce)
@ -76,7 +84,7 @@ func TestPKCESupport(t *testing.T) {
oidc.NewProviderX(&oidc.Configuration{IssuerURL: supported.URL, PKCE: "never"}, reg),
oidc.NewProviderX(&oidc.Configuration{IssuerURL: supported.URL, PKCE: "auto"}, reg),
} {
stateParam, pkce, err := strat.GenerateState(context.Background(), provider, x.NewUUID())
stateParam, pkce, err := strat.GenerateState(context.Background(), provider, &registration.Flow{ID: x.NewUUID()})
require.NoError(t, err)
require.NotEmpty(t, stateParam)
assert.Empty(t, pkce)
@ -84,6 +92,7 @@ func TestPKCESupport(t *testing.T) {
state, err := oidc.DecryptState(context.Background(), reg.Cipher(context.Background()), stateParam)
require.NoError(t, err)
assert.Empty(t, oidc.PKCEVerifier(state))
assert.Equal(t, oidcv1.FlowKind_FLOW_KIND_REGISTRATION, state.FlowKind)
}
})
}

View File

@ -8,13 +8,16 @@ import (
"crypto/sha512"
"crypto/subtle"
"github.com/gofrs/uuid"
"golang.org/x/oauth2"
"google.golang.org/protobuf/proto"
"github.com/ory/herodot"
"github.com/ory/kratos/cipher"
oidcv1 "github.com/ory/kratos/gen/oidc/v1"
"github.com/ory/kratos/selfservice/flow"
"github.com/ory/kratos/selfservice/flow/login"
"github.com/ory/kratos/selfservice/flow/registration"
"github.com/ory/kratos/selfservice/flow/settings"
"github.com/ory/kratos/x"
)
@ -38,14 +41,26 @@ func DecryptState(ctx context.Context, c cipher.Cipher, ciphertext string) (*oid
return &state, nil
}
func (s *Strategy) GenerateState(ctx context.Context, p Provider, flowID uuid.UUID) (stateParam string, pkce []oauth2.AuthCodeOption, err error) {
func (s *Strategy) GenerateState(ctx context.Context, p Provider, flow flow.Flow) (stateParam string, pkce []oauth2.AuthCodeOption, err error) {
state := oidcv1.State{
FlowId: flowID.Bytes(),
FlowId: flow.GetID().Bytes(),
SessionTokenExchangeCodeSha512: x.NewUUID().Bytes(),
ProviderId: p.Config().ID,
PkceVerifier: maybePKCE(ctx, s.d, p),
}
if code, hasCode, _ := s.d.SessionTokenExchangePersister().CodeForFlow(ctx, flowID); hasCode {
switch flow.(type) {
case *login.Flow:
state.FlowKind = oidcv1.FlowKind_FLOW_KIND_LOGIN
case *registration.Flow:
state.FlowKind = oidcv1.FlowKind_FLOW_KIND_REGISTRATION
case *settings.Flow:
state.FlowKind = oidcv1.FlowKind_FLOW_KIND_SETTINGS
default:
state.FlowKind = oidcv1.FlowKind_FLOW_KIND_UNSPECIFIED
}
if code, hasCode, _ := s.d.SessionTokenExchangePersister().CodeForFlow(ctx, flow.GetID()); hasCode {
sum := sha512.Sum512([]byte(code.InitCode))
state.SessionTokenExchangeCodeSha512 = sum[:]
}

View File

@ -11,7 +11,9 @@ import (
"github.com/stretchr/testify/require"
"github.com/ory/kratos/cipher"
oidcv1 "github.com/ory/kratos/gen/oidc/v1"
"github.com/ory/kratos/internal"
"github.com/ory/kratos/selfservice/flow/registration"
"github.com/ory/kratos/selfservice/strategy/oidc"
"github.com/ory/kratos/x"
)
@ -25,18 +27,21 @@ func TestGenerateState(t *testing.T) {
_, ok := ciph.(*cipher.Noop)
require.False(t, ok)
flowID := x.NewUUID()
flow := &registration.Flow{
ID: x.NewUUID(),
}
stateParam, pkce, err := strat.GenerateState(ctx, &testProvider{}, flowID)
stateParam, pkce, err := strat.GenerateState(ctx, &testProvider{}, flow)
require.NoError(t, err)
require.NotEmpty(t, stateParam)
assert.Empty(t, pkce)
state, err := oidc.DecryptState(ctx, ciph, stateParam)
require.NoError(t, err)
assert.Equal(t, flowID.Bytes(), state.FlowId)
assert.Equal(t, flow.GetID().Bytes(), state.FlowId)
assert.Empty(t, oidc.PKCEVerifier(state))
assert.Equal(t, "test-provider", state.ProviderId)
assert.Equal(t, oidcv1.FlowKind_FLOW_KIND_REGISTRATION, state.FlowKind)
}
type testProvider struct{}

View File

@ -297,39 +297,47 @@ func (s *Strategy) ID() identity.CredentialsType {
return s.credType
}
func (s *Strategy) validateFlow(ctx context.Context, r *http.Request, rid uuid.UUID) (flow.Flow, error) {
func (s *Strategy) validateFlow(ctx context.Context, r *http.Request, rid uuid.UUID, kind oidcv1.FlowKind) (f flow.Flow, err error) {
if rid.IsNil() {
return nil, errors.WithStack(herodot.ErrBadRequest.WithReason("The session cookie contains invalid values and the flow could not be executed. Please try again."))
}
if ar, err := s.d.RegistrationFlowPersister().GetRegistrationFlow(ctx, rid); err == nil {
if err := ar.Valid(); err != nil {
return ar, err
}
return ar, nil
}
switch kind {
if ar, err := s.d.LoginFlowPersister().GetLoginFlow(ctx, rid); err == nil {
if err := ar.Valid(); err != nil {
return ar, err
case oidcv1.FlowKind_FLOW_KIND_LOGIN:
lf, err := s.d.LoginFlowPersister().GetLoginFlow(ctx, rid)
if err != nil {
return nil, err
}
return ar, nil
}
return lf, lf.Valid()
ar, err := s.d.SettingsFlowPersister().GetSettingsFlow(ctx, rid)
if err == nil {
case oidcv1.FlowKind_FLOW_KIND_REGISTRATION:
rf, err := s.d.RegistrationFlowPersister().GetRegistrationFlow(ctx, rid)
if err != nil {
return nil, err
}
return rf, rf.Valid()
case oidcv1.FlowKind_FLOW_KIND_SETTINGS:
sf, err := s.d.SettingsFlowPersister().GetSettingsFlow(ctx, rid)
if err != nil {
return nil, err
}
sess, err := s.d.SessionManager().FetchFromRequest(ctx, r)
if err != nil {
return ar, err
return sf, err
}
return sf, sf.Valid(sess)
if err := ar.Valid(sess); err != nil {
return ar, err
}
return ar, nil
}
return ar, err // this must return the error
// fallback to the old behavior for backwards compatibility
for _, kind := range []oidcv1.FlowKind{oidcv1.FlowKind_FLOW_KIND_LOGIN, oidcv1.FlowKind_FLOW_KIND_REGISTRATION, oidcv1.FlowKind_FLOW_KIND_SETTINGS} {
if f, err = s.validateFlow(ctx, r, rid, kind); f != nil {
return f, err
}
}
return f, err
}
func (s *Strategy) ValidateCallback(w http.ResponseWriter, r *http.Request) (flow.Flow, *oidcv1.State, *AuthCodeContainer, error) {
@ -362,7 +370,7 @@ func (s *Strategy) ValidateCallback(w http.ResponseWriter, r *http.Request) (flo
return nil, nil, nil, errors.WithStack(herodot.ErrBadRequest.WithReasonf(`Unable to complete OpenID Connect flow: provider could not be retrieved from state nor URL.`))
}
f, err := s.validateFlow(r.Context(), r, uuid.FromBytesOrNil(state.FlowId))
f, err := s.validateFlow(r.Context(), r, uuid.FromBytesOrNil(state.FlowId), state.FlowKind)
if err != nil {
return nil, state, nil, err
}

View File

@ -16,6 +16,7 @@ import (
"github.com/ory/herodot"
"github.com/ory/kratos/continuity"
oidcv1 "github.com/ory/kratos/gen/oidc/v1"
"github.com/ory/kratos/identity"
"github.com/ory/kratos/selfservice/flow"
"github.com/ory/kratos/selfservice/flow/login"
@ -302,7 +303,7 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow,
return nil, s.HandleError(ctx, w, r, f, pid, nil, err)
}
req, err := s.validateFlow(ctx, r, f.ID)
req, err := s.validateFlow(ctx, r, f.ID, oidcv1.FlowKind_FLOW_KIND_LOGIN)
if err != nil {
return nil, s.HandleError(ctx, w, r, f, pid, nil, err)
}
@ -330,7 +331,7 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow,
return nil, errors.WithStack(flow.ErrCompletedByStrategy)
}
state, pkce, err := s.GenerateState(ctx, provider, f.ID)
state, pkce, err := s.GenerateState(ctx, provider, f)
if err != nil {
return nil, s.HandleError(ctx, w, r, f, pid, nil, err)
}

View File

@ -21,6 +21,7 @@ import (
"github.com/ory/herodot"
"github.com/ory/kratos/continuity"
oidcv1 "github.com/ory/kratos/gen/oidc/v1"
"github.com/ory/kratos/identity"
"github.com/ory/kratos/selfservice/flow"
"github.com/ory/kratos/selfservice/flow/login"
@ -201,7 +202,7 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registrat
return s.HandleError(ctx, w, r, f, pid, nil, err)
}
req, err := s.validateFlow(ctx, r, f.ID)
req, err := s.validateFlow(ctx, r, f.ID, oidcv1.FlowKind_FLOW_KIND_REGISTRATION)
if err != nil {
return s.HandleError(ctx, w, r, f, pid, nil, err)
}
@ -231,7 +232,7 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registrat
return errors.WithStack(flow.ErrCompletedByStrategy)
}
state, pkce, err := s.GenerateState(ctx, provider, f.ID)
state, pkce, err := s.GenerateState(ctx, provider, f)
if err != nil {
return s.HandleError(ctx, w, r, f, pid, nil, err)
}

View File

@ -25,6 +25,7 @@ import (
"github.com/ory/x/stringsx"
"github.com/ory/kratos/continuity"
oidcv1 "github.com/ory/kratos/gen/oidc/v1"
"github.com/ory/kratos/identity"
"github.com/ory/kratos/selfservice/flow"
"github.com/ory/kratos/selfservice/flow/settings"
@ -365,12 +366,12 @@ func (s *Strategy) initLinkProvider(ctx context.Context, w http.ResponseWriter,
return s.handleSettingsError(ctx, w, r, ctxUpdate, p, err)
}
req, err := s.validateFlow(ctx, r, ctxUpdate.Flow.ID)
req, err := s.validateFlow(ctx, r, ctxUpdate.Flow.ID, oidcv1.FlowKind_FLOW_KIND_SETTINGS)
if err != nil {
return s.handleSettingsError(ctx, w, r, ctxUpdate, p, err)
}
state, pkce, err := s.GenerateState(ctx, provider, ctxUpdate.Flow.ID)
state, pkce, err := s.GenerateState(ctx, provider, ctxUpdate.Flow)
if err != nil {
return s.handleSettingsError(ctx, w, r, ctxUpdate, p, err)
}