Merge branch 'dbkr/split_out_auth_handler' into dbkr/email_unsubscribe

This commit is contained in:
David Baker 2016-06-02 13:35:25 +01:00
commit 3a3fb2f6f9
14 changed files with 44 additions and 35 deletions

View file

@ -24,7 +24,6 @@ from .federation import FederationHandler
from .profile import ProfileHandler
from .directory import DirectoryHandler
from .admin import AdminHandler
from .auth import AuthHandler
from .identity import IdentityHandler
from .receipts import ReceiptsHandler
from .search import SearchHandler
@ -50,7 +49,6 @@ class Handlers(object):
self.directory_handler = DirectoryHandler(hs)
self.admin_handler = AdminHandler(hs)
self.receipts_handler = ReceiptsHandler(hs)
self.auth_handler = AuthHandler(hs)
self.identity_handler = IdentityHandler(hs)
self.search_handler = SearchHandler(hs)
self.room_context_handler = RoomContextHandler(hs)

View file

@ -413,7 +413,7 @@ class RegistrationHandler(BaseHandler):
defer.returnValue((user_id, token))
def auth_handler(self):
return self.hs.get_handlers().auth_handler
return self.hs.get_auth_handler()
@defer.inlineCallbacks
def guest_access_token_for(self, medium, address, inviter_user_id):

View file

@ -441,7 +441,7 @@ class RoomListHandler(BaseHandler):
self.hs.config.secondary_directory_servers
)
self.remote_list_request_cache.set((), deferred)
yield deferred
self.remote_list_cache = yield deferred
@defer.inlineCallbacks
def get_aggregated_public_room_list(self):

View file

@ -124,6 +124,8 @@ class Mailer(object):
user_display_name = yield self.store.get_profile_displayname(
UserID.from_string(user_id).localpart
)
if user_display_name is None:
user_display_name = user_id
except StoreError:
user_display_name = user_id

View file

@ -58,6 +58,7 @@ class LoginRestServlet(ClientV1RestServlet):
self.cas_required_attributes = hs.config.cas_required_attributes
self.servername = hs.config.server_name
self.http_client = hs.get_simple_http_client()
self.auth_handler = self.hs.get_auth_handler()
def on_GET(self, request):
flows = []
@ -143,7 +144,7 @@ class LoginRestServlet(ClientV1RestServlet):
user_id, self.hs.hostname
).to_string()
auth_handler = self.handlers.auth_handler
auth_handler = self.auth_handler
user_id, access_token, refresh_token = yield auth_handler.login_with_password(
user_id=user_id,
password=login_submission["password"])
@ -160,7 +161,7 @@ class LoginRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def do_token_login(self, login_submission):
token = login_submission['token']
auth_handler = self.handlers.auth_handler
auth_handler = self.auth_handler
user_id = (
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
)
@ -194,7 +195,7 @@ class LoginRestServlet(ClientV1RestServlet):
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.handlers.auth_handler
auth_handler = self.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id)
if user_exists:
user_id, access_token, refresh_token = (
@ -243,7 +244,7 @@ class LoginRestServlet(ClientV1RestServlet):
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.handlers.auth_handler
auth_handler = self.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id)
if user_exists:
user_id, access_token, refresh_token = (
@ -412,7 +413,7 @@ class CasTicketServlet(ClientV1RestServlet):
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.handlers.auth_handler
auth_handler = self.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id)
if not user_exists:
user_id, _ = (

View file

@ -35,7 +35,7 @@ class PasswordRestServlet(RestServlet):
super(PasswordRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_handlers().auth_handler
self.auth_handler = hs.get_auth_handler()
@defer.inlineCallbacks
def on_POST(self, request):
@ -97,7 +97,7 @@ class ThreepidRestServlet(RestServlet):
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth()
self.auth_handler = hs.get_handlers().auth_handler
self.auth_handler = hs.get_auth_handler()
@defer.inlineCallbacks
def on_GET(self, request):

View file

@ -104,7 +104,7 @@ class AuthRestServlet(RestServlet):
super(AuthRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_handlers().auth_handler
self.auth_handler = hs.get_auth_handler()
self.registration_handler = hs.get_handlers().registration_handler
@defer.inlineCallbacks

View file

@ -49,7 +49,7 @@ class RegisterRestServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.auth_handler = hs.get_handlers().auth_handler
self.auth_handler = hs.get_auth_handler()
self.registration_handler = hs.get_handlers().registration_handler
self.identity_handler = hs.get_handlers().identity_handler

View file

@ -38,7 +38,7 @@ class TokenRefreshRestServlet(RestServlet):
body = parse_json_object_from_request(request)
try:
old_refresh_token = body["refresh_token"]
auth_handler = self.hs.get_handlers().auth_handler
auth_handler = self.hs.get_auth_handler()
(user_id, new_refresh_token) = yield self.store.exchange_refresh_token(
old_refresh_token, auth_handler.generate_refresh_token)
new_access_token = yield auth_handler.issue_access_token(user_id)

View file

@ -33,6 +33,7 @@ from synapse.handlers.presence import PresenceHandler
from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import TypingHandler
from synapse.handlers.room import RoomListHandler
from synapse.handlers.auth import AuthHandler
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.state import StateHandler
from synapse.storage import DataStore
@ -89,6 +90,7 @@ class HomeServer(object):
'sync_handler',
'typing_handler',
'room_list_handler',
'auth_handler',
'application_service_api',
'application_service_scheduler',
'application_service_handler',
@ -190,6 +192,9 @@ class HomeServer(object):
def build_room_list_handler(self):
return RoomListHandler(self)
def build_auth_handler(self):
return AuthHandler(self)
def build_application_service_api(self):
return ApplicationServiceApi(self)

View file

@ -106,7 +106,8 @@ class PushRuleStore(SQLBaseStore):
desc="bulk_get_push_rules_enabled",
)
for row in rows:
results.setdefault(row['user_name'], {})[row['rule_id']] = row['enabled']
enabled = bool(row['enabled'])
results.setdefault(row['user_name'], {})[row['rule_id']] = enabled
defer.returnValue(results)
@defer.inlineCallbacks

View file

@ -19,17 +19,24 @@ from ._base import SQLBaseStore
from unpaddedbase64 import encode_base64
from synapse.crypto.event_signing import compute_event_reference_hash
from synapse.util.caches.descriptors import cached, cachedList
class SignatureStore(SQLBaseStore):
"""Persistence for event signatures and hashes"""
@cached(lru=True)
def get_event_reference_hash(self, event_id):
return self._get_event_reference_hashes_txn(event_id)
@cachedList(cached_method_name="get_event_reference_hash",
list_name="event_ids", num_args=1)
def get_event_reference_hashes(self, event_ids):
def f(txn):
return [
self._get_event_reference_hashes_txn(txn, ev)
for ev in event_ids
]
return {
event_id: self._get_event_reference_hashes_txn(txn, event_id)
for event_id in event_ids
}
return self.runInteraction(
"get_event_reference_hashes",
@ -41,15 +48,15 @@ class SignatureStore(SQLBaseStore):
hashes = yield self.get_event_reference_hashes(
event_ids
)
hashes = [
{
hashes = {
e_id: {
k: encode_base64(v) for k, v in h.items()
if k == "sha256"
}
for h in hashes
]
for e_id, h in hashes.items()
}
defer.returnValue(zip(event_ids, hashes))
defer.returnValue(hashes.items())
def _get_event_reference_hashes_txn(self, txn, event_id):
"""Get all the hashes for a given PDU.

View file

@ -33,7 +33,6 @@ class RegisterRestServletTestCase(unittest.TestCase):
# do the dance to hook it up to the hs global
self.handlers = Mock(
auth_handler=self.auth_handler,
registration_handler=self.registration_handler,
identity_handler=self.identity_handler,
login_handler=self.login_handler
@ -42,6 +41,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.hs.hostname = "superbig~testing~thing.com"
self.hs.get_auth = Mock(return_value=self.auth)
self.hs.get_handlers = Mock(return_value=self.handlers)
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
self.hs.config.enable_registration = True
# init the thing we're testing

View file

@ -81,16 +81,11 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
)
# bcrypt is far too slow to be doing in unit tests
def swap_out_hash_for_testing(old_build_handlers):
def build_handlers():
handlers = old_build_handlers()
auth_handler = handlers.auth_handler
auth_handler.hash = lambda p: hashlib.md5(p).hexdigest()
auth_handler.validate_hash = lambda p, h: hashlib.md5(p).hexdigest() == h
return handlers
return build_handlers
hs.build_handlers = swap_out_hash_for_testing(hs.build_handlers)
# Need to let the HS build an auth handler and then mess with it
# because AuthHandler's constructor requires the HS, so we can't make one
# beforehand and pass it in to the HS's constructor (chicken / egg)
hs.get_auth_handler().hash = lambda p: hashlib.md5(p).hexdigest()
hs.get_auth_handler().validate_hash = lambda p, h: hashlib.md5(p).hexdigest() == h
fed = kargs.get("resource_for_federation", None)
if fed: