0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 13:13:50 +01:00

Merge branch 'develop' of github.com:matrix-org/synapse into develop

This commit is contained in:
Erik Johnston 2015-11-05 15:35:17 +00:00
commit 347146be29
55 changed files with 571 additions and 489 deletions

View file

@ -49,6 +49,7 @@ class Auth(object):
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
self._KNOWN_CAVEAT_PREFIXES = set([ self._KNOWN_CAVEAT_PREFIXES = set([
"gen = ", "gen = ",
"guest = ",
"type = ", "type = ",
"time < ", "time < ",
"user_id = ", "user_id = ",
@ -183,15 +184,11 @@ class Auth(object):
defer.returnValue(member) defer.returnValue(member)
@defer.inlineCallbacks @defer.inlineCallbacks
def check_user_was_in_room(self, room_id, user_id, current_state=None): def check_user_was_in_room(self, room_id, user_id):
"""Check if the user was in the room at some point. """Check if the user was in the room at some point.
Args: Args:
room_id(str): The room to check. room_id(str): The room to check.
user_id(str): The user to check. user_id(str): The user to check.
current_state(dict): Optional map of the current state of the room.
If provided then that map is used to check whether they are a
member of the room. Otherwise the current membership is
loaded from the database.
Raises: Raises:
AuthError if the user was never in the room. AuthError if the user was never in the room.
Returns: Returns:
@ -199,12 +196,6 @@ class Auth(object):
room. This will be the join event if they are currently joined to room. This will be the join event if they are currently joined to
the room. This will be the leave event if they have left the room. the room. This will be the leave event if they have left the room.
""" """
if current_state:
member = current_state.get(
(EventTypes.Member, user_id),
None
)
else:
member = yield self.state.get_current_state( member = yield self.state.get_current_state(
room_id=room_id, room_id=room_id,
event_type=EventTypes.Member, event_type=EventTypes.Member,
@ -497,7 +488,7 @@ class Auth(object):
return default return default
@defer.inlineCallbacks @defer.inlineCallbacks
def get_user_by_req(self, request): def get_user_by_req(self, request, allow_guest=False):
""" Get a registered user's ID. """ Get a registered user's ID.
Args: Args:
@ -535,7 +526,7 @@ class Auth(object):
request.authenticated_entity = user_id request.authenticated_entity = user_id
defer.returnValue((UserID.from_string(user_id), "")) defer.returnValue((UserID.from_string(user_id), "", False))
return return
except KeyError: except KeyError:
pass # normal users won't have the user_id query parameter set. pass # normal users won't have the user_id query parameter set.
@ -543,6 +534,7 @@ class Auth(object):
user_info = yield self._get_user_by_access_token(access_token) user_info = yield self._get_user_by_access_token(access_token)
user = user_info["user"] user = user_info["user"]
token_id = user_info["token_id"] token_id = user_info["token_id"]
is_guest = user_info["is_guest"]
ip_addr = self.hs.get_ip_from_request(request) ip_addr = self.hs.get_ip_from_request(request)
user_agent = request.requestHeaders.getRawHeaders( user_agent = request.requestHeaders.getRawHeaders(
@ -557,9 +549,14 @@ class Auth(object):
user_agent=user_agent user_agent=user_agent
) )
if is_guest and not allow_guest:
raise AuthError(
403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
)
request.authenticated_entity = user.to_string() request.authenticated_entity = user.to_string()
defer.returnValue((user, token_id,)) defer.returnValue((user, token_id, is_guest,))
except KeyError: except KeyError:
raise AuthError( raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
@ -592,9 +589,27 @@ class Auth(object):
self._validate_macaroon(macaroon) self._validate_macaroon(macaroon)
user_prefix = "user_id = " user_prefix = "user_id = "
user = None
guest = False
for caveat in macaroon.caveats: for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix): if caveat.caveat_id.startswith(user_prefix):
user = UserID.from_string(caveat.caveat_id[len(user_prefix):]) user = UserID.from_string(caveat.caveat_id[len(user_prefix):])
elif caveat.caveat_id == "guest = true":
guest = True
if user is None:
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
errcode=Codes.UNKNOWN_TOKEN
)
if guest:
ret = {
"user": user,
"is_guest": True,
"token_id": None,
}
else:
# This codepath exists so that we can actually return a # This codepath exists so that we can actually return a
# token ID, because we use token IDs in place of device # token ID, because we use token IDs in place of device
# identifiers throughout the codebase. # identifiers throughout the codebase.
@ -613,10 +628,6 @@ class Auth(object):
errcode=Codes.UNKNOWN_TOKEN errcode=Codes.UNKNOWN_TOKEN
) )
defer.returnValue(ret) defer.returnValue(ret)
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
errcode=Codes.UNKNOWN_TOKEN
)
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
raise AuthError( raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.", self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.",
@ -629,6 +640,7 @@ class Auth(object):
v.satisfy_exact("type = access") v.satisfy_exact("type = access")
v.satisfy_general(lambda c: c.startswith("user_id = ")) v.satisfy_general(lambda c: c.startswith("user_id = "))
v.satisfy_general(self._verify_expiry) v.satisfy_general(self._verify_expiry)
v.satisfy_exact("guest = true")
v.verify(macaroon, self.hs.config.macaroon_secret_key) v.verify(macaroon, self.hs.config.macaroon_secret_key)
v = pymacaroons.Verifier() v = pymacaroons.Verifier()
@ -666,6 +678,7 @@ class Auth(object):
user_info = { user_info = {
"user": UserID.from_string(ret.get("name")), "user": UserID.from_string(ret.get("name")),
"token_id": ret.get("token_id", None), "token_id": ret.get("token_id", None),
"is_guest": False,
} }
defer.returnValue(user_info) defer.returnValue(user_info)

View file

@ -33,6 +33,7 @@ class Codes(object):
NOT_FOUND = "M_NOT_FOUND" NOT_FOUND = "M_NOT_FOUND"
MISSING_TOKEN = "M_MISSING_TOKEN" MISSING_TOKEN = "M_MISSING_TOKEN"
UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN" UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN"
GUEST_ACCESS_FORBIDDEN = "M_GUEST_ACCESS_FORBIDDEN"
LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED" LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED"
CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED" CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED"
CAPTCHA_INVALID = "M_CAPTCHA_INVALID" CAPTCHA_INVALID = "M_CAPTCHA_INVALID"

View file

@ -50,11 +50,11 @@ class Filtering(object):
# many definitions. # many definitions.
top_level_definitions = [ top_level_definitions = [
"public_user_data", "private_user_data", "server_data" "presence"
] ]
room_level_definitions = [ room_level_definitions = [
"state", "timeline", "ephemeral" "state", "timeline", "ephemeral", "private_user_data"
] ]
for key in top_level_definitions: for key in top_level_definitions:
@ -114,22 +114,6 @@ class Filtering(object):
if not isinstance(event_type, basestring): if not isinstance(event_type, basestring):
raise SynapseError(400, "Event type should be a string") raise SynapseError(400, "Event type should be a string")
if "format" in definition:
event_format = definition["format"]
if event_format not in ["federation", "events"]:
raise SynapseError(400, "Invalid format: %s" % (event_format,))
if "select" in definition:
event_select_list = definition["select"]
for select_key in event_select_list:
if select_key not in ["event_id", "origin_server_ts",
"thread_id", "content", "content.body"]:
raise SynapseError(400, "Bad select: %s" % (select_key,))
if ("bundle_updates" in definition and
type(definition["bundle_updates"]) != bool):
raise SynapseError(400, "Bad bundle_updates: expected bool.")
class FilterCollection(object): class FilterCollection(object):
def __init__(self, filter_json): def __init__(self, filter_json):

View file

@ -34,6 +34,7 @@ class RegistrationConfig(Config):
self.registration_shared_secret = config.get("registration_shared_secret") self.registration_shared_secret = config.get("registration_shared_secret")
self.macaroon_secret_key = config.get("macaroon_secret_key") self.macaroon_secret_key = config.get("macaroon_secret_key")
self.bcrypt_rounds = config.get("bcrypt_rounds", 12) self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
self.allow_guest_access = config.get("allow_guest_access", False)
def default_config(self, **kwargs): def default_config(self, **kwargs):
registration_shared_secret = random_string_with_symbols(50) registration_shared_secret = random_string_with_symbols(50)
@ -54,6 +55,11 @@ class RegistrationConfig(Config):
# Larger numbers increase the work factor needed to generate the hash. # Larger numbers increase the work factor needed to generate the hash.
# The default number of rounds is 12. # The default number of rounds is 12.
bcrypt_rounds: 12 bcrypt_rounds: 12
# Allows users to register as guests without a password/email/etc, and
# participate in rooms hosted on this server which have been made
# accessible to anonymous users.
allow_guest_access: False
""" % locals() """ % locals()
def add_arguments(self, parser): def add_arguments(self, parser):

View file

@ -47,37 +47,24 @@ class BaseHandler(object):
self.event_builder_factory = hs.get_event_builder_factory() self.event_builder_factory = hs.get_event_builder_factory()
@defer.inlineCallbacks @defer.inlineCallbacks
def _filter_events_for_client(self, user_id, events): def _filter_events_for_client(self, user_id, events, is_guest=False,
event_id_to_state = yield self.store.get_state_for_events( require_all_visible_for_guests=True):
frozenset(e.event_id for e in events), # Assumes that user has at some point joined the room if not is_guest.
types=(
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, user_id),
)
)
def allowed(event, state): def allowed(event, membership, visibility):
if event.type == EventTypes.RoomHistoryVisibility: if visibility == "world_readable":
return True return True
membership_ev = state.get((EventTypes.Member, user_id), None) if is_guest:
if membership_ev: return False
membership = membership_ev.membership
else:
membership = Membership.LEAVE
if membership == Membership.JOIN: if membership == Membership.JOIN:
return True return True
history = state.get((EventTypes.RoomHistoryVisibility, ''), None) if event.type == EventTypes.RoomHistoryVisibility:
if history: return not is_guest
visibility = history.content.get("history_visibility", "shared")
else:
visibility = "shared"
if visibility == "public": if visibility == "shared":
return True
elif visibility == "shared":
return True return True
elif visibility == "joined": elif visibility == "joined":
return membership == Membership.JOIN return membership == Membership.JOIN
@ -86,11 +73,46 @@ class BaseHandler(object):
return True return True
defer.returnValue([ event_id_to_state = yield self.store.get_state_for_events(
event frozenset(e.event_id for e in events),
for event in events types=(
if allowed(event, event_id_to_state[event.event_id]) (EventTypes.RoomHistoryVisibility, ""),
]) (EventTypes.Member, user_id),
)
)
events_to_return = []
for event in events:
state = event_id_to_state[event.event_id]
membership_event = state.get((EventTypes.Member, user_id), None)
if membership_event:
membership = membership_event.membership
else:
membership = None
visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
if visibility_event:
visibility = visibility_event.content.get("history_visibility", "shared")
else:
visibility = "shared"
should_include = allowed(event, membership, visibility)
if should_include:
events_to_return.append(event)
if (require_all_visible_for_guests
and is_guest
and len(events_to_return) < len(events)):
# This indicates that some events in the requested range were not
# visible to guest users. To be safe, we reject the entire request,
# so that we don't have to worry about interpreting visibility
# boundaries.
raise AuthError(403, "User %s does not have permission" % (
user_id
))
defer.returnValue(events_to_return)
def ratelimit(self, user_id): def ratelimit(self, user_id):
time_now = self.clock.time() time_now = self.clock.time()

View file

@ -372,12 +372,15 @@ class AuthHandler(BaseHandler):
yield self.store.add_refresh_token_to_user(user_id, refresh_token) yield self.store.add_refresh_token_to_user(user_id, refresh_token)
defer.returnValue(refresh_token) defer.returnValue(refresh_token)
def generate_access_token(self, user_id): def generate_access_token(self, user_id, extra_caveats=None):
extra_caveats = extra_caveats or []
macaroon = self._generate_base_macaroon(user_id) macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("type = access")
now = self.hs.get_clock().time_msec() now = self.hs.get_clock().time_msec()
expiry = now + (60 * 60 * 1000) expiry = now + (60 * 60 * 1000)
macaroon.add_first_party_caveat("time < %d" % (expiry,)) macaroon.add_first_party_caveat("time < %d" % (expiry,))
for caveat in extra_caveats:
macaroon.add_first_party_caveat(caveat)
return macaroon.serialize() return macaroon.serialize()
def generate_refresh_token(self, user_id): def generate_refresh_token(self, user_id):

View file

@ -100,7 +100,7 @@ class EventStreamHandler(BaseHandler):
@log_function @log_function
def get_stream(self, auth_user_id, pagin_config, timeout=0, def get_stream(self, auth_user_id, pagin_config, timeout=0,
as_client_event=True, affect_presence=True, as_client_event=True, affect_presence=True,
only_room_events=False): only_room_events=False, room_id=None, is_guest=False):
"""Fetches the events stream for a given user. """Fetches the events stream for a given user.
If `only_room_events` is `True` only room events will be returned. If `only_room_events` is `True` only room events will be returned.
@ -111,17 +111,6 @@ class EventStreamHandler(BaseHandler):
if affect_presence: if affect_presence:
yield self.started_stream(auth_user) yield self.started_stream(auth_user)
rm_handler = self.hs.get_handlers().room_member_handler
app_service = yield self.store.get_app_service_by_user_id(
auth_user.to_string()
)
if app_service:
rooms = yield self.store.get_app_service_rooms(app_service)
room_ids = set(r.room_id for r in rooms)
else:
room_ids = yield rm_handler.get_joined_rooms_for_user(auth_user)
if timeout: if timeout:
# If they've set a timeout set a minimum limit. # If they've set a timeout set a minimum limit.
timeout = max(timeout, 500) timeout = max(timeout, 500)
@ -130,9 +119,15 @@ class EventStreamHandler(BaseHandler):
# thundering herds on restart. # thundering herds on restart.
timeout = random.randint(int(timeout*0.9), int(timeout*1.1)) timeout = random.randint(int(timeout*0.9), int(timeout*1.1))
if is_guest:
yield self.distributor.fire(
"user_joined_room", user=auth_user, room_id=room_id
)
events, tokens = yield self.notifier.get_events_for( events, tokens = yield self.notifier.get_events_for(
auth_user, room_ids, pagin_config, timeout, auth_user, pagin_config, timeout,
only_room_events=only_room_events only_room_events=only_room_events,
is_guest=is_guest, guest_room_id=room_id
) )
time_now = self.clock.time_msec() time_now = self.clock.time_msec()

View file

@ -72,8 +72,6 @@ class FederationHandler(BaseHandler):
self.server_name = hs.hostname self.server_name = hs.hostname
self.keyring = hs.get_keyring() self.keyring = hs.get_keyring()
self.lock_manager = hs.get_room_lock_manager()
self.replication_layer.set_handler(self) self.replication_layer.set_handler(self)
# When joining a room we need to queue any events for that room up # When joining a room we need to queue any events for that room up

View file

@ -16,7 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError, AuthError, Codes
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
@ -71,7 +71,7 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_messages(self, user_id=None, room_id=None, pagin_config=None, def get_messages(self, user_id=None, room_id=None, pagin_config=None,
as_client_event=True): as_client_event=True, is_guest=False):
"""Get messages in a room. """Get messages in a room.
Args: Args:
@ -80,11 +80,11 @@ class MessageHandler(BaseHandler):
pagin_config (synapse.api.streams.PaginationConfig): The pagination pagin_config (synapse.api.streams.PaginationConfig): The pagination
config rules to apply, if any. config rules to apply, if any.
as_client_event (bool): True to get events in client-server format. as_client_event (bool): True to get events in client-server format.
is_guest (bool): Whether the requesting user is a guest (as opposed
to a fully registered user).
Returns: Returns:
dict: Pagination API results dict: Pagination API results
""" """
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
data_source = self.hs.get_event_sources().sources["room"] data_source = self.hs.get_event_sources().sources["room"]
if pagin_config.from_token: if pagin_config.from_token:
@ -107,9 +107,13 @@ class MessageHandler(BaseHandler):
source_config = pagin_config.get_source_config("room") source_config = pagin_config.get_source_config("room")
if not is_guest:
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
if member_event.membership == Membership.LEAVE: if member_event.membership == Membership.LEAVE:
# If they have left the room then clamp the token to be before # If they have left the room then clamp the token to be before
# they left the room # they left the room.
# If they're a guest, we'll just 403 them if they're asking for
# events they can't see.
leave_token = yield self.store.get_topological_token_for_event( leave_token = yield self.store.get_topological_token_for_event(
member_event.event_id member_event.event_id
) )
@ -146,7 +150,7 @@ class MessageHandler(BaseHandler):
"end": next_token.to_string(), "end": next_token.to_string(),
}) })
events = yield self._filter_events_for_client(user_id, events) events = yield self._filter_events_for_client(user_id, events, is_guest=is_guest)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
@ -225,7 +229,7 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_data(self, user_id=None, room_id=None, def get_room_data(self, user_id=None, room_id=None,
event_type=None, state_key=""): event_type=None, state_key="", is_guest=False):
""" Get data from a room. """ Get data from a room.
Args: Args:
@ -235,23 +239,42 @@ class MessageHandler(BaseHandler):
Raises: Raises:
SynapseError if something went wrong. SynapseError if something went wrong.
""" """
member_event = yield self.auth.check_user_was_in_room(room_id, user_id) membership, membership_event_id = yield self._check_in_room_or_world_readable(
room_id, user_id, is_guest
)
if member_event.membership == Membership.JOIN: if membership == Membership.JOIN:
data = yield self.state_handler.get_current_state( data = yield self.state_handler.get_current_state(
room_id, event_type, state_key room_id, event_type, state_key
) )
elif member_event.membership == Membership.LEAVE: elif membership == Membership.LEAVE:
key = (event_type, state_key) key = (event_type, state_key)
room_state = yield self.store.get_state_for_events( room_state = yield self.store.get_state_for_events(
[member_event.event_id], [key] [membership_event_id], [key]
) )
data = room_state[member_event.event_id].get(key) data = room_state[membership_event_id].get(key)
defer.returnValue(data) defer.returnValue(data)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_events(self, user_id, room_id): def _check_in_room_or_world_readable(self, room_id, user_id, is_guest):
if is_guest:
visibility = yield self.state_handler.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, ""
)
if visibility.content["history_visibility"] == "world_readable":
defer.returnValue((Membership.JOIN, None))
return
else:
raise AuthError(
403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
)
else:
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
defer.returnValue((member_event.membership, member_event.event_id))
@defer.inlineCallbacks
def get_state_events(self, user_id, room_id, is_guest=False):
"""Retrieve all state events for a given room. If the user is """Retrieve all state events for a given room. If the user is
joined to the room then return the current state. If the user has joined to the room then return the current state. If the user has
left the room return the state events from when they left. left the room return the state events from when they left.
@ -262,15 +285,17 @@ class MessageHandler(BaseHandler):
Returns: Returns:
A list of dicts representing state events. [{}, {}, {}] A list of dicts representing state events. [{}, {}, {}]
""" """
member_event = yield self.auth.check_user_was_in_room(room_id, user_id) membership, membership_event_id = yield self._check_in_room_or_world_readable(
room_id, user_id, is_guest
if member_event.membership == Membership.JOIN:
room_state = yield self.state_handler.get_current_state(room_id)
elif member_event.membership == Membership.LEAVE:
room_state = yield self.store.get_state_for_events(
[member_event.event_id], None
) )
room_state = room_state[member_event.event_id]
if membership == Membership.JOIN:
room_state = yield self.state_handler.get_current_state(room_id)
elif membership == Membership.LEAVE:
room_state = yield self.store.get_state_for_events(
[membership_event_id], None
)
room_state = room_state[membership_event_id]
now = self.clock.time_msec() now = self.clock.time_msec()
defer.returnValue( defer.returnValue(

View file

@ -1142,8 +1142,9 @@ class PresenceEventSource(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_new_events_for_user(self, user, from_key, limit): def get_new_events(self, user, from_key, room_ids=None, **kwargs):
from_key = int(from_key) from_key = int(from_key)
room_ids = room_ids or []
presence = self.hs.get_handlers().presence_handler presence = self.hs.get_handlers().presence_handler
cachemap = presence._user_cachemap cachemap = presence._user_cachemap
@ -1161,7 +1162,6 @@ class PresenceEventSource(object):
user_ids_to_check |= set( user_ids_to_check |= set(
UserID.from_string(p["observed_user_id"]) for p in presence_list UserID.from_string(p["observed_user_id"]) for p in presence_list
) )
room_ids = yield presence.get_joined_rooms_for_user(user)
for room_id in set(room_ids) & set(presence._room_serials): for room_id in set(room_ids) & set(presence._room_serials):
if presence._room_serials[room_id] > from_key: if presence._room_serials[room_id] > from_key:
joined = yield presence.get_joined_users_for_room_id(room_id) joined = yield presence.get_joined_users_for_room_id(room_id)

View file

@ -24,7 +24,7 @@ class PrivateUserDataEventSource(object):
return self.store.get_max_private_user_data_stream_id() return self.store.get_max_private_user_data_stream_id()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_new_events_for_user(self, user, from_key, limit): def get_new_events(self, user, from_key, **kwargs):
user_id = user.to_string() user_id = user.to_string()
last_stream_id = from_key last_stream_id = from_key

View file

@ -164,17 +164,15 @@ class ReceiptEventSource(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_new_events_for_user(self, user, from_key, limit): def get_new_events(self, from_key, room_ids, **kwargs):
from_key = int(from_key) from_key = int(from_key)
to_key = yield self.get_current_key() to_key = yield self.get_current_key()
if from_key == to_key: if from_key == to_key:
defer.returnValue(([], to_key)) defer.returnValue(([], to_key))
rooms = yield self.store.get_rooms_for_user(user.to_string())
rooms = [room.room_id for room in rooms]
events = yield self.store.get_linearized_receipts_for_rooms( events = yield self.store.get_linearized_receipts_for_rooms(
rooms, room_ids,
from_key=from_key, from_key=from_key,
to_key=to_key, to_key=to_key,
) )

View file

@ -64,7 +64,7 @@ class RegistrationHandler(BaseHandler):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def register(self, localpart=None, password=None): def register(self, localpart=None, password=None, generate_token=True):
"""Registers a new client on the server. """Registers a new client on the server.
Args: Args:
@ -89,6 +89,8 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
token = None
if generate_token:
token = self.auth_handler().generate_access_token(user_id) token = self.auth_handler().generate_access_token(user_id)
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
@ -102,13 +104,13 @@ class RegistrationHandler(BaseHandler):
attempts = 0 attempts = 0
user_id = None user_id = None
token = None token = None
while not user_id and not token: while not user_id:
try: try:
localpart = self._generate_user_id() localpart = self._generate_user_id()
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
yield self.check_user_id_is_valid(user_id) yield self.check_user_id_is_valid(user_id)
if generate_token:
token = self.auth_handler().generate_access_token(user_id) token = self.auth_handler().generate_access_token(user_id)
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,

View file

@ -807,7 +807,14 @@ class RoomEventSource(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_new_events_for_user(self, user, from_key, limit): def get_new_events(
self,
user,
from_key,
limit,
room_ids,
is_guest,
):
# We just ignore the key for now. # We just ignore the key for now.
to_key = yield self.get_current_key() to_key = yield self.get_current_key()
@ -827,8 +834,9 @@ class RoomEventSource(object):
user_id=user.to_string(), user_id=user.to_string(),
from_key=from_key, from_key=from_key,
to_key=to_key, to_key=to_key,
room_id=None,
limit=limit, limit=limit,
room_ids=room_ids,
is_guest=is_guest,
) )
defer.returnValue((events, end_key)) defer.returnValue((events, end_key))

View file

@ -143,21 +143,8 @@ class SyncHandler(BaseHandler):
def current_sync_callback(before_token, after_token): def current_sync_callback(before_token, after_token):
return self.current_sync_for_user(sync_config, since_token) return self.current_sync_for_user(sync_config, since_token)
rm_handler = self.hs.get_handlers().room_member_handler
app_service = yield self.store.get_app_service_by_user_id(
sync_config.user.to_string()
)
if app_service:
rooms = yield self.store.get_app_service_rooms(app_service)
room_ids = set(r.room_id for r in rooms)
else:
room_ids = yield rm_handler.get_joined_rooms_for_user(
sync_config.user
)
result = yield self.notifier.wait_for_events( result = yield self.notifier.wait_for_events(
sync_config.user, room_ids, timeout, current_sync_callback, sync_config.user, timeout, current_sync_callback,
from_token=since_token from_token=since_token
) )
defer.returnValue(result) defer.returnValue(result)
@ -308,11 +295,16 @@ class SyncHandler(BaseHandler):
typing_key = since_token.typing_key if since_token else "0" typing_key = since_token.typing_key if since_token else "0"
rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string())
room_ids = [room.room_id for room in rooms]
typing_source = self.event_sources.sources["typing"] typing_source = self.event_sources.sources["typing"]
typing, typing_key = yield typing_source.get_new_events_for_user( typing, typing_key = yield typing_source.get_new_events(
user=sync_config.user, user=sync_config.user,
from_key=typing_key, from_key=typing_key,
limit=sync_config.filter.ephemeral_limit(), limit=sync_config.filter.ephemeral_limit(),
room_ids=room_ids,
is_guest=False,
) )
now_token = now_token.copy_and_replace("typing_key", typing_key) now_token = now_token.copy_and_replace("typing_key", typing_key)
@ -325,10 +317,13 @@ class SyncHandler(BaseHandler):
receipt_key = since_token.receipt_key if since_token else "0" receipt_key = since_token.receipt_key if since_token else "0"
receipt_source = self.event_sources.sources["receipt"] receipt_source = self.event_sources.sources["receipt"]
receipts, receipt_key = yield receipt_source.get_new_events_for_user( receipts, receipt_key = yield receipt_source.get_new_events(
user=sync_config.user, user=sync_config.user,
from_key=receipt_key, from_key=receipt_key,
limit=sync_config.filter.ephemeral_limit(), limit=sync_config.filter.ephemeral_limit(),
room_ids=room_ids,
# /sync doesn't support guest access, they can't get to this point in code
is_guest=False,
) )
now_token = now_token.copy_and_replace("receipt_key", receipt_key) now_token = now_token.copy_and_replace("receipt_key", receipt_key)
@ -373,11 +368,17 @@ class SyncHandler(BaseHandler):
""" """
now_token = yield self.event_sources.get_current_token() now_token = yield self.event_sources.get_current_token()
rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string())
room_ids = [room.room_id for room in rooms]
presence_source = self.event_sources.sources["presence"] presence_source = self.event_sources.sources["presence"]
presence, presence_key = yield presence_source.get_new_events_for_user( presence, presence_key = yield presence_source.get_new_events(
user=sync_config.user, user=sync_config.user,
from_key=since_token.presence_key, from_key=since_token.presence_key,
limit=sync_config.filter.presence_limit(), limit=sync_config.filter.presence_limit(),
room_ids=room_ids,
# /sync doesn't support guest access, they can't get to this point in code
is_guest=False,
) )
now_token = now_token.copy_and_replace("presence_key", presence_key) now_token = now_token.copy_and_replace("presence_key", presence_key)
@ -403,7 +404,6 @@ class SyncHandler(BaseHandler):
sync_config.user.to_string(), sync_config.user.to_string(),
from_key=since_token.room_key, from_key=since_token.room_key,
to_key=now_token.room_key, to_key=now_token.room_key,
room_id=None,
limit=timeline_limit + 1, limit=timeline_limit + 1,
) )

View file

@ -246,17 +246,12 @@ class TypingNotificationEventSource(object):
}, },
} }
@defer.inlineCallbacks def get_new_events(self, from_key, room_ids, **kwargs):
def get_new_events_for_user(self, user, from_key, limit):
from_key = int(from_key) from_key = int(from_key)
handler = self.handler() handler = self.handler()
joined_room_ids = (
yield self.room_member_handler().get_joined_rooms_for_user(user)
)
events = [] events = []
for room_id in joined_room_ids: for room_id in room_ids:
if room_id not in handler._room_serials: if room_id not in handler._room_serials:
continue continue
if handler._room_serials[room_id] <= from_key: if handler._room_serials[room_id] <= from_key:
@ -264,7 +259,7 @@ class TypingNotificationEventSource(object):
events.append(self._make_event_for(room_id)) events.append(self._make_event_for(room_id))
defer.returnValue((events, handler._latest_room_serial)) return events, handler._latest_room_serial
def get_current_key(self): def get_current_key(self):
return self.handler()._latest_room_serial return self.handler()._latest_room_serial

View file

@ -269,7 +269,7 @@ class Notifier(object):
logger.exception("Failed to notify listener") logger.exception("Failed to notify listener")
@defer.inlineCallbacks @defer.inlineCallbacks
def wait_for_events(self, user, rooms, timeout, callback, def wait_for_events(self, user, timeout, callback, room_ids=None,
from_token=StreamToken("s0", "0", "0", "0", "0")): from_token=StreamToken("s0", "0", "0", "0", "0")):
"""Wait until the callback returns a non empty response or the """Wait until the callback returns a non empty response or the
timeout fires. timeout fires.
@ -279,11 +279,12 @@ class Notifier(object):
if user_stream is None: if user_stream is None:
appservice = yield self.store.get_app_service_by_user_id(user) appservice = yield self.store.get_app_service_by_user_id(user)
current_token = yield self.event_sources.get_current_token() current_token = yield self.event_sources.get_current_token()
if room_ids is None:
rooms = yield self.store.get_rooms_for_user(user) rooms = yield self.store.get_rooms_for_user(user)
rooms = [room.room_id for room in rooms] room_ids = [room.room_id for room in rooms]
user_stream = _NotifierUserStream( user_stream = _NotifierUserStream(
user=user, user=user,
rooms=rooms, rooms=room_ids,
appservice=appservice, appservice=appservice,
current_token=current_token, current_token=current_token,
time_now_ms=self.clock.time_msec(), time_now_ms=self.clock.time_msec(),
@ -328,8 +329,9 @@ class Notifier(object):
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_events_for(self, user, rooms, pagination_config, timeout, def get_events_for(self, user, pagination_config, timeout,
only_room_events=False): only_room_events=False,
is_guest=False, guest_room_id=None):
""" For the given user and rooms, return any new events for them. If """ For the given user and rooms, return any new events for them. If
there are no new events wait for up to `timeout` milliseconds for any there are no new events wait for up to `timeout` milliseconds for any
new events to happen before returning. new events to happen before returning.
@ -342,6 +344,16 @@ class Notifier(object):
limit = pagination_config.limit limit = pagination_config.limit
room_ids = []
if is_guest:
# TODO(daniel): Deal with non-room events too
only_room_events = True
if guest_room_id:
room_ids = [guest_room_id]
else:
rooms = yield self.store.get_rooms_for_user(user.to_string())
room_ids = [room.room_id for room in rooms]
@defer.inlineCallbacks @defer.inlineCallbacks
def check_for_updates(before_token, after_token): def check_for_updates(before_token, after_token):
if not after_token.is_after(before_token): if not after_token.is_after(before_token):
@ -357,9 +369,23 @@ class Notifier(object):
continue continue
if only_room_events and name != "room": if only_room_events and name != "room":
continue continue
new_events, new_key = yield source.get_new_events_for_user( new_events, new_key = yield source.get_new_events(
user, getattr(from_token, keyname), limit, user=user,
from_key=getattr(from_token, keyname),
limit=limit,
is_guest=is_guest,
room_ids=room_ids,
) )
if is_guest:
room_member_handler = self.hs.get_handlers().room_member_handler
new_events = yield room_member_handler._filter_events_for_client(
user.to_string(),
new_events,
is_guest=is_guest,
require_all_visible_for_guests=False
)
events.extend(new_events) events.extend(new_events)
end_token = end_token.copy_and_replace(keyname, new_key) end_token = end_token.copy_and_replace(keyname, new_key)
@ -369,7 +395,7 @@ class Notifier(object):
defer.returnValue(None) defer.returnValue(None)
result = yield self.wait_for_events( result = yield self.wait_for_events(
user, rooms, timeout, check_for_updates, from_token=from_token user, timeout, check_for_updates, room_ids=room_ids, from_token=from_token
) )
if result is None: if result is None:

View file

@ -31,7 +31,7 @@ class WhoisRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(auth_user) is_admin = yield self.auth.is_server_admin(auth_user)
if not is_admin and target_user != auth_user: if not is_admin and target_user != auth_user:

View file

@ -69,7 +69,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
try: try:
# try to auth as a user # try to auth as a user
user, _ = yield self.auth.get_user_by_req(request) user, _, _ = yield self.auth.get_user_by_req(request)
try: try:
user_id = user.to_string() user_id = user.to_string()
yield dir_handler.create_association( yield dir_handler.create_association(
@ -116,7 +116,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
# fallback to default user behaviour if they aren't an AS # fallback to default user behaviour if they aren't an AS
pass pass
user, _ = yield self.auth.get_user_by_req(request) user, _, _ = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(user) is_admin = yield self.auth.is_server_admin(user)
if not is_admin: if not is_admin:

View file

@ -34,7 +34,15 @@ class EventStreamRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, is_guest = yield self.auth.get_user_by_req(
request,
allow_guest=True
)
room_id = None
if is_guest:
if "room_id" not in request.args:
raise SynapseError(400, "Guest users must specify room_id param")
room_id = request.args["room_id"][0]
try: try:
handler = self.handlers.event_stream_handler handler = self.handlers.event_stream_handler
pagin_config = PaginationConfig.from_request(request) pagin_config = PaginationConfig.from_request(request)
@ -49,7 +57,8 @@ class EventStreamRestServlet(ClientV1RestServlet):
chunk = yield handler.get_stream( chunk = yield handler.get_stream(
auth_user.to_string(), pagin_config, timeout=timeout, auth_user.to_string(), pagin_config, timeout=timeout,
as_client_event=as_client_event as_client_event=as_client_event, affect_presence=(not is_guest),
room_id=room_id, is_guest=is_guest
) )
except: except:
logger.exception("Event stream failed") logger.exception("Event stream failed")
@ -71,7 +80,7 @@ class EventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, event_id): def on_GET(self, request, event_id):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
handler = self.handlers.event_handler handler = self.handlers.event_handler
event = yield handler.get_event(auth_user, event_id) event = yield handler.get_event(auth_user, event_id)

View file

@ -25,7 +25,7 @@ class InitialSyncRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
user, _ = yield self.auth.get_user_by_req(request) user, _, _ = yield self.auth.get_user_by_req(request)
as_client_event = "raw" not in request.args as_client_event = "raw" not in request.args
pagination_config = PaginationConfig.from_request(request) pagination_config = PaginationConfig.from_request(request)
handler = self.handlers.message_handler handler = self.handlers.message_handler

View file

@ -32,7 +32,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
state = yield self.handlers.presence_handler.get_state( state = yield self.handlers.presence_handler.get_state(
@ -42,7 +42,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, user_id): def on_PUT(self, request, user_id):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
state = {} state = {}
@ -77,7 +77,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
if not self.hs.is_mine(user): if not self.hs.is_mine(user):
@ -97,7 +97,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, user_id): def on_POST(self, request, user_id):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
if not self.hs.is_mine(user): if not self.hs.is_mine(user):

View file

@ -37,7 +37,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, user_id): def on_PUT(self, request, user_id):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
try: try:
@ -70,7 +70,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, user_id): def on_PUT(self, request, user_id):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
try: try:

View file

@ -43,7 +43,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
except InvalidRuleException as e: except InvalidRuleException as e:
raise SynapseError(400, e.message) raise SynapseError(400, e.message)
user, _ = yield self.auth.get_user_by_req(request) user, _, _ = yield self.auth.get_user_by_req(request)
if '/' in spec['rule_id'] or '\\' in spec['rule_id']: if '/' in spec['rule_id'] or '\\' in spec['rule_id']:
raise SynapseError(400, "rule_id may not contain slashes") raise SynapseError(400, "rule_id may not contain slashes")
@ -92,7 +92,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
def on_DELETE(self, request): def on_DELETE(self, request):
spec = _rule_spec_from_path(request.postpath) spec = _rule_spec_from_path(request.postpath)
user, _ = yield self.auth.get_user_by_req(request) user, _, _ = yield self.auth.get_user_by_req(request)
namespaced_rule_id = _namespaced_rule_id_from_spec(spec) namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
@ -109,7 +109,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
user, _ = yield self.auth.get_user_by_req(request) user, _, _ = yield self.auth.get_user_by_req(request)
# we build up the full structure and then decide which bits of it # we build up the full structure and then decide which bits of it
# to send which means doing unnecessary work sometimes but is # to send which means doing unnecessary work sometimes but is

View file

@ -27,7 +27,7 @@ class PusherRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
user, token_id = yield self.auth.get_user_by_req(request) user, token_id, _ = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)

View file

@ -62,7 +62,7 @@ class RoomCreateRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
room_config = self.get_room_config(request) room_config = self.get_room_config(request)
info = yield self.make_room(room_config, auth_user, None) info = yield self.make_room(room_config, auth_user, None)
@ -125,7 +125,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id, event_type, state_key): def on_GET(self, request, room_id, event_type, state_key):
user, _ = yield self.auth.get_user_by_req(request) user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
data = yield msg_handler.get_room_data( data = yield msg_handler.get_room_data(
@ -133,6 +133,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
room_id=room_id, room_id=room_id,
event_type=event_type, event_type=event_type,
state_key=state_key, state_key=state_key,
is_guest=is_guest,
) )
if not data: if not data:
@ -143,7 +144,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
user, token_id = yield self.auth.get_user_by_req(request) user, token_id, _ = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)
@ -175,7 +176,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, event_type, txn_id=None): def on_POST(self, request, room_id, event_type, txn_id=None):
user, token_id = yield self.auth.get_user_by_req(request) user, token_id, _ = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
@ -220,7 +221,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_identifier, txn_id=None): def on_POST(self, request, room_identifier, txn_id=None):
user, token_id = yield self.auth.get_user_by_req(request) user, token_id, _ = yield self.auth.get_user_by_req(request)
# the identifier could be a room alias or a room id. Try one then the # the identifier could be a room alias or a room id. Try one then the
# other if it fails to parse, without swallowing other valid # other if it fails to parse, without swallowing other valid
@ -289,7 +290,7 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens) # TODO support Pagination stream API (limit/tokens)
user, _ = yield self.auth.get_user_by_req(request) user, _, _ = yield self.auth.get_user_by_req(request)
handler = self.handlers.message_handler handler = self.handlers.message_handler
events = yield handler.get_state_events( events = yield handler.get_state_events(
room_id=room_id, room_id=room_id,
@ -325,7 +326,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
user, _ = yield self.auth.get_user_by_req(request) user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = PaginationConfig.from_request( pagination_config = PaginationConfig.from_request(
request, default_limit=10, request, default_limit=10,
) )
@ -334,6 +335,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
msgs = yield handler.get_messages( msgs = yield handler.get_messages(
room_id=room_id, room_id=room_id,
user_id=user.to_string(), user_id=user.to_string(),
is_guest=is_guest,
pagin_config=pagination_config, pagin_config=pagination_config,
as_client_event=as_client_event as_client_event=as_client_event
) )
@ -347,12 +349,13 @@ class RoomStateRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
user, _ = yield self.auth.get_user_by_req(request) user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
handler = self.handlers.message_handler handler = self.handlers.message_handler
# Get all the current state for this room # Get all the current state for this room
events = yield handler.get_state_events( events = yield handler.get_state_events(
room_id=room_id, room_id=room_id,
user_id=user.to_string(), user_id=user.to_string(),
is_guest=is_guest,
) )
defer.returnValue((200, events)) defer.returnValue((200, events))
@ -363,7 +366,7 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
user, _ = yield self.auth.get_user_by_req(request) user, _, _ = yield self.auth.get_user_by_req(request)
pagination_config = PaginationConfig.from_request(request) pagination_config = PaginationConfig.from_request(request)
content = yield self.handlers.message_handler.room_initial_sync( content = yield self.handlers.message_handler.room_initial_sync(
room_id=room_id, room_id=room_id,
@ -443,7 +446,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, membership_action, txn_id=None): def on_POST(self, request, room_id, membership_action, txn_id=None):
user, token_id = yield self.auth.get_user_by_req(request) user, token_id, _ = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)
@ -524,7 +527,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, event_id, txn_id=None): def on_POST(self, request, room_id, event_id, txn_id=None):
user, token_id = yield self.auth.get_user_by_req(request) user, token_id, _ = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
@ -564,7 +567,7 @@ class RoomTypingRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_id, user_id): def on_PUT(self, request, room_id, user_id):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
room_id = urllib.unquote(room_id) room_id = urllib.unquote(room_id)
target_user = UserID.from_string(urllib.unquote(user_id)) target_user = UserID.from_string(urllib.unquote(user_id))
@ -597,7 +600,7 @@ class SearchRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)

View file

@ -28,7 +28,7 @@ class VoipRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
turnUris = self.hs.config.turn_uris turnUris = self.hs.config.turn_uris
turnSecret = self.hs.config.turn_shared_secret turnSecret = self.hs.config.turn_shared_secret

View file

@ -55,7 +55,7 @@ class PasswordRestServlet(RestServlet):
if LoginType.PASSWORD in result: if LoginType.PASSWORD in result:
# if using password, they should also be logged in # if using password, they should also be logged in
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
if auth_user.to_string() != result[LoginType.PASSWORD]: if auth_user.to_string() != result[LoginType.PASSWORD]:
raise LoginError(400, "", Codes.UNKNOWN) raise LoginError(400, "", Codes.UNKNOWN)
user_id = auth_user.to_string() user_id = auth_user.to_string()
@ -102,7 +102,7 @@ class ThreepidRestServlet(RestServlet):
def on_GET(self, request): def on_GET(self, request):
yield run_on_reactor() yield run_on_reactor()
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
threepids = yield self.hs.get_datastore().user_get_threepids( threepids = yield self.hs.get_datastore().user_get_threepids(
auth_user.to_string() auth_user.to_string()
@ -120,7 +120,7 @@ class ThreepidRestServlet(RestServlet):
raise SynapseError(400, "Missing param", Codes.MISSING_PARAM) raise SynapseError(400, "Missing param", Codes.MISSING_PARAM)
threePidCreds = body['threePidCreds'] threePidCreds = body['threePidCreds']
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
threepid = yield self.identity_handler.threepid_from_creds(threePidCreds) threepid = yield self.identity_handler.threepid_from_creds(threePidCreds)

View file

@ -40,7 +40,7 @@ class GetFilterRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id, filter_id): def on_GET(self, request, user_id, filter_id):
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
if target_user != auth_user: if target_user != auth_user:
raise AuthError(403, "Cannot get filters for other users") raise AuthError(403, "Cannot get filters for other users")
@ -76,7 +76,7 @@ class CreateFilterRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, user_id): def on_POST(self, request, user_id):
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
if target_user != auth_user: if target_user != auth_user:
raise AuthError(403, "Cannot create filters for other users") raise AuthError(403, "Cannot create filters for other users")

View file

@ -64,7 +64,7 @@ class KeyUploadServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, device_id): def on_POST(self, request, device_id):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
user_id = auth_user.to_string() user_id = auth_user.to_string()
# TODO: Check that the device_id matches that in the authentication # TODO: Check that the device_id matches that in the authentication
# or derive the device_id from the authentication instead. # or derive the device_id from the authentication instead.
@ -109,7 +109,7 @@ class KeyUploadServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, device_id): def on_GET(self, request, device_id):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
user_id = auth_user.to_string() user_id = auth_user.to_string()
result = yield self.store.count_e2e_one_time_keys(user_id, device_id) result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
@ -181,7 +181,7 @@ class KeyQueryServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id, device_id): def on_GET(self, request, user_id, device_id):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
auth_user_id = auth_user.to_string() auth_user_id = auth_user.to_string()
user_id = user_id if user_id else auth_user_id user_id = user_id if user_id else auth_user_id
device_ids = [device_id] if device_id else [] device_ids = [device_id] if device_id else []

View file

@ -40,7 +40,7 @@ class ReceiptRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, receipt_type, event_id): def on_POST(self, request, room_id, receipt_type, event_id):
user, _ = yield self.auth.get_user_by_req(request) user, _, _ = yield self.auth.get_user_by_req(request)
if receipt_type != "m.read": if receipt_type != "m.read":
raise SynapseError(400, "Receipt type must be 'm.read'") raise SynapseError(400, "Receipt type must be 'm.read'")

View file

@ -16,7 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from ._base import client_v2_pattern, parse_json_dict_from_request from ._base import client_v2_pattern, parse_json_dict_from_request
@ -55,6 +55,19 @@ class RegisterRestServlet(RestServlet):
def on_POST(self, request): def on_POST(self, request):
yield run_on_reactor() yield run_on_reactor()
kind = "user"
if "kind" in request.args:
kind = request.args["kind"][0]
if kind == "guest":
ret = yield self._do_guest_registration()
defer.returnValue(ret)
return
elif kind != "user":
raise UnrecognizedRequestError(
"Do not understand membership kind: %s" % (kind,)
)
if '/register/email/requestToken' in request.path: if '/register/email/requestToken' in request.path:
ret = yield self.onEmailTokenRequest(request) ret = yield self.onEmailTokenRequest(request)
defer.returnValue(ret) defer.returnValue(ret)
@ -236,6 +249,18 @@ class RegisterRestServlet(RestServlet):
ret = yield self.identity_handler.requestEmailToken(**body) ret = yield self.identity_handler.requestEmailToken(**body)
defer.returnValue((200, ret)) defer.returnValue((200, ret))
@defer.inlineCallbacks
def _do_guest_registration(self):
if not self.hs.config.allow_guest_access:
defer.returnValue((403, "Guest access is disabled"))
user_id, _ = yield self.registration_handler.register(generate_token=False)
access_token = self.auth_handler.generate_access_token(user_id, ["guest = true"])
defer.returnValue((200, {
"user_id": user_id,
"access_token": access_token,
"home_server": self.hs.hostname,
}))
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
RegisterRestServlet(hs).register(http_server) RegisterRestServlet(hs).register(http_server)

View file

@ -81,7 +81,7 @@ class SyncRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
user, token_id = yield self.auth.get_user_by_req(request) user, token_id, _ = yield self.auth.get_user_by_req(request)
timeout = parse_integer(request, "timeout", default=0) timeout = parse_integer(request, "timeout", default=0)
since = parse_string(request, "since") since = parse_string(request, "since")

View file

@ -42,7 +42,7 @@ class TagListServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id, room_id): def on_GET(self, request, user_id, room_id):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
if user_id != auth_user.to_string(): if user_id != auth_user.to_string():
raise AuthError(403, "Cannot get tags for other users.") raise AuthError(403, "Cannot get tags for other users.")
@ -68,7 +68,7 @@ class TagServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, user_id, room_id, tag): def on_PUT(self, request, user_id, room_id, tag):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
if user_id != auth_user.to_string(): if user_id != auth_user.to_string():
raise AuthError(403, "Cannot add tags for other users.") raise AuthError(403, "Cannot add tags for other users.")
@ -88,7 +88,7 @@ class TagServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_DELETE(self, request, user_id, room_id, tag): def on_DELETE(self, request, user_id, room_id, tag):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
if user_id != auth_user.to_string(): if user_id != auth_user.to_string():
raise AuthError(403, "Cannot add tags for other users.") raise AuthError(403, "Cannot add tags for other users.")

View file

@ -66,7 +66,7 @@ class ContentRepoResource(resource.Resource):
@defer.inlineCallbacks @defer.inlineCallbacks
def map_request_to_name(self, request): def map_request_to_name(self, request):
# auth the user # auth the user
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
# namespace all file uploads on the user # namespace all file uploads on the user
prefix = base64.urlsafe_b64encode( prefix = base64.urlsafe_b64encode(

View file

@ -70,7 +70,7 @@ class UploadResource(BaseMediaResource):
@request_handler @request_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def _async_render_POST(self, request): def _async_render_POST(self, request):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have # TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point # already been uploaded to a tmp file at this point
content_length = request.getHeader("Content-Length") content_length = request.getHeader("Content-Length")

View file

@ -29,7 +29,6 @@ from synapse.state import StateHandler
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.util import Clock from synapse.util import Clock
from synapse.util.distributor import Distributor from synapse.util.distributor import Distributor
from synapse.util.lockutils import LockManager
from synapse.streams.events import EventSources from synapse.streams.events import EventSources
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
from synapse.crypto.keyring import Keyring from synapse.crypto.keyring import Keyring
@ -70,7 +69,6 @@ class BaseHomeServer(object):
'auth', 'auth',
'rest_servlet_factory', 'rest_servlet_factory',
'state_handler', 'state_handler',
'room_lock_manager',
'notifier', 'notifier',
'distributor', 'distributor',
'resource_for_client', 'resource_for_client',
@ -201,9 +199,6 @@ class HomeServer(BaseHomeServer):
def build_state_handler(self): def build_state_handler(self):
return StateHandler(self) return StateHandler(self)
def build_room_lock_manager(self):
return LockManager()
def build_distributor(self): def build_distributor(self):
return Distributor() return Distributor()

View file

@ -311,6 +311,8 @@ class EventsStore(SQLBaseStore):
self._store_room_message_txn(txn, event) self._store_room_message_txn(txn, event)
elif event.type == EventTypes.Redaction: elif event.type == EventTypes.Redaction:
self._store_redaction(txn, event) self._store_redaction(txn, event)
elif event.type == EventTypes.RoomHistoryVisibility:
self._store_history_visibility_txn(txn, event)
self._store_room_members_txn( self._store_room_members_txn(
txn, txn,

View file

@ -102,6 +102,7 @@ class RegistrationStore(SQLBaseStore):
400, "User ID already taken.", errcode=Codes.USER_IN_USE 400, "User ID already taken.", errcode=Codes.USER_IN_USE
) )
if token:
# it's possible for this to get a conflict, but only for a single user # it's possible for this to get a conflict, but only for a single user
# since tokens are namespaced based on their user ID # since tokens are namespaced based on their user ID
txn.execute( txn.execute(

View file

@ -202,6 +202,19 @@ class RoomStore(SQLBaseStore):
txn, event, "content.body", event.content["body"] txn, event, "content.body", event.content["body"]
) )
def _store_history_visibility_txn(self, txn, event):
if hasattr(event, "content") and "history_visibility" in event.content:
sql = (
"INSERT INTO history_visibility"
" (event_id, room_id, history_visibility)"
" VALUES (?, ?, ?)"
)
txn.execute(sql, (
event.event_id,
event.room_id,
event.content["history_visibility"]
))
def _store_event_search_txn(self, txn, event, key, value): def _store_event_search_txn(self, txn, event, key, value):
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
sql = ( sql = (

View file

@ -0,0 +1,26 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* This is a manual index of history_visibility content of state events,
* so that we can join on them in SELECT statements.
*/
CREATE TABLE IF NOT EXISTS history_visibility(
id INTEGER PRIMARY KEY,
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
history_visibility TEXT NOT NULL,
UNIQUE (event_id)
);

View file

@ -158,14 +158,40 @@ class StreamStore(SQLBaseStore):
defer.returnValue(results) defer.returnValue(results)
@log_function @log_function
def get_room_events_stream(self, user_id, from_key, to_key, room_id, def get_room_events_stream(
limit=0): self,
user_id,
from_key,
to_key,
limit=0,
is_guest=False,
room_ids=None
):
room_ids = room_ids or []
room_ids = [r for r in room_ids]
if is_guest:
current_room_membership_sql = (
"SELECT c.room_id FROM history_visibility AS h"
" INNER JOIN current_state_events AS c"
" ON h.event_id = c.event_id"
" WHERE c.room_id IN (%s) AND h.history_visibility = 'world_readable'" % (
",".join(map(lambda _: "?", room_ids))
)
)
current_room_membership_args = room_ids
else:
current_room_membership_sql = ( current_room_membership_sql = (
"SELECT m.room_id FROM room_memberships as m " "SELECT m.room_id FROM room_memberships as m "
" INNER JOIN current_state_events as c" " INNER JOIN current_state_events as c"
" ON m.event_id = c.event_id AND c.state_key = m.user_id" " ON m.event_id = c.event_id AND c.state_key = m.user_id"
" WHERE m.user_id = ? AND m.membership = 'join'" " WHERE m.user_id = ? AND m.membership = 'join'"
) )
current_room_membership_args = [user_id]
if room_ids:
current_room_membership_sql += " AND m.room_id in (%s)" % (
",".join(map(lambda _: "?", room_ids))
)
current_room_membership_args = [user_id] + room_ids
# We also want to get any membership events about that user, e.g. # We also want to get any membership events about that user, e.g.
# invites or leave notifications. # invites or leave notifications.
@ -174,6 +200,7 @@ class StreamStore(SQLBaseStore):
"INNER JOIN current_state_events as c ON m.event_id = c.event_id " "INNER JOIN current_state_events as c ON m.event_id = c.event_id "
"WHERE m.user_id = ? " "WHERE m.user_id = ? "
) )
membership_args = [user_id]
if limit: if limit:
limit = max(limit, MAX_STREAM_SIZE) limit = max(limit, MAX_STREAM_SIZE)
@ -200,7 +227,9 @@ class StreamStore(SQLBaseStore):
} }
def f(txn): def f(txn):
txn.execute(sql, (False, user_id, user_id, from_id.stream, to_id.stream,)) args = ([False] + current_room_membership_args + membership_args +
[from_id.stream, to_id.stream])
txn.execute(sql, args)
rows = self.cursor_to_dict(txn) rows = self.cursor_to_dict(txn)

View file

@ -1,74 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
import logging
logger = logging.getLogger(__name__)
class Lock(object):
def __init__(self, deferred, key):
self._deferred = deferred
self.released = False
self.key = key
def release(self):
self.released = True
self._deferred.callback(None)
def __del__(self):
if not self.released:
logger.critical("Lock was destructed but never released!")
self.release()
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
logger.debug("Releasing lock for key=%r", self.key)
self.release()
class LockManager(object):
""" Utility class that allows us to lock based on a `key` """
def __init__(self):
self._lock_deferreds = {}
@defer.inlineCallbacks
def lock(self, key):
""" Allows us to block until it is our turn.
Args:
key (str)
Returns:
Lock
"""
new_deferred = defer.Deferred()
old_deferred = self._lock_deferreds.get(key)
self._lock_deferreds[key] = new_deferred
if old_deferred:
logger.debug("Queueing on lock for key=%r", key)
yield old_deferred
logger.debug("Obtained lock for key=%r", key)
else:
logger.debug("Entering uncontended lock for key=%r", key)
defer.returnValue(Lock(new_deferred, key))

View file

@ -51,7 +51,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = Mock(return_value=[""])
(user, _) = yield self.auth.get_user_by_req(request) (user, _, _) = yield self.auth.get_user_by_req(request)
self.assertEquals(user.to_string(), self.test_user) self.assertEquals(user.to_string(), self.test_user)
def test_get_user_by_req_user_bad_token(self): def test_get_user_by_req_user_bad_token(self):
@ -86,7 +86,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = Mock(return_value=[""])
(user, _) = yield self.auth.get_user_by_req(request) (user, _, _) = yield self.auth.get_user_by_req(request)
self.assertEquals(user.to_string(), self.test_user) self.assertEquals(user.to_string(), self.test_user)
def test_get_user_by_req_appservice_bad_token(self): def test_get_user_by_req_appservice_bad_token(self):
@ -121,7 +121,7 @@ class AuthTestCase(unittest.TestCase):
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
request.args["user_id"] = [masquerading_user_id] request.args["user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = Mock(return_value=[""])
(user, _) = yield self.auth.get_user_by_req(request) (user, _, _) = yield self.auth.get_user_by_req(request)
self.assertEquals(user.to_string(), masquerading_user_id) self.assertEquals(user.to_string(), masquerading_user_id)
def test_get_user_by_req_appservice_valid_token_bad_user_id(self): def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
@ -158,6 +158,25 @@ class AuthTestCase(unittest.TestCase):
user = user_info["user"] user = user_info["user"]
self.assertEqual(UserID.from_string(user_id), user) self.assertEqual(UserID.from_string(user_id), user)
@defer.inlineCallbacks
def test_get_guest_user_from_macaroon(self):
user_id = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
key=self.hs.config.macaroon_secret_key)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
macaroon.add_first_party_caveat("guest = true")
serialized = macaroon.serialize()
user_info = yield self.auth._get_user_from_macaroon(serialized)
user = user_info["user"]
is_guest = user_info["is_guest"]
self.assertEqual(UserID.from_string(user_id), user)
self.assertTrue(is_guest)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_user_from_macaroon_user_db_mismatch(self): def test_get_user_from_macaroon_user_db_mismatch(self):
self.store.get_user_by_access_token = Mock( self.store.get_user_by_access_token = Mock(

View file

@ -650,9 +650,30 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
{"presence": ONLINE} {"presence": ONLINE}
) )
# Apple sees self-reflection even without room_id
(events, _) = yield self.event_source.get_new_events(
user=self.u_apple,
from_key=0,
)
self.assertEquals(self.event_source.get_current_key(), 1)
self.assertEquals(events,
[
{"type": "m.presence",
"content": {
"user_id": "@apple:test",
"presence": ONLINE,
"last_active_ago": 0,
}},
],
msg="Presence event should be visible to self-reflection"
)
# Apple sees self-reflection # Apple sees self-reflection
(events, _) = yield self.event_source.get_new_events_for_user( (events, _) = yield self.event_source.get_new_events(
self.u_apple, 0, None user=self.u_apple,
from_key=0,
room_ids=[self.room_id],
) )
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
@ -684,8 +705,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
) )
# Banana sees it because of presence subscription # Banana sees it because of presence subscription
(events, _) = yield self.event_source.get_new_events_for_user( (events, _) = yield self.event_source.get_new_events(
self.u_banana, 0, None user=self.u_banana,
from_key=0,
room_ids=[self.room_id],
) )
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
@ -702,8 +725,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
) )
# Elderberry sees it because of same room # Elderberry sees it because of same room
(events, _) = yield self.event_source.get_new_events_for_user( (events, _) = yield self.event_source.get_new_events(
self.u_elderberry, 0, None user=self.u_elderberry,
from_key=0,
room_ids=[self.room_id],
) )
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
@ -720,8 +745,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
) )
# Durian is not in the room, should not see this event # Durian is not in the room, should not see this event
(events, _) = yield self.event_source.get_new_events_for_user( (events, _) = yield self.event_source.get_new_events(
self.u_durian, 0, None user=self.u_durian,
from_key=0,
room_ids=[],
) )
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
@ -767,8 +794,9 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
"accepted": True}, "accepted": True},
], presence) ], presence)
(events, _) = yield self.event_source.get_new_events_for_user( (events, _) = yield self.event_source.get_new_events(
self.u_apple, 1, None user=self.u_apple,
from_key=1,
) )
self.assertEquals(self.event_source.get_current_key(), 2) self.assertEquals(self.event_source.get_current_key(), 2)
@ -858,8 +886,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
) )
) )
(events, _) = yield self.event_source.get_new_events_for_user( (events, _) = yield self.event_source.get_new_events(
self.u_apple, 0, None user=self.u_apple,
from_key=0,
room_ids=[self.room_id],
) )
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
@ -905,8 +935,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
(events, _) = yield self.event_source.get_new_events_for_user( (events, _) = yield self.event_source.get_new_events(
self.u_apple, 0, None user=self.u_apple,
from_key=0,
room_ids=[self.room_id,]
) )
self.assertEquals(events, self.assertEquals(events,
[ [
@ -932,8 +964,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
self.assertEquals(self.event_source.get_current_key(), 2) self.assertEquals(self.event_source.get_current_key(), 2)
(events, _) = yield self.event_source.get_new_events_for_user( (events, _) = yield self.event_source.get_new_events(
self.u_apple, 0, None user=self.u_apple,
from_key=0,
room_ids=[self.room_id,]
) )
self.assertEquals(events, self.assertEquals(events,
[ [
@ -966,8 +1000,9 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
self.room_members.append(self.u_clementine) self.room_members.append(self.u_clementine)
(events, _) = yield self.event_source.get_new_events_for_user( (events, _) = yield self.event_source.get_new_events(
self.u_apple, 0, None user=self.u_apple,
from_key=0,
) )
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)

View file

@ -187,7 +187,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
]) ])
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) events = yield self.event_source.get_new_events(
room_ids=[self.room_id],
from_key=0,
)
self.assertEquals( self.assertEquals(
events[0], events[0],
[ [
@ -250,7 +253,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
]) ])
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) events = yield self.event_source.get_new_events(
room_ids=[self.room_id],
from_key=0
)
self.assertEquals( self.assertEquals(
events[0], events[0],
[ [
@ -306,7 +312,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
yield put_json.await_calls() yield put_json.await_calls()
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) events = yield self.event_source.get_new_events(
room_ids=[self.room_id],
from_key=0,
)
self.assertEquals( self.assertEquals(
events[0], events[0],
[ [
@ -337,7 +346,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.on_new_event.reset_mock() self.on_new_event.reset_mock()
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) events = yield self.event_source.get_new_events(
room_ids=[self.room_id],
from_key=0,
)
self.assertEquals( self.assertEquals(
events[0], events[0],
[ [
@ -356,7 +368,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
]) ])
self.assertEquals(self.event_source.get_current_key(), 2) self.assertEquals(self.event_source.get_current_key(), 2)
events = yield self.event_source.get_new_events_for_user(self.u_apple, 1, None) events = yield self.event_source.get_new_events(
room_ids=[self.room_id],
from_key=1,
)
self.assertEquals( self.assertEquals(
events[0], events[0],
[ [
@ -383,7 +398,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.on_new_event.reset_mock() self.on_new_event.reset_mock()
self.assertEquals(self.event_source.get_current_key(), 3) self.assertEquals(self.event_source.get_current_key(), 3)
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) events = yield self.event_source.get_new_events(
room_ids=[self.room_id],
from_key=0,
)
self.assertEquals( self.assertEquals(
events[0], events[0],
[ [

View file

@ -47,7 +47,14 @@ class NullSource(object):
def __init__(self, hs): def __init__(self, hs):
pass pass
def get_new_events_for_user(self, user, from_key, limit): def get_new_events(
self,
user,
from_key,
room_ids=None,
limit=None,
is_guest=None
):
return defer.succeed(([], from_key)) return defer.succeed(([], from_key))
def get_current_key(self, direction='f'): def get_current_key(self, direction='f'):
@ -86,10 +93,11 @@ class PresenceStateTestCase(unittest.TestCase):
return defer.succeed([]) return defer.succeed([])
self.datastore.get_presence_list = get_presence_list self.datastore.get_presence_list = get_presence_list
def _get_user_by_access_token(token=None): def _get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(myid), "user": UserID.from_string(myid),
"token_id": 1, "token_id": 1,
"is_guest": False,
} }
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
@ -173,10 +181,11 @@ class PresenceListTestCase(unittest.TestCase):
) )
self.datastore.has_presence_state = has_presence_state self.datastore.has_presence_state = has_presence_state
def _get_user_by_access_token(token=None): def _get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(myid), "user": UserID.from_string(myid),
"token_id": 1, "token_id": 1,
"is_guest": False,
} }
hs.handlers.room_member_handler = Mock( hs.handlers.room_member_handler = Mock(
@ -291,8 +300,8 @@ class PresenceEventStreamTestCase(unittest.TestCase):
hs.get_clock().time_msec.return_value = 1000000 hs.get_clock().time_msec.return_value = 1000000
def _get_user_by_req(req=None): def _get_user_by_req(req=None, allow_guest=False):
return (UserID.from_string(myid), "") return (UserID.from_string(myid), "", False)
hs.get_v1auth().get_user_by_req = _get_user_by_req hs.get_v1auth().get_user_by_req = _get_user_by_req

View file

@ -52,8 +52,8 @@ class ProfileTestCase(unittest.TestCase):
replication_layer=Mock(), replication_layer=Mock(),
) )
def _get_user_by_req(request=None): def _get_user_by_req(request=None, allow_guest=False):
return (UserID.from_string(myid), "") return (UserID.from_string(myid), "", False)
hs.get_v1auth().get_user_by_req = _get_user_by_req hs.get_v1auth().get_user_by_req = _get_user_by_req

View file

@ -54,10 +54,11 @@ class RoomPermissionsTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_access_token(token=None): def _get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
"is_guest": False,
} }
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
@ -439,10 +440,11 @@ class RoomsMemberListTestCase(RestTestCase):
self.auth_user_id = self.user_id self.auth_user_id = self.user_id
def _get_user_by_access_token(token=None): def _get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
"is_guest": False,
} }
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
@ -517,10 +519,11 @@ class RoomsCreateTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_access_token(token=None): def _get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
"is_guest": False,
} }
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
@ -608,10 +611,11 @@ class RoomTopicTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_access_token(token=None): def _get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
"is_guest": False,
} }
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
@ -713,10 +717,11 @@ class RoomMemberStateTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_access_token(token=None): def _get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
"is_guest": False,
} }
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
@ -838,10 +843,11 @@ class RoomMessagesTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_access_token(token=None): def _get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
"is_guest": False,
} }
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
@ -933,10 +939,11 @@ class RoomInitialSyncTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_access_token(token=None): def _get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
"is_guest": False,
} }
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token

View file

@ -61,10 +61,11 @@ class RoomTypingTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_access_token(token=None): def _get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
"is_guest": False,
} }
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
@ -115,7 +116,10 @@ class RoomTypingTestCase(RestTestCase):
self.assertEquals(200, code) self.assertEquals(200, code)
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
events = yield self.event_source.get_new_events_for_user(self.user, 0, None) events = yield self.event_source.get_new_events(
from_key=0,
room_ids=[self.room_id],
)
self.assertEquals( self.assertEquals(
events[0], events[0],
[ [

View file

@ -43,10 +43,11 @@ class V2AlphaRestTestCase(unittest.TestCase):
resource_for_federation=self.mock_resource, resource_for_federation=self.mock_resource,
) )
def _get_user_by_access_token(token=None): def _get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.USER_ID), "user": UserID.from_string(self.USER_ID),
"token_id": 1, "token_id": 1,
"is_guest": False,
} }
hs.get_auth()._get_user_by_access_token = _get_user_by_access_token hs.get_auth()._get_user_by_access_token = _get_user_by_access_token

View file

@ -120,7 +120,6 @@ class RedactionTestCase(unittest.TestCase):
self.u_alice.to_string(), self.u_alice.to_string(),
start, start,
end, end,
None, # Is currently ignored
) )
self.assertEqual(1, len(results)) self.assertEqual(1, len(results))
@ -149,7 +148,6 @@ class RedactionTestCase(unittest.TestCase):
self.u_alice.to_string(), self.u_alice.to_string(),
start, start,
end, end,
None, # Is currently ignored
) )
self.assertEqual(1, len(results)) self.assertEqual(1, len(results))
@ -199,7 +197,6 @@ class RedactionTestCase(unittest.TestCase):
self.u_alice.to_string(), self.u_alice.to_string(),
start, start,
end, end,
None, # Is currently ignored
) )
self.assertEqual(1, len(results)) self.assertEqual(1, len(results))
@ -228,7 +225,6 @@ class RedactionTestCase(unittest.TestCase):
self.u_alice.to_string(), self.u_alice.to_string(),
start, start,
end, end,
None, # Is currently ignored
) )
self.assertEqual(1, len(results)) self.assertEqual(1, len(results))

View file

@ -68,7 +68,6 @@ class StreamStoreTestCase(unittest.TestCase):
self.u_bob.to_string(), self.u_bob.to_string(),
start, start,
end, end,
None, # Is currently ignored
) )
self.assertEqual(1, len(results)) self.assertEqual(1, len(results))
@ -105,7 +104,6 @@ class StreamStoreTestCase(unittest.TestCase):
self.u_alice.to_string(), self.u_alice.to_string(),
start, start,
end, end,
None, # Is currently ignored
) )
self.assertEqual(1, len(results)) self.assertEqual(1, len(results))
@ -147,7 +145,6 @@ class StreamStoreTestCase(unittest.TestCase):
self.u_bob.to_string(), self.u_bob.to_string(),
start, start,
end, end,
None, # Is currently ignored
) )
# We should not get the message, as it happened *after* bob left. # We should not get the message, as it happened *after* bob left.
@ -175,7 +172,6 @@ class StreamStoreTestCase(unittest.TestCase):
self.u_bob.to_string(), self.u_bob.to_string(),
start, start,
end, end,
None, # Is currently ignored
) )
# We should not get the message, as it happened *after* bob left. # We should not get the message, as it happened *after* bob left.

View file

@ -1,108 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from tests import unittest
from synapse.util.lockutils import LockManager
class LockManagerTestCase(unittest.TestCase):
def setUp(self):
self.lock_manager = LockManager()
@defer.inlineCallbacks
def test_one_lock(self):
key = "test"
deferred_lock1 = self.lock_manager.lock(key)
self.assertTrue(deferred_lock1.called)
lock1 = yield deferred_lock1
self.assertFalse(lock1.released)
lock1.release()
self.assertTrue(lock1.released)
@defer.inlineCallbacks
def test_concurrent_locks(self):
key = "test"
deferred_lock1 = self.lock_manager.lock(key)
deferred_lock2 = self.lock_manager.lock(key)
self.assertTrue(deferred_lock1.called)
self.assertFalse(deferred_lock2.called)
lock1 = yield deferred_lock1
self.assertFalse(lock1.released)
self.assertFalse(deferred_lock2.called)
lock1.release()
self.assertTrue(lock1.released)
self.assertTrue(deferred_lock2.called)
lock2 = yield deferred_lock2
lock2.release()
@defer.inlineCallbacks
def test_sequential_locks(self):
key = "test"
deferred_lock1 = self.lock_manager.lock(key)
self.assertTrue(deferred_lock1.called)
lock1 = yield deferred_lock1
self.assertFalse(lock1.released)
lock1.release()
self.assertTrue(lock1.released)
deferred_lock2 = self.lock_manager.lock(key)
self.assertTrue(deferred_lock2.called)
lock2 = yield deferred_lock2
self.assertFalse(lock2.released)
lock2.release()
self.assertTrue(lock2.released)
@defer.inlineCallbacks
def test_with_statement(self):
key = "test"
with (yield self.lock_manager.lock(key)) as lock:
self.assertFalse(lock.released)
self.assertTrue(lock.released)
@defer.inlineCallbacks
def test_two_with_statement(self):
key = "test"
with (yield self.lock_manager.lock(key)):
pass
with (yield self.lock_manager.lock(key)):
pass

View file

@ -335,7 +335,7 @@ class MemoryDataStore(object):
] ]
def get_room_events_stream(self, user_id=None, from_key=None, to_key=None, def get_room_events_stream(self, user_id=None, from_key=None, to_key=None,
room_id=None, limit=0, with_feedback=False): limit=0, with_feedback=False):
return ([], from_key) # TODO return ([], from_key) # TODO
def get_joined_hosts_for_room(self, room_id): def get_joined_hosts_for_room(self, room_id):