0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2025-03-23 02:06:09 +01:00

Add ability for access tokens to belong to one user but grant access to another user. ()

We do it this way round so that only the "owner" can delete the access token (i.e. `/logout/all` by the "owner" also deletes that token, but `/logout/all` by the "target user" doesn't).

A future PR will add an API for creating such a token.

When the target user and authenticated entity are different the `Processed request` log line will be logged with a: `{@admin:server as @bob:server} ...`. I'm not convinced by that format (especially since it adds spaces in there, making it harder to use `cut -d ' '` to chop off the start of log lines). Suggestions welcome.
This commit is contained in:
Erik Johnston 2020-10-29 15:58:44 +00:00 committed by GitHub
parent 22eeb6bc54
commit f21e24ffc2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 197 additions and 138 deletions

1
changelog.d/8616.misc Normal file
View file

@ -0,0 +1 @@
Change schema to support access tokens belonging to one user but granting access to another.

View file

@ -33,6 +33,7 @@ from synapse.api.errors import (
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.logging import opentracing as opentracing
from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import StateMap, UserID
from synapse.util.caches.lrucache import LruCache
from synapse.util.metrics import Measure
@ -190,10 +191,6 @@ class Auth:
user_id, app_service = await self._get_appservice_user_id(request)
if user_id:
request.authenticated_entity = user_id
opentracing.set_tag("authenticated_entity", user_id)
opentracing.set_tag("appservice_id", app_service.id)
if ip_addr and self._track_appservice_user_ips:
await self.store.insert_client_ip(
user_id=user_id,
@ -203,31 +200,38 @@ class Auth:
device_id="dummy-device", # stubbed
)
return synapse.types.create_requester(user_id, app_service=app_service)
requester = synapse.types.create_requester(
user_id, app_service=app_service
)
request.requester = user_id
opentracing.set_tag("authenticated_entity", user_id)
opentracing.set_tag("user_id", user_id)
opentracing.set_tag("appservice_id", app_service.id)
return requester
user_info = await self.get_user_by_access_token(
access_token, rights, allow_expired=allow_expired
)
user = user_info["user"]
token_id = user_info["token_id"]
is_guest = user_info["is_guest"]
shadow_banned = user_info["shadow_banned"]
token_id = user_info.token_id
is_guest = user_info.is_guest
shadow_banned = user_info.shadow_banned
# Deny the request if the user account has expired.
if self._account_validity.enabled and not allow_expired:
user_id = user.to_string()
if await self.store.is_account_expired(user_id, self.clock.time_msec()):
if await self.store.is_account_expired(
user_info.user_id, self.clock.time_msec()
):
raise AuthError(
403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
)
# device_id may not be present if get_user_by_access_token has been
# stubbed out.
device_id = user_info.get("device_id")
device_id = user_info.device_id
if user and access_token and ip_addr:
if access_token and ip_addr:
await self.store.insert_client_ip(
user_id=user.to_string(),
user_id=user_info.token_owner,
access_token=access_token,
ip=ip_addr,
user_agent=user_agent,
@ -241,19 +245,23 @@ class Auth:
errcode=Codes.GUEST_ACCESS_FORBIDDEN,
)
request.authenticated_entity = user.to_string()
opentracing.set_tag("authenticated_entity", user.to_string())
if device_id:
opentracing.set_tag("device_id", device_id)
return synapse.types.create_requester(
user,
requester = synapse.types.create_requester(
user_info.user_id,
token_id,
is_guest,
shadow_banned,
device_id,
app_service=app_service,
authenticated_entity=user_info.token_owner,
)
request.requester = requester
opentracing.set_tag("authenticated_entity", user_info.token_owner)
opentracing.set_tag("user_id", user_info.user_id)
if device_id:
opentracing.set_tag("device_id", device_id)
return requester
except KeyError:
raise MissingClientTokenError()
@ -284,7 +292,7 @@ class Auth:
async def get_user_by_access_token(
self, token: str, rights: str = "access", allow_expired: bool = False,
) -> dict:
) -> TokenLookupResult:
""" Validate access token and get user_id from it
Args:
@ -293,13 +301,7 @@ class Auth:
allow this
allow_expired: If False, raises an InvalidClientTokenError
if the token is expired
Returns:
dict that includes:
`user` (UserID)
`is_guest` (bool)
`shadow_banned` (bool)
`token_id` (int|None): access token id. May be None if guest
`device_id` (str|None): device corresponding to access token
Raises:
InvalidClientTokenError if a user by that token exists, but the token is
expired
@ -309,9 +311,9 @@ class Auth:
if rights == "access":
# first look in the database
r = await self._look_up_user_by_access_token(token)
r = await self.store.get_user_by_access_token(token)
if r:
valid_until_ms = r["valid_until_ms"]
valid_until_ms = r.valid_until_ms
if (
not allow_expired
and valid_until_ms is not None
@ -328,7 +330,6 @@ class Auth:
# otherwise it needs to be a valid macaroon
try:
user_id, guest = self._parse_and_validate_macaroon(token, rights)
user = UserID.from_string(user_id)
if rights == "access":
if not guest:
@ -354,23 +355,17 @@ class Auth:
raise InvalidClientTokenError(
"Guest access token used for regular user"
)
ret = {
"user": user,
"is_guest": True,
"shadow_banned": False,
"token_id": None,
ret = TokenLookupResult(
user_id=user_id,
is_guest=True,
# all guests get the same device id
"device_id": GUEST_DEVICE_ID,
}
device_id=GUEST_DEVICE_ID,
)
elif rights == "delete_pusher":
# We don't store these tokens in the database
ret = {
"user": user,
"is_guest": False,
"shadow_banned": False,
"token_id": None,
"device_id": None,
}
ret = TokenLookupResult(user_id=user_id, is_guest=False)
else:
raise RuntimeError("Unknown rights setting %s", rights)
return ret
@ -479,31 +474,15 @@ class Auth:
now = self.hs.get_clock().time_msec()
return now < expiry
async def _look_up_user_by_access_token(self, token):
ret = await self.store.get_user_by_access_token(token)
if not ret:
return None
# we use ret.get() below because *lots* of unit tests stub out
# get_user_by_access_token in a way where it only returns a couple of
# the fields.
user_info = {
"user": UserID.from_string(ret.get("name")),
"token_id": ret.get("token_id", None),
"is_guest": False,
"shadow_banned": ret.get("shadow_banned"),
"device_id": ret.get("device_id"),
"valid_until_ms": ret.get("valid_until_ms"),
}
return user_info
def get_appservice_by_req(self, request):
token = self.get_access_token_from_request(request)
service = self.store.get_app_service_by_token(token)
if not service:
logger.warning("Unrecognised appservice access token.")
raise InvalidClientTokenError()
request.authenticated_entity = service.sender
request.requester = synapse.types.create_requester(
service.sender, app_service=service
)
return service
async def is_server_admin(self, user: UserID) -> bool:

View file

@ -52,11 +52,11 @@ class ApplicationService:
self,
token,
hostname,
id,
sender,
url=None,
namespaces=None,
hs_token=None,
sender=None,
id=None,
protocols=None,
rate_limited=True,
ip_range_whitelist=None,

View file

@ -154,7 +154,7 @@ class Authenticator:
)
logger.debug("Request from %s", origin)
request.authenticated_entity = origin
request.requester = origin
# If we get a valid signed request from the other side, its probably
# alive

View file

@ -991,17 +991,17 @@ class AuthHandler(BaseHandler):
# This might return an awaitable, if it does block the log out
# until it completes.
result = provider.on_logged_out(
user_id=str(user_info["user"]),
device_id=user_info["device_id"],
user_id=user_info.user_id,
device_id=user_info.device_id,
access_token=access_token,
)
if inspect.isawaitable(result):
await result
# delete pushers associated with this access token
if user_info["token_id"] is not None:
if user_info.token_id is not None:
await self.hs.get_pusherpool().remove_pushers_by_access_token(
str(user_info["user"]), (user_info["token_id"],)
user_info.user_id, (user_info.token_id,)
)
async def delete_access_tokens_for_user(

View file

@ -115,7 +115,10 @@ class RegistrationHandler(BaseHandler):
400, "User ID already taken.", errcode=Codes.USER_IN_USE
)
user_data = await self.auth.get_user_by_access_token(guest_access_token)
if not user_data["is_guest"] or user_data["user"].localpart != localpart:
if (
not user_data.is_guest
or UserID.from_string(user_data.user_id).localpart != localpart
):
raise AuthError(
403,
"Cannot register taken user ID without valid guest "
@ -741,7 +744,7 @@ class RegistrationHandler(BaseHandler):
# up when the access token is saved, but that's quite an
# invasive change I'd rather do separately.
user_tuple = await self.store.get_user_by_access_token(token)
token_id = user_tuple["token_id"]
token_id = user_tuple.token_id
await self.pusher_pool.add_pusher(
user_id=user_id,

View file

@ -14,7 +14,7 @@
import contextlib
import logging
import time
from typing import Optional
from typing import Optional, Union
from twisted.python.failure import Failure
from twisted.web.server import Request, Site
@ -23,6 +23,7 @@ from synapse.config.server import ListenerConfig
from synapse.http import redact_uri
from synapse.http.request_metrics import RequestMetrics, requests_counter
from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.types import Requester
logger = logging.getLogger(__name__)
@ -54,9 +55,12 @@ class SynapseRequest(Request):
Request.__init__(self, channel, *args, **kw)
self.site = channel.site
self._channel = channel # this is used by the tests
self.authenticated_entity = None
self.start_time = 0.0
# The requester, if authenticated. For federation requests this is the
# server name, for client requests this is the Requester object.
self.requester = None # type: Optional[Union[Requester, str]]
# we can't yet create the logcontext, as we don't know the method.
self.logcontext = None # type: Optional[LoggingContext]
@ -271,11 +275,23 @@ class SynapseRequest(Request):
# to the client (nb may be negative)
response_send_time = self.finish_time - self._processing_finished_time
# need to decode as it could be raw utf-8 bytes
# from a IDN servname in an auth header
authenticated_entity = self.authenticated_entity
if authenticated_entity is not None and isinstance(authenticated_entity, bytes):
authenticated_entity = authenticated_entity.decode("utf-8", "replace")
# Convert the requester into a string that we can log
authenticated_entity = None
if isinstance(self.requester, str):
authenticated_entity = self.requester
elif isinstance(self.requester, Requester):
authenticated_entity = self.requester.authenticated_entity
# If this is a request where the target user doesn't match the user who
# authenticated (e.g. and admin is puppetting a user) then we log both.
if self.requester.user.to_string() != authenticated_entity:
authenticated_entity = "{},{}".format(
authenticated_entity, self.requester.user.to_string(),
)
elif self.requester is not None:
# This shouldn't happen, but we log it so we don't lose information
# and can see that we're doing something wrong.
authenticated_entity = repr(self.requester) # type: ignore[unreachable]
# ...or could be raw utf-8 bytes in the User-Agent header.
# N.B. if you don't do this, the logger explodes cryptically

View file

@ -77,8 +77,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
requester = Requester.deserialize(self.store, content["requester"])
if requester.user:
request.authenticated_entity = requester.user.to_string()
request.requester = requester
logger.info("remote_join: %s into room: %s", user_id, room_id)
@ -142,8 +141,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
requester = Requester.deserialize(self.store, content["requester"])
if requester.user:
request.authenticated_entity = requester.user.to_string()
request.requester = requester
# hopefully we're now on the master, so this won't recurse!
event_id, stream_id = await self.member_handler.remote_reject_invite(

View file

@ -115,8 +115,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
ratelimit = content["ratelimit"]
extra_users = [UserID.from_string(u) for u in content["extra_users"]]
if requester.user:
request.authenticated_entity = requester.user.to_string()
request.requester = requester
logger.info(
"Got event to send with ID: %s into room: %s", event.event_id, event.room_id

View file

@ -18,6 +18,8 @@ import logging
import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import attr
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.metrics.background_process_metrics import wrap_as_background_process
@ -38,6 +40,35 @@ THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
logger = logging.getLogger(__name__)
@attr.s(frozen=True, slots=True)
class TokenLookupResult:
"""Result of looking up an access token.
Attributes:
user_id: The user that this token authenticates as
is_guest
shadow_banned
token_id: The ID of the access token looked up
device_id: The device associated with the token, if any.
valid_until_ms: The timestamp the token expires, if any.
token_owner: The "owner" of the token. This is either the same as the
user, or a server admin who is logged in as the user.
"""
user_id = attr.ib(type=str)
is_guest = attr.ib(type=bool, default=False)
shadow_banned = attr.ib(type=bool, default=False)
token_id = attr.ib(type=Optional[int], default=None)
device_id = attr.ib(type=Optional[str], default=None)
valid_until_ms = attr.ib(type=Optional[int], default=None)
token_owner = attr.ib(type=str)
# Make the token owner default to the user ID, which is the common case.
@token_owner.default
def _default_token_owner(self):
return self.user_id
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
@ -102,15 +133,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return is_trial
@cached()
async def get_user_by_access_token(self, token: str) -> Optional[dict]:
async def get_user_by_access_token(self, token: str) -> Optional[TokenLookupResult]:
"""Get a user from the given access token.
Args:
token: The access token of a user.
Returns:
None, if the token did not match, otherwise dict
including the keys `name`, `is_guest`, `device_id`, `token_id`,
`valid_until_ms`.
None, if the token did not match, otherwise a `TokenLookupResult`
"""
return await self.db_pool.runInteraction(
"get_user_by_access_token", self._query_for_auth, token
@ -331,23 +360,24 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
def _query_for_auth(self, txn, token):
def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
sql = """
SELECT users.name,
SELECT users.name as user_id,
users.is_guest,
users.shadow_banned,
access_tokens.id as token_id,
access_tokens.device_id,
access_tokens.valid_until_ms
access_tokens.valid_until_ms,
access_tokens.user_id as token_owner
FROM users
INNER JOIN access_tokens on users.name = access_tokens.user_id
INNER JOIN access_tokens on users.name = COALESCE(puppets_user_id, access_tokens.user_id)
WHERE token = ?
"""
txn.execute(sql, (token,))
rows = self.db_pool.cursor_to_dict(txn)
if rows:
return rows[0]
return TokenLookupResult(**rows[0])
return None

View file

@ -0,0 +1,17 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Whether the access token is an admin token for controlling another user.
ALTER TABLE access_tokens ADD COLUMN puppets_user_id TEXT;

View file

@ -29,6 +29,7 @@ from typing import (
Tuple,
Type,
TypeVar,
Union,
)
import attr
@ -38,6 +39,7 @@ from unpaddedbase64 import decode_base64
from synapse.api.errors import Codes, SynapseError
if TYPE_CHECKING:
from synapse.appservice.api import ApplicationService
from synapse.storage.databases.main import DataStore
# define a version of typing.Collection that works on python 3.5
@ -74,6 +76,7 @@ class Requester(
"shadow_banned",
"device_id",
"app_service",
"authenticated_entity",
],
)
):
@ -104,6 +107,7 @@ class Requester(
"shadow_banned": self.shadow_banned,
"device_id": self.device_id,
"app_server_id": self.app_service.id if self.app_service else None,
"authenticated_entity": self.authenticated_entity,
}
@staticmethod
@ -129,16 +133,18 @@ class Requester(
shadow_banned=input["shadow_banned"],
device_id=input["device_id"],
app_service=appservice,
authenticated_entity=input["authenticated_entity"],
)
def create_requester(
user_id,
access_token_id=None,
is_guest=False,
shadow_banned=False,
device_id=None,
app_service=None,
user_id: Union[str, "UserID"],
access_token_id: Optional[int] = None,
is_guest: Optional[bool] = False,
shadow_banned: Optional[bool] = False,
device_id: Optional[str] = None,
app_service: Optional["ApplicationService"] = None,
authenticated_entity: Optional[str] = None,
):
"""
Create a new ``Requester`` object
@ -151,14 +157,27 @@ def create_requester(
shadow_banned (bool): True if the user making this request is shadow-banned.
device_id (str|None): device_id which was set at authentication time
app_service (ApplicationService|None): the AS requesting on behalf of the user
authenticated_entity: The entity that authenticated when making the request.
This is different to the user_id when an admin user or the server is
"puppeting" the user.
Returns:
Requester
"""
if not isinstance(user_id, UserID):
user_id = UserID.from_string(user_id)
if authenticated_entity is None:
authenticated_entity = user_id.to_string()
return Requester(
user_id, access_token_id, is_guest, shadow_banned, device_id, app_service
user_id,
access_token_id,
is_guest,
shadow_banned,
device_id,
app_service,
authenticated_entity,
)

View file

@ -29,6 +29,7 @@ from synapse.api.errors import (
MissingClientTokenError,
ResourceLimitError,
)
from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import UserID
from tests import unittest
@ -61,7 +62,9 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_user_by_req_user_valid_token(self):
user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"}
user_info = TokenLookupResult(
user_id=self.test_user, token_id=5, device_id="device"
)
self.store.get_user_by_access_token = Mock(
return_value=defer.succeed(user_info)
)
@ -84,7 +87,7 @@ class AuthTestCase(unittest.TestCase):
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_user_missing_token(self):
user_info = {"name": self.test_user, "token_id": "ditto"}
user_info = TokenLookupResult(user_id=self.test_user, token_id=5)
self.store.get_user_by_access_token = Mock(
return_value=defer.succeed(user_info)
)
@ -221,7 +224,7 @@ class AuthTestCase(unittest.TestCase):
def test_get_user_from_macaroon(self):
self.store.get_user_by_access_token = Mock(
return_value=defer.succeed(
{"name": "@baldrick:matrix.org", "device_id": "device"}
TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device")
)
)
@ -237,12 +240,11 @@ class AuthTestCase(unittest.TestCase):
user_info = yield defer.ensureDeferred(
self.auth.get_user_by_access_token(macaroon.serialize())
)
user = user_info["user"]
self.assertEqual(UserID.from_string(user_id), user)
self.assertEqual(user_id, user_info.user_id)
# TODO: device_id should come from the macaroon, but currently comes
# from the db.
self.assertEqual(user_info["device_id"], "device")
self.assertEqual(user_info.device_id, "device")
@defer.inlineCallbacks
def test_get_guest_user_from_macaroon(self):
@ -264,10 +266,8 @@ class AuthTestCase(unittest.TestCase):
user_info = yield defer.ensureDeferred(
self.auth.get_user_by_access_token(serialized)
)
user = user_info["user"]
is_guest = user_info["is_guest"]
self.assertEqual(UserID.from_string(user_id), user)
self.assertTrue(is_guest)
self.assertEqual(user_id, user_info.user_id)
self.assertTrue(user_info.is_guest)
self.store.get_user_by_id.assert_called_with(user_id)
@defer.inlineCallbacks
@ -289,12 +289,9 @@ class AuthTestCase(unittest.TestCase):
if token != tok:
return defer.succeed(None)
return defer.succeed(
{
"name": USER_ID,
"is_guest": False,
"token_id": 1234,
"device_id": "DEVICE",
}
TokenLookupResult(
user_id=USER_ID, is_guest=False, token_id=1234, device_id="DEVICE",
)
)
self.store.get_user_by_access_token = get_user

View file

@ -43,7 +43,7 @@ class TestRatelimiter(unittest.TestCase):
def test_allowed_appservice_ratelimited_via_can_requester_do_action(self):
appservice = ApplicationService(
None, "example.com", id="foo", rate_limited=True,
None, "example.com", id="foo", rate_limited=True, sender="@as:example.com",
)
as_requester = create_requester("@user:example.com", app_service=appservice)
@ -68,7 +68,7 @@ class TestRatelimiter(unittest.TestCase):
def test_allowed_appservice_via_can_requester_do_action(self):
appservice = ApplicationService(
None, "example.com", id="foo", rate_limited=False,
None, "example.com", id="foo", rate_limited=False, sender="@as:example.com",
)
as_requester = create_requester("@user:example.com", app_service=appservice)

View file

@ -31,6 +31,7 @@ class ApplicationServiceTestCase(unittest.TestCase):
def setUp(self):
self.service = ApplicationService(
id="unique_identifier",
sender="@as:test",
url="some_url",
token="some_token",
hostname="matrix.org", # only used by get_groups_for_user

View file

@ -289,7 +289,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
# make sure that our device ID has changed
user_info = self.get_success(self.auth.get_user_by_access_token(access_token))
self.assertEqual(user_info["device_id"], retrieved_device_id)
self.assertEqual(user_info.device_id, retrieved_device_id)
# make sure the device has the display name that was set from the login
res = self.get_success(self.handler.get_device(user_id, retrieved_device_id))

View file

@ -46,7 +46,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.info = self.get_success(
self.hs.get_datastore().get_user_by_access_token(self.access_token,)
)
self.token_id = self.info["token_id"]
self.token_id = self.info.token_id
self.requester = create_requester(self.user_id, access_token_id=self.token_id)

View file

@ -100,7 +100,7 @@ class EmailPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(self.access_token)
)
token_id = user_tuple["token_id"]
token_id = user_tuple.token_id
self.pusher = self.get_success(
self.hs.get_pusherpool().add_pusher(

View file

@ -69,7 +69,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
token_id = user_tuple["token_id"]
token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
@ -181,7 +181,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
token_id = user_tuple["token_id"]
token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
@ -297,7 +297,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
token_id = user_tuple["token_id"]
token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
@ -379,7 +379,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
token_id = user_tuple["token_id"]
token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
@ -452,7 +452,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
token_id = user_tuple["token_id"]
token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(

View file

@ -55,7 +55,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
user_dict = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
token_id = user_dict["token_id"]
token_id = user_dict.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(

View file

@ -55,6 +55,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.hs.config.server_name,
id="1234",
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
sender="@as:test",
)
self.hs.get_datastore().services_cache.append(appservice)

View file

@ -69,11 +69,9 @@ class RegistrationStoreTestCase(unittest.TestCase):
self.store.get_user_by_access_token(self.tokens[1])
)
self.assertDictContainsSubset(
{"name": self.user_id, "device_id": self.device_id}, result
)
self.assertTrue("token_id" in result)
self.assertEqual(result.user_id, self.user_id)
self.assertEqual(result.device_id, self.device_id)
self.assertIsNotNone(result.token_id)
@defer.inlineCallbacks
def test_user_delete_access_tokens(self):
@ -105,7 +103,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
user = yield defer.ensureDeferred(
self.store.get_user_by_access_token(self.tokens[0])
)
self.assertEqual(self.user_id, user["name"])
self.assertEqual(self.user_id, user.user_id)
# now delete the rest
yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id))