Merge branch 'dbkr/split_out_auth_handler' into dbkr/email_unsubscribe

This commit is contained in:
David Baker 2016-06-02 13:35:25 +01:00
commit 3a3fb2f6f9
14 changed files with 44 additions and 35 deletions

View file

@ -24,7 +24,6 @@ from .federation import FederationHandler
from .profile import ProfileHandler from .profile import ProfileHandler
from .directory import DirectoryHandler from .directory import DirectoryHandler
from .admin import AdminHandler from .admin import AdminHandler
from .auth import AuthHandler
from .identity import IdentityHandler from .identity import IdentityHandler
from .receipts import ReceiptsHandler from .receipts import ReceiptsHandler
from .search import SearchHandler from .search import SearchHandler
@ -50,7 +49,6 @@ class Handlers(object):
self.directory_handler = DirectoryHandler(hs) self.directory_handler = DirectoryHandler(hs)
self.admin_handler = AdminHandler(hs) self.admin_handler = AdminHandler(hs)
self.receipts_handler = ReceiptsHandler(hs) self.receipts_handler = ReceiptsHandler(hs)
self.auth_handler = AuthHandler(hs)
self.identity_handler = IdentityHandler(hs) self.identity_handler = IdentityHandler(hs)
self.search_handler = SearchHandler(hs) self.search_handler = SearchHandler(hs)
self.room_context_handler = RoomContextHandler(hs) self.room_context_handler = RoomContextHandler(hs)

View file

@ -413,7 +413,7 @@ class RegistrationHandler(BaseHandler):
defer.returnValue((user_id, token)) defer.returnValue((user_id, token))
def auth_handler(self): def auth_handler(self):
return self.hs.get_handlers().auth_handler return self.hs.get_auth_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def guest_access_token_for(self, medium, address, inviter_user_id): def guest_access_token_for(self, medium, address, inviter_user_id):

View file

@ -441,7 +441,7 @@ class RoomListHandler(BaseHandler):
self.hs.config.secondary_directory_servers self.hs.config.secondary_directory_servers
) )
self.remote_list_request_cache.set((), deferred) self.remote_list_request_cache.set((), deferred)
yield deferred self.remote_list_cache = yield deferred
@defer.inlineCallbacks @defer.inlineCallbacks
def get_aggregated_public_room_list(self): def get_aggregated_public_room_list(self):

View file

@ -124,6 +124,8 @@ class Mailer(object):
user_display_name = yield self.store.get_profile_displayname( user_display_name = yield self.store.get_profile_displayname(
UserID.from_string(user_id).localpart UserID.from_string(user_id).localpart
) )
if user_display_name is None:
user_display_name = user_id
except StoreError: except StoreError:
user_display_name = user_id user_display_name = user_id

View file

@ -58,6 +58,7 @@ class LoginRestServlet(ClientV1RestServlet):
self.cas_required_attributes = hs.config.cas_required_attributes self.cas_required_attributes = hs.config.cas_required_attributes
self.servername = hs.config.server_name self.servername = hs.config.server_name
self.http_client = hs.get_simple_http_client() self.http_client = hs.get_simple_http_client()
self.auth_handler = self.hs.get_auth_handler()
def on_GET(self, request): def on_GET(self, request):
flows = [] flows = []
@ -143,7 +144,7 @@ class LoginRestServlet(ClientV1RestServlet):
user_id, self.hs.hostname user_id, self.hs.hostname
).to_string() ).to_string()
auth_handler = self.handlers.auth_handler auth_handler = self.auth_handler
user_id, access_token, refresh_token = yield auth_handler.login_with_password( user_id, access_token, refresh_token = yield auth_handler.login_with_password(
user_id=user_id, user_id=user_id,
password=login_submission["password"]) password=login_submission["password"])
@ -160,7 +161,7 @@ class LoginRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def do_token_login(self, login_submission): def do_token_login(self, login_submission):
token = login_submission['token'] token = login_submission['token']
auth_handler = self.handlers.auth_handler auth_handler = self.auth_handler
user_id = ( user_id = (
yield auth_handler.validate_short_term_login_token_and_get_user_id(token) yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
) )
@ -194,7 +195,7 @@ class LoginRestServlet(ClientV1RestServlet):
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
user_id = UserID.create(user, self.hs.hostname).to_string() user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.handlers.auth_handler auth_handler = self.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id) user_exists = yield auth_handler.does_user_exist(user_id)
if user_exists: if user_exists:
user_id, access_token, refresh_token = ( user_id, access_token, refresh_token = (
@ -243,7 +244,7 @@ class LoginRestServlet(ClientV1RestServlet):
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
user_id = UserID.create(user, self.hs.hostname).to_string() user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.handlers.auth_handler auth_handler = self.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id) user_exists = yield auth_handler.does_user_exist(user_id)
if user_exists: if user_exists:
user_id, access_token, refresh_token = ( user_id, access_token, refresh_token = (
@ -412,7 +413,7 @@ class CasTicketServlet(ClientV1RestServlet):
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
user_id = UserID.create(user, self.hs.hostname).to_string() user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.handlers.auth_handler auth_handler = self.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id) user_exists = yield auth_handler.does_user_exist(user_id)
if not user_exists: if not user_exists:
user_id, _ = ( user_id, _ = (

View file

@ -35,7 +35,7 @@ class PasswordRestServlet(RestServlet):
super(PasswordRestServlet, self).__init__() super(PasswordRestServlet, self).__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_handlers().auth_handler self.auth_handler = hs.get_auth_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
@ -97,7 +97,7 @@ class ThreepidRestServlet(RestServlet):
self.hs = hs self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_handlers().auth_handler self.auth_handler = hs.get_auth_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):

View file

@ -104,7 +104,7 @@ class AuthRestServlet(RestServlet):
super(AuthRestServlet, self).__init__() super(AuthRestServlet, self).__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_handlers().auth_handler self.auth_handler = hs.get_auth_handler()
self.registration_handler = hs.get_handlers().registration_handler self.registration_handler = hs.get_handlers().registration_handler
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -49,7 +49,7 @@ class RegisterRestServlet(RestServlet):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth_handler = hs.get_handlers().auth_handler self.auth_handler = hs.get_auth_handler()
self.registration_handler = hs.get_handlers().registration_handler self.registration_handler = hs.get_handlers().registration_handler
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler

View file

@ -38,7 +38,7 @@ class TokenRefreshRestServlet(RestServlet):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
try: try:
old_refresh_token = body["refresh_token"] old_refresh_token = body["refresh_token"]
auth_handler = self.hs.get_handlers().auth_handler auth_handler = self.hs.get_auth_handler()
(user_id, new_refresh_token) = yield self.store.exchange_refresh_token( (user_id, new_refresh_token) = yield self.store.exchange_refresh_token(
old_refresh_token, auth_handler.generate_refresh_token) old_refresh_token, auth_handler.generate_refresh_token)
new_access_token = yield auth_handler.issue_access_token(user_id) new_access_token = yield auth_handler.issue_access_token(user_id)

View file

@ -33,6 +33,7 @@ from synapse.handlers.presence import PresenceHandler
from synapse.handlers.sync import SyncHandler from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import TypingHandler from synapse.handlers.typing import TypingHandler
from synapse.handlers.room import RoomListHandler from synapse.handlers.room import RoomListHandler
from synapse.handlers.auth import AuthHandler
from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.state import StateHandler from synapse.state import StateHandler
from synapse.storage import DataStore from synapse.storage import DataStore
@ -89,6 +90,7 @@ class HomeServer(object):
'sync_handler', 'sync_handler',
'typing_handler', 'typing_handler',
'room_list_handler', 'room_list_handler',
'auth_handler',
'application_service_api', 'application_service_api',
'application_service_scheduler', 'application_service_scheduler',
'application_service_handler', 'application_service_handler',
@ -190,6 +192,9 @@ class HomeServer(object):
def build_room_list_handler(self): def build_room_list_handler(self):
return RoomListHandler(self) return RoomListHandler(self)
def build_auth_handler(self):
return AuthHandler(self)
def build_application_service_api(self): def build_application_service_api(self):
return ApplicationServiceApi(self) return ApplicationServiceApi(self)

View file

@ -106,7 +106,8 @@ class PushRuleStore(SQLBaseStore):
desc="bulk_get_push_rules_enabled", desc="bulk_get_push_rules_enabled",
) )
for row in rows: for row in rows:
results.setdefault(row['user_name'], {})[row['rule_id']] = row['enabled'] enabled = bool(row['enabled'])
results.setdefault(row['user_name'], {})[row['rule_id']] = enabled
defer.returnValue(results) defer.returnValue(results)
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -19,17 +19,24 @@ from ._base import SQLBaseStore
from unpaddedbase64 import encode_base64 from unpaddedbase64 import encode_base64
from synapse.crypto.event_signing import compute_event_reference_hash from synapse.crypto.event_signing import compute_event_reference_hash
from synapse.util.caches.descriptors import cached, cachedList
class SignatureStore(SQLBaseStore): class SignatureStore(SQLBaseStore):
"""Persistence for event signatures and hashes""" """Persistence for event signatures and hashes"""
@cached(lru=True)
def get_event_reference_hash(self, event_id):
return self._get_event_reference_hashes_txn(event_id)
@cachedList(cached_method_name="get_event_reference_hash",
list_name="event_ids", num_args=1)
def get_event_reference_hashes(self, event_ids): def get_event_reference_hashes(self, event_ids):
def f(txn): def f(txn):
return [ return {
self._get_event_reference_hashes_txn(txn, ev) event_id: self._get_event_reference_hashes_txn(txn, event_id)
for ev in event_ids for event_id in event_ids
] }
return self.runInteraction( return self.runInteraction(
"get_event_reference_hashes", "get_event_reference_hashes",
@ -41,15 +48,15 @@ class SignatureStore(SQLBaseStore):
hashes = yield self.get_event_reference_hashes( hashes = yield self.get_event_reference_hashes(
event_ids event_ids
) )
hashes = [ hashes = {
{ e_id: {
k: encode_base64(v) for k, v in h.items() k: encode_base64(v) for k, v in h.items()
if k == "sha256" if k == "sha256"
} }
for h in hashes for e_id, h in hashes.items()
] }
defer.returnValue(zip(event_ids, hashes)) defer.returnValue(hashes.items())
def _get_event_reference_hashes_txn(self, txn, event_id): def _get_event_reference_hashes_txn(self, txn, event_id):
"""Get all the hashes for a given PDU. """Get all the hashes for a given PDU.

View file

@ -33,7 +33,6 @@ class RegisterRestServletTestCase(unittest.TestCase):
# do the dance to hook it up to the hs global # do the dance to hook it up to the hs global
self.handlers = Mock( self.handlers = Mock(
auth_handler=self.auth_handler,
registration_handler=self.registration_handler, registration_handler=self.registration_handler,
identity_handler=self.identity_handler, identity_handler=self.identity_handler,
login_handler=self.login_handler login_handler=self.login_handler
@ -42,6 +41,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.hs.hostname = "superbig~testing~thing.com" self.hs.hostname = "superbig~testing~thing.com"
self.hs.get_auth = Mock(return_value=self.auth) self.hs.get_auth = Mock(return_value=self.auth)
self.hs.get_handlers = Mock(return_value=self.handlers) self.hs.get_handlers = Mock(return_value=self.handlers)
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
self.hs.config.enable_registration = True self.hs.config.enable_registration = True
# init the thing we're testing # init the thing we're testing

View file

@ -81,16 +81,11 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
) )
# bcrypt is far too slow to be doing in unit tests # bcrypt is far too slow to be doing in unit tests
def swap_out_hash_for_testing(old_build_handlers): # Need to let the HS build an auth handler and then mess with it
def build_handlers(): # because AuthHandler's constructor requires the HS, so we can't make one
handlers = old_build_handlers() # beforehand and pass it in to the HS's constructor (chicken / egg)
auth_handler = handlers.auth_handler hs.get_auth_handler().hash = lambda p: hashlib.md5(p).hexdigest()
auth_handler.hash = lambda p: hashlib.md5(p).hexdigest() hs.get_auth_handler().validate_hash = lambda p, h: hashlib.md5(p).hexdigest() == h
auth_handler.validate_hash = lambda p, h: hashlib.md5(p).hexdigest() == h
return handlers
return build_handlers
hs.build_handlers = swap_out_hash_for_testing(hs.build_handlers)
fed = kargs.get("resource_for_federation", None) fed = kargs.get("resource_for_federation", None)
if fed: if fed: