mirror of
https://github.com/matrix-org/dendrite
synced 2024-11-18 07:40:53 +01:00
Return empty list instead of null for new UI-auth sessions (#406)
fixes #399 Signed-off-by: Anant Prakash <anantprakashjsr@gmail.com>
This commit is contained in:
parent
66af311b6a
commit
8a1f3195ca
2 changed files with 46 additions and 6 deletions
|
@ -50,9 +50,35 @@ const (
|
||||||
sessionIDLength = 24
|
sessionIDLength = 24
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// sessionsDict keeps track of completed auth stages for each session.
|
||||||
|
type sessionsDict struct {
|
||||||
|
sessions map[string][]authtypes.LoginType
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCompletedStages returns the completed stages for a session.
|
||||||
|
func (d sessionsDict) GetCompletedStages(sessionID string) []authtypes.LoginType {
|
||||||
|
if completedStages, ok := d.sessions[sessionID]; ok {
|
||||||
|
return completedStages
|
||||||
|
}
|
||||||
|
// Ensure that a empty slice is returned and not nil. See #399.
|
||||||
|
return make([]authtypes.LoginType, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AAddCompletedStage records that a session has completed an auth stage.
|
||||||
|
func (d *sessionsDict) AddCompletedStage(sessionID string, stage authtypes.LoginType) {
|
||||||
|
d.sessions[sessionID] = append(d.GetCompletedStages(sessionID), stage)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSessionsDict() *sessionsDict {
|
||||||
|
return &sessionsDict{
|
||||||
|
sessions: make(map[string][]authtypes.LoginType),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// TODO: Remove old sessions. Need to do so on a session-specific timeout.
|
// TODO: Remove old sessions. Need to do so on a session-specific timeout.
|
||||||
sessions = make(map[string][]authtypes.LoginType) // Sessions and completed flow stages
|
// sessions stores the completed flow stages for all sessions. Referenced using their sessionID.
|
||||||
|
sessions = newSessionsDict()
|
||||||
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-./]+$`)
|
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-./]+$`)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -112,7 +138,7 @@ func newUserInteractiveResponse(
|
||||||
params map[string]interface{},
|
params map[string]interface{},
|
||||||
) userInteractiveResponse {
|
) userInteractiveResponse {
|
||||||
return userInteractiveResponse{
|
return userInteractiveResponse{
|
||||||
fs, sessions[sessionID], params, sessionID,
|
fs, sessions.GetCompletedStages(sessionID), params, sessionID,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -433,7 +459,7 @@ func handleRegistrationFlow(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add Recaptcha to the list of completed registration stages
|
// Add Recaptcha to the list of completed registration stages
|
||||||
sessions[sessionID] = append(sessions[sessionID], authtypes.LoginTypeRecaptcha)
|
sessions.AddCompletedStage(sessionID, authtypes.LoginTypeRecaptcha)
|
||||||
|
|
||||||
case authtypes.LoginTypeSharedSecret:
|
case authtypes.LoginTypeSharedSecret:
|
||||||
// Check shared secret against config
|
// Check shared secret against config
|
||||||
|
@ -446,7 +472,7 @@ func handleRegistrationFlow(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add SharedSecret to the list of completed registration stages
|
// Add SharedSecret to the list of completed registration stages
|
||||||
sessions[sessionID] = append(sessions[sessionID], authtypes.LoginTypeSharedSecret)
|
sessions.AddCompletedStage(sessionID, authtypes.LoginTypeSharedSecret)
|
||||||
|
|
||||||
case authtypes.LoginTypeApplicationService:
|
case authtypes.LoginTypeApplicationService:
|
||||||
// Check Application Service register user request is valid.
|
// Check Application Service register user request is valid.
|
||||||
|
@ -466,7 +492,7 @@ func handleRegistrationFlow(
|
||||||
case authtypes.LoginTypeDummy:
|
case authtypes.LoginTypeDummy:
|
||||||
// there is nothing to do
|
// there is nothing to do
|
||||||
// Add Dummy to the list of completed registration stages
|
// Add Dummy to the list of completed registration stages
|
||||||
sessions[sessionID] = append(sessions[sessionID], authtypes.LoginTypeDummy)
|
sessions.AddCompletedStage(sessionID, authtypes.LoginTypeDummy)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
|
@ -478,7 +504,8 @@ func handleRegistrationFlow(
|
||||||
// Check if the user's registration flow has been completed successfully
|
// Check if the user's registration flow has been completed successfully
|
||||||
// A response with current registration flow and remaining available methods
|
// A response with current registration flow and remaining available methods
|
||||||
// will be returned if a flow has not been successfully completed yet
|
// will be returned if a flow has not been successfully completed yet
|
||||||
return checkAndCompleteFlow(sessions[sessionID], req, r, sessionID, cfg, accountDB, deviceDB)
|
return checkAndCompleteFlow(sessions.GetCompletedStages(sessionID),
|
||||||
|
req, r, sessionID, cfg, accountDB, deviceDB)
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkAndCompleteFlow checks if a given registration flow is completed given
|
// checkAndCompleteFlow checks if a given registration flow is completed given
|
||||||
|
|
|
@ -132,3 +132,16 @@ func TestFlowCheckingExtraneousIncorrectInput(t *testing.T) {
|
||||||
t.Error("Incorrect registration flow verification: ", testFlow, ", from allowed flows: ", allowedFlows, ". Should be false.")
|
t.Error("Incorrect registration flow verification: ", testFlow, ", from allowed flows: ", allowedFlows, ". Should be false.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Completed flows stages should always be a valid slice header.
|
||||||
|
// TestEmptyCompletedFlows checks that sessionsDict returns a slice & not nil.
|
||||||
|
func TestEmptyCompletedFlows(t *testing.T) {
|
||||||
|
fakeEmptySessions := newSessionsDict()
|
||||||
|
fakeSessionID := "aRandomSessionIDWhichDoesNotExist"
|
||||||
|
ret := fakeEmptySessions.GetCompletedStages(fakeSessionID)
|
||||||
|
|
||||||
|
// check for []
|
||||||
|
if ret == nil || len(ret) != 0 {
|
||||||
|
t.Error("Empty Completed Flow Stages should be a empty slice: returned ", ret, ". Should be []")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue