forked from MirrorHub/synapse
Merge branch 'dbkr/split_out_auth_handler' into dbkr/email_unsubscribe
This commit is contained in:
commit
3a3fb2f6f9
14 changed files with 44 additions and 35 deletions
|
@ -24,7 +24,6 @@ from .federation import FederationHandler
|
|||
from .profile import ProfileHandler
|
||||
from .directory import DirectoryHandler
|
||||
from .admin import AdminHandler
|
||||
from .auth import AuthHandler
|
||||
from .identity import IdentityHandler
|
||||
from .receipts import ReceiptsHandler
|
||||
from .search import SearchHandler
|
||||
|
@ -50,7 +49,6 @@ class Handlers(object):
|
|||
self.directory_handler = DirectoryHandler(hs)
|
||||
self.admin_handler = AdminHandler(hs)
|
||||
self.receipts_handler = ReceiptsHandler(hs)
|
||||
self.auth_handler = AuthHandler(hs)
|
||||
self.identity_handler = IdentityHandler(hs)
|
||||
self.search_handler = SearchHandler(hs)
|
||||
self.room_context_handler = RoomContextHandler(hs)
|
||||
|
|
|
@ -413,7 +413,7 @@ class RegistrationHandler(BaseHandler):
|
|||
defer.returnValue((user_id, token))
|
||||
|
||||
def auth_handler(self):
|
||||
return self.hs.get_handlers().auth_handler
|
||||
return self.hs.get_auth_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def guest_access_token_for(self, medium, address, inviter_user_id):
|
||||
|
|
|
@ -441,7 +441,7 @@ class RoomListHandler(BaseHandler):
|
|||
self.hs.config.secondary_directory_servers
|
||||
)
|
||||
self.remote_list_request_cache.set((), deferred)
|
||||
yield deferred
|
||||
self.remote_list_cache = yield deferred
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_aggregated_public_room_list(self):
|
||||
|
|
|
@ -124,6 +124,8 @@ class Mailer(object):
|
|||
user_display_name = yield self.store.get_profile_displayname(
|
||||
UserID.from_string(user_id).localpart
|
||||
)
|
||||
if user_display_name is None:
|
||||
user_display_name = user_id
|
||||
except StoreError:
|
||||
user_display_name = user_id
|
||||
|
||||
|
|
|
@ -58,6 +58,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
self.cas_required_attributes = hs.config.cas_required_attributes
|
||||
self.servername = hs.config.server_name
|
||||
self.http_client = hs.get_simple_http_client()
|
||||
self.auth_handler = self.hs.get_auth_handler()
|
||||
|
||||
def on_GET(self, request):
|
||||
flows = []
|
||||
|
@ -143,7 +144,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
user_id, self.hs.hostname
|
||||
).to_string()
|
||||
|
||||
auth_handler = self.handlers.auth_handler
|
||||
auth_handler = self.auth_handler
|
||||
user_id, access_token, refresh_token = yield auth_handler.login_with_password(
|
||||
user_id=user_id,
|
||||
password=login_submission["password"])
|
||||
|
@ -160,7 +161,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def do_token_login(self, login_submission):
|
||||
token = login_submission['token']
|
||||
auth_handler = self.handlers.auth_handler
|
||||
auth_handler = self.auth_handler
|
||||
user_id = (
|
||||
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
|
||||
)
|
||||
|
@ -194,7 +195,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
user_id = UserID.create(user, self.hs.hostname).to_string()
|
||||
auth_handler = self.handlers.auth_handler
|
||||
auth_handler = self.auth_handler
|
||||
user_exists = yield auth_handler.does_user_exist(user_id)
|
||||
if user_exists:
|
||||
user_id, access_token, refresh_token = (
|
||||
|
@ -243,7 +244,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
user_id = UserID.create(user, self.hs.hostname).to_string()
|
||||
auth_handler = self.handlers.auth_handler
|
||||
auth_handler = self.auth_handler
|
||||
user_exists = yield auth_handler.does_user_exist(user_id)
|
||||
if user_exists:
|
||||
user_id, access_token, refresh_token = (
|
||||
|
@ -412,7 +413,7 @@ class CasTicketServlet(ClientV1RestServlet):
|
|||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
user_id = UserID.create(user, self.hs.hostname).to_string()
|
||||
auth_handler = self.handlers.auth_handler
|
||||
auth_handler = self.auth_handler
|
||||
user_exists = yield auth_handler.does_user_exist(user_id)
|
||||
if not user_exists:
|
||||
user_id, _ = (
|
||||
|
|
|
@ -35,7 +35,7 @@ class PasswordRestServlet(RestServlet):
|
|||
super(PasswordRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.auth_handler = hs.get_handlers().auth_handler
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
|
@ -97,7 +97,7 @@ class ThreepidRestServlet(RestServlet):
|
|||
self.hs = hs
|
||||
self.identity_handler = hs.get_handlers().identity_handler
|
||||
self.auth = hs.get_auth()
|
||||
self.auth_handler = hs.get_handlers().auth_handler
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
|
|
|
@ -104,7 +104,7 @@ class AuthRestServlet(RestServlet):
|
|||
super(AuthRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.auth_handler = hs.get_handlers().auth_handler
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self.registration_handler = hs.get_handlers().registration_handler
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
|
@ -49,7 +49,7 @@ class RegisterRestServlet(RestServlet):
|
|||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
self.auth_handler = hs.get_handlers().auth_handler
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self.registration_handler = hs.get_handlers().registration_handler
|
||||
self.identity_handler = hs.get_handlers().identity_handler
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ class TokenRefreshRestServlet(RestServlet):
|
|||
body = parse_json_object_from_request(request)
|
||||
try:
|
||||
old_refresh_token = body["refresh_token"]
|
||||
auth_handler = self.hs.get_handlers().auth_handler
|
||||
auth_handler = self.hs.get_auth_handler()
|
||||
(user_id, new_refresh_token) = yield self.store.exchange_refresh_token(
|
||||
old_refresh_token, auth_handler.generate_refresh_token)
|
||||
new_access_token = yield auth_handler.issue_access_token(user_id)
|
||||
|
|
|
@ -33,6 +33,7 @@ from synapse.handlers.presence import PresenceHandler
|
|||
from synapse.handlers.sync import SyncHandler
|
||||
from synapse.handlers.typing import TypingHandler
|
||||
from synapse.handlers.room import RoomListHandler
|
||||
from synapse.handlers.auth import AuthHandler
|
||||
from synapse.handlers.appservice import ApplicationServicesHandler
|
||||
from synapse.state import StateHandler
|
||||
from synapse.storage import DataStore
|
||||
|
@ -89,6 +90,7 @@ class HomeServer(object):
|
|||
'sync_handler',
|
||||
'typing_handler',
|
||||
'room_list_handler',
|
||||
'auth_handler',
|
||||
'application_service_api',
|
||||
'application_service_scheduler',
|
||||
'application_service_handler',
|
||||
|
@ -190,6 +192,9 @@ class HomeServer(object):
|
|||
def build_room_list_handler(self):
|
||||
return RoomListHandler(self)
|
||||
|
||||
def build_auth_handler(self):
|
||||
return AuthHandler(self)
|
||||
|
||||
def build_application_service_api(self):
|
||||
return ApplicationServiceApi(self)
|
||||
|
||||
|
|
|
@ -106,7 +106,8 @@ class PushRuleStore(SQLBaseStore):
|
|||
desc="bulk_get_push_rules_enabled",
|
||||
)
|
||||
for row in rows:
|
||||
results.setdefault(row['user_name'], {})[row['rule_id']] = row['enabled']
|
||||
enabled = bool(row['enabled'])
|
||||
results.setdefault(row['user_name'], {})[row['rule_id']] = enabled
|
||||
defer.returnValue(results)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
|
@ -19,17 +19,24 @@ from ._base import SQLBaseStore
|
|||
|
||||
from unpaddedbase64 import encode_base64
|
||||
from synapse.crypto.event_signing import compute_event_reference_hash
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
|
||||
|
||||
class SignatureStore(SQLBaseStore):
|
||||
"""Persistence for event signatures and hashes"""
|
||||
|
||||
@cached(lru=True)
|
||||
def get_event_reference_hash(self, event_id):
|
||||
return self._get_event_reference_hashes_txn(event_id)
|
||||
|
||||
@cachedList(cached_method_name="get_event_reference_hash",
|
||||
list_name="event_ids", num_args=1)
|
||||
def get_event_reference_hashes(self, event_ids):
|
||||
def f(txn):
|
||||
return [
|
||||
self._get_event_reference_hashes_txn(txn, ev)
|
||||
for ev in event_ids
|
||||
]
|
||||
return {
|
||||
event_id: self._get_event_reference_hashes_txn(txn, event_id)
|
||||
for event_id in event_ids
|
||||
}
|
||||
|
||||
return self.runInteraction(
|
||||
"get_event_reference_hashes",
|
||||
|
@ -41,15 +48,15 @@ class SignatureStore(SQLBaseStore):
|
|||
hashes = yield self.get_event_reference_hashes(
|
||||
event_ids
|
||||
)
|
||||
hashes = [
|
||||
{
|
||||
hashes = {
|
||||
e_id: {
|
||||
k: encode_base64(v) for k, v in h.items()
|
||||
if k == "sha256"
|
||||
}
|
||||
for h in hashes
|
||||
]
|
||||
for e_id, h in hashes.items()
|
||||
}
|
||||
|
||||
defer.returnValue(zip(event_ids, hashes))
|
||||
defer.returnValue(hashes.items())
|
||||
|
||||
def _get_event_reference_hashes_txn(self, txn, event_id):
|
||||
"""Get all the hashes for a given PDU.
|
||||
|
|
|
@ -33,7 +33,6 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||
|
||||
# do the dance to hook it up to the hs global
|
||||
self.handlers = Mock(
|
||||
auth_handler=self.auth_handler,
|
||||
registration_handler=self.registration_handler,
|
||||
identity_handler=self.identity_handler,
|
||||
login_handler=self.login_handler
|
||||
|
@ -42,6 +41,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||
self.hs.hostname = "superbig~testing~thing.com"
|
||||
self.hs.get_auth = Mock(return_value=self.auth)
|
||||
self.hs.get_handlers = Mock(return_value=self.handlers)
|
||||
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
|
||||
self.hs.config.enable_registration = True
|
||||
|
||||
# init the thing we're testing
|
||||
|
|
|
@ -81,16 +81,11 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
|
|||
)
|
||||
|
||||
# bcrypt is far too slow to be doing in unit tests
|
||||
def swap_out_hash_for_testing(old_build_handlers):
|
||||
def build_handlers():
|
||||
handlers = old_build_handlers()
|
||||
auth_handler = handlers.auth_handler
|
||||
auth_handler.hash = lambda p: hashlib.md5(p).hexdigest()
|
||||
auth_handler.validate_hash = lambda p, h: hashlib.md5(p).hexdigest() == h
|
||||
return handlers
|
||||
return build_handlers
|
||||
|
||||
hs.build_handlers = swap_out_hash_for_testing(hs.build_handlers)
|
||||
# Need to let the HS build an auth handler and then mess with it
|
||||
# because AuthHandler's constructor requires the HS, so we can't make one
|
||||
# beforehand and pass it in to the HS's constructor (chicken / egg)
|
||||
hs.get_auth_handler().hash = lambda p: hashlib.md5(p).hexdigest()
|
||||
hs.get_auth_handler().validate_hash = lambda p, h: hashlib.md5(p).hexdigest() == h
|
||||
|
||||
fed = kargs.get("resource_for_federation", None)
|
||||
if fed:
|
||||
|
|
Loading…
Reference in a new issue