mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-15 14:23:50 +01:00
Merge branch 'develop' of github.com:matrix-org/synapse into release-v0.19.0
This commit is contained in:
commit
f8c407a13b
16 changed files with 135 additions and 107 deletions
|
@ -65,6 +65,7 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
||||||
self.device_handler = hs.get_device_handler()
|
self.device_handler = hs.get_device_handler()
|
||||||
|
self.macaroon_gen = hs.get_macaroon_generator()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_auth(self, flows, clientdict, clientip):
|
def check_auth(self, flows, clientdict, clientip):
|
||||||
|
@ -529,37 +530,11 @@ class AuthHandler(BaseHandler):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def issue_access_token(self, user_id, device_id=None):
|
def issue_access_token(self, user_id, device_id=None):
|
||||||
access_token = self.generate_access_token(user_id)
|
access_token = self.macaroon_gen.generate_access_token(user_id)
|
||||||
yield self.store.add_access_token_to_user(user_id, access_token,
|
yield self.store.add_access_token_to_user(user_id, access_token,
|
||||||
device_id)
|
device_id)
|
||||||
defer.returnValue(access_token)
|
defer.returnValue(access_token)
|
||||||
|
|
||||||
def generate_access_token(self, user_id, extra_caveats=None):
|
|
||||||
extra_caveats = extra_caveats or []
|
|
||||||
macaroon = self._generate_base_macaroon(user_id)
|
|
||||||
macaroon.add_first_party_caveat("type = access")
|
|
||||||
# Include a nonce, to make sure that each login gets a different
|
|
||||||
# access token.
|
|
||||||
macaroon.add_first_party_caveat("nonce = %s" % (
|
|
||||||
stringutils.random_string_with_symbols(16),
|
|
||||||
))
|
|
||||||
for caveat in extra_caveats:
|
|
||||||
macaroon.add_first_party_caveat(caveat)
|
|
||||||
return macaroon.serialize()
|
|
||||||
|
|
||||||
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
|
|
||||||
macaroon = self._generate_base_macaroon(user_id)
|
|
||||||
macaroon.add_first_party_caveat("type = login")
|
|
||||||
now = self.hs.get_clock().time_msec()
|
|
||||||
expiry = now + duration_in_ms
|
|
||||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
|
||||||
return macaroon.serialize()
|
|
||||||
|
|
||||||
def generate_delete_pusher_token(self, user_id):
|
|
||||||
macaroon = self._generate_base_macaroon(user_id)
|
|
||||||
macaroon.add_first_party_caveat("type = delete_pusher")
|
|
||||||
return macaroon.serialize()
|
|
||||||
|
|
||||||
def validate_short_term_login_token_and_get_user_id(self, login_token):
|
def validate_short_term_login_token_and_get_user_id(self, login_token):
|
||||||
auth_api = self.hs.get_auth()
|
auth_api = self.hs.get_auth()
|
||||||
try:
|
try:
|
||||||
|
@ -570,15 +545,6 @@ class AuthHandler(BaseHandler):
|
||||||
except Exception:
|
except Exception:
|
||||||
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
|
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
|
||||||
|
|
||||||
def _generate_base_macaroon(self, user_id):
|
|
||||||
macaroon = pymacaroons.Macaroon(
|
|
||||||
location=self.hs.config.server_name,
|
|
||||||
identifier="key",
|
|
||||||
key=self.hs.config.macaroon_secret_key)
|
|
||||||
macaroon.add_first_party_caveat("gen = 1")
|
|
||||||
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
|
||||||
return macaroon
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def set_password(self, user_id, newpassword, requester=None):
|
def set_password(self, user_id, newpassword, requester=None):
|
||||||
password_hash = self.hash(newpassword)
|
password_hash = self.hash(newpassword)
|
||||||
|
@ -673,6 +639,48 @@ class AuthHandler(BaseHandler):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class MacaroonGeneartor(object):
|
||||||
|
def __init__(self, hs):
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.server_name = hs.config.server_name
|
||||||
|
self.macaroon_secret_key = hs.config.macaroon_secret_key
|
||||||
|
|
||||||
|
def generate_access_token(self, user_id, extra_caveats=None):
|
||||||
|
extra_caveats = extra_caveats or []
|
||||||
|
macaroon = self._generate_base_macaroon(user_id)
|
||||||
|
macaroon.add_first_party_caveat("type = access")
|
||||||
|
# Include a nonce, to make sure that each login gets a different
|
||||||
|
# access token.
|
||||||
|
macaroon.add_first_party_caveat("nonce = %s" % (
|
||||||
|
stringutils.random_string_with_symbols(16),
|
||||||
|
))
|
||||||
|
for caveat in extra_caveats:
|
||||||
|
macaroon.add_first_party_caveat(caveat)
|
||||||
|
return macaroon.serialize()
|
||||||
|
|
||||||
|
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
|
||||||
|
macaroon = self._generate_base_macaroon(user_id)
|
||||||
|
macaroon.add_first_party_caveat("type = login")
|
||||||
|
now = self.clock.time_msec()
|
||||||
|
expiry = now + duration_in_ms
|
||||||
|
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||||
|
return macaroon.serialize()
|
||||||
|
|
||||||
|
def generate_delete_pusher_token(self, user_id):
|
||||||
|
macaroon = self._generate_base_macaroon(user_id)
|
||||||
|
macaroon.add_first_party_caveat("type = delete_pusher")
|
||||||
|
return macaroon.serialize()
|
||||||
|
|
||||||
|
def _generate_base_macaroon(self, user_id):
|
||||||
|
macaroon = pymacaroons.Macaroon(
|
||||||
|
location=self.server_name,
|
||||||
|
identifier="key",
|
||||||
|
key=self.macaroon_secret_key)
|
||||||
|
macaroon.add_first_party_caveat("gen = 1")
|
||||||
|
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
||||||
|
return macaroon
|
||||||
|
|
||||||
|
|
||||||
class _AccountHandler(object):
|
class _AccountHandler(object):
|
||||||
"""A proxy object that gets passed to password auth providers so they
|
"""A proxy object that gets passed to password auth providers so they
|
||||||
can register new users etc if necessary.
|
can register new users etc if necessary.
|
||||||
|
|
|
@ -17,7 +17,7 @@ from synapse.api import errors
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.util import stringutils
|
from synapse.util import stringutils
|
||||||
from synapse.util.async import Linearizer
|
from synapse.util.async import Linearizer
|
||||||
from synapse.types import get_domain_from_id
|
from synapse.types import get_domain_from_id, RoomStreamToken
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
|
@ -198,20 +198,22 @@ class DeviceHandler(BaseHandler):
|
||||||
"""Notify that a user's device(s) has changed. Pokes the notifier, and
|
"""Notify that a user's device(s) has changed. Pokes the notifier, and
|
||||||
remote servers if the user is local.
|
remote servers if the user is local.
|
||||||
"""
|
"""
|
||||||
rooms = yield self.store.get_rooms_for_user(user_id)
|
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
|
||||||
room_ids = [r.room_id for r in rooms]
|
user_id
|
||||||
|
)
|
||||||
|
|
||||||
hosts = set()
|
hosts = set()
|
||||||
if self.hs.is_mine_id(user_id):
|
if self.hs.is_mine_id(user_id):
|
||||||
for room_id in room_ids:
|
hosts.update(get_domain_from_id(u) for u in users_who_share_room)
|
||||||
users = yield self.store.get_users_in_room(room_id)
|
|
||||||
hosts.update(get_domain_from_id(u) for u in users)
|
|
||||||
hosts.discard(self.server_name)
|
hosts.discard(self.server_name)
|
||||||
|
|
||||||
position = yield self.store.add_device_change_to_streams(
|
position = yield self.store.add_device_change_to_streams(
|
||||||
user_id, device_ids, list(hosts)
|
user_id, device_ids, list(hosts)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
rooms = yield self.store.get_rooms_for_user(user_id)
|
||||||
|
room_ids = [r.room_id for r in rooms]
|
||||||
|
|
||||||
yield self.notifier.on_new_event(
|
yield self.notifier.on_new_event(
|
||||||
"device_list_key", position, rooms=room_ids,
|
"device_list_key", position, rooms=room_ids,
|
||||||
)
|
)
|
||||||
|
@ -243,15 +245,15 @@ class DeviceHandler(BaseHandler):
|
||||||
|
|
||||||
possibly_changed = set(changed)
|
possibly_changed = set(changed)
|
||||||
for room_id in rooms_changed:
|
for room_id in rooms_changed:
|
||||||
# Fetch (an approximation) of the current state at the time.
|
# Fetch the current state at the time.
|
||||||
event_rows, token = yield self.store.get_recent_event_ids_for_room(
|
stream_ordering = RoomStreamToken.parse_stream_token(from_token.room_key)
|
||||||
room_id, end_token=from_token.room_key, limit=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
if event_rows:
|
try:
|
||||||
last_event_id = event_rows[-1]["event_id"]
|
event_ids = yield self.store.get_forward_extremeties_for_room(
|
||||||
prev_state_ids = yield self.store.get_state_ids_for_event(last_event_id)
|
room_id, stream_ordering=stream_ordering
|
||||||
else:
|
)
|
||||||
|
prev_state_ids = yield self.store.get_state_ids_for_events(event_ids)
|
||||||
|
except:
|
||||||
prev_state_ids = {}
|
prev_state_ids = {}
|
||||||
|
|
||||||
current_state_ids = yield self.state.get_current_state_ids(room_id)
|
current_state_ids = yield self.state.get_current_state_ids(room_id)
|
||||||
|
@ -266,13 +268,13 @@ class DeviceHandler(BaseHandler):
|
||||||
if not prev_event_id or prev_event_id != event_id:
|
if not prev_event_id or prev_event_id != event_id:
|
||||||
possibly_changed.add(state_key)
|
possibly_changed.add(state_key)
|
||||||
|
|
||||||
user_ids_changed = set()
|
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
|
||||||
for other_user_id in possibly_changed:
|
user_id
|
||||||
other_rooms = yield self.store.get_rooms_for_user(other_user_id)
|
)
|
||||||
if room_ids.intersection(e.room_id for e in other_rooms):
|
|
||||||
user_ids_changed.add(other_user_id)
|
|
||||||
|
|
||||||
defer.returnValue(user_ids_changed)
|
# Take the intersection of the users whose devices may have changed
|
||||||
|
# and those that actually still share a room with the user
|
||||||
|
defer.returnValue(users_who_share_room & possibly_changed)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _incoming_device_list_update(self, origin, edu_content):
|
def _incoming_device_list_update(self, origin, edu_content):
|
||||||
|
|
|
@ -1011,7 +1011,7 @@ class PresenceEventSource(object):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def get_new_events(self, user, from_key, room_ids=None, include_offline=True,
|
def get_new_events(self, user, from_key, room_ids=None, include_offline=True,
|
||||||
**kwargs):
|
explicit_room_id=None, **kwargs):
|
||||||
# The process for getting presence events are:
|
# The process for getting presence events are:
|
||||||
# 1. Get the rooms the user is in.
|
# 1. Get the rooms the user is in.
|
||||||
# 2. Get the list of user in the rooms.
|
# 2. Get the list of user in the rooms.
|
||||||
|
@ -1028,22 +1028,24 @@ class PresenceEventSource(object):
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
if from_key is not None:
|
if from_key is not None:
|
||||||
from_key = int(from_key)
|
from_key = int(from_key)
|
||||||
room_ids = room_ids or []
|
|
||||||
|
|
||||||
presence = self.get_presence_handler()
|
presence = self.get_presence_handler()
|
||||||
stream_change_cache = self.store.presence_stream_cache
|
stream_change_cache = self.store.presence_stream_cache
|
||||||
|
|
||||||
if not room_ids:
|
|
||||||
rooms = yield self.store.get_rooms_for_user(user_id)
|
|
||||||
room_ids = set(e.room_id for e in rooms)
|
|
||||||
else:
|
|
||||||
room_ids = set(room_ids)
|
|
||||||
|
|
||||||
max_token = self.store.get_current_presence_token()
|
max_token = self.store.get_current_presence_token()
|
||||||
|
|
||||||
plist = yield self.store.get_presence_list_accepted(user.localpart)
|
plist = yield self.store.get_presence_list_accepted(user.localpart)
|
||||||
friends = set(row["observed_user_id"] for row in plist)
|
users_interested_in = set(row["observed_user_id"] for row in plist)
|
||||||
friends.add(user_id) # So that we receive our own presence
|
users_interested_in.add(user_id) # So that we receive our own presence
|
||||||
|
|
||||||
|
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
|
||||||
|
user_id
|
||||||
|
)
|
||||||
|
users_interested_in.update(users_who_share_room)
|
||||||
|
|
||||||
|
if explicit_room_id:
|
||||||
|
user_ids = yield self.store.get_users_in_room(explicit_room_id)
|
||||||
|
users_interested_in.update(user_ids)
|
||||||
|
|
||||||
user_ids_changed = set()
|
user_ids_changed = set()
|
||||||
changed = None
|
changed = None
|
||||||
|
@ -1055,35 +1057,19 @@ class PresenceEventSource(object):
|
||||||
# work out if we share a room or they're in our presence list
|
# work out if we share a room or they're in our presence list
|
||||||
get_updates_counter.inc("stream")
|
get_updates_counter.inc("stream")
|
||||||
for other_user_id in changed:
|
for other_user_id in changed:
|
||||||
if other_user_id in friends:
|
if other_user_id in users_interested_in:
|
||||||
user_ids_changed.add(other_user_id)
|
user_ids_changed.add(other_user_id)
|
||||||
continue
|
|
||||||
other_rooms = yield self.store.get_rooms_for_user(other_user_id)
|
|
||||||
if room_ids.intersection(e.room_id for e in other_rooms):
|
|
||||||
user_ids_changed.add(other_user_id)
|
|
||||||
continue
|
|
||||||
else:
|
else:
|
||||||
# Too many possible updates. Find all users we can see and check
|
# Too many possible updates. Find all users we can see and check
|
||||||
# if any of them have changed.
|
# if any of them have changed.
|
||||||
get_updates_counter.inc("full")
|
get_updates_counter.inc("full")
|
||||||
|
|
||||||
user_ids_to_check = set()
|
|
||||||
for room_id in room_ids:
|
|
||||||
users = yield self.store.get_users_in_room(room_id)
|
|
||||||
user_ids_to_check.update(users)
|
|
||||||
|
|
||||||
user_ids_to_check.update(friends)
|
|
||||||
|
|
||||||
# Always include yourself. Only really matters for when the user is
|
|
||||||
# not in any rooms, but still.
|
|
||||||
user_ids_to_check.add(user_id)
|
|
||||||
|
|
||||||
if from_key:
|
if from_key:
|
||||||
user_ids_changed = stream_change_cache.get_entities_changed(
|
user_ids_changed = stream_change_cache.get_entities_changed(
|
||||||
user_ids_to_check, from_key,
|
users_interested_in, from_key,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
user_ids_changed = user_ids_to_check
|
user_ids_changed = users_interested_in
|
||||||
|
|
||||||
updates = yield presence.current_state_for_users(user_ids_changed)
|
updates = yield presence.current_state_for_users(user_ids_changed)
|
||||||
|
|
||||||
|
|
|
@ -40,6 +40,8 @@ class RegistrationHandler(BaseHandler):
|
||||||
|
|
||||||
self._next_generated_user_id = None
|
self._next_generated_user_id = None
|
||||||
|
|
||||||
|
self.macaroon_gen = hs.get_macaroon_generator()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_username(self, localpart, guest_access_token=None,
|
def check_username(self, localpart, guest_access_token=None,
|
||||||
assigned_user_id=None):
|
assigned_user_id=None):
|
||||||
|
@ -143,7 +145,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
|
|
||||||
token = None
|
token = None
|
||||||
if generate_token:
|
if generate_token:
|
||||||
token = self.auth_handler().generate_access_token(user_id)
|
token = self.macaroon_gen.generate_access_token(user_id)
|
||||||
yield self.store.register(
|
yield self.store.register(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
token=token,
|
token=token,
|
||||||
|
@ -167,7 +169,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
yield self.check_user_id_not_appservice_exclusive(user_id)
|
yield self.check_user_id_not_appservice_exclusive(user_id)
|
||||||
if generate_token:
|
if generate_token:
|
||||||
token = self.auth_handler().generate_access_token(user_id)
|
token = self.macaroon_gen.generate_access_token(user_id)
|
||||||
try:
|
try:
|
||||||
yield self.store.register(
|
yield self.store.register(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
@ -254,7 +256,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
|
|
||||||
yield self.check_user_id_not_appservice_exclusive(user_id)
|
yield self.check_user_id_not_appservice_exclusive(user_id)
|
||||||
token = self.auth_handler().generate_access_token(user_id)
|
token = self.macaroon_gen.generate_access_token(user_id)
|
||||||
try:
|
try:
|
||||||
yield self.store.register(
|
yield self.store.register(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
@ -399,7 +401,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
|
|
||||||
user = UserID(localpart, self.hs.hostname)
|
user = UserID(localpart, self.hs.hostname)
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
token = self.auth_handler().generate_access_token(user_id)
|
token = self.macaroon_gen.generate_access_token(user_id)
|
||||||
|
|
||||||
if need_register:
|
if need_register:
|
||||||
yield self.store.register(
|
yield self.store.register(
|
||||||
|
|
|
@ -437,6 +437,7 @@ class RoomEventSource(object):
|
||||||
limit,
|
limit,
|
||||||
room_ids,
|
room_ids,
|
||||||
is_guest,
|
is_guest,
|
||||||
|
explicit_room_id=None,
|
||||||
):
|
):
|
||||||
# We just ignore the key for now.
|
# We just ignore the key for now.
|
||||||
|
|
||||||
|
|
|
@ -378,6 +378,7 @@ class Notifier(object):
|
||||||
limit=limit,
|
limit=limit,
|
||||||
is_guest=is_peeking,
|
is_guest=is_peeking,
|
||||||
room_ids=room_ids,
|
room_ids=room_ids,
|
||||||
|
explicit_room_id=explicit_room_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if name == "room":
|
if name == "room":
|
||||||
|
|
|
@ -81,7 +81,7 @@ class Mailer(object):
|
||||||
def __init__(self, hs, app_name):
|
def __init__(self, hs, app_name):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.store = self.hs.get_datastore()
|
self.store = self.hs.get_datastore()
|
||||||
self.auth_handler = self.hs.get_auth_handler()
|
self.macaroon_gen = self.hs.get_macaroon_generator()
|
||||||
self.state_handler = self.hs.get_state_handler()
|
self.state_handler = self.hs.get_state_handler()
|
||||||
loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir)
|
loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir)
|
||||||
self.app_name = app_name
|
self.app_name = app_name
|
||||||
|
@ -466,7 +466,7 @@ class Mailer(object):
|
||||||
|
|
||||||
def make_unsubscribe_link(self, user_id, app_id, email_address):
|
def make_unsubscribe_link(self, user_id, app_id, email_address):
|
||||||
params = {
|
params = {
|
||||||
"access_token": self.auth_handler.generate_delete_pusher_token(user_id),
|
"access_token": self.macaroon_gen.generate_delete_pusher_token(user_id),
|
||||||
"app_id": app_id,
|
"app_id": app_id,
|
||||||
"pushkey": email_address,
|
"pushkey": email_address,
|
||||||
}
|
}
|
||||||
|
|
|
@ -73,6 +73,9 @@ class SlavedEventStore(BaseSlavedStore):
|
||||||
# to reach inside the __dict__ to extract them.
|
# to reach inside the __dict__ to extract them.
|
||||||
get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"]
|
get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"]
|
||||||
get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"]
|
get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"]
|
||||||
|
get_users_who_share_room_with_user = (
|
||||||
|
RoomMemberStore.__dict__["get_users_who_share_room_with_user"]
|
||||||
|
)
|
||||||
get_latest_event_ids_in_room = EventFederationStore.__dict__[
|
get_latest_event_ids_in_room = EventFederationStore.__dict__[
|
||||||
"get_latest_event_ids_in_room"
|
"get_latest_event_ids_in_room"
|
||||||
]
|
]
|
||||||
|
|
|
@ -330,6 +330,7 @@ class CasTicketServlet(ClientV1RestServlet):
|
||||||
self.cas_required_attributes = hs.config.cas_required_attributes
|
self.cas_required_attributes = hs.config.cas_required_attributes
|
||||||
self.auth_handler = hs.get_auth_handler()
|
self.auth_handler = hs.get_auth_handler()
|
||||||
self.handlers = hs.get_handlers()
|
self.handlers = hs.get_handlers()
|
||||||
|
self.macaroon_gen = hs.get_macaroon_generator()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
|
@ -368,7 +369,9 @@ class CasTicketServlet(ClientV1RestServlet):
|
||||||
yield self.handlers.registration_handler.register(localpart=user)
|
yield self.handlers.registration_handler.register(localpart=user)
|
||||||
)
|
)
|
||||||
|
|
||||||
login_token = auth_handler.generate_short_term_login_token(registered_user_id)
|
login_token = self.macaroon_gen.generate_short_term_login_token(
|
||||||
|
registered_user_id
|
||||||
|
)
|
||||||
redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
|
redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
|
||||||
login_token)
|
login_token)
|
||||||
request.redirect(redirect_url)
|
request.redirect(redirect_url)
|
||||||
|
|
|
@ -193,7 +193,7 @@ class KeyChangesServlet(RestServlet):
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((200, {
|
defer.returnValue((200, {
|
||||||
"changed": changed
|
"changed": list(changed),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -96,6 +96,7 @@ class RegisterRestServlet(RestServlet):
|
||||||
self.registration_handler = hs.get_handlers().registration_handler
|
self.registration_handler = hs.get_handlers().registration_handler
|
||||||
self.identity_handler = hs.get_handlers().identity_handler
|
self.identity_handler = hs.get_handlers().identity_handler
|
||||||
self.device_handler = hs.get_device_handler()
|
self.device_handler = hs.get_device_handler()
|
||||||
|
self.macaroon_gen = hs.get_macaroon_generator()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
|
@ -436,7 +437,7 @@ class RegisterRestServlet(RestServlet):
|
||||||
user_id, device_id, initial_display_name
|
user_id, device_id, initial_display_name
|
||||||
)
|
)
|
||||||
|
|
||||||
access_token = self.auth_handler.generate_access_token(
|
access_token = self.macaroon_gen.generate_access_token(
|
||||||
user_id, ["guest = true"]
|
user_id, ["guest = true"]
|
||||||
)
|
)
|
||||||
defer.returnValue((200, {
|
defer.returnValue((200, {
|
||||||
|
|
|
@ -37,7 +37,7 @@ from synapse.federation.transport.client import TransportLayerClient
|
||||||
from synapse.federation.transaction_queue import TransactionQueue
|
from synapse.federation.transaction_queue import TransactionQueue
|
||||||
from synapse.handlers import Handlers
|
from synapse.handlers import Handlers
|
||||||
from synapse.handlers.appservice import ApplicationServicesHandler
|
from synapse.handlers.appservice import ApplicationServicesHandler
|
||||||
from synapse.handlers.auth import AuthHandler
|
from synapse.handlers.auth import AuthHandler, MacaroonGeneartor
|
||||||
from synapse.handlers.devicemessage import DeviceMessageHandler
|
from synapse.handlers.devicemessage import DeviceMessageHandler
|
||||||
from synapse.handlers.device import DeviceHandler
|
from synapse.handlers.device import DeviceHandler
|
||||||
from synapse.handlers.e2e_keys import E2eKeysHandler
|
from synapse.handlers.e2e_keys import E2eKeysHandler
|
||||||
|
@ -131,6 +131,7 @@ class HomeServer(object):
|
||||||
'federation_transport_client',
|
'federation_transport_client',
|
||||||
'federation_sender',
|
'federation_sender',
|
||||||
'receipts_handler',
|
'receipts_handler',
|
||||||
|
'macaroon_generator',
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, hostname, **kwargs):
|
def __init__(self, hostname, **kwargs):
|
||||||
|
@ -213,6 +214,9 @@ class HomeServer(object):
|
||||||
def build_auth_handler(self):
|
def build_auth_handler(self):
|
||||||
return AuthHandler(self)
|
return AuthHandler(self)
|
||||||
|
|
||||||
|
def build_macaroon_generator(self):
|
||||||
|
return MacaroonGeneartor(self)
|
||||||
|
|
||||||
def build_device_handler(self):
|
def build_device_handler(self):
|
||||||
return DeviceHandler(self)
|
return DeviceHandler(self)
|
||||||
|
|
||||||
|
|
|
@ -280,6 +280,23 @@ class RoomMemberStore(SQLBaseStore):
|
||||||
user_id, membership_list=[Membership.JOIN],
|
user_id, membership_list=[Membership.JOIN],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@cachedInlineCallbacks(max_entries=50000, cache_context=True, iterable=True)
|
||||||
|
def get_users_who_share_room_with_user(self, user_id, cache_context):
|
||||||
|
"""Returns the set of users who share a room with `user_id`
|
||||||
|
"""
|
||||||
|
rooms = yield self.get_rooms_for_user(
|
||||||
|
user_id, on_invalidate=cache_context.invalidate,
|
||||||
|
)
|
||||||
|
|
||||||
|
user_who_share_room = set()
|
||||||
|
for room in rooms:
|
||||||
|
user_ids = yield self.get_users_in_room(
|
||||||
|
room.room_id, on_invalidate=cache_context.invalidate,
|
||||||
|
)
|
||||||
|
user_who_share_room.update(user_ids)
|
||||||
|
|
||||||
|
defer.returnValue(user_who_share_room)
|
||||||
|
|
||||||
def forget(self, user_id, room_id):
|
def forget(self, user_id, room_id):
|
||||||
"""Indicate that user_id wishes to discard history for room_id."""
|
"""Indicate that user_id wishes to discard history for room_id."""
|
||||||
def f(txn):
|
def f(txn):
|
||||||
|
|
|
@ -478,6 +478,11 @@ class CacheListDescriptor(object):
|
||||||
|
|
||||||
|
|
||||||
class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
|
class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
|
||||||
|
# We rely on _CacheContext implementing __eq__ and __hash__ sensibly,
|
||||||
|
# which namedtuple does for us (i.e. two _CacheContext are the same if
|
||||||
|
# their caches and keys match). This is important in particular to
|
||||||
|
# dedupe when we add callbacks to lru cache nodes, otherwise the number
|
||||||
|
# of callbacks would grow.
|
||||||
def invalidate(self):
|
def invalidate(self):
|
||||||
self.cache.invalidate(self.key)
|
self.cache.invalidate(self.key)
|
||||||
|
|
||||||
|
|
|
@ -34,11 +34,10 @@ class AuthTestCase(unittest.TestCase):
|
||||||
self.hs = yield setup_test_homeserver(handlers=None)
|
self.hs = yield setup_test_homeserver(handlers=None)
|
||||||
self.hs.handlers = AuthHandlers(self.hs)
|
self.hs.handlers = AuthHandlers(self.hs)
|
||||||
self.auth_handler = self.hs.handlers.auth_handler
|
self.auth_handler = self.hs.handlers.auth_handler
|
||||||
|
self.macaroon_generator = self.hs.get_macaroon_generator()
|
||||||
|
|
||||||
def test_token_is_a_macaroon(self):
|
def test_token_is_a_macaroon(self):
|
||||||
self.hs.config.macaroon_secret_key = "this key is a huge secret"
|
token = self.macaroon_generator.generate_access_token("some_user")
|
||||||
|
|
||||||
token = self.auth_handler.generate_access_token("some_user")
|
|
||||||
# Check that we can parse the thing with pymacaroons
|
# Check that we can parse the thing with pymacaroons
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(token)
|
macaroon = pymacaroons.Macaroon.deserialize(token)
|
||||||
# The most basic of sanity checks
|
# The most basic of sanity checks
|
||||||
|
@ -46,10 +45,9 @@ class AuthTestCase(unittest.TestCase):
|
||||||
self.fail("some_user was not in %s" % macaroon.inspect())
|
self.fail("some_user was not in %s" % macaroon.inspect())
|
||||||
|
|
||||||
def test_macaroon_caveats(self):
|
def test_macaroon_caveats(self):
|
||||||
self.hs.config.macaroon_secret_key = "this key is a massive secret"
|
|
||||||
self.hs.clock.now = 5000
|
self.hs.clock.now = 5000
|
||||||
|
|
||||||
token = self.auth_handler.generate_access_token("a_user")
|
token = self.macaroon_generator.generate_access_token("a_user")
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(token)
|
macaroon = pymacaroons.Macaroon.deserialize(token)
|
||||||
|
|
||||||
def verify_gen(caveat):
|
def verify_gen(caveat):
|
||||||
|
@ -74,7 +72,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
def test_short_term_login_token_gives_user_id(self):
|
def test_short_term_login_token_gives_user_id(self):
|
||||||
self.hs.clock.now = 1000
|
self.hs.clock.now = 1000
|
||||||
|
|
||||||
token = self.auth_handler.generate_short_term_login_token(
|
token = self.macaroon_generator.generate_short_term_login_token(
|
||||||
"a_user", 5000
|
"a_user", 5000
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -93,7 +91,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_short_term_login_token_cannot_replace_user_id(self):
|
def test_short_term_login_token_cannot_replace_user_id(self):
|
||||||
token = self.auth_handler.generate_short_term_login_token(
|
token = self.macaroon_generator.generate_short_term_login_token(
|
||||||
"a_user", 5000
|
"a_user", 5000
|
||||||
)
|
)
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(token)
|
macaroon = pymacaroons.Macaroon.deserialize(token)
|
||||||
|
|
|
@ -41,15 +41,12 @@ class RegistrationTestCase(unittest.TestCase):
|
||||||
handlers=None,
|
handlers=None,
|
||||||
http_client=None,
|
http_client=None,
|
||||||
expire_access_token=True)
|
expire_access_token=True)
|
||||||
self.auth_handler = Mock(
|
self.macaroon_generator = Mock(
|
||||||
generate_access_token=Mock(return_value='secret'))
|
generate_access_token=Mock(return_value='secret'))
|
||||||
|
self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator)
|
||||||
self.hs.handlers = RegistrationHandlers(self.hs)
|
self.hs.handlers = RegistrationHandlers(self.hs)
|
||||||
self.handler = self.hs.get_handlers().registration_handler
|
self.handler = self.hs.get_handlers().registration_handler
|
||||||
self.hs.get_handlers().profile_handler = Mock()
|
self.hs.get_handlers().profile_handler = Mock()
|
||||||
self.mock_handler = Mock(spec=[
|
|
||||||
"generate_access_token",
|
|
||||||
])
|
|
||||||
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_user_is_created_and_logged_in_if_doesnt_exist(self):
|
def test_user_is_created_and_logged_in_if_doesnt_exist(self):
|
||||||
|
|
Loading…
Reference in a new issue