mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-15 14:32:30 +01:00
400 lines
14 KiB
Python
400 lines
14 KiB
Python
# Copyright 2020 Matrix.org Foundation C.I.C.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from typing import Any, Dict, List, Optional, Tuple, Union, cast
|
|
|
|
import attr
|
|
|
|
from synapse.api.constants import LoginType
|
|
from synapse.api.errors import StoreError
|
|
from synapse.storage._base import SQLBaseStore, db_to_json
|
|
from synapse.storage.database import LoggingTransaction
|
|
from synapse.types import JsonDict
|
|
from synapse.util import json_encoder, stringutils
|
|
|
|
|
|
@attr.s(slots=True, auto_attribs=True)
|
|
class UIAuthSessionData:
|
|
session_id: str
|
|
# The dictionary from the client root level, not the 'auth' key.
|
|
clientdict: JsonDict
|
|
# The URI and method the session was intiatied with. These are checked at
|
|
# each stage of the authentication to ensure that the asked for operation
|
|
# has not changed.
|
|
uri: str
|
|
method: str
|
|
# A string description of the operation that the current authentication is
|
|
# authorising.
|
|
description: str
|
|
|
|
|
|
class UIAuthWorkerStore(SQLBaseStore):
|
|
"""
|
|
Manage user interactive authentication sessions.
|
|
"""
|
|
|
|
async def create_ui_auth_session(
|
|
self,
|
|
clientdict: JsonDict,
|
|
uri: str,
|
|
method: str,
|
|
description: str,
|
|
) -> UIAuthSessionData:
|
|
"""
|
|
Creates a new user interactive authentication session.
|
|
|
|
The session can be used to track the stages necessary to authenticate a
|
|
user across multiple HTTP requests.
|
|
|
|
Args:
|
|
clientdict:
|
|
The dictionary from the client root level, not the 'auth' key.
|
|
uri:
|
|
The URI this session was initiated with, this is checked at each
|
|
stage of the authentication to ensure that the asked for
|
|
operation has not changed.
|
|
method:
|
|
The method this session was initiated with, this is checked at each
|
|
stage of the authentication to ensure that the asked for
|
|
operation has not changed.
|
|
description:
|
|
A string description of the operation that the current
|
|
authentication is authorising.
|
|
Returns:
|
|
The newly created session.
|
|
Raises:
|
|
StoreError if a unique session ID cannot be generated.
|
|
"""
|
|
# The clientdict gets stored as JSON.
|
|
clientdict_json = json_encoder.encode(clientdict)
|
|
|
|
# autogen a session ID and try to create it. We may clash, so just
|
|
# try a few times till one goes through, giving up eventually.
|
|
attempts = 0
|
|
while attempts < 5:
|
|
session_id = stringutils.random_string(24)
|
|
|
|
try:
|
|
await self.db_pool.simple_insert(
|
|
table="ui_auth_sessions",
|
|
values={
|
|
"session_id": session_id,
|
|
"clientdict": clientdict_json,
|
|
"uri": uri,
|
|
"method": method,
|
|
"description": description,
|
|
"serverdict": "{}",
|
|
"creation_time": self.hs.get_clock().time_msec(),
|
|
},
|
|
desc="create_ui_auth_session",
|
|
)
|
|
return UIAuthSessionData(
|
|
session_id, clientdict, uri, method, description
|
|
)
|
|
except self.db_pool.engine.module.IntegrityError:
|
|
attempts += 1
|
|
raise StoreError(500, "Couldn't generate a session ID.")
|
|
|
|
async def get_ui_auth_session(self, session_id: str) -> UIAuthSessionData:
|
|
"""Retrieve a UI auth session.
|
|
|
|
Args:
|
|
session_id: The ID of the session.
|
|
Returns:
|
|
A dict containing the device information.
|
|
Raises:
|
|
StoreError if the session is not found.
|
|
"""
|
|
result = await self.db_pool.simple_select_one(
|
|
table="ui_auth_sessions",
|
|
keyvalues={"session_id": session_id},
|
|
retcols=("clientdict", "uri", "method", "description"),
|
|
desc="get_ui_auth_session",
|
|
)
|
|
|
|
result["clientdict"] = db_to_json(result["clientdict"])
|
|
|
|
return UIAuthSessionData(session_id, **result)
|
|
|
|
async def mark_ui_auth_stage_complete(
|
|
self,
|
|
session_id: str,
|
|
stage_type: str,
|
|
result: Union[str, bool, JsonDict],
|
|
) -> None:
|
|
"""
|
|
Mark a session stage as completed.
|
|
|
|
Args:
|
|
session_id: The ID of the corresponding session.
|
|
stage_type: The completed stage type.
|
|
result: The result of the stage verification.
|
|
Raises:
|
|
StoreError if the session cannot be found.
|
|
"""
|
|
# Add (or update) the results of the current stage to the database.
|
|
#
|
|
# Note that we need to allow for the same stage to complete multiple
|
|
# times here so that registration is idempotent.
|
|
try:
|
|
await self.db_pool.simple_upsert(
|
|
table="ui_auth_sessions_credentials",
|
|
keyvalues={"session_id": session_id, "stage_type": stage_type},
|
|
values={"result": json_encoder.encode(result)},
|
|
desc="mark_ui_auth_stage_complete",
|
|
)
|
|
except self.db_pool.engine.module.IntegrityError:
|
|
raise StoreError(400, "Unknown session ID: %s" % (session_id,))
|
|
|
|
async def get_completed_ui_auth_stages(
|
|
self, session_id: str
|
|
) -> Dict[str, Union[str, bool, JsonDict]]:
|
|
"""
|
|
Retrieve the completed stages of a UI authentication session.
|
|
|
|
Args:
|
|
session_id: The ID of the session.
|
|
Returns:
|
|
The completed stages mapped to the result of the verification of
|
|
that auth-type.
|
|
"""
|
|
results = {}
|
|
for row in await self.db_pool.simple_select_list(
|
|
table="ui_auth_sessions_credentials",
|
|
keyvalues={"session_id": session_id},
|
|
retcols=("stage_type", "result"),
|
|
desc="get_completed_ui_auth_stages",
|
|
):
|
|
results[row["stage_type"]] = db_to_json(row["result"])
|
|
|
|
return results
|
|
|
|
async def set_ui_auth_clientdict(
|
|
self, session_id: str, clientdict: JsonDict
|
|
) -> None:
|
|
"""
|
|
Store an updated clientdict for a given session ID.
|
|
|
|
Args:
|
|
session_id: The ID of this session as returned from check_auth
|
|
clientdict:
|
|
The dictionary from the client root level, not the 'auth' key.
|
|
"""
|
|
# The clientdict gets stored as JSON.
|
|
clientdict_json = json_encoder.encode(clientdict)
|
|
|
|
await self.db_pool.simple_update_one(
|
|
table="ui_auth_sessions",
|
|
keyvalues={"session_id": session_id},
|
|
updatevalues={"clientdict": clientdict_json},
|
|
desc="set_ui_auth_client_dict",
|
|
)
|
|
|
|
async def set_ui_auth_session_data(
|
|
self, session_id: str, key: str, value: Any
|
|
) -> None:
|
|
"""
|
|
Store a key-value pair into the sessions data associated with this
|
|
request. This data is stored server-side and cannot be modified by
|
|
the client.
|
|
|
|
Args:
|
|
session_id: The ID of this session as returned from check_auth
|
|
key: The key to store the data under
|
|
value: The data to store
|
|
Raises:
|
|
StoreError if the session cannot be found.
|
|
"""
|
|
await self.db_pool.runInteraction(
|
|
"set_ui_auth_session_data",
|
|
self._set_ui_auth_session_data_txn,
|
|
session_id,
|
|
key,
|
|
value,
|
|
)
|
|
|
|
def _set_ui_auth_session_data_txn(
|
|
self, txn: LoggingTransaction, session_id: str, key: str, value: Any
|
|
) -> None:
|
|
# Get the current value.
|
|
result = cast(
|
|
Dict[str, Any],
|
|
self.db_pool.simple_select_one_txn(
|
|
txn,
|
|
table="ui_auth_sessions",
|
|
keyvalues={"session_id": session_id},
|
|
retcols=("serverdict",),
|
|
),
|
|
)
|
|
|
|
# Update it and add it back to the database.
|
|
serverdict = db_to_json(result["serverdict"])
|
|
serverdict[key] = value
|
|
|
|
self.db_pool.simple_update_one_txn(
|
|
txn,
|
|
table="ui_auth_sessions",
|
|
keyvalues={"session_id": session_id},
|
|
updatevalues={"serverdict": json_encoder.encode(serverdict)},
|
|
)
|
|
|
|
async def get_ui_auth_session_data(
|
|
self, session_id: str, key: str, default: Optional[Any] = None
|
|
) -> Any:
|
|
"""
|
|
Retrieve data stored with set_session_data
|
|
|
|
Args:
|
|
session_id: The ID of this session as returned from check_auth
|
|
key: The key to store the data under
|
|
default: Value to return if the key has not been set
|
|
Raises:
|
|
StoreError if the session cannot be found.
|
|
"""
|
|
result = await self.db_pool.simple_select_one(
|
|
table="ui_auth_sessions",
|
|
keyvalues={"session_id": session_id},
|
|
retcols=("serverdict",),
|
|
desc="get_ui_auth_session_data",
|
|
)
|
|
|
|
serverdict = db_to_json(result["serverdict"])
|
|
|
|
return serverdict.get(key, default)
|
|
|
|
async def add_user_agent_ip_to_ui_auth_session(
|
|
self,
|
|
session_id: str,
|
|
user_agent: str,
|
|
ip: str,
|
|
) -> None:
|
|
"""Add the given user agent / IP to the tracking table"""
|
|
await self.db_pool.simple_upsert(
|
|
table="ui_auth_sessions_ips",
|
|
keyvalues={"session_id": session_id, "user_agent": user_agent, "ip": ip},
|
|
values={},
|
|
desc="add_user_agent_ip_to_ui_auth_session",
|
|
)
|
|
|
|
async def get_user_agents_ips_to_ui_auth_session(
|
|
self,
|
|
session_id: str,
|
|
) -> List[Tuple[str, str]]:
|
|
"""Get the given user agents / IPs used during the ui auth process
|
|
|
|
Returns:
|
|
List of user_agent/ip pairs
|
|
"""
|
|
rows = await self.db_pool.simple_select_list(
|
|
table="ui_auth_sessions_ips",
|
|
keyvalues={"session_id": session_id},
|
|
retcols=("user_agent", "ip"),
|
|
desc="get_user_agents_ips_to_ui_auth_session",
|
|
)
|
|
return [(row["user_agent"], row["ip"]) for row in rows]
|
|
|
|
async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None:
|
|
"""
|
|
Remove sessions which were last used earlier than the expiration time.
|
|
|
|
Args:
|
|
expiration_time: The latest time that is still considered valid.
|
|
This is an epoch time in milliseconds.
|
|
|
|
"""
|
|
await self.db_pool.runInteraction(
|
|
"delete_old_ui_auth_sessions",
|
|
self._delete_old_ui_auth_sessions_txn,
|
|
expiration_time,
|
|
)
|
|
|
|
def _delete_old_ui_auth_sessions_txn(
|
|
self, txn: LoggingTransaction, expiration_time: int
|
|
) -> None:
|
|
# Get the expired sessions.
|
|
sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?"
|
|
txn.execute(sql, [expiration_time])
|
|
session_ids = [r[0] for r in txn.fetchall()]
|
|
|
|
# Delete the corresponding IP/user agents.
|
|
self.db_pool.simple_delete_many_txn(
|
|
txn,
|
|
table="ui_auth_sessions_ips",
|
|
column="session_id",
|
|
values=session_ids,
|
|
keyvalues={},
|
|
)
|
|
|
|
# If a registration token was used, decrement the pending counter
|
|
# before deleting the session.
|
|
rows = self.db_pool.simple_select_many_txn(
|
|
txn,
|
|
table="ui_auth_sessions_credentials",
|
|
column="session_id",
|
|
iterable=session_ids,
|
|
keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN},
|
|
retcols=["result"],
|
|
)
|
|
|
|
# Get the tokens used and how much pending needs to be decremented by.
|
|
token_counts: Dict[str, int] = {}
|
|
for r in rows:
|
|
# If registration was successfully completed, the result of the
|
|
# registration token stage for that session will be True.
|
|
# If a token was used to authenticate, but registration was
|
|
# never completed, the result will be the token used.
|
|
token = db_to_json(r["result"])
|
|
if isinstance(token, str):
|
|
token_counts[token] = token_counts.get(token, 0) + 1
|
|
|
|
# Update the `pending` counters.
|
|
if len(token_counts) > 0:
|
|
token_rows = self.db_pool.simple_select_many_txn(
|
|
txn,
|
|
table="registration_tokens",
|
|
column="token",
|
|
iterable=list(token_counts.keys()),
|
|
keyvalues={},
|
|
retcols=["token", "pending"],
|
|
)
|
|
for token_row in token_rows:
|
|
token = token_row["token"]
|
|
new_pending = token_row["pending"] - token_counts[token]
|
|
self.db_pool.simple_update_one_txn(
|
|
txn,
|
|
table="registration_tokens",
|
|
keyvalues={"token": token},
|
|
updatevalues={"pending": new_pending},
|
|
)
|
|
|
|
# Delete the corresponding completed credentials.
|
|
self.db_pool.simple_delete_many_txn(
|
|
txn,
|
|
table="ui_auth_sessions_credentials",
|
|
column="session_id",
|
|
values=session_ids,
|
|
keyvalues={},
|
|
)
|
|
|
|
# Finally, delete the sessions.
|
|
self.db_pool.simple_delete_many_txn(
|
|
txn,
|
|
table="ui_auth_sessions",
|
|
column="session_id",
|
|
values=session_ids,
|
|
keyvalues={},
|
|
)
|
|
|
|
|
|
class UIAuthStore(UIAuthWorkerStore):
|
|
pass
|