forked from MirrorHub/synapse
Merge branch 'dbkr/split_out_auth_handler' into dbkr/email_unsubscribe
This commit is contained in:
commit
3a3fb2f6f9
14 changed files with 44 additions and 35 deletions
|
@ -24,7 +24,6 @@ from .federation import FederationHandler
|
||||||
from .profile import ProfileHandler
|
from .profile import ProfileHandler
|
||||||
from .directory import DirectoryHandler
|
from .directory import DirectoryHandler
|
||||||
from .admin import AdminHandler
|
from .admin import AdminHandler
|
||||||
from .auth import AuthHandler
|
|
||||||
from .identity import IdentityHandler
|
from .identity import IdentityHandler
|
||||||
from .receipts import ReceiptsHandler
|
from .receipts import ReceiptsHandler
|
||||||
from .search import SearchHandler
|
from .search import SearchHandler
|
||||||
|
@ -50,7 +49,6 @@ class Handlers(object):
|
||||||
self.directory_handler = DirectoryHandler(hs)
|
self.directory_handler = DirectoryHandler(hs)
|
||||||
self.admin_handler = AdminHandler(hs)
|
self.admin_handler = AdminHandler(hs)
|
||||||
self.receipts_handler = ReceiptsHandler(hs)
|
self.receipts_handler = ReceiptsHandler(hs)
|
||||||
self.auth_handler = AuthHandler(hs)
|
|
||||||
self.identity_handler = IdentityHandler(hs)
|
self.identity_handler = IdentityHandler(hs)
|
||||||
self.search_handler = SearchHandler(hs)
|
self.search_handler = SearchHandler(hs)
|
||||||
self.room_context_handler = RoomContextHandler(hs)
|
self.room_context_handler = RoomContextHandler(hs)
|
||||||
|
|
|
@ -413,7 +413,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
defer.returnValue((user_id, token))
|
defer.returnValue((user_id, token))
|
||||||
|
|
||||||
def auth_handler(self):
|
def auth_handler(self):
|
||||||
return self.hs.get_handlers().auth_handler
|
return self.hs.get_auth_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def guest_access_token_for(self, medium, address, inviter_user_id):
|
def guest_access_token_for(self, medium, address, inviter_user_id):
|
||||||
|
|
|
@ -441,7 +441,7 @@ class RoomListHandler(BaseHandler):
|
||||||
self.hs.config.secondary_directory_servers
|
self.hs.config.secondary_directory_servers
|
||||||
)
|
)
|
||||||
self.remote_list_request_cache.set((), deferred)
|
self.remote_list_request_cache.set((), deferred)
|
||||||
yield deferred
|
self.remote_list_cache = yield deferred
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_aggregated_public_room_list(self):
|
def get_aggregated_public_room_list(self):
|
||||||
|
|
|
@ -124,6 +124,8 @@ class Mailer(object):
|
||||||
user_display_name = yield self.store.get_profile_displayname(
|
user_display_name = yield self.store.get_profile_displayname(
|
||||||
UserID.from_string(user_id).localpart
|
UserID.from_string(user_id).localpart
|
||||||
)
|
)
|
||||||
|
if user_display_name is None:
|
||||||
|
user_display_name = user_id
|
||||||
except StoreError:
|
except StoreError:
|
||||||
user_display_name = user_id
|
user_display_name = user_id
|
||||||
|
|
||||||
|
|
|
@ -58,6 +58,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
self.cas_required_attributes = hs.config.cas_required_attributes
|
self.cas_required_attributes = hs.config.cas_required_attributes
|
||||||
self.servername = hs.config.server_name
|
self.servername = hs.config.server_name
|
||||||
self.http_client = hs.get_simple_http_client()
|
self.http_client = hs.get_simple_http_client()
|
||||||
|
self.auth_handler = self.hs.get_auth_handler()
|
||||||
|
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
flows = []
|
flows = []
|
||||||
|
@ -143,7 +144,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
user_id, self.hs.hostname
|
user_id, self.hs.hostname
|
||||||
).to_string()
|
).to_string()
|
||||||
|
|
||||||
auth_handler = self.handlers.auth_handler
|
auth_handler = self.auth_handler
|
||||||
user_id, access_token, refresh_token = yield auth_handler.login_with_password(
|
user_id, access_token, refresh_token = yield auth_handler.login_with_password(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
password=login_submission["password"])
|
password=login_submission["password"])
|
||||||
|
@ -160,7 +161,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def do_token_login(self, login_submission):
|
def do_token_login(self, login_submission):
|
||||||
token = login_submission['token']
|
token = login_submission['token']
|
||||||
auth_handler = self.handlers.auth_handler
|
auth_handler = self.auth_handler
|
||||||
user_id = (
|
user_id = (
|
||||||
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
|
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
|
||||||
)
|
)
|
||||||
|
@ -194,7 +195,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
||||||
|
|
||||||
user_id = UserID.create(user, self.hs.hostname).to_string()
|
user_id = UserID.create(user, self.hs.hostname).to_string()
|
||||||
auth_handler = self.handlers.auth_handler
|
auth_handler = self.auth_handler
|
||||||
user_exists = yield auth_handler.does_user_exist(user_id)
|
user_exists = yield auth_handler.does_user_exist(user_id)
|
||||||
if user_exists:
|
if user_exists:
|
||||||
user_id, access_token, refresh_token = (
|
user_id, access_token, refresh_token = (
|
||||||
|
@ -243,7 +244,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
|
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
|
||||||
|
|
||||||
user_id = UserID.create(user, self.hs.hostname).to_string()
|
user_id = UserID.create(user, self.hs.hostname).to_string()
|
||||||
auth_handler = self.handlers.auth_handler
|
auth_handler = self.auth_handler
|
||||||
user_exists = yield auth_handler.does_user_exist(user_id)
|
user_exists = yield auth_handler.does_user_exist(user_id)
|
||||||
if user_exists:
|
if user_exists:
|
||||||
user_id, access_token, refresh_token = (
|
user_id, access_token, refresh_token = (
|
||||||
|
@ -412,7 +413,7 @@ class CasTicketServlet(ClientV1RestServlet):
|
||||||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
||||||
|
|
||||||
user_id = UserID.create(user, self.hs.hostname).to_string()
|
user_id = UserID.create(user, self.hs.hostname).to_string()
|
||||||
auth_handler = self.handlers.auth_handler
|
auth_handler = self.auth_handler
|
||||||
user_exists = yield auth_handler.does_user_exist(user_id)
|
user_exists = yield auth_handler.does_user_exist(user_id)
|
||||||
if not user_exists:
|
if not user_exists:
|
||||||
user_id, _ = (
|
user_id, _ = (
|
||||||
|
|
|
@ -35,7 +35,7 @@ class PasswordRestServlet(RestServlet):
|
||||||
super(PasswordRestServlet, self).__init__()
|
super(PasswordRestServlet, self).__init__()
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.auth_handler = hs.get_handlers().auth_handler
|
self.auth_handler = hs.get_auth_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
|
@ -97,7 +97,7 @@ class ThreepidRestServlet(RestServlet):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.identity_handler = hs.get_handlers().identity_handler
|
self.identity_handler = hs.get_handlers().identity_handler
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.auth_handler = hs.get_handlers().auth_handler
|
self.auth_handler = hs.get_auth_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
|
|
|
@ -104,7 +104,7 @@ class AuthRestServlet(RestServlet):
|
||||||
super(AuthRestServlet, self).__init__()
|
super(AuthRestServlet, self).__init__()
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.auth_handler = hs.get_handlers().auth_handler
|
self.auth_handler = hs.get_auth_handler()
|
||||||
self.registration_handler = hs.get_handlers().registration_handler
|
self.registration_handler = hs.get_handlers().registration_handler
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -49,7 +49,7 @@ class RegisterRestServlet(RestServlet):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.auth_handler = hs.get_handlers().auth_handler
|
self.auth_handler = hs.get_auth_handler()
|
||||||
self.registration_handler = hs.get_handlers().registration_handler
|
self.registration_handler = hs.get_handlers().registration_handler
|
||||||
self.identity_handler = hs.get_handlers().identity_handler
|
self.identity_handler = hs.get_handlers().identity_handler
|
||||||
|
|
||||||
|
|
|
@ -38,7 +38,7 @@ class TokenRefreshRestServlet(RestServlet):
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
try:
|
try:
|
||||||
old_refresh_token = body["refresh_token"]
|
old_refresh_token = body["refresh_token"]
|
||||||
auth_handler = self.hs.get_handlers().auth_handler
|
auth_handler = self.hs.get_auth_handler()
|
||||||
(user_id, new_refresh_token) = yield self.store.exchange_refresh_token(
|
(user_id, new_refresh_token) = yield self.store.exchange_refresh_token(
|
||||||
old_refresh_token, auth_handler.generate_refresh_token)
|
old_refresh_token, auth_handler.generate_refresh_token)
|
||||||
new_access_token = yield auth_handler.issue_access_token(user_id)
|
new_access_token = yield auth_handler.issue_access_token(user_id)
|
||||||
|
|
|
@ -33,6 +33,7 @@ from synapse.handlers.presence import PresenceHandler
|
||||||
from synapse.handlers.sync import SyncHandler
|
from synapse.handlers.sync import SyncHandler
|
||||||
from synapse.handlers.typing import TypingHandler
|
from synapse.handlers.typing import TypingHandler
|
||||||
from synapse.handlers.room import RoomListHandler
|
from synapse.handlers.room import RoomListHandler
|
||||||
|
from synapse.handlers.auth import AuthHandler
|
||||||
from synapse.handlers.appservice import ApplicationServicesHandler
|
from synapse.handlers.appservice import ApplicationServicesHandler
|
||||||
from synapse.state import StateHandler
|
from synapse.state import StateHandler
|
||||||
from synapse.storage import DataStore
|
from synapse.storage import DataStore
|
||||||
|
@ -89,6 +90,7 @@ class HomeServer(object):
|
||||||
'sync_handler',
|
'sync_handler',
|
||||||
'typing_handler',
|
'typing_handler',
|
||||||
'room_list_handler',
|
'room_list_handler',
|
||||||
|
'auth_handler',
|
||||||
'application_service_api',
|
'application_service_api',
|
||||||
'application_service_scheduler',
|
'application_service_scheduler',
|
||||||
'application_service_handler',
|
'application_service_handler',
|
||||||
|
@ -190,6 +192,9 @@ class HomeServer(object):
|
||||||
def build_room_list_handler(self):
|
def build_room_list_handler(self):
|
||||||
return RoomListHandler(self)
|
return RoomListHandler(self)
|
||||||
|
|
||||||
|
def build_auth_handler(self):
|
||||||
|
return AuthHandler(self)
|
||||||
|
|
||||||
def build_application_service_api(self):
|
def build_application_service_api(self):
|
||||||
return ApplicationServiceApi(self)
|
return ApplicationServiceApi(self)
|
||||||
|
|
||||||
|
|
|
@ -106,7 +106,8 @@ class PushRuleStore(SQLBaseStore):
|
||||||
desc="bulk_get_push_rules_enabled",
|
desc="bulk_get_push_rules_enabled",
|
||||||
)
|
)
|
||||||
for row in rows:
|
for row in rows:
|
||||||
results.setdefault(row['user_name'], {})[row['rule_id']] = row['enabled']
|
enabled = bool(row['enabled'])
|
||||||
|
results.setdefault(row['user_name'], {})[row['rule_id']] = enabled
|
||||||
defer.returnValue(results)
|
defer.returnValue(results)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -19,17 +19,24 @@ from ._base import SQLBaseStore
|
||||||
|
|
||||||
from unpaddedbase64 import encode_base64
|
from unpaddedbase64 import encode_base64
|
||||||
from synapse.crypto.event_signing import compute_event_reference_hash
|
from synapse.crypto.event_signing import compute_event_reference_hash
|
||||||
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
|
|
||||||
|
|
||||||
class SignatureStore(SQLBaseStore):
|
class SignatureStore(SQLBaseStore):
|
||||||
"""Persistence for event signatures and hashes"""
|
"""Persistence for event signatures and hashes"""
|
||||||
|
|
||||||
|
@cached(lru=True)
|
||||||
|
def get_event_reference_hash(self, event_id):
|
||||||
|
return self._get_event_reference_hashes_txn(event_id)
|
||||||
|
|
||||||
|
@cachedList(cached_method_name="get_event_reference_hash",
|
||||||
|
list_name="event_ids", num_args=1)
|
||||||
def get_event_reference_hashes(self, event_ids):
|
def get_event_reference_hashes(self, event_ids):
|
||||||
def f(txn):
|
def f(txn):
|
||||||
return [
|
return {
|
||||||
self._get_event_reference_hashes_txn(txn, ev)
|
event_id: self._get_event_reference_hashes_txn(txn, event_id)
|
||||||
for ev in event_ids
|
for event_id in event_ids
|
||||||
]
|
}
|
||||||
|
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"get_event_reference_hashes",
|
"get_event_reference_hashes",
|
||||||
|
@ -41,15 +48,15 @@ class SignatureStore(SQLBaseStore):
|
||||||
hashes = yield self.get_event_reference_hashes(
|
hashes = yield self.get_event_reference_hashes(
|
||||||
event_ids
|
event_ids
|
||||||
)
|
)
|
||||||
hashes = [
|
hashes = {
|
||||||
{
|
e_id: {
|
||||||
k: encode_base64(v) for k, v in h.items()
|
k: encode_base64(v) for k, v in h.items()
|
||||||
if k == "sha256"
|
if k == "sha256"
|
||||||
}
|
}
|
||||||
for h in hashes
|
for e_id, h in hashes.items()
|
||||||
]
|
}
|
||||||
|
|
||||||
defer.returnValue(zip(event_ids, hashes))
|
defer.returnValue(hashes.items())
|
||||||
|
|
||||||
def _get_event_reference_hashes_txn(self, txn, event_id):
|
def _get_event_reference_hashes_txn(self, txn, event_id):
|
||||||
"""Get all the hashes for a given PDU.
|
"""Get all the hashes for a given PDU.
|
||||||
|
|
|
@ -33,7 +33,6 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
|
|
||||||
# do the dance to hook it up to the hs global
|
# do the dance to hook it up to the hs global
|
||||||
self.handlers = Mock(
|
self.handlers = Mock(
|
||||||
auth_handler=self.auth_handler,
|
|
||||||
registration_handler=self.registration_handler,
|
registration_handler=self.registration_handler,
|
||||||
identity_handler=self.identity_handler,
|
identity_handler=self.identity_handler,
|
||||||
login_handler=self.login_handler
|
login_handler=self.login_handler
|
||||||
|
@ -42,6 +41,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
self.hs.hostname = "superbig~testing~thing.com"
|
self.hs.hostname = "superbig~testing~thing.com"
|
||||||
self.hs.get_auth = Mock(return_value=self.auth)
|
self.hs.get_auth = Mock(return_value=self.auth)
|
||||||
self.hs.get_handlers = Mock(return_value=self.handlers)
|
self.hs.get_handlers = Mock(return_value=self.handlers)
|
||||||
|
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
|
||||||
self.hs.config.enable_registration = True
|
self.hs.config.enable_registration = True
|
||||||
|
|
||||||
# init the thing we're testing
|
# init the thing we're testing
|
||||||
|
|
|
@ -81,16 +81,11 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
|
||||||
)
|
)
|
||||||
|
|
||||||
# bcrypt is far too slow to be doing in unit tests
|
# bcrypt is far too slow to be doing in unit tests
|
||||||
def swap_out_hash_for_testing(old_build_handlers):
|
# Need to let the HS build an auth handler and then mess with it
|
||||||
def build_handlers():
|
# because AuthHandler's constructor requires the HS, so we can't make one
|
||||||
handlers = old_build_handlers()
|
# beforehand and pass it in to the HS's constructor (chicken / egg)
|
||||||
auth_handler = handlers.auth_handler
|
hs.get_auth_handler().hash = lambda p: hashlib.md5(p).hexdigest()
|
||||||
auth_handler.hash = lambda p: hashlib.md5(p).hexdigest()
|
hs.get_auth_handler().validate_hash = lambda p, h: hashlib.md5(p).hexdigest() == h
|
||||||
auth_handler.validate_hash = lambda p, h: hashlib.md5(p).hexdigest() == h
|
|
||||||
return handlers
|
|
||||||
return build_handlers
|
|
||||||
|
|
||||||
hs.build_handlers = swap_out_hash_for_testing(hs.build_handlers)
|
|
||||||
|
|
||||||
fed = kargs.get("resource_for_federation", None)
|
fed = kargs.get("resource_for_federation", None)
|
||||||
if fed:
|
if fed:
|
||||||
|
|
Loading…
Reference in a new issue