mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-17 23:42:33 +01:00
Merge pull request #1157 from Rugvip/nolimit
Remove rate limiting from app service senders and fix get_or_create_user requester
This commit is contained in:
commit
a2f2516199
14 changed files with 65 additions and 78 deletions
|
@ -653,7 +653,7 @@ class Auth(object):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _get_appservice_user_id(self, request):
|
def _get_appservice_user_id(self, request):
|
||||||
app_service = yield self.store.get_app_service_by_token(
|
app_service = self.store.get_app_service_by_token(
|
||||||
get_access_token_from_request(
|
get_access_token_from_request(
|
||||||
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
|
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
|
||||||
)
|
)
|
||||||
|
@ -855,13 +855,12 @@ class Auth(object):
|
||||||
}
|
}
|
||||||
defer.returnValue(user_info)
|
defer.returnValue(user_info)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_appservice_by_req(self, request):
|
def get_appservice_by_req(self, request):
|
||||||
try:
|
try:
|
||||||
token = get_access_token_from_request(
|
token = get_access_token_from_request(
|
||||||
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
|
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
|
||||||
)
|
)
|
||||||
service = yield self.store.get_app_service_by_token(token)
|
service = self.store.get_app_service_by_token(token)
|
||||||
if not service:
|
if not service:
|
||||||
logger.warn("Unrecognised appservice access token: %s" % (token,))
|
logger.warn("Unrecognised appservice access token: %s" % (token,))
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
|
@ -870,7 +869,7 @@ class Auth(object):
|
||||||
errcode=Codes.UNKNOWN_TOKEN
|
errcode=Codes.UNKNOWN_TOKEN
|
||||||
)
|
)
|
||||||
request.authenticated_entity = service.sender
|
request.authenticated_entity = service.sender
|
||||||
defer.returnValue(service)
|
return defer.succeed(service)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token."
|
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token."
|
||||||
|
|
|
@ -55,8 +55,14 @@ class BaseHandler(object):
|
||||||
|
|
||||||
def ratelimit(self, requester):
|
def ratelimit(self, requester):
|
||||||
time_now = self.clock.time()
|
time_now = self.clock.time()
|
||||||
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
app_service = self.store.get_app_service_by_user_id(user_id)
|
||||||
|
if app_service is not None:
|
||||||
|
return # do not ratelimit app service senders
|
||||||
|
|
||||||
allowed, time_allowed = self.ratelimiter.send_message(
|
allowed, time_allowed = self.ratelimiter.send_message(
|
||||||
requester.user.to_string(), time_now,
|
user_id, time_now,
|
||||||
msg_rate_hz=self.hs.config.rc_messages_per_second,
|
msg_rate_hz=self.hs.config.rc_messages_per_second,
|
||||||
burst_count=self.hs.config.rc_message_burst_count,
|
burst_count=self.hs.config.rc_message_burst_count,
|
||||||
)
|
)
|
||||||
|
|
|
@ -59,7 +59,7 @@ class ApplicationServicesHandler(object):
|
||||||
Args:
|
Args:
|
||||||
current_id(int): The current maximum ID.
|
current_id(int): The current maximum ID.
|
||||||
"""
|
"""
|
||||||
services = yield self.store.get_app_services()
|
services = self.store.get_app_services()
|
||||||
if not services or not self.notify_appservices:
|
if not services or not self.notify_appservices:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -142,7 +142,7 @@ class ApplicationServicesHandler(object):
|
||||||
association can be found.
|
association can be found.
|
||||||
"""
|
"""
|
||||||
room_alias_str = room_alias.to_string()
|
room_alias_str = room_alias.to_string()
|
||||||
services = yield self.store.get_app_services()
|
services = self.store.get_app_services()
|
||||||
alias_query_services = [
|
alias_query_services = [
|
||||||
s for s in services if (
|
s for s in services if (
|
||||||
s.is_interested_in_alias(room_alias_str)
|
s.is_interested_in_alias(room_alias_str)
|
||||||
|
@ -177,7 +177,7 @@ class ApplicationServicesHandler(object):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_3pe_protocols(self, only_protocol=None):
|
def get_3pe_protocols(self, only_protocol=None):
|
||||||
services = yield self.store.get_app_services()
|
services = self.store.get_app_services()
|
||||||
protocols = {}
|
protocols = {}
|
||||||
|
|
||||||
# Collect up all the individual protocol responses out of the ASes
|
# Collect up all the individual protocol responses out of the ASes
|
||||||
|
@ -224,7 +224,7 @@ class ApplicationServicesHandler(object):
|
||||||
list<ApplicationService>: A list of services interested in this
|
list<ApplicationService>: A list of services interested in this
|
||||||
event based on the service regex.
|
event based on the service regex.
|
||||||
"""
|
"""
|
||||||
services = yield self.store.get_app_services()
|
services = self.store.get_app_services()
|
||||||
interested_list = [
|
interested_list = [
|
||||||
s for s in services if (
|
s for s in services if (
|
||||||
yield s.is_interested(event, self.store)
|
yield s.is_interested(event, self.store)
|
||||||
|
@ -232,23 +232,21 @@ class ApplicationServicesHandler(object):
|
||||||
]
|
]
|
||||||
defer.returnValue(interested_list)
|
defer.returnValue(interested_list)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _get_services_for_user(self, user_id):
|
def _get_services_for_user(self, user_id):
|
||||||
services = yield self.store.get_app_services()
|
services = self.store.get_app_services()
|
||||||
interested_list = [
|
interested_list = [
|
||||||
s for s in services if (
|
s for s in services if (
|
||||||
s.is_interested_in_user(user_id)
|
s.is_interested_in_user(user_id)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
defer.returnValue(interested_list)
|
return defer.succeed(interested_list)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _get_services_for_3pn(self, protocol):
|
def _get_services_for_3pn(self, protocol):
|
||||||
services = yield self.store.get_app_services()
|
services = self.store.get_app_services()
|
||||||
interested_list = [
|
interested_list = [
|
||||||
s for s in services if s.is_interested_in_protocol(protocol)
|
s for s in services if s.is_interested_in_protocol(protocol)
|
||||||
]
|
]
|
||||||
defer.returnValue(interested_list)
|
return defer.succeed(interested_list)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _is_unknown_user(self, user_id):
|
def _is_unknown_user(self, user_id):
|
||||||
|
@ -264,7 +262,7 @@ class ApplicationServicesHandler(object):
|
||||||
return
|
return
|
||||||
|
|
||||||
# user not found; could be the AS though, so check.
|
# user not found; could be the AS though, so check.
|
||||||
services = yield self.store.get_app_services()
|
services = self.store.get_app_services()
|
||||||
service_list = [s for s in services if s.sender == user_id]
|
service_list = [s for s in services if s.sender == user_id]
|
||||||
defer.returnValue(len(service_list) == 0)
|
defer.returnValue(len(service_list) == 0)
|
||||||
|
|
||||||
|
|
|
@ -288,13 +288,12 @@ class DirectoryHandler(BaseHandler):
|
||||||
result = yield as_handler.query_room_alias_exists(room_alias)
|
result = yield as_handler.query_room_alias_exists(room_alias)
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def can_modify_alias(self, alias, user_id=None):
|
def can_modify_alias(self, alias, user_id=None):
|
||||||
# Any application service "interested" in an alias they are regexing on
|
# Any application service "interested" in an alias they are regexing on
|
||||||
# can modify the alias.
|
# can modify the alias.
|
||||||
# Users can only modify the alias if ALL the interested services have
|
# Users can only modify the alias if ALL the interested services have
|
||||||
# non-exclusive locks on the alias (or there are no interested services)
|
# non-exclusive locks on the alias (or there are no interested services)
|
||||||
services = yield self.store.get_app_services()
|
services = self.store.get_app_services()
|
||||||
interested_services = [
|
interested_services = [
|
||||||
s for s in services if s.is_interested_in_alias(alias.to_string())
|
s for s in services if s.is_interested_in_alias(alias.to_string())
|
||||||
]
|
]
|
||||||
|
@ -302,14 +301,12 @@ class DirectoryHandler(BaseHandler):
|
||||||
for service in interested_services:
|
for service in interested_services:
|
||||||
if user_id == service.sender:
|
if user_id == service.sender:
|
||||||
# this user IS the app service so they can do whatever they like
|
# this user IS the app service so they can do whatever they like
|
||||||
defer.returnValue(True)
|
return defer.succeed(True)
|
||||||
return
|
|
||||||
elif service.is_exclusive_alias(alias.to_string()):
|
elif service.is_exclusive_alias(alias.to_string()):
|
||||||
# another service has an exclusive lock on this alias.
|
# another service has an exclusive lock on this alias.
|
||||||
defer.returnValue(False)
|
return defer.succeed(False)
|
||||||
return
|
|
||||||
# either no interested services, or no service with an exclusive lock
|
# either no interested services, or no service with an exclusive lock
|
||||||
defer.returnValue(True)
|
return defer.succeed(True)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _user_can_delete_alias(self, alias, user_id):
|
def _user_can_delete_alias(self, alias, user_id):
|
||||||
|
|
|
@ -65,13 +65,13 @@ class ProfileHandler(BaseHandler):
|
||||||
defer.returnValue(result["displayname"])
|
defer.returnValue(result["displayname"])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def set_displayname(self, target_user, requester, new_displayname):
|
def set_displayname(self, target_user, requester, new_displayname, by_admin=False):
|
||||||
"""target_user is the user whose displayname is to be changed;
|
"""target_user is the user whose displayname is to be changed;
|
||||||
auth_user is the user attempting to make this change."""
|
auth_user is the user attempting to make this change."""
|
||||||
if not self.hs.is_mine(target_user):
|
if not self.hs.is_mine(target_user):
|
||||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||||
|
|
||||||
if target_user != requester.user:
|
if not by_admin and target_user != requester.user:
|
||||||
raise AuthError(400, "Cannot set another user's displayname")
|
raise AuthError(400, "Cannot set another user's displayname")
|
||||||
|
|
||||||
if new_displayname == '':
|
if new_displayname == '':
|
||||||
|
@ -111,13 +111,13 @@ class ProfileHandler(BaseHandler):
|
||||||
defer.returnValue(result["avatar_url"])
|
defer.returnValue(result["avatar_url"])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def set_avatar_url(self, target_user, requester, new_avatar_url):
|
def set_avatar_url(self, target_user, requester, new_avatar_url, by_admin=False):
|
||||||
"""target_user is the user whose avatar_url is to be changed;
|
"""target_user is the user whose avatar_url is to be changed;
|
||||||
auth_user is the user attempting to make this change."""
|
auth_user is the user attempting to make this change."""
|
||||||
if not self.hs.is_mine(target_user):
|
if not self.hs.is_mine(target_user):
|
||||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||||
|
|
||||||
if target_user != requester.user:
|
if not by_admin and target_user != requester.user:
|
||||||
raise AuthError(400, "Cannot set another user's avatar_url")
|
raise AuthError(400, "Cannot set another user's avatar_url")
|
||||||
|
|
||||||
yield self.store.set_profile_avatar_url(
|
yield self.store.set_profile_avatar_url(
|
||||||
|
|
|
@ -19,7 +19,6 @@ import urllib
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
import synapse.types
|
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
|
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
|
||||||
)
|
)
|
||||||
|
@ -194,7 +193,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
def appservice_register(self, user_localpart, as_token):
|
def appservice_register(self, user_localpart, as_token):
|
||||||
user = UserID(user_localpart, self.hs.hostname)
|
user = UserID(user_localpart, self.hs.hostname)
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
service = yield self.store.get_app_service_by_token(as_token)
|
service = self.store.get_app_service_by_token(as_token)
|
||||||
if not service:
|
if not service:
|
||||||
raise AuthError(403, "Invalid application service token.")
|
raise AuthError(403, "Invalid application service token.")
|
||||||
if not service.is_interested_in_user(user_id):
|
if not service.is_interested_in_user(user_id):
|
||||||
|
@ -305,11 +304,10 @@ class RegistrationHandler(BaseHandler):
|
||||||
# XXX: This should be a deferred list, shouldn't it?
|
# XXX: This should be a deferred list, shouldn't it?
|
||||||
yield identity_handler.bind_threepid(c, user_id)
|
yield identity_handler.bind_threepid(c, user_id)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None):
|
def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None):
|
||||||
# valid user IDs must not clash with any user ID namespaces claimed by
|
# valid user IDs must not clash with any user ID namespaces claimed by
|
||||||
# application services.
|
# application services.
|
||||||
services = yield self.store.get_app_services()
|
services = self.store.get_app_services()
|
||||||
interested_services = [
|
interested_services = [
|
||||||
s for s in services
|
s for s in services
|
||||||
if s.is_interested_in_user(user_id)
|
if s.is_interested_in_user(user_id)
|
||||||
|
@ -371,7 +369,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
defer.returnValue(data)
|
defer.returnValue(data)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_or_create_user(self, localpart, displayname, duration_in_ms,
|
def get_or_create_user(self, requester, localpart, displayname, duration_in_ms,
|
||||||
password_hash=None):
|
password_hash=None):
|
||||||
"""Creates a new user if the user does not exist,
|
"""Creates a new user if the user does not exist,
|
||||||
else revokes all previous access tokens and generates a new one.
|
else revokes all previous access tokens and generates a new one.
|
||||||
|
@ -418,9 +416,8 @@ class RegistrationHandler(BaseHandler):
|
||||||
if displayname is not None:
|
if displayname is not None:
|
||||||
logger.info("setting user display name: %s -> %s", user_id, displayname)
|
logger.info("setting user display name: %s -> %s", user_id, displayname)
|
||||||
profile_handler = self.hs.get_handlers().profile_handler
|
profile_handler = self.hs.get_handlers().profile_handler
|
||||||
requester = synapse.types.create_requester(user)
|
|
||||||
yield profile_handler.set_displayname(
|
yield profile_handler.set_displayname(
|
||||||
user, requester, displayname
|
user, requester, displayname, by_admin=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((user_id, token))
|
defer.returnValue((user_id, token))
|
||||||
|
|
|
@ -437,7 +437,7 @@ class RoomEventSource(object):
|
||||||
logger.warn("Stream has topological part!!!! %r", from_key)
|
logger.warn("Stream has topological part!!!! %r", from_key)
|
||||||
from_key = "s%s" % (from_token.stream,)
|
from_key = "s%s" % (from_token.stream,)
|
||||||
|
|
||||||
app_service = yield self.store.get_app_service_by_user_id(
|
app_service = self.store.get_app_service_by_user_id(
|
||||||
user.to_string()
|
user.to_string()
|
||||||
)
|
)
|
||||||
if app_service:
|
if app_service:
|
||||||
|
|
|
@ -788,7 +788,7 @@ class SyncHandler(object):
|
||||||
|
|
||||||
assert since_token
|
assert since_token
|
||||||
|
|
||||||
app_service = yield self.store.get_app_service_by_user_id(user_id)
|
app_service = self.store.get_app_service_by_user_id(user_id)
|
||||||
if app_service:
|
if app_service:
|
||||||
rooms = yield self.store.get_app_service_rooms(app_service)
|
rooms = yield self.store.get_app_service_rooms(app_service)
|
||||||
joined_room_ids = set(r.room_id for r in rooms)
|
joined_room_ids = set(r.room_id for r in rooms)
|
||||||
|
|
|
@ -22,6 +22,7 @@ from synapse.api.auth import get_access_token_from_request
|
||||||
from .base import ClientV1RestServlet, client_path_patterns
|
from .base import ClientV1RestServlet, client_path_patterns
|
||||||
import synapse.util.stringutils as stringutils
|
import synapse.util.stringutils as stringutils
|
||||||
from synapse.http.servlet import parse_json_object_from_request
|
from synapse.http.servlet import parse_json_object_from_request
|
||||||
|
from synapse.types import create_requester
|
||||||
|
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
|
|
||||||
|
@ -391,15 +392,16 @@ class CreateUserRestServlet(ClientV1RestServlet):
|
||||||
user_json = parse_json_object_from_request(request)
|
user_json = parse_json_object_from_request(request)
|
||||||
|
|
||||||
access_token = get_access_token_from_request(request)
|
access_token = get_access_token_from_request(request)
|
||||||
app_service = yield self.store.get_app_service_by_token(
|
app_service = self.store.get_app_service_by_token(
|
||||||
access_token
|
access_token
|
||||||
)
|
)
|
||||||
if not app_service:
|
if not app_service:
|
||||||
raise SynapseError(403, "Invalid application service token.")
|
raise SynapseError(403, "Invalid application service token.")
|
||||||
|
|
||||||
logger.debug("creating user: %s", user_json)
|
requester = create_requester(app_service.sender)
|
||||||
|
|
||||||
response = yield self._do_create(user_json)
|
logger.debug("creating user: %s", user_json)
|
||||||
|
response = yield self._do_create(requester, user_json)
|
||||||
|
|
||||||
defer.returnValue((200, response))
|
defer.returnValue((200, response))
|
||||||
|
|
||||||
|
@ -407,7 +409,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
|
||||||
return 403, {}
|
return 403, {}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _do_create(self, user_json):
|
def _do_create(self, requester, user_json):
|
||||||
yield run_on_reactor()
|
yield run_on_reactor()
|
||||||
|
|
||||||
if "localpart" not in user_json:
|
if "localpart" not in user_json:
|
||||||
|
@ -433,6 +435,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
handler = self.handlers.registration_handler
|
handler = self.handlers.registration_handler
|
||||||
user_id, token = yield handler.get_or_create_user(
|
user_id, token = yield handler.get_or_create_user(
|
||||||
|
requester=requester,
|
||||||
localpart=localpart,
|
localpart=localpart,
|
||||||
displayname=displayname,
|
displayname=displayname,
|
||||||
duration_in_ms=(duration_seconds * 1000),
|
duration_in_ms=(duration_seconds * 1000),
|
||||||
|
|
|
@ -37,7 +37,7 @@ class ApplicationServiceStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_app_services(self):
|
def get_app_services(self):
|
||||||
return defer.succeed(self.services_cache)
|
return self.services_cache
|
||||||
|
|
||||||
def get_app_service_by_user_id(self, user_id):
|
def get_app_service_by_user_id(self, user_id):
|
||||||
"""Retrieve an application service from their user ID.
|
"""Retrieve an application service from their user ID.
|
||||||
|
@ -54,8 +54,8 @@ class ApplicationServiceStore(SQLBaseStore):
|
||||||
"""
|
"""
|
||||||
for service in self.services_cache:
|
for service in self.services_cache:
|
||||||
if service.sender == user_id:
|
if service.sender == user_id:
|
||||||
return defer.succeed(service)
|
return service
|
||||||
return defer.succeed(None)
|
return None
|
||||||
|
|
||||||
def get_app_service_by_token(self, token):
|
def get_app_service_by_token(self, token):
|
||||||
"""Get the application service with the given appservice token.
|
"""Get the application service with the given appservice token.
|
||||||
|
@ -67,8 +67,8 @@ class ApplicationServiceStore(SQLBaseStore):
|
||||||
"""
|
"""
|
||||||
for service in self.services_cache:
|
for service in self.services_cache:
|
||||||
if service.token == token:
|
if service.token == token:
|
||||||
return defer.succeed(service)
|
return service
|
||||||
return defer.succeed(None)
|
return None
|
||||||
|
|
||||||
def get_app_service_rooms(self, service):
|
def get_app_service_rooms(self, service):
|
||||||
"""Get a list of RoomsForUser for this application service.
|
"""Get a list of RoomsForUser for this application service.
|
||||||
|
@ -163,7 +163,7 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
|
||||||
["as_id"]
|
["as_id"]
|
||||||
)
|
)
|
||||||
# NB: This assumes this class is linked with ApplicationServiceStore
|
# NB: This assumes this class is linked with ApplicationServiceStore
|
||||||
as_list = yield self.get_app_services()
|
as_list = self.get_app_services()
|
||||||
services = []
|
services = []
|
||||||
|
|
||||||
for res in results:
|
for res in results:
|
||||||
|
|
|
@ -17,7 +17,7 @@ from twisted.internet import defer
|
||||||
from .. import unittest
|
from .. import unittest
|
||||||
|
|
||||||
from synapse.handlers.register import RegistrationHandler
|
from synapse.handlers.register import RegistrationHandler
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID, create_requester
|
||||||
|
|
||||||
from tests.utils import setup_test_homeserver
|
from tests.utils import setup_test_homeserver
|
||||||
|
|
||||||
|
@ -57,8 +57,9 @@ class RegistrationTestCase(unittest.TestCase):
|
||||||
local_part = "someone"
|
local_part = "someone"
|
||||||
display_name = "someone"
|
display_name = "someone"
|
||||||
user_id = "@someone:test"
|
user_id = "@someone:test"
|
||||||
|
requester = create_requester("@as:test")
|
||||||
result_user_id, result_token = yield self.handler.get_or_create_user(
|
result_user_id, result_token = yield self.handler.get_or_create_user(
|
||||||
local_part, display_name, duration_ms)
|
requester, local_part, display_name, duration_ms)
|
||||||
self.assertEquals(result_user_id, user_id)
|
self.assertEquals(result_user_id, user_id)
|
||||||
self.assertEquals(result_token, 'secret')
|
self.assertEquals(result_token, 'secret')
|
||||||
|
|
||||||
|
@ -74,7 +75,8 @@ class RegistrationTestCase(unittest.TestCase):
|
||||||
local_part = "frank"
|
local_part = "frank"
|
||||||
display_name = "Frank"
|
display_name = "Frank"
|
||||||
user_id = "@frank:test"
|
user_id = "@frank:test"
|
||||||
|
requester = create_requester("@as:test")
|
||||||
result_user_id, result_token = yield self.handler.get_or_create_user(
|
result_user_id, result_token = yield self.handler.get_or_create_user(
|
||||||
local_part, display_name, duration_ms)
|
requester, local_part, display_name, duration_ms)
|
||||||
self.assertEquals(result_user_id, user_id)
|
self.assertEquals(result_user_id, user_id)
|
||||||
self.assertEquals(result_token, 'secret')
|
self.assertEquals(result_token, 'secret')
|
||||||
|
|
|
@ -31,33 +31,21 @@ class CreateUserServletTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
self.request.args = {}
|
self.request.args = {}
|
||||||
|
|
||||||
self.appservice = None
|
|
||||||
self.auth = Mock(get_appservice_by_req=Mock(
|
|
||||||
side_effect=lambda x: defer.succeed(self.appservice))
|
|
||||||
)
|
|
||||||
|
|
||||||
self.auth_result = (False, None, None, None)
|
|
||||||
self.auth_handler = Mock(
|
|
||||||
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
|
|
||||||
get_session_data=Mock(return_value=None)
|
|
||||||
)
|
|
||||||
self.registration_handler = Mock()
|
self.registration_handler = Mock()
|
||||||
self.identity_handler = Mock()
|
|
||||||
self.login_handler = Mock()
|
|
||||||
|
|
||||||
# do the dance to hook it up to the hs global
|
self.appservice = Mock(sender="@as:test")
|
||||||
self.handlers = Mock(
|
self.datastore = Mock(
|
||||||
auth_handler=self.auth_handler,
|
get_app_service_by_token=Mock(return_value=self.appservice)
|
||||||
|
)
|
||||||
|
|
||||||
|
# do the dance to hook things up to the hs global
|
||||||
|
handlers = Mock(
|
||||||
registration_handler=self.registration_handler,
|
registration_handler=self.registration_handler,
|
||||||
identity_handler=self.identity_handler,
|
|
||||||
login_handler=self.login_handler
|
|
||||||
)
|
)
|
||||||
self.hs = Mock()
|
self.hs = Mock()
|
||||||
self.hs.hostname = "supergbig~testing~thing.com"
|
self.hs.hostname = "superbig~testing~thing.com"
|
||||||
self.hs.get_auth = Mock(return_value=self.auth)
|
self.hs.get_datastore = Mock(return_value=self.datastore)
|
||||||
self.hs.get_handlers = Mock(return_value=self.handlers)
|
self.hs.get_handlers = Mock(return_value=handlers)
|
||||||
self.hs.config.enable_registration = True
|
|
||||||
# init the thing we're testing
|
|
||||||
self.servlet = CreateUserRestServlet(self.hs)
|
self.servlet = CreateUserRestServlet(self.hs)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -19,7 +19,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.appservice = None
|
self.appservice = None
|
||||||
self.auth = Mock(get_appservice_by_req=Mock(
|
self.auth = Mock(get_appservice_by_req=Mock(
|
||||||
side_effect=lambda x: defer.succeed(self.appservice))
|
side_effect=lambda x: self.appservice)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.auth_result = (False, None, None, None)
|
self.auth_result = (False, None, None, None)
|
||||||
|
|
|
@ -71,14 +71,12 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
|
||||||
outfile.write(yaml.dump(as_yaml))
|
outfile.write(yaml.dump(as_yaml))
|
||||||
self.as_yaml_files.append(as_token)
|
self.as_yaml_files.append(as_token)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_retrieve_unknown_service_token(self):
|
def test_retrieve_unknown_service_token(self):
|
||||||
service = yield self.store.get_app_service_by_token("invalid_token")
|
service = self.store.get_app_service_by_token("invalid_token")
|
||||||
self.assertEquals(service, None)
|
self.assertEquals(service, None)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_retrieval_of_service(self):
|
def test_retrieval_of_service(self):
|
||||||
stored_service = yield self.store.get_app_service_by_token(
|
stored_service = self.store.get_app_service_by_token(
|
||||||
self.as_token
|
self.as_token
|
||||||
)
|
)
|
||||||
self.assertEquals(stored_service.token, self.as_token)
|
self.assertEquals(stored_service.token, self.as_token)
|
||||||
|
@ -97,9 +95,8 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
|
||||||
[]
|
[]
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_retrieval_of_all_services(self):
|
def test_retrieval_of_all_services(self):
|
||||||
services = yield self.store.get_app_services()
|
services = self.store.get_app_services()
|
||||||
self.assertEquals(len(services), 3)
|
self.assertEquals(len(services), 3)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue