0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-17 07:21:37 +01:00

Merge branch 'develop' into paul/tiny-fixes

This commit is contained in:
Paul "LeoNerd" Evans 2015-11-13 17:26:59 +00:00
commit 4fbe6ca401
100 changed files with 3170 additions and 1036 deletions

View file

@ -1,3 +1,19 @@
Changes in synapse v0.11.0-rc1 (2015-11-11)
===========================================
* Add Search API (PR #307, #324, #327, #336, #350, #359)
* Add 'archived' state to v2 /sync API (PR #316)
* Add ability to reject invites (PR #317)
* Add config option to disable password login (PR #322)
* Add the login fallback API (PR #330)
* Add room context API (PR #334)
* Add room tagging support (PR #335)
* Update v2 /sync API to match spec (PR #305, #316, #321, #332, #337, #341)
* Change retry schedule for application services (PR #320)
* Change retry schedule for remote servers (PR #340)
* Fix bug where we hosted static content in the incorrect place (PR #329)
* Fix bug where we didn't increment retry interval for remote servers (PR #343)
Changes in synapse v0.10.1-rc1 (2015-10-15) Changes in synapse v0.10.1-rc1 (2015-10-15)
=========================================== ===========================================

View file

@ -15,8 +15,9 @@ recursive-include scripts *
recursive-include scripts-dev * recursive-include scripts-dev *
recursive-include tests *.py recursive-include tests *.py
recursive-include static *.css recursive-include synapse/static *.css
recursive-include static *.html recursive-include synapse/static *.gif
recursive-include static *.js recursive-include synapse/static *.html
recursive-include synapse/static *.js
prune demo/etc prune demo/etc

View file

@ -20,8 +20,8 @@ The overall architecture is::
https://somewhere.org/_matrix https://elsewhere.net/_matrix https://somewhere.org/_matrix https://elsewhere.net/_matrix
``#matrix:matrix.org`` is the official support room for Matrix, and can be ``#matrix:matrix.org`` is the official support room for Matrix, and can be
accessed by the web client at http://matrix.org/beta or via an IRC bridge at accessed by any client from https://matrix.org/blog/try-matrix-now or via IRC
irc://irc.freenode.net/matrix. bridge at irc://irc.freenode.net/matrix.
Synapse is currently in rapid development, but as of version 0.5 we believe it Synapse is currently in rapid development, but as of version 0.5 we believe it
is sufficiently stable to be run as an internet-facing service for real usage! is sufficiently stable to be run as an internet-facing service for real usage!
@ -77,14 +77,14 @@ Meanwhile, iOS and Android SDKs and clients are available from:
- https://github.com/matrix-org/matrix-android-sdk - https://github.com/matrix-org/matrix-android-sdk
We'd like to invite you to join #matrix:matrix.org (via We'd like to invite you to join #matrix:matrix.org (via
https://matrix.org/beta), run a homeserver, take a look at the Matrix spec at https://matrix.org/blog/try-matrix-now), run a homeserver, take a look at the
https://matrix.org/docs/spec and API docs at https://matrix.org/docs/api, Matrix spec at https://matrix.org/docs/spec and API docs at
experiment with the APIs and the demo clients, and report any bugs via https://matrix.org/docs/api, experiment with the APIs and the demo clients, and
https://matrix.org/jira. report any bugs via https://matrix.org/jira.
Thanks for using Matrix! Thanks for using Matrix!
[1] End-to-end encryption is currently in development [1] End-to-end encryption is currently in development - see https://matrix.org/git/olm
Synapse Installation Synapse Installation
==================== ====================

View file

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.10.1-rc1" __version__ = "0.11.0-rc1"

View file

@ -24,7 +24,6 @@ from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import RoomID, UserID, EventID from synapse.types import RoomID, UserID, EventID
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util import third_party_invites
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
import logging import logging
@ -49,6 +48,7 @@ class Auth(object):
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
self._KNOWN_CAVEAT_PREFIXES = set([ self._KNOWN_CAVEAT_PREFIXES = set([
"gen = ", "gen = ",
"guest = ",
"type = ", "type = ",
"time < ", "time < ",
"user_id = ", "user_id = ",
@ -183,15 +183,11 @@ class Auth(object):
defer.returnValue(member) defer.returnValue(member)
@defer.inlineCallbacks @defer.inlineCallbacks
def check_user_was_in_room(self, room_id, user_id, current_state=None): def check_user_was_in_room(self, room_id, user_id):
"""Check if the user was in the room at some point. """Check if the user was in the room at some point.
Args: Args:
room_id(str): The room to check. room_id(str): The room to check.
user_id(str): The user to check. user_id(str): The user to check.
current_state(dict): Optional map of the current state of the room.
If provided then that map is used to check whether they are a
member of the room. Otherwise the current membership is
loaded from the database.
Raises: Raises:
AuthError if the user was never in the room. AuthError if the user was never in the room.
Returns: Returns:
@ -199,12 +195,6 @@ class Auth(object):
room. This will be the join event if they are currently joined to room. This will be the join event if they are currently joined to
the room. This will be the leave event if they have left the room. the room. This will be the leave event if they have left the room.
""" """
if current_state:
member = current_state.get(
(EventTypes.Member, user_id),
None
)
else:
member = yield self.state.get_current_state( member = yield self.state.get_current_state(
room_id=room_id, room_id=room_id,
event_type=EventTypes.Member, event_type=EventTypes.Member,
@ -327,6 +317,11 @@ class Auth(object):
} }
) )
if Membership.INVITE == membership and "third_party_invite" in event.content:
if not self._verify_third_party_invite(event, auth_events):
raise AuthError(403, "You are not invited to this room.")
return True
if Membership.JOIN != membership: if Membership.JOIN != membership:
if (caller_invited if (caller_invited
and Membership.LEAVE == membership and Membership.LEAVE == membership
@ -370,7 +365,6 @@ class Auth(object):
pass pass
elif join_rule == JoinRules.INVITE: elif join_rule == JoinRules.INVITE:
if not caller_in_room and not caller_invited: if not caller_in_room and not caller_invited:
if not self._verify_third_party_invite(event, auth_events):
raise AuthError(403, "You are not invited to this room.") raise AuthError(403, "You are not invited to this room.")
else: else:
# TODO (erikj): may_join list # TODO (erikj): may_join list
@ -399,10 +393,10 @@ class Auth(object):
def _verify_third_party_invite(self, event, auth_events): def _verify_third_party_invite(self, event, auth_events):
""" """
Validates that the join event is authorized by a previous third-party invite. Validates that the invite event is authorized by a previous third-party invite.
Checks that the public key, and keyserver, match those in the invite, Checks that the public key, and keyserver, match those in the third party invite,
and that the join event has a signature issued using that public key. and that the invite event has a signature issued using that public key.
Args: Args:
event: The m.room.member join event being validated. event: The m.room.member join event being validated.
@ -413,35 +407,28 @@ class Auth(object):
True if the event fulfills the expectations of a previous third party True if the event fulfills the expectations of a previous third party
invite event. invite event.
""" """
if not third_party_invites.join_has_third_party_invite(event.content): if "third_party_invite" not in event.content:
return False return False
join_third_party_invite = event.content["third_party_invite"] if "signed" not in event.content["third_party_invite"]:
token = join_third_party_invite["token"] return False
signed = event.content["third_party_invite"]["signed"]
for key in {"mxid", "token"}:
if key not in signed:
return False
token = signed["token"]
invite_event = auth_events.get( invite_event = auth_events.get(
(EventTypes.ThirdPartyInvite, token,) (EventTypes.ThirdPartyInvite, token,)
) )
if not invite_event: if not invite_event:
logger.info("Failing 3pid invite because no invite found for token %s", token) return False
if event.user_id != invite_event.user_id:
return False return False
try: try:
public_key = join_third_party_invite["public_key"] public_key = invite_event.content["public_key"]
key_validity_url = join_third_party_invite["key_validity_url"] if signed["mxid"] != event.state_key:
if invite_event.content["public_key"] != public_key:
logger.info(
"Failing 3pid invite because public key invite: %s != join: %s",
invite_event.content["public_key"],
public_key
)
return False
if invite_event.content["key_validity_url"] != key_validity_url:
logger.info(
"Failing 3pid invite because key_validity_url invite: %s != join: %s",
invite_event.content["key_validity_url"],
key_validity_url
)
return False
signed = join_third_party_invite["signed"]
if signed["mxid"] != event.user_id:
return False return False
if signed["token"] != token: if signed["token"] != token:
return False return False
@ -454,6 +441,11 @@ class Auth(object):
decode_base64(public_key) decode_base64(public_key)
) )
verify_signed_json(signed, server, verify_key) verify_signed_json(signed, server, verify_key)
# We got the public key from the invite, so we know that the
# correct server signed the signed bundle.
# The caller is responsible for checking that the signing
# server has not revoked that public key.
return True return True
return False return False
except (KeyError, SignatureVerifyException,): except (KeyError, SignatureVerifyException,):
@ -497,7 +489,7 @@ class Auth(object):
return default return default
@defer.inlineCallbacks @defer.inlineCallbacks
def get_user_by_req(self, request): def get_user_by_req(self, request, allow_guest=False):
""" Get a registered user's ID. """ Get a registered user's ID.
Args: Args:
@ -535,7 +527,7 @@ class Auth(object):
request.authenticated_entity = user_id request.authenticated_entity = user_id
defer.returnValue((UserID.from_string(user_id), "")) defer.returnValue((UserID.from_string(user_id), "", False))
return return
except KeyError: except KeyError:
pass # normal users won't have the user_id query parameter set. pass # normal users won't have the user_id query parameter set.
@ -543,6 +535,7 @@ class Auth(object):
user_info = yield self._get_user_by_access_token(access_token) user_info = yield self._get_user_by_access_token(access_token)
user = user_info["user"] user = user_info["user"]
token_id = user_info["token_id"] token_id = user_info["token_id"]
is_guest = user_info["is_guest"]
ip_addr = self.hs.get_ip_from_request(request) ip_addr = self.hs.get_ip_from_request(request)
user_agent = request.requestHeaders.getRawHeaders( user_agent = request.requestHeaders.getRawHeaders(
@ -557,9 +550,14 @@ class Auth(object):
user_agent=user_agent user_agent=user_agent
) )
if is_guest and not allow_guest:
raise AuthError(
403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
)
request.authenticated_entity = user.to_string() request.authenticated_entity = user.to_string()
defer.returnValue((user, token_id,)) defer.returnValue((user, token_id, is_guest,))
except KeyError: except KeyError:
raise AuthError( raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
@ -592,9 +590,27 @@ class Auth(object):
self._validate_macaroon(macaroon) self._validate_macaroon(macaroon)
user_prefix = "user_id = " user_prefix = "user_id = "
user = None
guest = False
for caveat in macaroon.caveats: for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix): if caveat.caveat_id.startswith(user_prefix):
user = UserID.from_string(caveat.caveat_id[len(user_prefix):]) user = UserID.from_string(caveat.caveat_id[len(user_prefix):])
elif caveat.caveat_id == "guest = true":
guest = True
if user is None:
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
errcode=Codes.UNKNOWN_TOKEN
)
if guest:
ret = {
"user": user,
"is_guest": True,
"token_id": None,
}
else:
# This codepath exists so that we can actually return a # This codepath exists so that we can actually return a
# token ID, because we use token IDs in place of device # token ID, because we use token IDs in place of device
# identifiers throughout the codebase. # identifiers throughout the codebase.
@ -613,10 +629,6 @@ class Auth(object):
errcode=Codes.UNKNOWN_TOKEN errcode=Codes.UNKNOWN_TOKEN
) )
defer.returnValue(ret) defer.returnValue(ret)
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
errcode=Codes.UNKNOWN_TOKEN
)
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
raise AuthError( raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.", self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.",
@ -629,6 +641,7 @@ class Auth(object):
v.satisfy_exact("type = access") v.satisfy_exact("type = access")
v.satisfy_general(lambda c: c.startswith("user_id = ")) v.satisfy_general(lambda c: c.startswith("user_id = "))
v.satisfy_general(self._verify_expiry) v.satisfy_general(self._verify_expiry)
v.satisfy_exact("guest = true")
v.verify(macaroon, self.hs.config.macaroon_secret_key) v.verify(macaroon, self.hs.config.macaroon_secret_key)
v = pymacaroons.Verifier() v = pymacaroons.Verifier()
@ -666,6 +679,7 @@ class Auth(object):
user_info = { user_info = {
"user": UserID.from_string(ret.get("name")), "user": UserID.from_string(ret.get("name")),
"token_id": ret.get("token_id", None), "token_id": ret.get("token_id", None),
"is_guest": False,
} }
defer.returnValue(user_info) defer.returnValue(user_info)
@ -738,17 +752,19 @@ class Auth(object):
if e_type == Membership.JOIN: if e_type == Membership.JOIN:
if member_event and not is_public: if member_event and not is_public:
auth_ids.append(member_event.event_id) auth_ids.append(member_event.event_id)
if third_party_invites.join_has_third_party_invite(event.content): else:
if member_event:
auth_ids.append(member_event.event_id)
if e_type == Membership.INVITE:
if "third_party_invite" in event.content:
key = ( key = (
EventTypes.ThirdPartyInvite, EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["token"] event.content["third_party_invite"]["token"]
) )
invite = current_state.get(key) third_party_invite = current_state.get(key)
if invite: if third_party_invite:
auth_ids.append(invite.event_id) auth_ids.append(third_party_invite.event_id)
else:
if member_event:
auth_ids.append(member_event.event_id)
elif member_event: elif member_event:
if member_event.content["membership"] == Membership.JOIN: if member_event.content["membership"] == Membership.JOIN:
auth_ids.append(member_event.event_id) auth_ids.append(member_event.event_id)

View file

@ -68,6 +68,7 @@ class EventTypes(object):
RoomHistoryVisibility = "m.room.history_visibility" RoomHistoryVisibility = "m.room.history_visibility"
CanonicalAlias = "m.room.canonical_alias" CanonicalAlias = "m.room.canonical_alias"
RoomAvatar = "m.room.avatar" RoomAvatar = "m.room.avatar"
GuestAccess = "m.room.guest_access"
# These are used for validation # These are used for validation
Message = "m.room.message" Message = "m.room.message"

View file

@ -33,6 +33,7 @@ class Codes(object):
NOT_FOUND = "M_NOT_FOUND" NOT_FOUND = "M_NOT_FOUND"
MISSING_TOKEN = "M_MISSING_TOKEN" MISSING_TOKEN = "M_MISSING_TOKEN"
UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN" UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN"
GUEST_ACCESS_FORBIDDEN = "M_GUEST_ACCESS_FORBIDDEN"
LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED" LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED"
CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED" CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED"
CAPTCHA_INVALID = "M_CAPTCHA_INVALID" CAPTCHA_INVALID = "M_CAPTCHA_INVALID"

View file

@ -50,11 +50,11 @@ class Filtering(object):
# many definitions. # many definitions.
top_level_definitions = [ top_level_definitions = [
"public_user_data", "private_user_data", "server_data" "presence"
] ]
room_level_definitions = [ room_level_definitions = [
"state", "timeline", "ephemeral" "state", "timeline", "ephemeral", "private_user_data"
] ]
for key in top_level_definitions: for key in top_level_definitions:
@ -114,22 +114,6 @@ class Filtering(object):
if not isinstance(event_type, basestring): if not isinstance(event_type, basestring):
raise SynapseError(400, "Event type should be a string") raise SynapseError(400, "Event type should be a string")
if "format" in definition:
event_format = definition["format"]
if event_format not in ["federation", "events"]:
raise SynapseError(400, "Invalid format: %s" % (event_format,))
if "select" in definition:
event_select_list = definition["select"]
for select_key in event_select_list:
if select_key not in ["event_id", "origin_server_ts",
"thread_id", "content", "content.body"]:
raise SynapseError(400, "Bad select: %s" % (select_key,))
if ("bundle_updates" in definition and
type(definition["bundle_updates"]) != bool):
raise SynapseError(400, "Bad bundle_updates: expected bool.")
class FilterCollection(object): class FilterCollection(object):
def __init__(self, filter_json): def __init__(self, filter_json):
@ -147,6 +131,10 @@ class FilterCollection(object):
self.filter_json.get("room", {}).get("ephemeral", {}) self.filter_json.get("room", {}).get("ephemeral", {})
) )
self.room_private_user_data = Filter(
self.filter_json.get("room", {}).get("private_user_data", {})
)
self.presence_filter = Filter( self.presence_filter = Filter(
self.filter_json.get("presence", {}) self.filter_json.get("presence", {})
) )
@ -172,6 +160,9 @@ class FilterCollection(object):
def filter_room_ephemeral(self, events): def filter_room_ephemeral(self, events):
return self.room_ephemeral_filter.filter(events) return self.room_ephemeral_filter.filter(events)
def filter_room_private_user_data(self, events):
return self.room_private_user_data.filter(events)
class Filter(object): class Filter(object):
def __init__(self, filter_json): def __init__(self, filter_json):

View file

@ -132,7 +132,9 @@ class SynapseHomeServer(HomeServer):
def build_resource_for_static_content(self): def build_resource_for_static_content(self):
# This is old and should go away: not going to bother adding gzip # This is old and should go away: not going to bother adding gzip
return File("static") return File(
os.path.join(os.path.dirname(synapse.__file__), "static")
)
def build_resource_for_content_repo(self): def build_resource_for_content_repo(self):
return ContentRepoResource( return ContentRepoResource(
@ -437,6 +439,7 @@ def setup(config_options):
hs.get_pusherpool().start() hs.get_pusherpool().start()
hs.get_state_handler().start_caching() hs.get_state_handler().start_caching()
hs.get_datastore().start_profiling() hs.get_datastore().start_profiling()
hs.get_datastore().start_doing_background_updates()
hs.get_replication_layer().start_get_pdu_cache() hs.get_replication_layer().start_get_pdu_cache()
return hs return hs

View file

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import errno
import os import os
import yaml import yaml
import sys import sys
@ -91,8 +92,11 @@ class Config(object):
@classmethod @classmethod
def ensure_directory(cls, dir_path): def ensure_directory(cls, dir_path):
dir_path = cls.abspath(dir_path) dir_path = cls.abspath(dir_path)
if not os.path.exists(dir_path): try:
os.makedirs(dir_path) os.makedirs(dir_path)
except OSError, e:
if e.errno != errno.EEXIST:
raise
if not os.path.isdir(dir_path): if not os.path.isdir(dir_path):
raise ConfigError( raise ConfigError(
"%s is not a directory" % (dir_path,) "%s is not a directory" % (dir_path,)

View file

@ -34,6 +34,7 @@ class RegistrationConfig(Config):
self.registration_shared_secret = config.get("registration_shared_secret") self.registration_shared_secret = config.get("registration_shared_secret")
self.macaroon_secret_key = config.get("macaroon_secret_key") self.macaroon_secret_key = config.get("macaroon_secret_key")
self.bcrypt_rounds = config.get("bcrypt_rounds", 12) self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
self.allow_guest_access = config.get("allow_guest_access", False)
def default_config(self, **kwargs): def default_config(self, **kwargs):
registration_shared_secret = random_string_with_symbols(50) registration_shared_secret = random_string_with_symbols(50)
@ -54,6 +55,11 @@ class RegistrationConfig(Config):
# Larger numbers increase the work factor needed to generate the hash. # Larger numbers increase the work factor needed to generate the hash.
# The default number of rounds is 12. # The default number of rounds is 12.
bcrypt_rounds: 12 bcrypt_rounds: 12
# Allows users to register as guests without a password/email/etc, and
# participate in rooms hosted on this server which have been made
# accessible to anonymous users.
allow_guest_access: False
""" % locals() """ % locals()
def add_arguments(self, parser): def add_arguments(self, parser):

View file

@ -26,7 +26,6 @@ from synapse.api.errors import (
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util import third_party_invites
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
import synapse.metrics import synapse.metrics
@ -358,7 +357,8 @@ class FederationClient(FederationBase):
defer.returnValue(signed_auth) defer.returnValue(signed_auth)
@defer.inlineCallbacks @defer.inlineCallbacks
def make_membership_event(self, destinations, room_id, user_id, membership, content): def make_membership_event(self, destinations, room_id, user_id, membership,
content={},):
""" """
Creates an m.room.member event, with context, without participating in the room. Creates an m.room.member event, with context, without participating in the room.
@ -390,20 +390,17 @@ class FederationClient(FederationBase):
if destination == self.server_name: if destination == self.server_name:
continue continue
args = {}
if third_party_invites.join_has_third_party_invite(content):
args = third_party_invites.extract_join_keys(
content["third_party_invite"]
)
try: try:
ret = yield self.transport_layer.make_membership_event( ret = yield self.transport_layer.make_membership_event(
destination, room_id, user_id, membership, args destination, room_id, user_id, membership
) )
pdu_dict = ret["event"] pdu_dict = ret["event"]
logger.debug("Got response to make_%s: %s", membership, pdu_dict) logger.debug("Got response to make_%s: %s", membership, pdu_dict)
pdu_dict["content"].update(content)
defer.returnValue( defer.returnValue(
(destination, self.event_from_pdu_json(pdu_dict)) (destination, self.event_from_pdu_json(pdu_dict))
) )
@ -704,3 +701,26 @@ class FederationClient(FederationBase):
event.internal_metadata.outlier = outlier event.internal_metadata.outlier = outlier
return event return event
@defer.inlineCallbacks
def forward_third_party_invite(self, destinations, room_id, event_dict):
for destination in destinations:
if destination == self.server_name:
continue
try:
yield self.transport_layer.exchange_third_party_invite(
destination=destination,
room_id=room_id,
event_dict=event_dict,
)
defer.returnValue(None)
except CodeMessageException:
raise
except Exception as e:
logger.exception(
"Failed to send_third_party_invite via %s: %s",
destination, e.message
)
raise RuntimeError("Failed to send to any server.")

View file

@ -23,12 +23,10 @@ from synapse.util.logutils import log_function
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
import synapse.metrics import synapse.metrics
from synapse.api.errors import FederationError, SynapseError, Codes from synapse.api.errors import FederationError, SynapseError
from synapse.crypto.event_signing import compute_event_signature from synapse.crypto.event_signing import compute_event_signature
from synapse.util import third_party_invites
import simplejson as json import simplejson as json
import logging import logging
@ -230,19 +228,8 @@ class FederationServer(FederationBase):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def on_make_join_request(self, room_id, user_id, query): def on_make_join_request(self, room_id, user_id):
threepid_details = {} pdu = yield self.handler.on_make_join_request(room_id, user_id)
if third_party_invites.has_join_keys(query):
for k in third_party_invites.JOIN_KEYS:
if not isinstance(query[k], list) or len(query[k]) != 1:
raise FederationError(
"FATAL",
Codes.MISSING_PARAM,
"key %s value %s" % (k, query[k],),
None
)
threepid_details[k] = query[k][0]
pdu = yield self.handler.on_make_join_request(room_id, user_id, threepid_details)
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
defer.returnValue({"event": pdu.get_pdu_json(time_now)}) defer.returnValue({"event": pdu.get_pdu_json(time_now)})
@ -556,3 +543,15 @@ class FederationServer(FederationBase):
event.internal_metadata.outlier = outlier event.internal_metadata.outlier = outlier
return event return event
@defer.inlineCallbacks
def exchange_third_party_invite(self, invite):
ret = yield self.handler.exchange_third_party_invite(invite)
defer.returnValue(ret)
@defer.inlineCallbacks
def on_exchange_third_party_invite_request(self, origin, room_id, event_dict):
ret = yield self.handler.on_exchange_third_party_invite_request(
origin, room_id, event_dict
)
defer.returnValue(ret)

View file

@ -202,6 +202,7 @@ class TransactionQueue(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def _attempt_new_transaction(self, destination): def _attempt_new_transaction(self, destination):
# list of (pending_pdu, deferred, order)
if destination in self.pending_transactions: if destination in self.pending_transactions:
# XXX: pending_transactions can get stuck on by a never-ending # XXX: pending_transactions can get stuck on by a never-ending
# request at which point pending_pdus_by_dest just keeps growing. # request at which point pending_pdus_by_dest just keeps growing.
@ -213,9 +214,6 @@ class TransactionQueue(object):
) )
return return
logger.debug("TX [%s] _attempt_new_transaction", destination)
# list of (pending_pdu, deferred, order)
pending_pdus = self.pending_pdus_by_dest.pop(destination, []) pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, []) pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_failures = self.pending_failures_by_dest.pop(destination, []) pending_failures = self.pending_failures_by_dest.pop(destination, [])
@ -228,6 +226,11 @@ class TransactionQueue(object):
logger.debug("TX [%s] Nothing to send", destination) logger.debug("TX [%s] Nothing to send", destination)
return return
try:
self.pending_transactions[destination] = 1
logger.debug("TX [%s] _attempt_new_transaction", destination)
# Sort based on the order field # Sort based on the order field
pending_pdus.sort(key=lambda t: t[2]) pending_pdus.sort(key=lambda t: t[2])
@ -239,9 +242,6 @@ class TransactionQueue(object):
for x in pending_pdus + pending_edus + pending_failures for x in pending_pdus + pending_edus + pending_failures
] ]
try:
self.pending_transactions[destination] = 1
txn_id = str(self._next_txn_id) txn_id = str(self._next_txn_id)
limiter = yield get_retry_limiter( limiter = yield get_retry_limiter(

View file

@ -161,7 +161,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def make_membership_event(self, destination, room_id, user_id, membership, args={}): def make_membership_event(self, destination, room_id, user_id, membership):
valid_memberships = {Membership.JOIN, Membership.LEAVE} valid_memberships = {Membership.JOIN, Membership.LEAVE}
if membership not in valid_memberships: if membership not in valid_memberships:
raise RuntimeError( raise RuntimeError(
@ -173,7 +173,6 @@ class TransportLayerClient(object):
content = yield self.client.get_json( content = yield self.client.get_json(
destination=destination, destination=destination,
path=path, path=path,
args=args,
retry_on_dns_fail=True, retry_on_dns_fail=True,
) )
@ -218,6 +217,19 @@ class TransportLayerClient(object):
defer.returnValue(response) defer.returnValue(response)
@defer.inlineCallbacks
@log_function
def exchange_third_party_invite(self, destination, room_id, event_dict):
path = PREFIX + "/exchange_third_party_invite/%s" % (room_id,)
response = yield self.client.put_json(
destination=destination,
path=path,
data=event_dict,
)
defer.returnValue(response)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_event_auth(self, destination, room_id, event_id): def get_event_auth(self, destination, room_id, event_id):

View file

@ -292,7 +292,7 @@ class FederationMakeJoinServlet(BaseFederationServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, origin, content, query, context, user_id): def on_GET(self, origin, content, query, context, user_id):
content = yield self.handler.on_make_join_request(context, user_id, query) content = yield self.handler.on_make_join_request(context, user_id)
defer.returnValue((200, content)) defer.returnValue((200, content))
@ -343,6 +343,17 @@ class FederationInviteServlet(BaseFederationServlet):
defer.returnValue((200, content)) defer.returnValue((200, content))
class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
PATH = "/exchange_third_party_invite/([^/]*)"
@defer.inlineCallbacks
def on_PUT(self, origin, content, query, room_id):
content = yield self.handler.on_exchange_third_party_invite_request(
origin, room_id, content
)
defer.returnValue((200, content))
class FederationClientKeysQueryServlet(BaseFederationServlet): class FederationClientKeysQueryServlet(BaseFederationServlet):
PATH = "/user/keys/query" PATH = "/user/keys/query"
@ -396,6 +407,30 @@ class FederationGetMissingEventsServlet(BaseFederationServlet):
defer.returnValue((200, content)) defer.returnValue((200, content))
class On3pidBindServlet(BaseFederationServlet):
PATH = "/3pid/onbind"
@defer.inlineCallbacks
def on_POST(self, request):
content_bytes = request.content.read()
content = json.loads(content_bytes)
if "invites" in content:
last_exception = None
for invite in content["invites"]:
try:
yield self.handler.exchange_third_party_invite(invite)
except Exception as e:
last_exception = e
if last_exception:
raise last_exception
defer.returnValue((200, {}))
# Avoid doing remote HS authorization checks which are done by default by
# BaseFederationServlet.
def _wrap(self, code):
return code
SERVLET_CLASSES = ( SERVLET_CLASSES = (
FederationPullServlet, FederationPullServlet,
FederationEventServlet, FederationEventServlet,
@ -413,4 +448,6 @@ SERVLET_CLASSES = (
FederationEventAuthServlet, FederationEventAuthServlet,
FederationClientKeysQueryServlet, FederationClientKeysQueryServlet,
FederationClientKeysClaimServlet, FederationClientKeysClaimServlet,
FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet,
) )

View file

@ -21,7 +21,6 @@ from synapse.api.constants import Membership, EventTypes
from synapse.types import UserID, RoomAlias from synapse.types import UserID, RoomAlias
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.util import third_party_invites
import logging import logging
@ -30,6 +29,12 @@ logger = logging.getLogger(__name__)
class BaseHandler(object): class BaseHandler(object):
"""
Common base class for the event handlers.
:type store: synapse.storage.events.StateStore
:type state_handler: synapse.state.StateHandler
"""
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -47,37 +52,24 @@ class BaseHandler(object):
self.event_builder_factory = hs.get_event_builder_factory() self.event_builder_factory = hs.get_event_builder_factory()
@defer.inlineCallbacks @defer.inlineCallbacks
def _filter_events_for_client(self, user_id, events): def _filter_events_for_client(self, user_id, events, is_guest=False,
event_id_to_state = yield self.store.get_state_for_events( require_all_visible_for_guests=True):
frozenset(e.event_id for e in events), # Assumes that user has at some point joined the room if not is_guest.
types=(
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, user_id),
)
)
def allowed(event, state): def allowed(event, membership, visibility):
if event.type == EventTypes.RoomHistoryVisibility: if visibility == "world_readable":
return True return True
membership_ev = state.get((EventTypes.Member, user_id), None) if is_guest:
if membership_ev: return False
membership = membership_ev.membership
else:
membership = Membership.LEAVE
if membership == Membership.JOIN: if membership == Membership.JOIN:
return True return True
history = state.get((EventTypes.RoomHistoryVisibility, ''), None) if event.type == EventTypes.RoomHistoryVisibility:
if history: return not is_guest
visibility = history.content.get("history_visibility", "shared")
else:
visibility = "shared"
if visibility == "public": if visibility == "shared":
return True
elif visibility == "shared":
return True return True
elif visibility == "joined": elif visibility == "joined":
return membership == Membership.JOIN return membership == Membership.JOIN
@ -86,11 +78,46 @@ class BaseHandler(object):
return True return True
defer.returnValue([ event_id_to_state = yield self.store.get_state_for_events(
event frozenset(e.event_id for e in events),
for event in events types=(
if allowed(event, event_id_to_state[event.event_id]) (EventTypes.RoomHistoryVisibility, ""),
]) (EventTypes.Member, user_id),
)
)
events_to_return = []
for event in events:
state = event_id_to_state[event.event_id]
membership_event = state.get((EventTypes.Member, user_id), None)
if membership_event:
membership = membership_event.membership
else:
membership = None
visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
if visibility_event:
visibility = visibility_event.content.get("history_visibility", "shared")
else:
visibility = "shared"
should_include = allowed(event, membership, visibility)
if should_include:
events_to_return.append(event)
if (require_all_visible_for_guests
and is_guest
and len(events_to_return) < len(events)):
# This indicates that some events in the requested range were not
# visible to guest users. To be safe, we reject the entire request,
# so that we don't have to worry about interpreting visibility
# boundaries.
raise AuthError(403, "User %s does not have permission" % (
user_id
))
defer.returnValue(events_to_return)
def ratelimit(self, user_id): def ratelimit(self, user_id):
time_now = self.clock.time() time_now = self.clock.time()
@ -154,6 +181,8 @@ class BaseHandler(object):
if not suppress_auth: if not suppress_auth:
self.auth.check(event, auth_events=context.current_state) self.auth.check(event, auth_events=context.current_state)
yield self.maybe_kick_guest_users(event, context.current_state.values())
if event.type == EventTypes.CanonicalAlias: if event.type == EventTypes.CanonicalAlias:
# Check the alias is acually valid (at this time at least) # Check the alias is acually valid (at this time at least)
room_alias_str = event.content.get("alias", None) room_alias_str = event.content.get("alias", None)
@ -170,16 +199,6 @@ class BaseHandler(object):
) )
) )
if (
event.type == EventTypes.Member and
event.content["membership"] == Membership.JOIN and
third_party_invites.join_has_third_party_invite(event.content)
):
yield third_party_invites.check_key_valid(
self.hs.get_simple_http_client(),
event
)
federation_handler = self.hs.get_handlers().federation_handler federation_handler = self.hs.get_handlers().federation_handler
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
@ -271,3 +290,58 @@ class BaseHandler(object):
federation_handler.handle_new_event( federation_handler.handle_new_event(
event, destinations=destinations, event, destinations=destinations,
) )
@defer.inlineCallbacks
def maybe_kick_guest_users(self, event, current_state):
# Technically this function invalidates current_state by changing it.
# Hopefully this isn't that important to the caller.
if event.type == EventTypes.GuestAccess:
guest_access = event.content.get("guest_access", "forbidden")
if guest_access != "can_join":
yield self.kick_guest_users(current_state)
@defer.inlineCallbacks
def kick_guest_users(self, current_state):
for member_event in current_state:
try:
if member_event.type != EventTypes.Member:
continue
if not self.hs.is_mine(UserID.from_string(member_event.state_key)):
continue
if member_event.content["membership"] not in {
Membership.JOIN,
Membership.INVITE
}:
continue
if (
"kind" not in member_event.content
or member_event.content["kind"] != "guest"
):
continue
# We make the user choose to leave, rather than have the
# event-sender kick them. This is partially because we don't
# need to worry about power levels, and partially because guest
# users are a concept which doesn't hugely work over federation,
# and having homeservers have their own users leave keeps more
# of that decision-making and control local to the guest-having
# homeserver.
message_handler = self.hs.get_handlers().message_handler
yield message_handler.create_and_send_event(
{
"type": EventTypes.Member,
"state_key": member_event.state_key,
"content": {
"membership": Membership.LEAVE,
"kind": "guest"
},
"room_id": member_event.room_id,
"sender": member_event.state_key
},
ratelimit=False,
)
except Exception as e:
logger.warn("Error kicking guest user: %s" % (e,))

View file

@ -372,12 +372,15 @@ class AuthHandler(BaseHandler):
yield self.store.add_refresh_token_to_user(user_id, refresh_token) yield self.store.add_refresh_token_to_user(user_id, refresh_token)
defer.returnValue(refresh_token) defer.returnValue(refresh_token)
def generate_access_token(self, user_id): def generate_access_token(self, user_id, extra_caveats=None):
extra_caveats = extra_caveats or []
macaroon = self._generate_base_macaroon(user_id) macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("type = access")
now = self.hs.get_clock().time_msec() now = self.hs.get_clock().time_msec()
expiry = now + (60 * 60 * 1000) expiry = now + (60 * 60 * 1000)
macaroon.add_first_party_caveat("time < %d" % (expiry,)) macaroon.add_first_party_caveat("time < %d" % (expiry,))
for caveat in extra_caveats:
macaroon.add_first_party_caveat(caveat)
return macaroon.serialize() return macaroon.serialize()
def generate_refresh_token(self, user_id): def generate_refresh_token(self, user_id):

View file

@ -100,7 +100,7 @@ class EventStreamHandler(BaseHandler):
@log_function @log_function
def get_stream(self, auth_user_id, pagin_config, timeout=0, def get_stream(self, auth_user_id, pagin_config, timeout=0,
as_client_event=True, affect_presence=True, as_client_event=True, affect_presence=True,
only_room_events=False): only_room_events=False, room_id=None, is_guest=False):
"""Fetches the events stream for a given user. """Fetches the events stream for a given user.
If `only_room_events` is `True` only room events will be returned. If `only_room_events` is `True` only room events will be returned.
@ -111,17 +111,6 @@ class EventStreamHandler(BaseHandler):
if affect_presence: if affect_presence:
yield self.started_stream(auth_user) yield self.started_stream(auth_user)
rm_handler = self.hs.get_handlers().room_member_handler
app_service = yield self.store.get_app_service_by_user_id(
auth_user.to_string()
)
if app_service:
rooms = yield self.store.get_app_service_rooms(app_service)
room_ids = set(r.room_id for r in rooms)
else:
room_ids = yield rm_handler.get_joined_rooms_for_user(auth_user)
if timeout: if timeout:
# If they've set a timeout set a minimum limit. # If they've set a timeout set a minimum limit.
timeout = max(timeout, 500) timeout = max(timeout, 500)
@ -130,9 +119,15 @@ class EventStreamHandler(BaseHandler):
# thundering herds on restart. # thundering herds on restart.
timeout = random.randint(int(timeout*0.9), int(timeout*1.1)) timeout = random.randint(int(timeout*0.9), int(timeout*1.1))
if is_guest:
yield self.distributor.fire(
"user_joined_room", user=auth_user, room_id=room_id
)
events, tokens = yield self.notifier.get_events_for( events, tokens = yield self.notifier.get_events_for(
auth_user, room_ids, pagin_config, timeout, auth_user, pagin_config, timeout,
only_room_events=only_room_events only_room_events=only_room_events,
is_guest=is_guest, guest_room_id=room_id
) )
time_now = self.clock.time_msec() time_now = self.clock.time_msec()

View file

@ -21,6 +21,7 @@ from synapse.api.errors import (
AuthError, FederationError, StoreError, CodeMessageException, SynapseError, AuthError, FederationError, StoreError, CodeMessageException, SynapseError,
) )
from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.events.validator import EventValidator
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
@ -39,7 +40,6 @@ from twisted.internet import defer
import itertools import itertools
import logging import logging
from synapse.util import third_party_invites
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -58,6 +58,8 @@ class FederationHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(FederationHandler, self).__init__(hs) super(FederationHandler, self).__init__(hs)
self.hs = hs
self.distributor.observe( self.distributor.observe(
"user_joined_room", "user_joined_room",
self._on_user_joined self._on_user_joined
@ -68,12 +70,9 @@ class FederationHandler(BaseHandler):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.replication_layer = hs.get_replication_layer() self.replication_layer = hs.get_replication_layer()
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
# self.auth_handler = gs.get_auth_handler()
self.server_name = hs.hostname self.server_name = hs.hostname
self.keyring = hs.get_keyring() self.keyring = hs.get_keyring()
self.lock_manager = hs.get_room_lock_manager()
self.replication_layer.set_handler(self) self.replication_layer.set_handler(self)
# When joining a room we need to queue any events for that room up # When joining a room we need to queue any events for that room up
@ -586,7 +585,7 @@ class FederationHandler(BaseHandler):
room_id, room_id,
joinee, joinee,
"join", "join",
content content,
) )
self.room_queues[room_id] = [] self.room_queues[room_id] = []
@ -663,16 +662,12 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_make_join_request(self, room_id, user_id, query): def on_make_join_request(self, room_id, user_id):
""" We've received a /make_join/ request, so we create a partial """ We've received a /make_join/ request, so we create a partial
join event for the room and return that. We do *not* persist or join event for the room and return that. We do *not* persist or
process it until the other server has signed it and sent it back. process it until the other server has signed it and sent it back.
""" """
event_content = {"membership": Membership.JOIN} event_content = {"membership": Membership.JOIN}
if third_party_invites.has_join_keys(query):
event_content["third_party_invite"] = (
third_party_invites.extract_join_keys(query)
)
builder = self.event_builder_factory.new({ builder = self.event_builder_factory.new({
"type": EventTypes.Member, "type": EventTypes.Member,
@ -688,9 +683,6 @@ class FederationHandler(BaseHandler):
self.auth.check(event, auth_events=context.current_state) self.auth.check(event, auth_events=context.current_state)
if third_party_invites.join_has_third_party_invite(event.content):
third_party_invites.check_key_valid(self.hs.get_simple_http_client(), event)
defer.returnValue(event) defer.returnValue(event)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -830,8 +822,7 @@ class FederationHandler(BaseHandler):
target_hosts, target_hosts,
room_id, room_id,
user_id, user_id,
"leave", "leave"
{}
) )
signed_event = self._sign_event(event) signed_event = self._sign_event(event)
@ -850,13 +841,14 @@ class FederationHandler(BaseHandler):
defer.returnValue(None) defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def _make_and_verify_event(self, target_hosts, room_id, user_id, membership, content): def _make_and_verify_event(self, target_hosts, room_id, user_id, membership,
content={},):
origin, pdu = yield self.replication_layer.make_membership_event( origin, pdu = yield self.replication_layer.make_membership_event(
target_hosts, target_hosts,
room_id, room_id,
user_id, user_id,
membership, membership,
content content,
) )
logger.debug("Got response to make_%s: %s", membership, pdu) logger.debug("Got response to make_%s: %s", membership, pdu)
@ -1108,8 +1100,6 @@ class FederationHandler(BaseHandler):
context = yield self._prep_event( context = yield self._prep_event(
origin, event, origin, event,
state=state, state=state,
backfilled=backfilled,
current_state=current_state,
auth_events=auth_events, auth_events=auth_events,
) )
@ -1132,7 +1122,6 @@ class FederationHandler(BaseHandler):
origin, origin,
ev_info["event"], ev_info["event"],
state=ev_info.get("state"), state=ev_info.get("state"),
backfilled=backfilled,
auth_events=ev_info.get("auth_events"), auth_events=ev_info.get("auth_events"),
) )
for ev_info in event_infos for ev_info in event_infos
@ -1219,8 +1208,7 @@ class FederationHandler(BaseHandler):
defer.returnValue((event_stream_id, max_stream_id)) defer.returnValue((event_stream_id, max_stream_id))
@defer.inlineCallbacks @defer.inlineCallbacks
def _prep_event(self, origin, event, state=None, backfilled=False, def _prep_event(self, origin, event, state=None, auth_events=None):
current_state=None, auth_events=None):
outlier = event.internal_metadata.is_outlier() outlier = event.internal_metadata.is_outlier()
context = yield self.state_handler.compute_event_context( context = yield self.state_handler.compute_event_context(
@ -1253,6 +1241,10 @@ class FederationHandler(BaseHandler):
context.rejected = RejectedReason.AUTH_ERROR context.rejected = RejectedReason.AUTH_ERROR
if event.type == EventTypes.GuestAccess:
full_context = yield self.store.get_current_state(room_id=event.room_id)
yield self.maybe_kick_guest_users(event, full_context)
defer.returnValue(context) defer.returnValue(context)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -1649,3 +1641,75 @@ class FederationHandler(BaseHandler):
}, },
"missing": [e.event_id for e in missing_locals], "missing": [e.event_id for e in missing_locals],
}) })
@defer.inlineCallbacks
@log_function
def exchange_third_party_invite(self, invite):
sender = invite["sender"]
room_id = invite["room_id"]
event_dict = {
"type": EventTypes.Member,
"content": {
"membership": Membership.INVITE,
"third_party_invite": invite,
},
"room_id": room_id,
"sender": sender,
"state_key": invite["mxid"],
}
if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
builder = self.event_builder_factory.new(event_dict)
EventValidator().validate_new(builder)
event, context = yield self._create_new_client_event(builder=builder)
self.auth.check(event, context.current_state)
yield self._validate_keyserver(event, auth_events=context.current_state)
member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.change_membership(event, context)
else:
destinations = set([x.split(":", 1)[-1] for x in (sender, room_id)])
yield self.replication_layer.forward_third_party_invite(
destinations,
room_id,
event_dict,
)
@defer.inlineCallbacks
@log_function
def on_exchange_third_party_invite_request(self, origin, room_id, event_dict):
builder = self.event_builder_factory.new(event_dict)
event, context = yield self._create_new_client_event(
builder=builder,
)
self.auth.check(event, auth_events=context.current_state)
yield self._validate_keyserver(event, auth_events=context.current_state)
returned_invite = yield self.send_invite(origin, event)
# TODO: Make sure the signatures actually are correct.
event.signatures.update(returned_invite.signatures)
member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.change_membership(event, context)
@defer.inlineCallbacks
def _validate_keyserver(self, event, auth_events):
token = event.content["third_party_invite"]["signed"]["token"]
invite_event = auth_events.get(
(EventTypes.ThirdPartyInvite, token,)
)
try:
response = yield self.hs.get_simple_http_client().get_json(
invite_event.content["key_validity_url"],
{"public_key": invite_event.content["public_key"]}
)
except Exception:
raise SynapseError(
502,
"Third party certificate could not be checked"
)
if "valid" not in response or not response["valid"]:
raise AuthError(403, "Third party certificate was invalid")

View file

@ -16,7 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError, AuthError, Codes
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
@ -71,7 +71,7 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_messages(self, user_id=None, room_id=None, pagin_config=None, def get_messages(self, user_id=None, room_id=None, pagin_config=None,
as_client_event=True): as_client_event=True, is_guest=False):
"""Get messages in a room. """Get messages in a room.
Args: Args:
@ -80,11 +80,11 @@ class MessageHandler(BaseHandler):
pagin_config (synapse.api.streams.PaginationConfig): The pagination pagin_config (synapse.api.streams.PaginationConfig): The pagination
config rules to apply, if any. config rules to apply, if any.
as_client_event (bool): True to get events in client-server format. as_client_event (bool): True to get events in client-server format.
is_guest (bool): Whether the requesting user is a guest (as opposed
to a fully registered user).
Returns: Returns:
dict: Pagination API results dict: Pagination API results
""" """
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
data_source = self.hs.get_event_sources().sources["room"] data_source = self.hs.get_event_sources().sources["room"]
if pagin_config.from_token: if pagin_config.from_token:
@ -107,9 +107,13 @@ class MessageHandler(BaseHandler):
source_config = pagin_config.get_source_config("room") source_config = pagin_config.get_source_config("room")
if not is_guest:
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
if member_event.membership == Membership.LEAVE: if member_event.membership == Membership.LEAVE:
# If they have left the room then clamp the token to be before # If they have left the room then clamp the token to be before
# they left the room # they left the room.
# If they're a guest, we'll just 403 them if they're asking for
# events they can't see.
leave_token = yield self.store.get_topological_token_for_event( leave_token = yield self.store.get_topological_token_for_event(
member_event.event_id member_event.event_id
) )
@ -146,7 +150,7 @@ class MessageHandler(BaseHandler):
"end": next_token.to_string(), "end": next_token.to_string(),
}) })
events = yield self._filter_events_for_client(user_id, events) events = yield self._filter_events_for_client(user_id, events, is_guest=is_guest)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
@ -163,7 +167,7 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def create_and_send_event(self, event_dict, ratelimit=True, def create_and_send_event(self, event_dict, ratelimit=True,
token_id=None, txn_id=None): token_id=None, txn_id=None, is_guest=False):
""" Given a dict from a client, create and handle a new event. """ Given a dict from a client, create and handle a new event.
Creates an FrozenEvent object, filling out auth_events, prev_events, Creates an FrozenEvent object, filling out auth_events, prev_events,
@ -209,7 +213,7 @@ class MessageHandler(BaseHandler):
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
member_handler = self.hs.get_handlers().room_member_handler member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.change_membership(event, context) yield member_handler.change_membership(event, context, is_guest=is_guest)
else: else:
yield self.handle_new_client_event( yield self.handle_new_client_event(
event=event, event=event,
@ -225,7 +229,7 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_data(self, user_id=None, room_id=None, def get_room_data(self, user_id=None, room_id=None,
event_type=None, state_key=""): event_type=None, state_key="", is_guest=False):
""" Get data from a room. """ Get data from a room.
Args: Args:
@ -235,23 +239,52 @@ class MessageHandler(BaseHandler):
Raises: Raises:
SynapseError if something went wrong. SynapseError if something went wrong.
""" """
member_event = yield self.auth.check_user_was_in_room(room_id, user_id) membership, membership_event_id = yield self._check_in_room_or_world_readable(
room_id, user_id, is_guest
)
if member_event.membership == Membership.JOIN: if membership == Membership.JOIN:
data = yield self.state_handler.get_current_state( data = yield self.state_handler.get_current_state(
room_id, event_type, state_key room_id, event_type, state_key
) )
elif member_event.membership == Membership.LEAVE: elif membership == Membership.LEAVE:
key = (event_type, state_key) key = (event_type, state_key)
room_state = yield self.store.get_state_for_events( room_state = yield self.store.get_state_for_events(
[member_event.event_id], [key] [membership_event_id], [key]
) )
data = room_state[member_event.event_id].get(key) data = room_state[membership_event_id].get(key)
defer.returnValue(data) defer.returnValue(data)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_events(self, user_id, room_id): def _check_in_room_or_world_readable(self, room_id, user_id, is_guest):
try:
# check_user_was_in_room will return the most recent membership
# event for the user if:
# * The user is a non-guest user, and was ever in the room
# * The user is a guest user, and has joined the room
# else it will throw.
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
defer.returnValue((member_event.membership, member_event.event_id))
return
except AuthError, auth_error:
visibility = yield self.state_handler.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, ""
)
if (
visibility and
visibility.content["history_visibility"] == "world_readable"
):
defer.returnValue((Membership.JOIN, None))
return
if not is_guest:
raise auth_error
raise AuthError(
403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
)
@defer.inlineCallbacks
def get_state_events(self, user_id, room_id, is_guest=False):
"""Retrieve all state events for a given room. If the user is """Retrieve all state events for a given room. If the user is
joined to the room then return the current state. If the user has joined to the room then return the current state. If the user has
left the room return the state events from when they left. left the room return the state events from when they left.
@ -262,15 +295,17 @@ class MessageHandler(BaseHandler):
Returns: Returns:
A list of dicts representing state events. [{}, {}, {}] A list of dicts representing state events. [{}, {}, {}]
""" """
member_event = yield self.auth.check_user_was_in_room(room_id, user_id) membership, membership_event_id = yield self._check_in_room_or_world_readable(
room_id, user_id, is_guest
if member_event.membership == Membership.JOIN:
room_state = yield self.state_handler.get_current_state(room_id)
elif member_event.membership == Membership.LEAVE:
room_state = yield self.store.get_state_for_events(
[member_event.event_id], None
) )
room_state = room_state[member_event.event_id]
if membership == Membership.JOIN:
room_state = yield self.state_handler.get_current_state(room_id)
elif membership == Membership.LEAVE:
room_state = yield self.store.get_state_for_events(
[membership_event_id], None
)
room_state = room_state[membership_event_id]
now = self.clock.time_msec() now = self.clock.time_msec()
defer.returnValue( defer.returnValue(
@ -322,6 +357,8 @@ class MessageHandler(BaseHandler):
user, pagination_config.get_source_config("receipt"), None user, pagination_config.get_source_config("receipt"), None
) )
tags_by_room = yield self.store.get_tags_for_user(user_id)
public_room_ids = yield self.store.get_public_room_ids() public_room_ids = yield self.store.get_public_room_ids()
limit = pagin_config.limit limit = pagin_config.limit
@ -398,6 +435,15 @@ class MessageHandler(BaseHandler):
serialize_event(c, time_now, as_client_event) serialize_event(c, time_now, as_client_event)
for c in current_state.values() for c in current_state.values()
] ]
private_user_data = []
tags = tags_by_room.get(event.room_id)
if tags:
private_user_data.append({
"type": "m.tag",
"content": {"tags": tags},
})
d["private_user_data"] = private_user_data
except: except:
logger.exception("Failed to get snapshot") logger.exception("Failed to get snapshot")
@ -420,7 +466,7 @@ class MessageHandler(BaseHandler):
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks @defer.inlineCallbacks
def room_initial_sync(self, user_id, room_id, pagin_config=None): def room_initial_sync(self, user_id, room_id, pagin_config=None, is_guest=False):
"""Capture the a snapshot of a room. If user is currently a member of """Capture the a snapshot of a room. If user is currently a member of
the room this will be what is currently in the room. If the user left the room this will be what is currently in the room. If the user left
the room this will be what was in the room when they left. the room this will be what was in the room when they left.
@ -437,33 +483,47 @@ class MessageHandler(BaseHandler):
A JSON serialisable dict with the snapshot of the room. A JSON serialisable dict with the snapshot of the room.
""" """
member_event = yield self.auth.check_user_was_in_room(room_id, user_id) membership, member_event_id = yield self._check_in_room_or_world_readable(
room_id,
user_id,
is_guest
)
if member_event.membership == Membership.JOIN: if membership == Membership.JOIN:
result = yield self._room_initial_sync_joined( result = yield self._room_initial_sync_joined(
user_id, room_id, pagin_config, member_event user_id, room_id, pagin_config, membership, is_guest
) )
elif member_event.membership == Membership.LEAVE: elif membership == Membership.LEAVE:
result = yield self._room_initial_sync_parted( result = yield self._room_initial_sync_parted(
user_id, room_id, pagin_config, member_event user_id, room_id, pagin_config, membership, member_event_id, is_guest
) )
private_user_data = []
tags = yield self.store.get_tags_for_room(user_id, room_id)
if tags:
private_user_data.append({
"type": "m.tag",
"content": {"tags": tags},
})
result["private_user_data"] = private_user_data
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def _room_initial_sync_parted(self, user_id, room_id, pagin_config, def _room_initial_sync_parted(self, user_id, room_id, pagin_config,
member_event): membership, member_event_id, is_guest):
room_state = yield self.store.get_state_for_events( room_state = yield self.store.get_state_for_events(
[member_event.event_id], None [member_event_id], None
) )
room_state = room_state[member_event.event_id] room_state = room_state[member_event_id]
limit = pagin_config.limit if pagin_config else None limit = pagin_config.limit if pagin_config else None
if limit is None: if limit is None:
limit = 10 limit = 10
stream_token = yield self.store.get_stream_token_for_event( stream_token = yield self.store.get_stream_token_for_event(
member_event.event_id member_event_id
) )
messages, token = yield self.store.get_recent_events_for_room( messages, token = yield self.store.get_recent_events_for_room(
@ -473,16 +533,16 @@ class MessageHandler(BaseHandler):
) )
messages = yield self._filter_events_for_client( messages = yield self._filter_events_for_client(
user_id, messages user_id, messages, is_guest=is_guest
) )
start_token = StreamToken(token[0], 0, 0, 0) start_token = StreamToken(token[0], 0, 0, 0, 0)
end_token = StreamToken(token[1], 0, 0, 0) end_token = StreamToken(token[1], 0, 0, 0, 0)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
defer.returnValue({ defer.returnValue({
"membership": member_event.membership, "membership": membership,
"room_id": room_id, "room_id": room_id,
"messages": { "messages": {
"chunk": [serialize_event(m, time_now) for m in messages], "chunk": [serialize_event(m, time_now) for m in messages],
@ -496,7 +556,7 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _room_initial_sync_joined(self, user_id, room_id, pagin_config, def _room_initial_sync_joined(self, user_id, room_id, pagin_config,
member_event): membership, is_guest):
current_state = yield self.state.get_current_state( current_state = yield self.state.get_current_state(
room_id=room_id, room_id=room_id,
) )
@ -528,6 +588,8 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_presence(): def get_presence():
states = {}
if not is_guest:
states = yield presence_handler.get_states( states = yield presence_handler.get_states(
target_users=[UserID.from_string(m.user_id) for m in room_members], target_users=[UserID.from_string(m.user_id) for m in room_members],
auth_user=auth_user, auth_user=auth_user,
@ -553,7 +615,7 @@ class MessageHandler(BaseHandler):
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
messages = yield self._filter_events_for_client( messages = yield self._filter_events_for_client(
user_id, messages user_id, messages, is_guest=is_guest, require_all_visible_for_guests=False
) )
start_token = now_token.copy_and_replace("room_key", token[0]) start_token = now_token.copy_and_replace("room_key", token[0])
@ -561,8 +623,7 @@ class MessageHandler(BaseHandler):
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
defer.returnValue({ ret = {
"membership": member_event.membership,
"room_id": room_id, "room_id": room_id,
"messages": { "messages": {
"chunk": [serialize_event(m, time_now) for m in messages], "chunk": [serialize_event(m, time_now) for m in messages],
@ -572,4 +633,8 @@ class MessageHandler(BaseHandler):
"state": state, "state": state,
"presence": presence, "presence": presence,
"receipts": receipts, "receipts": receipts,
}) }
if not is_guest:
ret["membership"] = membership
defer.returnValue(ret)

View file

@ -950,6 +950,7 @@ class PresenceHandler(BaseHandler):
) )
while len(self._remote_offline_serials) > MAX_OFFLINE_SERIALS: while len(self._remote_offline_serials) > MAX_OFFLINE_SERIALS:
self._remote_offline_serials.pop() # remove the oldest self._remote_offline_serials.pop() # remove the oldest
if user in self._user_cachemap:
del self._user_cachemap[user] del self._user_cachemap[user]
else: else:
# Remove the user from remote_offline_serials now that they're # Remove the user from remote_offline_serials now that they're
@ -1142,8 +1143,9 @@ class PresenceEventSource(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_new_events_for_user(self, user, from_key, limit): def get_new_events(self, user, from_key, room_ids=None, **kwargs):
from_key = int(from_key) from_key = int(from_key)
room_ids = room_ids or []
presence = self.hs.get_handlers().presence_handler presence = self.hs.get_handlers().presence_handler
cachemap = presence._user_cachemap cachemap = presence._user_cachemap
@ -1161,7 +1163,6 @@ class PresenceEventSource(object):
user_ids_to_check |= set( user_ids_to_check |= set(
UserID.from_string(p["observed_user_id"]) for p in presence_list UserID.from_string(p["observed_user_id"]) for p in presence_list
) )
room_ids = yield presence.get_joined_rooms_for_user(user)
for room_id in set(room_ids) & set(presence._room_serials): for room_id in set(room_ids) & set(presence._room_serials):
if presence._room_serials[room_id] > from_key: if presence._room_serials[room_id] > from_key:
joined = yield presence.get_joined_users_for_room_id(room_id) joined = yield presence.get_joined_users_for_room_id(room_id)

View file

@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
class PrivateUserDataEventSource(object):
def __init__(self, hs):
self.store = hs.get_datastore()
def get_current_key(self, direction='f'):
return self.store.get_max_private_user_data_stream_id()
@defer.inlineCallbacks
def get_new_events(self, user, from_key, **kwargs):
user_id = user.to_string()
last_stream_id = from_key
current_stream_id = yield self.store.get_max_private_user_data_stream_id()
tags = yield self.store.get_updated_tags(user_id, last_stream_id)
results = []
for room_id, room_tags in tags.items():
results.append({
"type": "m.tag",
"content": {"tags": room_tags},
"room_id": room_id,
})
defer.returnValue((results, current_stream_id))
@defer.inlineCallbacks
def get_pagination_rows(self, user, config, key):
defer.returnValue(([], config.to_id))

View file

@ -164,17 +164,15 @@ class ReceiptEventSource(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_new_events_for_user(self, user, from_key, limit): def get_new_events(self, from_key, room_ids, **kwargs):
from_key = int(from_key) from_key = int(from_key)
to_key = yield self.get_current_key() to_key = yield self.get_current_key()
if from_key == to_key: if from_key == to_key:
defer.returnValue(([], to_key)) defer.returnValue(([], to_key))
rooms = yield self.store.get_rooms_for_user(user.to_string())
rooms = [room.room_id for room in rooms]
events = yield self.store.get_linearized_receipts_for_rooms( events = yield self.store.get_linearized_receipts_for_rooms(
rooms, room_ids,
from_key=from_key, from_key=from_key,
to_key=to_key, to_key=to_key,
) )

View file

@ -64,7 +64,7 @@ class RegistrationHandler(BaseHandler):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def register(self, localpart=None, password=None): def register(self, localpart=None, password=None, generate_token=True):
"""Registers a new client on the server. """Registers a new client on the server.
Args: Args:
@ -89,6 +89,8 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
token = None
if generate_token:
token = self.auth_handler().generate_access_token(user_id) token = self.auth_handler().generate_access_token(user_id)
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
@ -102,13 +104,13 @@ class RegistrationHandler(BaseHandler):
attempts = 0 attempts = 0
user_id = None user_id = None
token = None token = None
while not user_id and not token: while not user_id:
try: try:
localpart = self._generate_user_id() localpart = self._generate_user_id()
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
yield self.check_user_id_is_valid(user_id) yield self.check_user_id_is_valid(user_id)
if generate_token:
token = self.auth_handler().generate_access_token(user_id) token = self.auth_handler().generate_access_token(user_id)
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,

View file

@ -38,6 +38,8 @@ import string
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
id_server_scheme = "https://"
class RoomCreationHandler(BaseHandler): class RoomCreationHandler(BaseHandler):
@ -367,7 +369,7 @@ class RoomMemberHandler(BaseHandler):
remotedomains.add(member.domain) remotedomains.add(member.domain)
@defer.inlineCallbacks @defer.inlineCallbacks
def change_membership(self, event, context, do_auth=True): def change_membership(self, event, context, do_auth=True, is_guest=False):
""" Change the membership status of a user in a room. """ Change the membership status of a user in a room.
Args: Args:
@ -388,6 +390,20 @@ class RoomMemberHandler(BaseHandler):
# if this HS is not currently in the room, i.e. we have to do the # if this HS is not currently in the room, i.e. we have to do the
# invite/join dance. # invite/join dance.
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
if is_guest:
guest_access = context.current_state.get(
(EventTypes.GuestAccess, ""),
None
)
is_guest_access_allowed = (
guest_access
and guest_access.content
and "guest_access" in guest_access.content
and guest_access.content["guest_access"] == "can_join"
)
if not is_guest_access_allowed:
raise AuthError(403, "Guest access not allowed")
yield self._do_join(event, context, do_auth=do_auth) yield self._do_join(event, context, do_auth=do_auth)
else: else:
if event.membership == Membership.LEAVE: if event.membership == Membership.LEAVE:
@ -489,7 +505,7 @@ class RoomMemberHandler(BaseHandler):
room_hosts, room_hosts,
room_id, room_id,
event.user_id, event.user_id,
event.content # FIXME To get a non-frozen dict event.content,
) )
else: else:
logger.debug("Doing normal join") logger.debug("Doing normal join")
@ -581,7 +597,6 @@ class RoomMemberHandler(BaseHandler):
medium, medium,
address, address,
id_server, id_server,
display_name,
token_id, token_id,
txn_id txn_id
): ):
@ -608,7 +623,6 @@ class RoomMemberHandler(BaseHandler):
else: else:
yield self._make_and_store_3pid_invite( yield self._make_and_store_3pid_invite(
id_server, id_server,
display_name,
medium, medium,
address, address,
room_id, room_id,
@ -632,7 +646,7 @@ class RoomMemberHandler(BaseHandler):
""" """
try: try:
data = yield self.hs.get_simple_http_client().get_json( data = yield self.hs.get_simple_http_client().get_json(
"https://%s/_matrix/identity/api/v1/lookup" % (id_server,), "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,),
{ {
"medium": medium, "medium": medium,
"address": address, "address": address,
@ -655,8 +669,8 @@ class RoomMemberHandler(BaseHandler):
raise AuthError(401, "No signature from server %s" % (server_hostname,)) raise AuthError(401, "No signature from server %s" % (server_hostname,))
for key_name, signature in data["signatures"][server_hostname].items(): for key_name, signature in data["signatures"][server_hostname].items():
key_data = yield self.hs.get_simple_http_client().get_json( key_data = yield self.hs.get_simple_http_client().get_json(
"https://%s/_matrix/identity/api/v1/pubkey/%s" % "%s%s/_matrix/identity/api/v1/pubkey/%s" %
(server_hostname, key_name,), (id_server_scheme, server_hostname, key_name,),
) )
if "public_key" not in key_data: if "public_key" not in key_data:
raise AuthError(401, "No public key named %s from %s" % raise AuthError(401, "No public key named %s from %s" %
@ -672,7 +686,6 @@ class RoomMemberHandler(BaseHandler):
def _make_and_store_3pid_invite( def _make_and_store_3pid_invite(
self, self,
id_server, id_server,
display_name,
medium, medium,
address, address,
room_id, room_id,
@ -680,7 +693,7 @@ class RoomMemberHandler(BaseHandler):
token_id, token_id,
txn_id txn_id
): ):
token, public_key, key_validity_url = ( token, public_key, key_validity_url, display_name = (
yield self._ask_id_server_for_third_party_invite( yield self._ask_id_server_for_third_party_invite(
id_server, id_server,
medium, medium,
@ -709,7 +722,9 @@ class RoomMemberHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _ask_id_server_for_third_party_invite( def _ask_id_server_for_third_party_invite(
self, id_server, medium, address, room_id, sender): self, id_server, medium, address, room_id, sender):
is_url = "https://%s/_matrix/identity/api/v1/store-invite" % (id_server,) is_url = "%s%s/_matrix/identity/api/v1/store-invite" % (
id_server_scheme, id_server,
)
data = yield self.hs.get_simple_http_client().post_urlencoded_get_json( data = yield self.hs.get_simple_http_client().post_urlencoded_get_json(
is_url, is_url,
{ {
@ -722,10 +737,11 @@ class RoomMemberHandler(BaseHandler):
# TODO: Check for success # TODO: Check for success
token = data["token"] token = data["token"]
public_key = data["public_key"] public_key = data["public_key"]
key_validity_url = "https://%s/_matrix/identity/api/v1/pubkey/isvalid" % ( display_name = data["display_name"]
id_server, key_validity_url = "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % (
id_server_scheme, id_server,
) )
defer.returnValue((token, public_key, key_validity_url)) defer.returnValue((token, public_key, key_validity_url, display_name))
class RoomListHandler(BaseHandler): class RoomListHandler(BaseHandler):
@ -750,7 +766,7 @@ class RoomListHandler(BaseHandler):
class RoomContextHandler(BaseHandler): class RoomContextHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_event_context(self, user, room_id, event_id, limit): def get_event_context(self, user, room_id, event_id, limit, is_guest):
"""Retrieves events, pagination tokens and state around a given event """Retrieves events, pagination tokens and state around a given event
in a room. in a room.
@ -774,11 +790,17 @@ class RoomContextHandler(BaseHandler):
) )
results["events_before"] = yield self._filter_events_for_client( results["events_before"] = yield self._filter_events_for_client(
user.to_string(), results["events_before"] user.to_string(),
results["events_before"],
is_guest=is_guest,
require_all_visible_for_guests=False
) )
results["events_after"] = yield self._filter_events_for_client( results["events_after"] = yield self._filter_events_for_client(
user.to_string(), results["events_after"] user.to_string(),
results["events_after"],
is_guest=is_guest,
require_all_visible_for_guests=False
) )
if results["events_after"]: if results["events_after"]:
@ -807,7 +829,14 @@ class RoomEventSource(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_new_events_for_user(self, user, from_key, limit): def get_new_events(
self,
user,
from_key,
limit,
room_ids,
is_guest,
):
# We just ignore the key for now. # We just ignore the key for now.
to_key = yield self.get_current_key() to_key = yield self.get_current_key()
@ -827,8 +856,9 @@ class RoomEventSource(object):
user_id=user.to_string(), user_id=user.to_string(),
from_key=from_key, from_key=from_key,
to_key=to_key, to_key=to_key,
room_id=None,
limit=limit, limit=limit,
room_ids=room_ids,
is_guest=is_guest,
) )
defer.returnValue((events, end_key)) defer.returnValue((events, end_key))

View file

@ -22,6 +22,8 @@ from synapse.api.filtering import Filter
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from unpaddedbase64 import decode_base64, encode_base64
import logging import logging
@ -34,27 +36,59 @@ class SearchHandler(BaseHandler):
super(SearchHandler, self).__init__(hs) super(SearchHandler, self).__init__(hs)
@defer.inlineCallbacks @defer.inlineCallbacks
def search(self, user, content): def search(self, user, content, batch=None):
"""Performs a full text search for a user. """Performs a full text search for a user.
Args: Args:
user (UserID) user (UserID)
content (dict): Search parameters content (dict): Search parameters
batch (str): The next_batch parameter. Used for pagination.
Returns: Returns:
dict to be returned to the client with results of search dict to be returned to the client with results of search
""" """
batch_group = None
batch_group_key = None
batch_token = None
if batch:
try: try:
search_term = content["search_categories"]["room_events"]["search_term"] b = decode_base64(batch)
keys = content["search_categories"]["room_events"].get("keys", [ batch_group, batch_group_key, batch_token = b.split("\n")
assert batch_group is not None
assert batch_group_key is not None
assert batch_token is not None
except:
raise SynapseError(400, "Invalid batch")
try:
room_cat = content["search_categories"]["room_events"]
# The actual thing to query in FTS
search_term = room_cat["search_term"]
# Which "keys" to search over in FTS query
keys = room_cat.get("keys", [
"content.body", "content.name", "content.topic", "content.body", "content.name", "content.topic",
]) ])
filter_dict = content["search_categories"]["room_events"].get("filter", {})
event_context = content["search_categories"]["room_events"].get( # Filter to apply to results
filter_dict = room_cat.get("filter", {})
# What to order results by (impacts whether pagination can be doen)
order_by = room_cat.get("order_by", "rank")
# Include context around each event?
event_context = room_cat.get(
"event_context", None "event_context", None
) )
# Group results together? May allow clients to paginate within a
# group
group_by = room_cat.get("groupings", {}).get("group_by", {})
group_keys = [g["key"] for g in group_by]
if event_context is not None: if event_context is not None:
before_limit = int(event_context.get( before_limit = int(event_context.get(
"before_limit", 5 "before_limit", 5
@ -65,6 +99,15 @@ class SearchHandler(BaseHandler):
except KeyError: except KeyError:
raise SynapseError(400, "Invalid search query") raise SynapseError(400, "Invalid search query")
if order_by not in ("rank", "recent"):
raise SynapseError(400, "Invalid order by: %r" % (order_by,))
if set(group_keys) - {"room_id", "sender"}:
raise SynapseError(
400,
"Invalid group by keys: %r" % (set(group_keys) - {"room_id", "sender"},)
)
search_filter = Filter(filter_dict) search_filter = Filter(filter_dict)
# TODO: Search through left rooms too # TODO: Search through left rooms too
@ -77,19 +120,130 @@ class SearchHandler(BaseHandler):
room_ids = search_filter.filter_rooms(room_ids) room_ids = search_filter.filter_rooms(room_ids)
rank_map, event_map, _ = yield self.store.search_msgs( if batch_group == "room_id":
room_ids.intersection_update({batch_group_key})
rank_map = {} # event_id -> rank of event
allowed_events = []
room_groups = {} # Holds result of grouping by room, if applicable
sender_group = {} # Holds result of grouping by sender, if applicable
# Holds the next_batch for the entire result set if one of those exists
global_next_batch = None
if order_by == "rank":
results = yield self.store.search_msgs(
room_ids, search_term, keys room_ids, search_term, keys
) )
filtered_events = search_filter.filter(event_map.values()) results_map = {r["event"].event_id: r for r in results}
allowed_events = yield self._filter_events_for_client( rank_map.update({r["event"].event_id: r["rank"] for r in results})
filtered_events = search_filter.filter([r["event"] for r in results])
events = yield self._filter_events_for_client(
user.to_string(), filtered_events user.to_string(), filtered_events
) )
allowed_events.sort(key=lambda e: -rank_map[e.event_id]) events.sort(key=lambda e: -rank_map[e.event_id])
allowed_events = allowed_events[:search_filter.limit()] allowed_events = events[:search_filter.limit()]
for e in allowed_events:
rm = room_groups.setdefault(e.room_id, {
"results": [],
"order": rank_map[e.event_id],
})
rm["results"].append(e.event_id)
s = sender_group.setdefault(e.sender, {
"results": [],
"order": rank_map[e.event_id],
})
s["results"].append(e.event_id)
elif order_by == "recent":
# In this case we specifically loop through each room as the given
# limit applies to each room, rather than a global list.
# This is not necessarilly a good idea.
for room_id in room_ids:
room_events = []
if batch_group == "room_id" and batch_group_key == room_id:
pagination_token = batch_token
else:
pagination_token = None
i = 0
# We keep looping and we keep filtering until we reach the limit
# or we run out of things.
# But only go around 5 times since otherwise synapse will be sad.
while len(room_events) < search_filter.limit() and i < 5:
i += 1
results = yield self.store.search_room(
room_id, search_term, keys, search_filter.limit() * 2,
pagination_token=pagination_token,
)
results_map = {r["event"].event_id: r for r in results}
rank_map.update({r["event"].event_id: r["rank"] for r in results})
filtered_events = search_filter.filter([
r["event"] for r in results
])
events = yield self._filter_events_for_client(
user.to_string(), filtered_events
)
room_events.extend(events)
room_events = room_events[:search_filter.limit()]
if len(results) < search_filter.limit() * 2:
pagination_token = None
break
else:
pagination_token = results[-1]["pagination_token"]
if room_events:
res = results_map[room_events[-1].event_id]
pagination_token = res["pagination_token"]
group = room_groups.setdefault(room_id, {})
if pagination_token:
next_batch = encode_base64("%s\n%s\n%s" % (
"room_id", room_id, pagination_token
))
group["next_batch"] = next_batch
if batch_token:
global_next_batch = next_batch
group["results"] = [e.event_id for e in room_events]
group["order"] = max(
e.origin_server_ts/1000 for e in room_events
if hasattr(e, "origin_server_ts")
)
allowed_events.extend(room_events)
# Normalize the group orders
if room_groups:
if len(room_groups) > 1:
mx = max(g["order"] for g in room_groups.values())
mn = min(g["order"] for g in room_groups.values())
for g in room_groups.values():
g["order"] = (g["order"] - mn) * 1.0 / (mx - mn)
else:
room_groups.values()[0]["order"] = 1
else:
# We should never get here due to the guard earlier.
raise NotImplementedError()
# If client has asked for "context" for each event (i.e. some surrounding
# events and state), fetch that
if event_context is not None: if event_context is not None:
now_token = yield self.hs.get_event_sources().get_current_token() now_token = yield self.hs.get_event_sources().get_current_token()
@ -144,11 +298,22 @@ class SearchHandler(BaseHandler):
logger.info("Found %d results", len(results)) logger.info("Found %d results", len(results))
defer.returnValue({ rooms_cat_res = {
"search_categories": {
"room_events": {
"results": results, "results": results,
"count": len(results) "count": len(results)
} }
if room_groups and "room_id" in group_keys:
rooms_cat_res.setdefault("groups", {})["room_id"] = room_groups
if sender_group and "sender" in group_keys:
rooms_cat_res.setdefault("groups", {})["sender"] = sender_group
if global_next_batch:
rooms_cat_res["next_batch"] = global_next_batch
defer.returnValue({
"search_categories": {
"room_events": rooms_cat_res
} }
}) })

View file

@ -47,10 +47,11 @@ class TimelineBatch(collections.namedtuple("TimelineBatch", [
class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
"room_id", "room_id", # str
"timeline", "timeline", # TimelineBatch
"state", "state", # dict[(str, str), FrozenEvent]
"ephemeral", "ephemeral",
"private_user_data",
])): ])):
__slots__ = [] __slots__ = []
@ -58,13 +59,19 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
"""Make the result appear empty if there are no updates. This is used """Make the result appear empty if there are no updates. This is used
to tell if room needs to be part of the sync result. to tell if room needs to be part of the sync result.
""" """
return bool(self.timeline or self.state or self.ephemeral) return bool(
self.timeline
or self.state
or self.ephemeral
or self.private_user_data
)
class ArchivedSyncResult(collections.namedtuple("JoinedSyncResult", [ class ArchivedSyncResult(collections.namedtuple("JoinedSyncResult", [
"room_id", "room_id", # str
"timeline", "timeline", # TimelineBatch
"state", "state", # dict[(str, str), FrozenEvent]
"private_user_data",
])): ])):
__slots__ = [] __slots__ = []
@ -72,12 +79,16 @@ class ArchivedSyncResult(collections.namedtuple("JoinedSyncResult", [
"""Make the result appear empty if there are no updates. This is used """Make the result appear empty if there are no updates. This is used
to tell if room needs to be part of the sync result. to tell if room needs to be part of the sync result.
""" """
return bool(self.timeline or self.state) return bool(
self.timeline
or self.state
or self.private_user_data
)
class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [ class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [
"room_id", "room_id", # str
"invite", "invite", # FrozenEvent: the invite event
])): ])):
__slots__ = [] __slots__ = []
@ -132,21 +143,8 @@ class SyncHandler(BaseHandler):
def current_sync_callback(before_token, after_token): def current_sync_callback(before_token, after_token):
return self.current_sync_for_user(sync_config, since_token) return self.current_sync_for_user(sync_config, since_token)
rm_handler = self.hs.get_handlers().room_member_handler
app_service = yield self.store.get_app_service_by_user_id(
sync_config.user.to_string()
)
if app_service:
rooms = yield self.store.get_app_service_rooms(app_service)
room_ids = set(r.room_id for r in rooms)
else:
room_ids = yield rm_handler.get_joined_rooms_for_user(
sync_config.user
)
result = yield self.notifier.wait_for_events( result = yield self.notifier.wait_for_events(
sync_config.user, room_ids, timeout, current_sync_callback, sync_config.user, timeout, current_sync_callback,
from_token=since_token from_token=since_token
) )
defer.returnValue(result) defer.returnValue(result)
@ -174,7 +172,7 @@ class SyncHandler(BaseHandler):
""" """
now_token = yield self.event_sources.get_current_token() now_token = yield self.event_sources.get_current_token()
now_token, typing_by_room = yield self.typing_by_room( now_token, ephemeral_by_room = yield self.ephemeral_by_room(
sync_config, now_token sync_config, now_token
) )
@ -197,6 +195,10 @@ class SyncHandler(BaseHandler):
) )
) )
tags_by_room = yield self.store.get_tags_for_user(
sync_config.user.to_string()
)
joined = [] joined = []
invited = [] invited = []
archived = [] archived = []
@ -207,7 +209,8 @@ class SyncHandler(BaseHandler):
sync_config=sync_config, sync_config=sync_config,
now_token=now_token, now_token=now_token,
timeline_since_token=timeline_since_token, timeline_since_token=timeline_since_token,
typing_by_room=typing_by_room ephemeral_by_room=ephemeral_by_room,
tags_by_room=tags_by_room,
) )
joined.append(room_sync) joined.append(room_sync)
elif event.membership == Membership.INVITE: elif event.membership == Membership.INVITE:
@ -226,6 +229,7 @@ class SyncHandler(BaseHandler):
leave_event_id=event.event_id, leave_event_id=event.event_id,
leave_token=leave_token, leave_token=leave_token,
timeline_since_token=timeline_since_token, timeline_since_token=timeline_since_token,
tags_by_room=tags_by_room,
) )
archived.append(room_sync) archived.append(room_sync)
@ -240,7 +244,7 @@ class SyncHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def full_state_sync_for_joined_room(self, room_id, sync_config, def full_state_sync_for_joined_room(self, room_id, sync_config,
now_token, timeline_since_token, now_token, timeline_since_token,
typing_by_room): ephemeral_by_room, tags_by_room):
"""Sync a room for a client which is starting without any state """Sync a room for a client which is starting without any state
Returns: Returns:
A Deferred JoinedSyncResult. A Deferred JoinedSyncResult.
@ -250,21 +254,31 @@ class SyncHandler(BaseHandler):
room_id, sync_config, now_token, since_token=timeline_since_token room_id, sync_config, now_token, since_token=timeline_since_token
) )
current_state = yield self.state_handler.get_current_state( current_state = yield self.get_state_at(room_id, now_token)
room_id
)
current_state_events = current_state.values()
defer.returnValue(JoinedSyncResult( defer.returnValue(JoinedSyncResult(
room_id=room_id, room_id=room_id,
timeline=batch, timeline=batch,
state=current_state_events, state=current_state,
ephemeral=typing_by_room.get(room_id, []), ephemeral=ephemeral_by_room.get(room_id, []),
private_user_data=self.private_user_data_for_room(
room_id, tags_by_room
),
)) ))
def private_user_data_for_room(self, room_id, tags_by_room):
private_user_data = []
tags = tags_by_room.get(room_id)
if tags is not None:
private_user_data.append({
"type": "m.tag",
"content": {"tags": tags},
})
return private_user_data
@defer.inlineCallbacks @defer.inlineCallbacks
def typing_by_room(self, sync_config, now_token, since_token=None): def ephemeral_by_room(self, sync_config, now_token, since_token=None):
"""Get the typing events for each room the user is in """Get the ephemeral events for each room the user is in
Args: Args:
sync_config (SyncConfig): The flags, filters and user for the sync. sync_config (SyncConfig): The flags, filters and user for the sync.
now_token (StreamToken): Where the server is currently up to. now_token (StreamToken): Where the server is currently up to.
@ -278,25 +292,56 @@ class SyncHandler(BaseHandler):
typing_key = since_token.typing_key if since_token else "0" typing_key = since_token.typing_key if since_token else "0"
rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string())
room_ids = [room.room_id for room in rooms]
typing_source = self.event_sources.sources["typing"] typing_source = self.event_sources.sources["typing"]
typing, typing_key = yield typing_source.get_new_events_for_user( typing, typing_key = yield typing_source.get_new_events(
user=sync_config.user, user=sync_config.user,
from_key=typing_key, from_key=typing_key,
limit=sync_config.filter.ephemeral_limit(), limit=sync_config.filter.ephemeral_limit(),
room_ids=room_ids,
is_guest=False,
) )
now_token = now_token.copy_and_replace("typing_key", typing_key) now_token = now_token.copy_and_replace("typing_key", typing_key)
typing_by_room = {event["room_id"]: [event] for event in typing} ephemeral_by_room = {}
for event in typing:
event.pop("room_id")
logger.debug("Typing %r", typing_by_room)
defer.returnValue((now_token, typing_by_room)) for event in typing:
# we want to exclude the room_id from the event, but modifying the
# result returned by the event source is poor form (it might cache
# the object)
room_id = event["room_id"]
event_copy = {k: v for (k, v) in event.iteritems()
if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy)
receipt_key = since_token.receipt_key if since_token else "0"
receipt_source = self.event_sources.sources["receipt"]
receipts, receipt_key = yield receipt_source.get_new_events(
user=sync_config.user,
from_key=receipt_key,
limit=sync_config.filter.ephemeral_limit(),
room_ids=room_ids,
# /sync doesn't support guest access, they can't get to this point in code
is_guest=False,
)
now_token = now_token.copy_and_replace("receipt_key", receipt_key)
for event in receipts:
room_id = event["room_id"]
# exclude room id, as above
event_copy = {k: v for (k, v) in event.iteritems()
if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy)
defer.returnValue((now_token, ephemeral_by_room))
@defer.inlineCallbacks @defer.inlineCallbacks
def full_state_sync_for_archived_room(self, room_id, sync_config, def full_state_sync_for_archived_room(self, room_id, sync_config,
leave_event_id, leave_token, leave_event_id, leave_token,
timeline_since_token): timeline_since_token, tags_by_room):
"""Sync a room for a client which is starting without any state """Sync a room for a client which is starting without any state
Returns: Returns:
A Deferred JoinedSyncResult. A Deferred JoinedSyncResult.
@ -306,14 +351,15 @@ class SyncHandler(BaseHandler):
room_id, sync_config, leave_token, since_token=timeline_since_token room_id, sync_config, leave_token, since_token=timeline_since_token
) )
leave_state = yield self.store.get_state_for_events( leave_state = yield self.store.get_state_for_event(leave_event_id)
[leave_event_id], None
)
defer.returnValue(ArchivedSyncResult( defer.returnValue(ArchivedSyncResult(
room_id=room_id, room_id=room_id,
timeline=batch, timeline=batch,
state=leave_state[leave_event_id].values(), state=leave_state,
private_user_data=self.private_user_data_for_room(
room_id, tags_by_room
),
)) ))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -325,15 +371,21 @@ class SyncHandler(BaseHandler):
""" """
now_token = yield self.event_sources.get_current_token() now_token = yield self.event_sources.get_current_token()
rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string())
room_ids = [room.room_id for room in rooms]
presence_source = self.event_sources.sources["presence"] presence_source = self.event_sources.sources["presence"]
presence, presence_key = yield presence_source.get_new_events_for_user( presence, presence_key = yield presence_source.get_new_events(
user=sync_config.user, user=sync_config.user,
from_key=since_token.presence_key, from_key=since_token.presence_key,
limit=sync_config.filter.presence_limit(), limit=sync_config.filter.presence_limit(),
room_ids=room_ids,
# /sync doesn't support guest access, they can't get to this point in code
is_guest=False,
) )
now_token = now_token.copy_and_replace("presence_key", presence_key) now_token = now_token.copy_and_replace("presence_key", presence_key)
now_token, typing_by_room = yield self.typing_by_room( now_token, ephemeral_by_room = yield self.ephemeral_by_room(
sync_config, now_token, since_token sync_config, now_token, since_token
) )
@ -355,15 +407,22 @@ class SyncHandler(BaseHandler):
sync_config.user.to_string(), sync_config.user.to_string(),
from_key=since_token.room_key, from_key=since_token.room_key,
to_key=now_token.room_key, to_key=now_token.room_key,
room_id=None,
limit=timeline_limit + 1, limit=timeline_limit + 1,
) )
tags_by_room = yield self.store.get_updated_tags(
sync_config.user.to_string(),
since_token.private_user_data_key,
)
joined = [] joined = []
archived = [] archived = []
if len(room_events) <= timeline_limit: if len(room_events) <= timeline_limit:
# There is no gap in any of the rooms. Therefore we can just # There is no gap in any of the rooms. Therefore we can just
# partition the new events by room and return them. # partition the new events by room and return them.
logger.debug("Got %i events for incremental sync - not limited",
len(room_events))
invite_events = [] invite_events = []
leave_events = [] leave_events = []
events_by_room_id = {} events_by_room_id = {}
@ -379,7 +438,12 @@ class SyncHandler(BaseHandler):
for room_id in joined_room_ids: for room_id in joined_room_ids:
recents = events_by_room_id.get(room_id, []) recents = events_by_room_id.get(room_id, [])
state = [event for event in recents if event.is_state()] logger.debug("Events for room %s: %r", room_id, recents)
state = {
(event.type, event.state_key): event
for event in recents if event.is_state()}
limited = False
if recents: if recents:
prev_batch = now_token.copy_and_replace( prev_batch = now_token.copy_and_replace(
"room_key", recents[0].internal_metadata.before "room_key", recents[0].internal_metadata.before
@ -387,9 +451,13 @@ class SyncHandler(BaseHandler):
else: else:
prev_batch = now_token prev_batch = now_token
state, limited = yield self.check_joined_room( just_joined = yield self.check_joined_room(sync_config, state)
sync_config, room_id, state if just_joined:
) logger.debug("User has just joined %s: needs full state",
room_id)
state = yield self.get_state_at(room_id, now_token)
# the timeline is inherently limited if we've just joined
limited = True
room_sync = JoinedSyncResult( room_sync = JoinedSyncResult(
room_id=room_id, room_id=room_id,
@ -399,12 +467,20 @@ class SyncHandler(BaseHandler):
limited=limited, limited=limited,
), ),
state=state, state=state,
ephemeral=typing_by_room.get(room_id, []) ephemeral=ephemeral_by_room.get(room_id, []),
private_user_data=self.private_user_data_for_room(
room_id, tags_by_room
),
) )
logger.debug("Result for room %s: %r", room_id, room_sync)
if room_sync: if room_sync:
joined.append(room_sync) joined.append(room_sync)
else: else:
logger.debug("Got %i events for incremental sync - hit limit",
len(room_events))
invite_events = yield self.store.get_invites_for_user( invite_events = yield self.store.get_invites_for_user(
sync_config.user.to_string() sync_config.user.to_string()
) )
@ -416,14 +492,14 @@ class SyncHandler(BaseHandler):
for room_id in joined_room_ids: for room_id in joined_room_ids:
room_sync = yield self.incremental_sync_with_gap_for_room( room_sync = yield self.incremental_sync_with_gap_for_room(
room_id, sync_config, since_token, now_token, room_id, sync_config, since_token, now_token,
typing_by_room ephemeral_by_room, tags_by_room
) )
if room_sync: if room_sync:
joined.append(room_sync) joined.append(room_sync)
for leave_event in leave_events: for leave_event in leave_events:
room_sync = yield self.incremental_sync_for_archived_room( room_sync = yield self.incremental_sync_for_archived_room(
sync_config, leave_event, since_token sync_config, leave_event, since_token, tags_by_room
) )
archived.append(room_sync) archived.append(room_sync)
@ -443,6 +519,9 @@ class SyncHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def load_filtered_recents(self, room_id, sync_config, now_token, def load_filtered_recents(self, room_id, sync_config, now_token,
since_token=None): since_token=None):
"""
:returns a Deferred TimelineBatch
"""
limited = True limited = True
recents = [] recents = []
filtering_factor = 2 filtering_factor = 2
@ -487,13 +566,15 @@ class SyncHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def incremental_sync_with_gap_for_room(self, room_id, sync_config, def incremental_sync_with_gap_for_room(self, room_id, sync_config,
since_token, now_token, since_token, now_token,
typing_by_room): ephemeral_by_room, tags_by_room):
""" Get the incremental delta needed to bring the client up to date for """ Get the incremental delta needed to bring the client up to date for
the room. Gives the client the most recent events and the changes to the room. Gives the client the most recent events and the changes to
state. state.
Returns: Returns:
A Deferred JoinedSyncResult A Deferred JoinedSyncResult
""" """
logger.debug("Doing incremental sync for room %s between %s and %s",
room_id, since_token, now_token)
# TODO(mjark): Check for redactions we might have missed. # TODO(mjark): Check for redactions we might have missed.
@ -503,32 +584,30 @@ class SyncHandler(BaseHandler):
logging.debug("Recents %r", batch) logging.debug("Recents %r", batch)
# TODO(mjark): This seems racy since this isn't being passed a current_state = yield self.get_state_at(room_id, now_token)
# token to indicate what point in the stream this is
current_state = yield self.state_handler.get_current_state(
room_id
)
current_state_events = current_state.values()
state_at_previous_sync = yield self.get_state_at_previous_sync( state_at_previous_sync = yield self.get_state_at(
room_id, since_token=since_token room_id, stream_position=since_token
) )
state_events_delta = yield self.compute_state_delta( state = yield self.compute_state_delta(
since_token=since_token, since_token=since_token,
previous_state=state_at_previous_sync, previous_state=state_at_previous_sync,
current_state=current_state_events, current_state=current_state,
) )
state_events_delta, _ = yield self.check_joined_room( just_joined = yield self.check_joined_room(sync_config, state)
sync_config, room_id, state_events_delta if just_joined:
) state = yield self.get_state_at(room_id, now_token)
room_sync = JoinedSyncResult( room_sync = JoinedSyncResult(
room_id=room_id, room_id=room_id,
timeline=batch, timeline=batch,
state=state_events_delta, state=state,
ephemeral=typing_by_room.get(room_id, []) ephemeral=ephemeral_by_room.get(room_id, []),
private_user_data=self.private_user_data_for_room(
room_id, tags_by_room
),
) )
logging.debug("Room sync: %r", room_sync) logging.debug("Room sync: %r", room_sync)
@ -537,7 +616,7 @@ class SyncHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def incremental_sync_for_archived_room(self, sync_config, leave_event, def incremental_sync_for_archived_room(self, sync_config, leave_event,
since_token): since_token, tags_by_room):
""" Get the incremental delta needed to bring the client up to date for """ Get the incremental delta needed to bring the client up to date for
the archived room. the archived room.
Returns: Returns:
@ -556,16 +635,12 @@ class SyncHandler(BaseHandler):
logging.debug("Recents %r", batch) logging.debug("Recents %r", batch)
# TODO(mjark): This seems racy since this isn't being passed a state_events_at_leave = yield self.store.get_state_for_event(
# token to indicate what point in the stream this is leave_event.event_id
leave_state = yield self.store.get_state_for_events(
[leave_event.event_id], None
) )
state_events_at_leave = leave_state[leave_event.event_id].values() state_at_previous_sync = yield self.get_state_at(
leave_event.room_id, stream_position=since_token
state_at_previous_sync = yield self.get_state_at_previous_sync(
leave_event.room_id, since_token=since_token
) )
state_events_delta = yield self.compute_state_delta( state_events_delta = yield self.compute_state_delta(
@ -578,6 +653,9 @@ class SyncHandler(BaseHandler):
room_id=leave_event.room_id, room_id=leave_event.room_id,
timeline=batch, timeline=batch,
state=state_events_delta, state=state_events_delta,
private_user_data=self.private_user_data_for_room(
leave_event.room_id, tags_by_room
),
) )
logging.debug("Room sync: %r", room_sync) logging.debug("Room sync: %r", room_sync)
@ -585,60 +663,77 @@ class SyncHandler(BaseHandler):
defer.returnValue(room_sync) defer.returnValue(room_sync)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_at_previous_sync(self, room_id, since_token): def get_state_after_event(self, event):
""" Get the room state at the previous sync the client made. """
Returns: Get the room state after the given event
A Deferred list of Events.
:param synapse.events.EventBase event: event of interest
:return: A Deferred map from ((type, state_key)->Event)
"""
state = yield self.store.get_state_for_event(event.event_id)
if event.is_state():
state = state.copy()
state[(event.type, event.state_key)] = event
defer.returnValue(state)
@defer.inlineCallbacks
def get_state_at(self, room_id, stream_position):
""" Get the room state at a particular stream position
:param str room_id: room for which to get state
:param StreamToken stream_position: point at which to get state
:returns: A Deferred map from ((type, state_key)->Event)
""" """
last_events, token = yield self.store.get_recent_events_for_room( last_events, token = yield self.store.get_recent_events_for_room(
room_id, end_token=since_token.room_key, limit=1, room_id, end_token=stream_position.room_key, limit=1,
) )
if last_events: if last_events:
last_event = last_events[0] last_event = last_events[-1]
last_context = yield self.state_handler.compute_event_context( state = yield self.get_state_after_event(last_event)
last_event
)
if last_event.is_state():
state = [last_event] + last_context.current_state.values()
else: else:
state = last_context.current_state.values() # no events in this room - so presumably no state
else: state = {}
state = ()
defer.returnValue(state) defer.returnValue(state)
def compute_state_delta(self, since_token, previous_state, current_state): def compute_state_delta(self, since_token, previous_state, current_state):
""" Works out the differnce in state between the current state and the """ Works out the differnce in state between the current state and the
state the client got when it last performed a sync. state the client got when it last performed a sync.
Returns:
A list of events. :param str since_token: the point we are comparing against
:param dict[(str,str), synapse.events.FrozenEvent] previous_state: the
state to compare to
:param dict[(str,str), synapse.events.FrozenEvent] current_state: the
new state
:returns A new event dictionary
""" """
# TODO(mjark) Check if the state events were received by the server # TODO(mjark) Check if the state events were received by the server
# after the previous sync, since we need to include those state # after the previous sync, since we need to include those state
# updates even if they occured logically before the previous event. # updates even if they occured logically before the previous event.
# TODO(mjark) Check for new redactions in the state events. # TODO(mjark) Check for new redactions in the state events.
previous_dict = {event.event_id: event for event in previous_state}
state_delta = [] state_delta = {}
for event in current_state: for key, event in current_state.iteritems():
if event.event_id not in previous_dict: if (key not in previous_state or
state_delta.append(event) previous_state[key].event_id != event.event_id):
state_delta[key] = event
return state_delta return state_delta
@defer.inlineCallbacks def check_joined_room(self, sync_config, state_delta):
def check_joined_room(self, sync_config, room_id, state_delta): """
joined = False Check if the user has just joined the given room (so should
limited = False be given the full state)
for event in state_delta:
if (
event.type == EventTypes.Member
and event.state_key == sync_config.user.to_string()
):
if event.content["membership"] == Membership.JOIN:
joined = True
if joined: :param sync_config:
res = yield self.state_handler.get_current_state(room_id) :param dict[(str,str), synapse.events.FrozenEvent] state_delta: the
state_delta = res.values() difference in state since the last sync
limited = True
defer.returnValue((state_delta, limited)) :returns A deferred Tuple (state_delta, limited)
"""
join_event = state_delta.get((
EventTypes.Member, sync_config.user.to_string()), None)
if join_event is not None:
if join_event.content["membership"] == Membership.JOIN:
return True
return False

View file

@ -246,17 +246,12 @@ class TypingNotificationEventSource(object):
}, },
} }
@defer.inlineCallbacks def get_new_events(self, from_key, room_ids, **kwargs):
def get_new_events_for_user(self, user, from_key, limit):
from_key = int(from_key) from_key = int(from_key)
handler = self.handler() handler = self.handler()
joined_room_ids = (
yield self.room_member_handler().get_joined_rooms_for_user(user)
)
events = [] events = []
for room_id in joined_room_ids: for room_id in room_ids:
if room_id not in handler._room_serials: if room_id not in handler._room_serials:
continue continue
if handler._room_serials[room_id] <= from_key: if handler._room_serials[room_id] <= from_key:
@ -264,7 +259,7 @@ class TypingNotificationEventSource(object):
events.append(self._make_event_for(room_id)) events.append(self._make_event_for(room_id))
defer.returnValue((events, handler._latest_room_serial)) return events, handler._latest_room_serial
def get_current_key(self): def get_current_key(self):
return self.handler()._latest_room_serial return self.handler()._latest_room_serial

View file

@ -35,6 +35,7 @@ from signedjson.sign import sign_json
import simplejson as json import simplejson as json
import logging import logging
import random
import sys import sys
import urllib import urllib
import urlparse import urlparse
@ -55,6 +56,9 @@ incoming_responses_counter = metrics.register_counter(
) )
MAX_RETRIES = 4
class MatrixFederationEndpointFactory(object): class MatrixFederationEndpointFactory(object):
def __init__(self, hs): def __init__(self, hs):
self.tls_server_context_factory = hs.tls_server_context_factory self.tls_server_context_factory = hs.tls_server_context_factory
@ -119,7 +123,7 @@ class MatrixFederationHttpClient(object):
# XXX: Would be much nicer to retry only at the transaction-layer # XXX: Would be much nicer to retry only at the transaction-layer
# (once we have reliable transactions in place) # (once we have reliable transactions in place)
retries_left = 5 retries_left = MAX_RETRIES
http_url_bytes = urlparse.urlunparse( http_url_bytes = urlparse.urlunparse(
("", "", path_bytes, param_bytes, query_bytes, "") ("", "", path_bytes, param_bytes, query_bytes, "")
@ -180,7 +184,9 @@ class MatrixFederationHttpClient(object):
) )
if retries_left and not timeout: if retries_left and not timeout:
yield sleep(2 ** (5 - retries_left)) delay = 5 ** (MAX_RETRIES + 1 - retries_left)
delay *= random.uniform(0.8, 1.4)
yield sleep(delay)
retries_left -= 1 retries_left -= 1
else: else:
raise raise

View file

@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor, ObservableDeferred from synapse.util.async import run_on_reactor, ObservableDeferred
@ -269,8 +271,8 @@ class Notifier(object):
logger.exception("Failed to notify listener") logger.exception("Failed to notify listener")
@defer.inlineCallbacks @defer.inlineCallbacks
def wait_for_events(self, user, rooms, timeout, callback, def wait_for_events(self, user, timeout, callback, room_ids=None,
from_token=StreamToken("s0", "0", "0", "0")): from_token=StreamToken("s0", "0", "0", "0", "0")):
"""Wait until the callback returns a non empty response or the """Wait until the callback returns a non empty response or the
timeout fires. timeout fires.
""" """
@ -279,11 +281,12 @@ class Notifier(object):
if user_stream is None: if user_stream is None:
appservice = yield self.store.get_app_service_by_user_id(user) appservice = yield self.store.get_app_service_by_user_id(user)
current_token = yield self.event_sources.get_current_token() current_token = yield self.event_sources.get_current_token()
if room_ids is None:
rooms = yield self.store.get_rooms_for_user(user) rooms = yield self.store.get_rooms_for_user(user)
rooms = [room.room_id for room in rooms] room_ids = [room.room_id for room in rooms]
user_stream = _NotifierUserStream( user_stream = _NotifierUserStream(
user=user, user=user,
rooms=rooms, rooms=room_ids,
appservice=appservice, appservice=appservice,
current_token=current_token, current_token=current_token,
time_now_ms=self.clock.time_msec(), time_now_ms=self.clock.time_msec(),
@ -328,8 +331,9 @@ class Notifier(object):
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_events_for(self, user, rooms, pagination_config, timeout, def get_events_for(self, user, pagination_config, timeout,
only_room_events=False): only_room_events=False,
is_guest=False, guest_room_id=None):
""" For the given user and rooms, return any new events for them. If """ For the given user and rooms, return any new events for them. If
there are no new events wait for up to `timeout` milliseconds for any there are no new events wait for up to `timeout` milliseconds for any
new events to happen before returning. new events to happen before returning.
@ -342,6 +346,16 @@ class Notifier(object):
limit = pagination_config.limit limit = pagination_config.limit
room_ids = []
if is_guest:
if guest_room_id:
if not self._is_world_readable(guest_room_id):
raise AuthError(403, "Guest access not allowed")
room_ids = [guest_room_id]
else:
rooms = yield self.store.get_rooms_for_user(user.to_string())
room_ids = [room.room_id for room in rooms]
@defer.inlineCallbacks @defer.inlineCallbacks
def check_for_updates(before_token, after_token): def check_for_updates(before_token, after_token):
if not after_token.is_after(before_token): if not after_token.is_after(before_token):
@ -349,6 +363,7 @@ class Notifier(object):
events = [] events = []
end_token = from_token end_token = from_token
for name, source in self.event_sources.sources.items(): for name, source in self.event_sources.sources.items():
keyname = "%s_key" % name keyname = "%s_key" % name
before_id = getattr(before_token, keyname) before_id = getattr(before_token, keyname)
@ -357,9 +372,23 @@ class Notifier(object):
continue continue
if only_room_events and name != "room": if only_room_events and name != "room":
continue continue
new_events, new_key = yield source.get_new_events_for_user( new_events, new_key = yield source.get_new_events(
user, getattr(from_token, keyname), limit, user=user,
from_key=getattr(from_token, keyname),
limit=limit,
is_guest=is_guest,
room_ids=room_ids,
) )
if name == "room":
room_member_handler = self.hs.get_handlers().room_member_handler
new_events = yield room_member_handler._filter_events_for_client(
user.to_string(),
new_events,
is_guest=is_guest,
require_all_visible_for_guests=False
)
events.extend(new_events) events.extend(new_events)
end_token = end_token.copy_and_replace(keyname, new_key) end_token = end_token.copy_and_replace(keyname, new_key)
@ -369,7 +398,7 @@ class Notifier(object):
defer.returnValue(None) defer.returnValue(None)
result = yield self.wait_for_events( result = yield self.wait_for_events(
user, rooms, timeout, check_for_updates, from_token=from_token user, timeout, check_for_updates, room_ids=room_ids, from_token=from_token
) )
if result is None: if result is None:
@ -377,6 +406,17 @@ class Notifier(object):
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks
def _is_world_readable(self, room_id):
state = yield self.hs.get_state_handler().get_current_state(
room_id,
EventTypes.RoomHistoryVisibility
)
if state and "history_visibility" in state.content:
defer.returnValue(state.content["history_visibility"] == "world_readable")
else:
defer.returnValue(False)
@log_function @log_function
def remove_expired_streams(self): def remove_expired_streams(self):
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()

View file

@ -31,7 +31,7 @@ class WhoisRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(auth_user) is_admin = yield self.auth.is_server_admin(auth_user)
if not is_admin and target_user != auth_user: if not is_admin and target_user != auth_user:

View file

@ -69,7 +69,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
try: try:
# try to auth as a user # try to auth as a user
user, _ = yield self.auth.get_user_by_req(request) user, _, _ = yield self.auth.get_user_by_req(request)
try: try:
user_id = user.to_string() user_id = user.to_string()
yield dir_handler.create_association( yield dir_handler.create_association(
@ -116,7 +116,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
# fallback to default user behaviour if they aren't an AS # fallback to default user behaviour if they aren't an AS
pass pass
user, _ = yield self.auth.get_user_by_req(request) user, _, _ = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(user) is_admin = yield self.auth.is_server_admin(user)
if not is_admin: if not is_admin:

View file

@ -34,7 +34,15 @@ class EventStreamRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, is_guest = yield self.auth.get_user_by_req(
request,
allow_guest=True
)
room_id = None
if is_guest:
if "room_id" not in request.args:
raise SynapseError(400, "Guest users must specify room_id param")
room_id = request.args["room_id"][0]
try: try:
handler = self.handlers.event_stream_handler handler = self.handlers.event_stream_handler
pagin_config = PaginationConfig.from_request(request) pagin_config = PaginationConfig.from_request(request)
@ -49,7 +57,8 @@ class EventStreamRestServlet(ClientV1RestServlet):
chunk = yield handler.get_stream( chunk = yield handler.get_stream(
auth_user.to_string(), pagin_config, timeout=timeout, auth_user.to_string(), pagin_config, timeout=timeout,
as_client_event=as_client_event as_client_event=as_client_event, affect_presence=(not is_guest),
room_id=room_id, is_guest=is_guest
) )
except: except:
logger.exception("Event stream failed") logger.exception("Event stream failed")
@ -71,7 +80,7 @@ class EventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, event_id): def on_GET(self, request, event_id):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
handler = self.handlers.event_handler handler = self.handlers.event_handler
event = yield handler.get_event(auth_user, event_id) event = yield handler.get_event(auth_user, event_id)

View file

@ -25,7 +25,7 @@ class InitialSyncRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
user, _ = yield self.auth.get_user_by_req(request) user, _, _ = yield self.auth.get_user_by_req(request)
as_client_event = "raw" not in request.args as_client_event = "raw" not in request.args
pagination_config = PaginationConfig.from_request(request) pagination_config = PaginationConfig.from_request(request)
handler = self.handlers.message_handler handler = self.handlers.message_handler

View file

@ -32,7 +32,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
state = yield self.handlers.presence_handler.get_state( state = yield self.handlers.presence_handler.get_state(
@ -42,7 +42,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, user_id): def on_PUT(self, request, user_id):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
state = {} state = {}
@ -77,7 +77,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
if not self.hs.is_mine(user): if not self.hs.is_mine(user):
@ -97,7 +97,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, user_id): def on_POST(self, request, user_id):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
if not self.hs.is_mine(user): if not self.hs.is_mine(user):

View file

@ -37,7 +37,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, user_id): def on_PUT(self, request, user_id):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request, allow_guest=True)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
try: try:
@ -70,7 +70,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, user_id): def on_PUT(self, request, user_id):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
try: try:

View file

@ -43,7 +43,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
except InvalidRuleException as e: except InvalidRuleException as e:
raise SynapseError(400, e.message) raise SynapseError(400, e.message)
user, _ = yield self.auth.get_user_by_req(request) user, _, _ = yield self.auth.get_user_by_req(request)
if '/' in spec['rule_id'] or '\\' in spec['rule_id']: if '/' in spec['rule_id'] or '\\' in spec['rule_id']:
raise SynapseError(400, "rule_id may not contain slashes") raise SynapseError(400, "rule_id may not contain slashes")
@ -92,7 +92,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
def on_DELETE(self, request): def on_DELETE(self, request):
spec = _rule_spec_from_path(request.postpath) spec = _rule_spec_from_path(request.postpath)
user, _ = yield self.auth.get_user_by_req(request) user, _, _ = yield self.auth.get_user_by_req(request)
namespaced_rule_id = _namespaced_rule_id_from_spec(spec) namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
@ -109,7 +109,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
user, _ = yield self.auth.get_user_by_req(request) user, _, _ = yield self.auth.get_user_by_req(request)
# we build up the full structure and then decide which bits of it # we build up the full structure and then decide which bits of it
# to send which means doing unnecessary work sometimes but is # to send which means doing unnecessary work sometimes but is

View file

@ -27,7 +27,7 @@ class PusherRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
user, token_id = yield self.auth.get_user_by_req(request) user, token_id, _ = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)

View file

@ -17,7 +17,7 @@
from twisted.internet import defer from twisted.internet import defer
from base import ClientV1RestServlet, client_path_pattern from base import ClientV1RestServlet, client_path_pattern
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes, AuthError
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.types import UserID, RoomID, RoomAlias from synapse.types import UserID, RoomID, RoomAlias
@ -26,7 +26,6 @@ from synapse.events.utils import serialize_event
import simplejson as json import simplejson as json
import logging import logging
import urllib import urllib
from synapse.util import third_party_invites
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -62,7 +61,7 @@ class RoomCreateRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
room_config = self.get_room_config(request) room_config = self.get_room_config(request)
info = yield self.make_room(room_config, auth_user, None) info = yield self.make_room(room_config, auth_user, None)
@ -125,7 +124,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id, event_type, state_key): def on_GET(self, request, room_id, event_type, state_key):
user, _ = yield self.auth.get_user_by_req(request) user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
data = yield msg_handler.get_room_data( data = yield msg_handler.get_room_data(
@ -133,6 +132,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
room_id=room_id, room_id=room_id,
event_type=event_type, event_type=event_type,
state_key=state_key, state_key=state_key,
is_guest=is_guest,
) )
if not data: if not data:
@ -143,7 +143,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
user, token_id = yield self.auth.get_user_by_req(request) user, token_id, _ = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)
@ -175,7 +175,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, event_type, txn_id=None): def on_POST(self, request, room_id, event_type, txn_id=None):
user, token_id = yield self.auth.get_user_by_req(request) user, token_id, _ = yield self.auth.get_user_by_req(request, allow_guest=True)
content = _parse_json(request) content = _parse_json(request)
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
@ -220,7 +220,10 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_identifier, txn_id=None): def on_POST(self, request, room_identifier, txn_id=None):
user, token_id = yield self.auth.get_user_by_req(request) user, token_id, is_guest = yield self.auth.get_user_by_req(
request,
allow_guest=True
)
# the identifier could be a room alias or a room id. Try one then the # the identifier could be a room alias or a room id. Try one then the
# other if it fails to parse, without swallowing other valid # other if it fails to parse, without swallowing other valid
@ -242,16 +245,20 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
defer.returnValue((200, ret_dict)) defer.returnValue((200, ret_dict))
else: # room id else: # room id
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
content = {"membership": Membership.JOIN}
if is_guest:
content["kind"] = "guest"
yield msg_handler.create_and_send_event( yield msg_handler.create_and_send_event(
{ {
"type": EventTypes.Member, "type": EventTypes.Member,
"content": {"membership": Membership.JOIN}, "content": content,
"room_id": identifier.to_string(), "room_id": identifier.to_string(),
"sender": user.to_string(), "sender": user.to_string(),
"state_key": user.to_string(), "state_key": user.to_string(),
}, },
token_id=token_id, token_id=token_id,
txn_id=txn_id, txn_id=txn_id,
is_guest=is_guest,
) )
defer.returnValue((200, {"room_id": identifier.to_string()})) defer.returnValue((200, {"room_id": identifier.to_string()}))
@ -289,7 +296,7 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens) # TODO support Pagination stream API (limit/tokens)
user, _ = yield self.auth.get_user_by_req(request) user, _, _ = yield self.auth.get_user_by_req(request)
handler = self.handlers.message_handler handler = self.handlers.message_handler
events = yield handler.get_state_events( events = yield handler.get_state_events(
room_id=room_id, room_id=room_id,
@ -319,13 +326,13 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
})) }))
# TODO: Needs unit testing # TODO: Needs better unit testing
class RoomMessageListRestServlet(ClientV1RestServlet): class RoomMessageListRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/messages$") PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/messages$")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
user, _ = yield self.auth.get_user_by_req(request) user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = PaginationConfig.from_request( pagination_config = PaginationConfig.from_request(
request, default_limit=10, request, default_limit=10,
) )
@ -334,6 +341,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
msgs = yield handler.get_messages( msgs = yield handler.get_messages(
room_id=room_id, room_id=room_id,
user_id=user.to_string(), user_id=user.to_string(),
is_guest=is_guest,
pagin_config=pagination_config, pagin_config=pagination_config,
as_client_event=as_client_event as_client_event=as_client_event
) )
@ -347,12 +355,13 @@ class RoomStateRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
user, _ = yield self.auth.get_user_by_req(request) user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
handler = self.handlers.message_handler handler = self.handlers.message_handler
# Get all the current state for this room # Get all the current state for this room
events = yield handler.get_state_events( events = yield handler.get_state_events(
room_id=room_id, room_id=room_id,
user_id=user.to_string(), user_id=user.to_string(),
is_guest=is_guest,
) )
defer.returnValue((200, events)) defer.returnValue((200, events))
@ -363,12 +372,13 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
user, _ = yield self.auth.get_user_by_req(request) user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = PaginationConfig.from_request(request) pagination_config = PaginationConfig.from_request(request)
content = yield self.handlers.message_handler.room_initial_sync( content = yield self.handlers.message_handler.room_initial_sync(
room_id=room_id, room_id=room_id,
user_id=user.to_string(), user_id=user.to_string(),
pagin_config=pagination_config, pagin_config=pagination_config,
is_guest=is_guest,
) )
defer.returnValue((200, content)) defer.returnValue((200, content))
@ -408,12 +418,12 @@ class RoomEventContext(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id, event_id): def on_GET(self, request, room_id, event_id):
user, _ = yield self.auth.get_user_by_req(request) user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
limit = int(request.args.get("limit", [10])[0]) limit = int(request.args.get("limit", [10])[0])
results = yield self.handlers.room_context_handler.get_event_context( results = yield self.handlers.room_context_handler.get_event_context(
user, room_id, event_id, limit, user, room_id, event_id, limit, is_guest
) )
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
@ -443,21 +453,26 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, membership_action, txn_id=None): def on_POST(self, request, room_id, membership_action, txn_id=None):
user, token_id = yield self.auth.get_user_by_req(request) user, token_id, is_guest = yield self.auth.get_user_by_req(
request,
allow_guest=True
)
if is_guest and membership_action not in {Membership.JOIN, Membership.LEAVE}:
raise AuthError(403, "Guest access not allowed")
content = _parse_json(request) content = _parse_json(request)
# target user is you unless it is an invite # target user is you unless it is an invite
state_key = user.to_string() state_key = user.to_string()
if membership_action == "invite" and third_party_invites.has_invite_keys(content): if membership_action == "invite" and self._has_3pid_invite_keys(content):
yield self.handlers.room_member_handler.do_3pid_invite( yield self.handlers.room_member_handler.do_3pid_invite(
room_id, room_id,
user, user,
content["medium"], content["medium"],
content["address"], content["address"],
content["id_server"], content["id_server"],
content["display_name"],
token_id, token_id,
txn_id txn_id
) )
@ -477,29 +492,31 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
event_content = { content = {"membership": unicode(membership_action)}
"membership": unicode(membership_action), if is_guest:
} content["kind"] = "guest"
if membership_action == "join" and third_party_invites.has_join_keys(content):
event_content["third_party_invite"] = (
third_party_invites.extract_join_keys(content)
)
yield msg_handler.create_and_send_event( yield msg_handler.create_and_send_event(
{ {
"type": EventTypes.Member, "type": EventTypes.Member,
"content": event_content, "content": content,
"room_id": room_id, "room_id": room_id,
"sender": user.to_string(), "sender": user.to_string(),
"state_key": state_key, "state_key": state_key,
}, },
token_id=token_id, token_id=token_id,
txn_id=txn_id, txn_id=txn_id,
is_guest=is_guest,
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
def _has_3pid_invite_keys(self, content):
for key in {"id_server", "medium", "address"}:
if key not in content:
return False
return True
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_id, membership_action, txn_id): def on_PUT(self, request, room_id, membership_action, txn_id):
try: try:
@ -524,7 +541,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, event_id, txn_id=None): def on_POST(self, request, room_id, event_id, txn_id=None):
user, token_id = yield self.auth.get_user_by_req(request) user, token_id, _ = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
@ -564,7 +581,7 @@ class RoomTypingRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_id, user_id): def on_PUT(self, request, room_id, user_id):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
room_id = urllib.unquote(room_id) room_id = urllib.unquote(room_id)
target_user = UserID.from_string(urllib.unquote(user_id)) target_user = UserID.from_string(urllib.unquote(user_id))
@ -597,11 +614,12 @@ class SearchRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)
results = yield self.handlers.search_handler.search(auth_user, content) batch = request.args.get("next_batch", [None])[0]
results = yield self.handlers.search_handler.search(auth_user, content, batch)
defer.returnValue((200, results)) defer.returnValue((200, results))

View file

@ -28,7 +28,7 @@ class VoipRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
turnUris = self.hs.config.turn_uris turnUris = self.hs.config.turn_uris
turnSecret = self.hs.config.turn_shared_secret turnSecret = self.hs.config.turn_shared_secret

View file

@ -22,6 +22,7 @@ from . import (
receipts, receipts,
keys, keys,
tokenrefresh, tokenrefresh,
tags,
) )
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
@ -44,3 +45,4 @@ class ClientV2AlphaRestResource(JsonResource):
receipts.register_servlets(hs, client_resource) receipts.register_servlets(hs, client_resource)
keys.register_servlets(hs, client_resource) keys.register_servlets(hs, client_resource)
tokenrefresh.register_servlets(hs, client_resource) tokenrefresh.register_servlets(hs, client_resource)
tags.register_servlets(hs, client_resource)

View file

@ -55,7 +55,7 @@ class PasswordRestServlet(RestServlet):
if LoginType.PASSWORD in result: if LoginType.PASSWORD in result:
# if using password, they should also be logged in # if using password, they should also be logged in
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
if auth_user.to_string() != result[LoginType.PASSWORD]: if auth_user.to_string() != result[LoginType.PASSWORD]:
raise LoginError(400, "", Codes.UNKNOWN) raise LoginError(400, "", Codes.UNKNOWN)
user_id = auth_user.to_string() user_id = auth_user.to_string()
@ -102,7 +102,7 @@ class ThreepidRestServlet(RestServlet):
def on_GET(self, request): def on_GET(self, request):
yield run_on_reactor() yield run_on_reactor()
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
threepids = yield self.hs.get_datastore().user_get_threepids( threepids = yield self.hs.get_datastore().user_get_threepids(
auth_user.to_string() auth_user.to_string()
@ -120,7 +120,7 @@ class ThreepidRestServlet(RestServlet):
raise SynapseError(400, "Missing param", Codes.MISSING_PARAM) raise SynapseError(400, "Missing param", Codes.MISSING_PARAM)
threePidCreds = body['threePidCreds'] threePidCreds = body['threePidCreds']
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
threepid = yield self.identity_handler.threepid_from_creds(threePidCreds) threepid = yield self.identity_handler.threepid_from_creds(threePidCreds)

View file

@ -40,7 +40,7 @@ class GetFilterRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id, filter_id): def on_GET(self, request, user_id, filter_id):
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
if target_user != auth_user: if target_user != auth_user:
raise AuthError(403, "Cannot get filters for other users") raise AuthError(403, "Cannot get filters for other users")
@ -76,7 +76,7 @@ class CreateFilterRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, user_id): def on_POST(self, request, user_id):
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
if target_user != auth_user: if target_user != auth_user:
raise AuthError(403, "Cannot create filters for other users") raise AuthError(403, "Cannot create filters for other users")

View file

@ -64,7 +64,7 @@ class KeyUploadServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, device_id): def on_POST(self, request, device_id):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
user_id = auth_user.to_string() user_id = auth_user.to_string()
# TODO: Check that the device_id matches that in the authentication # TODO: Check that the device_id matches that in the authentication
# or derive the device_id from the authentication instead. # or derive the device_id from the authentication instead.
@ -109,7 +109,7 @@ class KeyUploadServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, device_id): def on_GET(self, request, device_id):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
user_id = auth_user.to_string() user_id = auth_user.to_string()
result = yield self.store.count_e2e_one_time_keys(user_id, device_id) result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
@ -181,7 +181,7 @@ class KeyQueryServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id, device_id): def on_GET(self, request, user_id, device_id):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
auth_user_id = auth_user.to_string() auth_user_id = auth_user.to_string()
user_id = user_id if user_id else auth_user_id user_id = user_id if user_id else auth_user_id
device_ids = [device_id] if device_id else [] device_ids = [device_id] if device_id else []

View file

@ -40,7 +40,7 @@ class ReceiptRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, receipt_type, event_id): def on_POST(self, request, room_id, receipt_type, event_id):
user, _ = yield self.auth.get_user_by_req(request) user, _, _ = yield self.auth.get_user_by_req(request)
if receipt_type != "m.read": if receipt_type != "m.read":
raise SynapseError(400, "Receipt type must be 'm.read'") raise SynapseError(400, "Receipt type must be 'm.read'")

View file

@ -16,7 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from ._base import client_v2_pattern, parse_json_dict_from_request from ._base import client_v2_pattern, parse_json_dict_from_request
@ -55,6 +55,19 @@ class RegisterRestServlet(RestServlet):
def on_POST(self, request): def on_POST(self, request):
yield run_on_reactor() yield run_on_reactor()
kind = "user"
if "kind" in request.args:
kind = request.args["kind"][0]
if kind == "guest":
ret = yield self._do_guest_registration()
defer.returnValue(ret)
return
elif kind != "user":
raise UnrecognizedRequestError(
"Do not understand membership kind: %s" % (kind,)
)
if '/register/email/requestToken' in request.path: if '/register/email/requestToken' in request.path:
ret = yield self.onEmailTokenRequest(request) ret = yield self.onEmailTokenRequest(request)
defer.returnValue(ret) defer.returnValue(ret)
@ -236,6 +249,18 @@ class RegisterRestServlet(RestServlet):
ret = yield self.identity_handler.requestEmailToken(**body) ret = yield self.identity_handler.requestEmailToken(**body)
defer.returnValue((200, ret)) defer.returnValue((200, ret))
@defer.inlineCallbacks
def _do_guest_registration(self):
if not self.hs.config.allow_guest_access:
defer.returnValue((403, "Guest access is disabled"))
user_id, _ = yield self.registration_handler.register(generate_token=False)
access_token = self.auth_handler.generate_access_token(user_id, ["guest = true"])
defer.returnValue((200, {
"user_id": user_id,
"access_token": access_token,
"home_server": self.hs.hostname,
}))
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
RegisterRestServlet(hs).register(http_server) RegisterRestServlet(hs).register(http_server)

View file

@ -20,6 +20,7 @@ from synapse.http.servlet import (
) )
from synapse.handlers.sync import SyncConfig from synapse.handlers.sync import SyncConfig
from synapse.types import StreamToken from synapse.types import StreamToken
from synapse.events import FrozenEvent
from synapse.events.utils import ( from synapse.events.utils import (
serialize_event, format_event_for_client_v2_without_event_id, serialize_event, format_event_for_client_v2_without_event_id,
) )
@ -81,7 +82,7 @@ class SyncRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
user, token_id = yield self.auth.get_user_by_req(request) user, token_id, _ = yield self.auth.get_user_by_req(request)
timeout = parse_integer(request, "timeout", default=0) timeout = parse_integer(request, "timeout", default=0)
since = parse_string(request, "since") since = parse_string(request, "since")
@ -165,6 +166,20 @@ class SyncRestServlet(RestServlet):
return {"events": filter.filter_presence(formatted)} return {"events": filter.filter_presence(formatted)}
def encode_joined(self, rooms, filter, time_now, token_id): def encode_joined(self, rooms, filter, time_now, token_id):
"""
Encode the joined rooms in a sync result
:param list[synapse.handlers.sync.JoinedSyncResult] rooms: list of sync
results for rooms this user is joined to
:param FilterCollection filter: filters to apply to the results
:param int time_now: current time - used as a baseline for age
calculations
:param int token_id: ID of the user's auth token - used for namespacing
of transaction IDs
:return: the joined rooms list, in our response format
:rtype: dict[str, dict[str, object]]
"""
joined = {} joined = {}
for room in rooms: for room in rooms:
joined[room.room_id] = self.encode_room( joined[room.room_id] = self.encode_room(
@ -174,6 +189,20 @@ class SyncRestServlet(RestServlet):
return joined return joined
def encode_invited(self, rooms, filter, time_now, token_id): def encode_invited(self, rooms, filter, time_now, token_id):
"""
Encode the invited rooms in a sync result
:param list[synapse.handlers.sync.InvitedSyncResult] rooms: list of
sync results for rooms this user is joined to
:param FilterCollection filter: filters to apply to the results
:param int time_now: current time - used as a baseline for age
calculations
:param int token_id: ID of the user's auth token - used for namespacing
of transaction IDs
:return: the invited rooms list, in our response format
:rtype: dict[str, dict[str, object]]
"""
invited = {} invited = {}
for room in rooms: for room in rooms:
invite = serialize_event( invite = serialize_event(
@ -189,6 +218,20 @@ class SyncRestServlet(RestServlet):
return invited return invited
def encode_archived(self, rooms, filter, time_now, token_id): def encode_archived(self, rooms, filter, time_now, token_id):
"""
Encode the archived rooms in a sync result
:param list[synapse.handlers.sync.ArchivedSyncResult] rooms: list of
sync results for rooms this user is joined to
:param FilterCollection filter: filters to apply to the results
:param int time_now: current time - used as a baseline for age
calculations
:param int token_id: ID of the user's auth token - used for namespacing
of transaction IDs
:return: the invited rooms list, in our response format
:rtype: dict[str, dict[str, object]]
"""
joined = {} joined = {}
for room in rooms: for room in rooms:
joined[room.room_id] = self.encode_room( joined[room.room_id] = self.encode_room(
@ -199,8 +242,28 @@ class SyncRestServlet(RestServlet):
@staticmethod @staticmethod
def encode_room(room, filter, time_now, token_id, joined=True): def encode_room(room, filter, time_now, token_id, joined=True):
"""
:param JoinedSyncResult|ArchivedSyncResult room: sync result for a
single room
:param FilterCollection filter: filters to apply to the results
:param int time_now: current time - used as a baseline for age
calculations
:param int token_id: ID of the user's auth token - used for namespacing
of transaction IDs
:param joined: True if the user is joined to this room - will mean
we handle ephemeral events
:return: the room, encoded in our response format
:rtype: dict[str, object]
"""
event_map = {} event_map = {}
state_events = filter.filter_room_state(room.state) state_dict = room.state
timeline_events = filter.filter_room_timeline(room.timeline.events)
state_dict = SyncRestServlet._rollback_state_for_timeline(
state_dict, timeline_events)
state_events = filter.filter_room_state(state_dict.values())
state_event_ids = [] state_event_ids = []
for event in state_events: for event in state_events:
# TODO(mjark): Respect formatting requirements in the filter. # TODO(mjark): Respect formatting requirements in the filter.
@ -210,7 +273,6 @@ class SyncRestServlet(RestServlet):
) )
state_event_ids.append(event.event_id) state_event_ids.append(event.event_id)
timeline_events = filter.filter_room_timeline(room.timeline.events)
timeline_event_ids = [] timeline_event_ids = []
for event in timeline_events: for event in timeline_events:
# TODO(mjark): Respect formatting requirements in the filter. # TODO(mjark): Respect formatting requirements in the filter.
@ -220,6 +282,10 @@ class SyncRestServlet(RestServlet):
) )
timeline_event_ids.append(event.event_id) timeline_event_ids.append(event.event_id)
private_user_data = filter.filter_room_private_user_data(
room.private_user_data
)
result = { result = {
"event_map": event_map, "event_map": event_map,
"timeline": { "timeline": {
@ -228,6 +294,7 @@ class SyncRestServlet(RestServlet):
"limited": room.timeline.limited, "limited": room.timeline.limited,
}, },
"state": {"events": state_event_ids}, "state": {"events": state_event_ids},
"private_user_data": {"events": private_user_data},
} }
if joined: if joined:
@ -236,6 +303,63 @@ class SyncRestServlet(RestServlet):
return result return result
@staticmethod
def _rollback_state_for_timeline(state, timeline):
"""
Wind the state dictionary backwards, so that it represents the
state at the start of the timeline, rather than at the end.
:param dict[(str, str), synapse.events.EventBase] state: the
state dictionary. Will be updated to the state before the timeline.
:param list[synapse.events.EventBase] timeline: the event timeline
:return: updated state dictionary
"""
logger.debug("Processing state dict %r; timeline %r", state,
[e.get_dict() for e in timeline])
result = state.copy()
for timeline_event in reversed(timeline):
if not timeline_event.is_state():
continue
event_key = (timeline_event.type, timeline_event.state_key)
logger.debug("Considering %s for removal", event_key)
state_event = result.get(event_key)
if (state_event is None or
state_event.event_id != timeline_event.event_id):
# the event in the timeline isn't present in the state
# dictionary.
#
# the most likely cause for this is that there was a fork in
# the event graph, and the state is no longer valid. Really,
# the event shouldn't be in the timeline. We're going to ignore
# it for now, however.
logger.warn("Found state event %r in timeline which doesn't "
"match state dictionary", timeline_event)
continue
prev_event_id = timeline_event.unsigned.get("replaces_state", None)
logger.debug("Replacing %s with %s in state dict",
timeline_event.event_id, prev_event_id)
if prev_event_id is None:
del result[event_key]
else:
result[event_key] = FrozenEvent({
"type": timeline_event.type,
"state_key": timeline_event.state_key,
"content": timeline_event.unsigned['prev_content'],
"sender": timeline_event.unsigned['prev_sender'],
"event_id": prev_event_id,
"room_id": timeline_event.room_id,
})
logger.debug("New value: %r", result.get(event_key))
return result
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
SyncRestServlet(hs).register(http_server) SyncRestServlet(hs).register(http_server)

View file

@ -0,0 +1,106 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import client_v2_pattern
from synapse.http.servlet import RestServlet
from synapse.api.errors import AuthError, SynapseError
from twisted.internet import defer
import logging
import simplejson as json
logger = logging.getLogger(__name__)
class TagListServlet(RestServlet):
"""
GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1
"""
PATTERN = client_v2_pattern(
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags"
)
def __init__(self, hs):
super(TagListServlet, self).__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@defer.inlineCallbacks
def on_GET(self, request, user_id, room_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
if user_id != auth_user.to_string():
raise AuthError(403, "Cannot get tags for other users.")
tags = yield self.store.get_tags_for_room(user_id, room_id)
defer.returnValue((200, {"tags": tags}))
class TagServlet(RestServlet):
"""
PUT /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1
DELETE /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1
"""
PATTERN = client_v2_pattern(
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags/(?P<tag>[^/]*)"
)
def __init__(self, hs):
super(TagServlet, self).__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
@defer.inlineCallbacks
def on_PUT(self, request, user_id, room_id, tag):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
if user_id != auth_user.to_string():
raise AuthError(403, "Cannot add tags for other users.")
try:
content_bytes = request.content.read()
body = json.loads(content_bytes)
except:
raise SynapseError(400, "Invalid tag JSON")
max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body)
yield self.notifier.on_new_event(
"private_user_data_key", max_id, users=[user_id]
)
defer.returnValue((200, {}))
@defer.inlineCallbacks
def on_DELETE(self, request, user_id, room_id, tag):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
if user_id != auth_user.to_string():
raise AuthError(403, "Cannot add tags for other users.")
max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag)
yield self.notifier.on_new_event(
"private_user_data_key", max_id, users=[user_id]
)
defer.returnValue((200, {}))
def register_servlets(hs, http_server):
TagListServlet(hs).register(http_server)
TagServlet(hs).register(http_server)

View file

@ -66,7 +66,7 @@ class ContentRepoResource(resource.Resource):
@defer.inlineCallbacks @defer.inlineCallbacks
def map_request_to_name(self, request): def map_request_to_name(self, request):
# auth the user # auth the user
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
# namespace all file uploads on the user # namespace all file uploads on the user
prefix = base64.urlsafe_b64encode( prefix = base64.urlsafe_b64encode(

View file

@ -70,7 +70,7 @@ class UploadResource(BaseMediaResource):
@request_handler @request_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def _async_render_POST(self, request): def _async_render_POST(self, request):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have # TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point # already been uploaded to a tmp file at this point
content_length = request.getHeader("Content-Length") content_length = request.getHeader("Content-Length")

View file

@ -29,7 +29,6 @@ from synapse.state import StateHandler
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.util import Clock from synapse.util import Clock
from synapse.util.distributor import Distributor from synapse.util.distributor import Distributor
from synapse.util.lockutils import LockManager
from synapse.streams.events import EventSources from synapse.streams.events import EventSources
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
from synapse.crypto.keyring import Keyring from synapse.crypto.keyring import Keyring
@ -70,7 +69,6 @@ class BaseHomeServer(object):
'auth', 'auth',
'rest_servlet_factory', 'rest_servlet_factory',
'state_handler', 'state_handler',
'room_lock_manager',
'notifier', 'notifier',
'distributor', 'distributor',
'resource_for_client', 'resource_for_client',
@ -201,9 +199,6 @@ class HomeServer(BaseHomeServer):
def build_state_handler(self): def build_state_handler(self):
return StateHandler(self) return StateHandler(self)
def build_room_lock_manager(self):
return LockManager()
def build_distributor(self): def build_distributor(self):
return Distributor() return Distributor()

View file

@ -71,7 +71,7 @@ class StateHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key=""): def get_current_state(self, room_id, event_type=None, state_key=""):
""" Returns the current state for the room as a list. This is done by """ Retrieves the current state for the room. This is done by
calling `get_latest_events_in_room` to get the leading edges of the calling `get_latest_events_in_room` to get the leading edges of the
event graph and then resolving any of the state conflicts. event graph and then resolving any of the state conflicts.
@ -80,6 +80,8 @@ class StateHandler(object):
If `event_type` is specified, then the method returns only the one If `event_type` is specified, then the method returns only the one
event (or None) with that `event_type` and `state_key`. event (or None) with that `event_type` and `state_key`.
:returns map from (type, state_key) to event
""" """
event_ids = yield self.store.get_latest_event_ids_in_room(room_id) event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
@ -177,9 +179,10 @@ class StateHandler(object):
""" Given a list of event_ids this method fetches the state at each """ Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them. event, resolves conflicts between them and returns them.
Return format is a tuple: (`state_group`, `state_events`), where the :returns a Deferred tuple of (`state_group`, `state`, `prev_state`).
first is the name of a state group if one and only one is involved, `state_group` is the name of a state group if one and only one is
otherwise `None`. involved. `state` is a map from (type, state_key) to event, and
`prev_state` is a list of event ids.
""" """
logger.debug("resolve_state_groups event_ids %s", event_ids) logger.debug("resolve_state_groups event_ids %s", event_ids)
@ -255,6 +258,11 @@ class StateHandler(object):
return self._resolve_events(state_sets) return self._resolve_events(state_sets)
def _resolve_events(self, state_sets, event_type=None, state_key=""): def _resolve_events(self, state_sets, event_type=None, state_key=""):
"""
:returns a tuple (new_state, prev_states). new_state is a map
from (type, state_key) to event. prev_states is a list of event_ids.
:rtype: (dict[(str, str), synapse.events.FrozenEvent], list[str])
"""
state = {} state = {}
for st in state_sets: for st in state_sets:
for e in st: for e in st:
@ -307,19 +315,23 @@ class StateHandler(object):
We resolve conflicts in the following order: We resolve conflicts in the following order:
1. power levels 1. power levels
2. memberships 2. join rules
3. other events. 3. memberships
4. other events.
""" """
resolved_state = {} resolved_state = {}
power_key = (EventTypes.PowerLevels, "") power_key = (EventTypes.PowerLevels, "")
if power_key in conflicted_state.items(): if power_key in conflicted_state:
power_levels = conflicted_state[power_key] events = conflicted_state[power_key]
resolved_state[power_key] = self._resolve_auth_events(power_levels) logger.debug("Resolving conflicted power levels %r", events)
resolved_state[power_key] = self._resolve_auth_events(
events, auth_events)
auth_events.update(resolved_state) auth_events.update(resolved_state)
for key, events in conflicted_state.items(): for key, events in conflicted_state.items():
if key[0] == EventTypes.JoinRules: if key[0] == EventTypes.JoinRules:
logger.debug("Resolving conflicted join rules %r", events)
resolved_state[key] = self._resolve_auth_events( resolved_state[key] = self._resolve_auth_events(
events, events,
auth_events auth_events
@ -329,6 +341,7 @@ class StateHandler(object):
for key, events in conflicted_state.items(): for key, events in conflicted_state.items():
if key[0] == EventTypes.Member: if key[0] == EventTypes.Member:
logger.debug("Resolving conflicted member lists %r", events)
resolved_state[key] = self._resolve_auth_events( resolved_state[key] = self._resolve_auth_events(
events, events,
auth_events auth_events
@ -338,6 +351,7 @@ class StateHandler(object):
for key, events in conflicted_state.items(): for key, events in conflicted_state.items():
if key not in resolved_state: if key not in resolved_state:
logger.debug("Resolving conflicted state %r:%r", key, events)
resolved_state[key] = self._resolve_normal_events( resolved_state[key] = self._resolve_normal_events(
events, auth_events events, auth_events
) )

View file

@ -0,0 +1,50 @@
<html>
<head>
<title> Login </title>
<meta name='viewport' content='width=device-width, initial-scale=1, user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
<link rel="stylesheet" href="style.css">
<script src="js/jquery-2.1.3.min.js"></script>
<script src="js/login.js"></script>
</head>
<body onload="matrixLogin.onLoad()">
<center>
<br/>
<h1>Log in with one of the following methods</h1>
<span id="feedback" style="color: #f00"></span>
<br/>
<br/>
<div id="loading">
<img src="spinner.gif" />
</div>
<div id="cas_flow" class="login_flow" style="display:none"
onclick="gotoCas(); return false;">
CAS Authentication: <button id="cas_button" style="margin: 10px">Log in</button>
</div>
<br/>
<form id="password_form" class="login_flow" style="display:none"
onsubmit="matrixLogin.password_login(); return false;">
<div>
Password Authentication:<br/>
<div style="text-align: center">
<input id="user_id" size="32" type="text" placeholder="Matrix ID (e.g. bob)" autocapitalize="off" autocorrect="off" />
<br/>
<input id="password" size="32" type="password" placeholder="Password"/>
<br/>
<button type="submit" style="margin: 10px">Log in</button>
</div>
</div>
</form>
<div id="no_login_types" type="button" class="login_flow" style="display:none">
Log in currently unavailable.
</div>
</center>
</body>
</html>

View file

@ -0,0 +1,167 @@
window.matrixLogin = {
endpoint: location.origin + "/_matrix/client/api/v1/login",
serverAcceptsPassword: false,
serverAcceptsCas: false
};
var submitPassword = function(user, pwd) {
console.log("Logging in with password...");
var data = {
type: "m.login.password",
user: user,
password: pwd,
};
$.post(matrixLogin.endpoint, JSON.stringify(data), function(response) {
show_login();
matrixLogin.onLogin(response);
}).error(errorFunc);
};
var submitCas = function(ticket, service) {
console.log("Logging in with cas...");
var data = {
type: "m.login.cas",
ticket: ticket,
service: service,
};
$.post(matrixLogin.endpoint, JSON.stringify(data), function(response) {
show_login();
matrixLogin.onLogin(response);
}).error(errorFunc);
};
var errorFunc = function(err) {
show_login();
if (err.responseJSON && err.responseJSON.error) {
setFeedbackString(err.responseJSON.error + " (" + err.responseJSON.errcode + ")");
}
else {
setFeedbackString("Request failed: " + err.status);
}
};
var getCasURL = function(cb) {
$.get(matrixLogin.endpoint + "/cas", function(response) {
var cas_url = response.serverUrl;
cb(cas_url);
}).error(errorFunc);
};
var gotoCas = function() {
getCasURL(function(cas_url) {
var this_page = window.location.origin + window.location.pathname;
var redirect_url = cas_url + "/login?service=" + encodeURIComponent(this_page);
window.location.replace(redirect_url);
});
}
var setFeedbackString = function(text) {
$("#feedback").text(text);
};
var show_login = function() {
$("#loading").hide();
if (matrixLogin.serverAcceptsPassword) {
$("#password_form").show();
}
if (matrixLogin.serverAcceptsCas) {
$("#cas_flow").show();
}
if (!matrixLogin.serverAcceptsPassword && !matrixLogin.serverAcceptsCas) {
$("#no_login_types").show();
}
};
var show_spinner = function() {
$("#password_form").hide();
$("#cas_flow").hide();
$("#no_login_types").hide();
$("#loading").show();
};
var fetch_info = function(cb) {
$.get(matrixLogin.endpoint, function(response) {
var serverAcceptsPassword = false;
var serverAcceptsCas = false;
for (var i=0; i<response.flows.length; i++) {
var flow = response.flows[i];
if ("m.login.cas" === flow.type) {
matrixLogin.serverAcceptsCas = true;
console.log("Server accepts CAS");
}
if ("m.login.password" === flow.type) {
matrixLogin.serverAcceptsPassword = true;
console.log("Server accepts password");
}
}
cb();
}).error(errorFunc);
}
matrixLogin.onLoad = function() {
fetch_info(function() {
if (!try_cas()) {
show_login();
}
});
};
matrixLogin.password_login = function() {
var user = $("#user_id").val();
var pwd = $("#password").val();
setFeedbackString("");
show_spinner();
submitPassword(user, pwd);
};
matrixLogin.onLogin = function(response) {
// clobber this function
console.log("onLogin - This function should be replaced to proceed.");
console.log(response);
};
var parseQsFromUrl = function(query) {
var result = {};
query.split("&").forEach(function(part) {
var item = part.split("=");
var key = item[0];
var val = item[1];
if (val) {
val = decodeURIComponent(val);
}
result[key] = val
});
return result;
};
var try_cas = function() {
var pos = window.location.href.indexOf("?");
if (pos == -1) {
return false;
}
var qs = parseQsFromUrl(window.location.href.substr(pos+1));
var ticket = qs.ticket;
if (!ticket) {
return false;
}
submitCas(ticket, location.origin);
return true;
};

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.8 KiB

View file

@ -0,0 +1,57 @@
html {
height: 100%;
}
body {
height: 100%;
font-family: "Myriad Pro", "Myriad", Helvetica, Arial, sans-serif;
font-size: 12pt;
margin: 0px;
}
h1 {
font-size: 20pt;
}
a:link { color: #666; }
a:visited { color: #666; }
a:hover { color: #000; }
a:active { color: #000; }
input {
width: 90%
}
textarea, input {
font-family: inherit;
font-size: inherit;
margin: 5px;
}
.smallPrint {
color: #888;
font-size: 9pt ! important;
font-style: italic ! important;
}
.g-recaptcha div {
margin: auto;
}
.login_flow {
text-align: left;
padding: 10px;
margin-bottom: 40px;
display: inline-block;
-webkit-border-radius: 10px;
-moz-border-radius: 10px;
border-radius: 10px;
-webkit-box-shadow: 0px 0px 20px 0px rgba(0,0,0,0.15);
-moz-box-shadow: 0px 0px 20px 0px rgba(0,0,0,0.15);
box-shadow: 0px 0px 20px 0px rgba(0,0,0,0.15);
background-color: #f8f8f8;
border: 1px #ccc solid;
}

File diff suppressed because one or more lines are too long

View file

@ -41,6 +41,7 @@ from .end_to_end_keys import EndToEndKeyStore
from .receipts import ReceiptsStore from .receipts import ReceiptsStore
from .search import SearchStore from .search import SearchStore
from .tags import TagsStore
import logging import logging
@ -71,6 +72,7 @@ class DataStore(RoomMemberStore, RoomStore,
ReceiptsStore, ReceiptsStore,
EndToEndKeyStore, EndToEndKeyStore,
SearchStore, SearchStore,
TagsStore,
): ):
def __init__(self, hs): def __init__(self, hs):

View file

@ -0,0 +1,256 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore
from twisted.internet import defer
import ujson as json
import logging
logger = logging.getLogger(__name__)
class BackgroundUpdatePerformance(object):
"""Tracks the how long a background update is taking to update its items"""
def __init__(self, name):
self.name = name
self.total_item_count = 0
self.total_duration_ms = 0
self.avg_item_count = 0
self.avg_duration_ms = 0
def update(self, item_count, duration_ms):
"""Update the stats after doing an update"""
self.total_item_count += item_count
self.total_duration_ms += duration_ms
# Exponential moving averages for the number of items updated and
# the duration.
self.avg_item_count += 0.1 * (item_count - self.avg_item_count)
self.avg_duration_ms += 0.1 * (duration_ms - self.avg_duration_ms)
def average_items_per_ms(self):
"""An estimate of how long it takes to do a single update.
Returns:
A duration in ms as a float
"""
if self.total_item_count == 0:
return None
else:
# Use the exponential moving average so that we can adapt to
# changes in how long the update process takes.
return float(self.avg_item_count) / float(self.avg_duration_ms)
def total_items_per_ms(self):
"""An estimate of how long it takes to do a single update.
Returns:
A duration in ms as a float
"""
if self.total_item_count == 0:
return None
else:
return float(self.total_item_count) / float(self.total_duration_ms)
class BackgroundUpdateStore(SQLBaseStore):
""" Background updates are updates to the database that run in the
background. Each update processes a batch of data at once. We attempt to
limit the impact of each update by monitoring how long each batch takes to
process and autotuning the batch size.
"""
MINIMUM_BACKGROUND_BATCH_SIZE = 100
DEFAULT_BACKGROUND_BATCH_SIZE = 100
BACKGROUND_UPDATE_INTERVAL_MS = 1000
BACKGROUND_UPDATE_DURATION_MS = 100
def __init__(self, hs):
super(BackgroundUpdateStore, self).__init__(hs)
self._background_update_performance = {}
self._background_update_queue = []
self._background_update_handlers = {}
self._background_update_timer = None
@defer.inlineCallbacks
def start_doing_background_updates(self):
while True:
if self._background_update_timer is not None:
return
sleep = defer.Deferred()
self._background_update_timer = self._clock.call_later(
self.BACKGROUND_UPDATE_INTERVAL_MS / 1000., sleep.callback, None
)
try:
yield sleep
finally:
self._background_update_timer = None
try:
result = yield self.do_background_update(
self.BACKGROUND_UPDATE_DURATION_MS
)
except:
logger.exception("Error doing update")
if result is None:
logger.info(
"No more background updates to do."
" Unscheduling background update task."
)
return
@defer.inlineCallbacks
def do_background_update(self, desired_duration_ms):
"""Does some amount of work on a background update
Args:
desired_duration_ms(float): How long we want to spend
updating.
Returns:
A deferred that completes once some amount of work is done.
The deferred will have a value of None if there is currently
no more work to do.
"""
if not self._background_update_queue:
updates = yield self._simple_select_list(
"background_updates",
keyvalues=None,
retcols=("update_name",),
)
for update in updates:
self._background_update_queue.append(update['update_name'])
if not self._background_update_queue:
defer.returnValue(None)
update_name = self._background_update_queue.pop(0)
self._background_update_queue.append(update_name)
update_handler = self._background_update_handlers[update_name]
performance = self._background_update_performance.get(update_name)
if performance is None:
performance = BackgroundUpdatePerformance(update_name)
self._background_update_performance[update_name] = performance
items_per_ms = performance.average_items_per_ms()
if items_per_ms is not None:
batch_size = int(desired_duration_ms * items_per_ms)
# Clamp the batch size so that we always make progress
batch_size = max(batch_size, self.MINIMUM_BACKGROUND_BATCH_SIZE)
else:
batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE
progress_json = yield self._simple_select_one_onecol(
"background_updates",
keyvalues={"update_name": update_name},
retcol="progress_json"
)
progress = json.loads(progress_json)
time_start = self._clock.time_msec()
items_updated = yield update_handler(progress, batch_size)
time_stop = self._clock.time_msec()
duration_ms = time_stop - time_start
logger.info(
"Updating %r. Updated %r items in %rms."
" (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r)",
update_name, items_updated, duration_ms,
performance.total_items_per_ms(),
performance.average_items_per_ms(),
performance.total_item_count,
)
performance.update(items_updated, duration_ms)
defer.returnValue(len(self._background_update_performance))
def register_background_update_handler(self, update_name, update_handler):
"""Register a handler for doing a background update.
The handler should take two arguments:
* A dict of the current progress
* An integer count of the number of items to update in this batch.
The handler should return a deferred integer count of items updated.
The hander is responsible for updating the progress of the update.
Args:
update_name(str): The name of the update that this code handles.
update_handler(function): The function that does the update.
"""
self._background_update_handlers[update_name] = update_handler
def start_background_update(self, update_name, progress):
"""Starts a background update running.
Args:
update_name: The update to set running.
progress: The initial state of the progress of the update.
Returns:
A deferred that completes once the task has been added to the
queue.
"""
# Clear the background update queue so that we will pick up the new
# task on the next iteration of do_background_update.
self._background_update_queue = []
progress_json = json.dumps(progress)
return self._simple_insert(
"background_updates",
{"update_name": update_name, "progress_json": progress_json}
)
def _end_background_update(self, update_name):
"""Removes a completed background update task from the queue.
Args:
update_name(str): The name of the completed task to remove
Returns:
A deferred that completes once the task is removed.
"""
self._background_update_queue = [
name for name in self._background_update_queue if name != update_name
]
return self._simple_delete_one(
"background_updates", keyvalues={"update_name": update_name}
)
def _background_update_progress_txn(self, txn, update_name, progress):
"""Update the progress of a background update
Args:
txn(cursor): The transaction.
update_name(str): The name of the background update task
progress(dict): The progress of the update.
"""
progress_json = json.dumps(progress)
self._simple_update_one_txn(
txn,
"background_updates",
keyvalues={"update_name": update_name},
updatevalues={"progress_json": progress_json},
)

View file

@ -311,6 +311,10 @@ class EventsStore(SQLBaseStore):
self._store_room_message_txn(txn, event) self._store_room_message_txn(txn, event)
elif event.type == EventTypes.Redaction: elif event.type == EventTypes.Redaction:
self._store_redaction(txn, event) self._store_redaction(txn, event)
elif event.type == EventTypes.RoomHistoryVisibility:
self._store_history_visibility_txn(txn, event)
elif event.type == EventTypes.GuestAccess:
self._store_guest_access_txn(txn, event)
self._store_room_members_txn( self._store_room_members_txn(
txn, txn,
@ -827,7 +831,8 @@ class EventsStore(SQLBaseStore):
allow_none=True, allow_none=True,
) )
if prev: if prev:
ev.unsigned["prev_content"] = prev.get_dict()["content"] ev.unsigned["prev_content"] = prev.content
ev.unsigned["prev_sender"] = prev.sender
self._get_event_cache.prefill( self._get_event_cache.prefill(
(ev.event_id, check_redacted, get_prev_content), ev (ev.event_id, check_redacted, get_prev_content), ev
@ -884,7 +889,8 @@ class EventsStore(SQLBaseStore):
get_prev_content=False, get_prev_content=False,
) )
if prev: if prev:
ev.unsigned["prev_content"] = prev.get_dict()["content"] ev.unsigned["prev_content"] = prev.content
ev.unsigned["prev_sender"] = prev.sender
self._get_event_cache.prefill( self._get_event_cache.prefill(
(ev.event_id, check_redacted, get_prev_content), ev (ev.event_id, check_redacted, get_prev_content), ev

View file

@ -102,6 +102,7 @@ class RegistrationStore(SQLBaseStore):
400, "User ID already taken.", errcode=Codes.USER_IN_USE 400, "User ID already taken.", errcode=Codes.USER_IN_USE
) )
if token:
# it's possible for this to get a conflict, but only for a single user # it's possible for this to get a conflict, but only for a single user
# since tokens are namespaced based on their user ID # since tokens are namespaced based on their user ID
txn.execute( txn.execute(

View file

@ -99,34 +99,39 @@ class RoomStore(SQLBaseStore):
""" """
def f(txn): def f(txn):
topic_subquery = ( def subquery(table_name, column_name=None):
"SELECT topics.event_id as event_id, " column_name = column_name or table_name
"topics.room_id as room_id, topic " return (
"FROM topics " "SELECT %(table_name)s.event_id as event_id, "
"%(table_name)s.room_id as room_id, %(column_name)s "
"FROM %(table_name)s "
"INNER JOIN current_state_events as c " "INNER JOIN current_state_events as c "
"ON c.event_id = topics.event_id " "ON c.event_id = %(table_name)s.event_id " % {
"column_name": column_name,
"table_name": table_name,
}
) )
name_subquery = (
"SELECT room_names.event_id as event_id, "
"room_names.room_id as room_id, name "
"FROM room_names "
"INNER JOIN current_state_events as c "
"ON c.event_id = room_names.event_id "
)
# We use non printing ascii character US (\x1F) as a separator
sql = ( sql = (
"SELECT r.room_id, max(n.name), max(t.topic)" "SELECT"
" r.room_id,"
" max(n.name),"
" max(t.topic),"
" max(v.history_visibility),"
" max(g.guest_access)"
" FROM rooms AS r" " FROM rooms AS r"
" LEFT JOIN (%(topic)s) AS t ON t.room_id = r.room_id" " LEFT JOIN (%(topic)s) AS t ON t.room_id = r.room_id"
" LEFT JOIN (%(name)s) AS n ON n.room_id = r.room_id" " LEFT JOIN (%(name)s) AS n ON n.room_id = r.room_id"
" LEFT JOIN (%(history_visibility)s) AS v ON v.room_id = r.room_id"
" LEFT JOIN (%(guest_access)s) AS g ON g.room_id = r.room_id"
" WHERE r.is_public = ?" " WHERE r.is_public = ?"
" GROUP BY r.room_id" " GROUP BY r.room_id" % {
) % { "topic": subquery("topics", "topic"),
"topic": topic_subquery, "name": subquery("room_names", "name"),
"name": name_subquery, "history_visibility": subquery("history_visibility"),
"guest_access": subquery("guest_access"),
} }
)
txn.execute(sql, (is_public,)) txn.execute(sql, (is_public,))
@ -156,10 +161,12 @@ class RoomStore(SQLBaseStore):
"room_id": r[0], "room_id": r[0],
"name": r[1], "name": r[1],
"topic": r[2], "topic": r[2],
"aliases": r[3], "world_readable": r[3] == "world_readable",
"guest_can_join": r[4] == "can_join",
"aliases": r[5],
} }
for r in rows for r in rows
if r[3] # We only return rooms that have at least one alias. if r[5] # We only return rooms that have at least one alias.
] ]
defer.returnValue(ret) defer.returnValue(ret)
@ -202,6 +209,25 @@ class RoomStore(SQLBaseStore):
txn, event, "content.body", event.content["body"] txn, event, "content.body", event.content["body"]
) )
def _store_history_visibility_txn(self, txn, event):
self._store_content_index_txn(txn, event, "history_visibility")
def _store_guest_access_txn(self, txn, event):
self._store_content_index_txn(txn, event, "guest_access")
def _store_content_index_txn(self, txn, event, key):
if hasattr(event, "content") and key in event.content:
sql = (
"INSERT INTO %(key)s"
" (event_id, room_id, %(key)s)"
" VALUES (?, ?, ?)" % {"key": key}
)
txn.execute(sql, (
event.event_id,
event.room_id,
event.content[key]
))
def _store_event_search_txn(self, txn, event, key, value): def _store_event_search_txn(self, txn, event, key, value):
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
sql = ( sql = (

View file

@ -0,0 +1,21 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS background_updates(
update_name TEXT NOT NULL, -- The name of the background update.
progress_json TEXT NOT NULL, -- The current progress of the update as JSON.
CONSTRAINT background_updates_uniqueness UNIQUE (update_name)
);

View file

@ -22,7 +22,7 @@ import ujson
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
POSTGRES_SQL = """ POSTGRES_TABLE = """
CREATE TABLE IF NOT EXISTS event_search ( CREATE TABLE IF NOT EXISTS event_search (
event_id TEXT, event_id TEXT,
room_id TEXT, room_id TEXT,
@ -31,22 +31,6 @@ CREATE TABLE IF NOT EXISTS event_search (
vector tsvector vector tsvector
); );
INSERT INTO event_search SELECT
event_id, room_id, json::json->>'sender', 'content.body',
to_tsvector('english', json::json->'content'->>'body')
FROM events NATURAL JOIN event_json WHERE type = 'm.room.message';
INSERT INTO event_search SELECT
event_id, room_id, json::json->>'sender', 'content.name',
to_tsvector('english', json::json->'content'->>'name')
FROM events NATURAL JOIN event_json WHERE type = 'm.room.name';
INSERT INTO event_search SELECT
event_id, room_id, json::json->>'sender', 'content.topic',
to_tsvector('english', json::json->'content'->>'topic')
FROM events NATURAL JOIN event_json WHERE type = 'm.room.topic';
CREATE INDEX event_search_fts_idx ON event_search USING gin(vector); CREATE INDEX event_search_fts_idx ON event_search USING gin(vector);
CREATE INDEX event_search_ev_idx ON event_search(event_id); CREATE INDEX event_search_ev_idx ON event_search(event_id);
CREATE INDEX event_search_ev_ridx ON event_search(room_id); CREATE INDEX event_search_ev_ridx ON event_search(room_id);
@ -61,67 +45,34 @@ SQLITE_TABLE = (
def run_upgrade(cur, database_engine, *args, **kwargs): def run_upgrade(cur, database_engine, *args, **kwargs):
if isinstance(database_engine, PostgresEngine): if isinstance(database_engine, PostgresEngine):
run_postgres_upgrade(cur) for statement in get_statements(POSTGRES_TABLE.splitlines()):
return
if isinstance(database_engine, Sqlite3Engine):
run_sqlite_upgrade(cur)
return
def run_postgres_upgrade(cur):
for statement in get_statements(POSTGRES_SQL.splitlines()):
cur.execute(statement) cur.execute(statement)
elif isinstance(database_engine, Sqlite3Engine):
def run_sqlite_upgrade(cur):
cur.execute(SQLITE_TABLE) cur.execute(SQLITE_TABLE)
else:
raise Exception("Unrecognized database engine")
rowid = -1 cur.execute("SELECT MIN(stream_ordering) FROM events")
while True: rows = cur.fetchall()
cur.execute( min_stream_id = rows[0][0]
"SELECT rowid, json FROM event_json"
" WHERE rowid > ?" cur.execute("SELECT MAX(stream_ordering) FROM events")
" ORDER BY rowid ASC LIMIT 100", rows = cur.fetchall()
(rowid,) max_stream_id = rows[0][0]
if min_stream_id is not None and max_stream_id is not None:
progress = {
"target_min_stream_id_inclusive": min_stream_id,
"max_stream_id_exclusive": max_stream_id + 1,
"rows_inserted": 0,
}
progress_json = ujson.dumps(progress)
sql = (
"INSERT into background_updates (update_name, progress_json)"
" VALUES (?, ?)"
) )
res = cur.fetchall() sql = database_engine.convert_param_style(sql)
if not res: cur.execute(sql, ("event_search", progress_json))
break
events = [
ujson.loads(js)
for _, js in res
]
rowid = max(rid for rid, _ in res)
rows = []
for ev in events:
content = ev.get("content", {})
body = content.get("body", None)
name = content.get("name", None)
topic = content.get("topic", None)
sender = ev.get("sender", None)
if ev["type"] == "m.room.message" and body:
rows.append((
ev["event_id"], ev["room_id"], sender, "content.body", body
))
if ev["type"] == "m.room.name" and name:
rows.append((
ev["event_id"], ev["room_id"], sender, "content.name", name
))
if ev["type"] == "m.room.topic" and topic:
rows.append((
ev["event_id"], ev["room_id"], sender, "content.topic", topic
))
if rows:
logger.info(rows)
cur.executemany(
"INSERT INTO event_search (event_id, room_id, sender, key, value)"
" VALUES (?,?,?,?,?)",
rows
)

View file

@ -0,0 +1,25 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* This is a manual index of guest_access content of state events,
* so that we can join on them in SELECT statements.
*/
CREATE TABLE IF NOT EXISTS guest_access(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
guest_access TEXT NOT NULL,
UNIQUE (event_id)
);

View file

@ -0,0 +1,25 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* This is a manual index of history_visibility content of state events,
* so that we can join on them in SELECT statements.
*/
CREATE TABLE IF NOT EXISTS history_visibility(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
history_visibility TEXT NOT NULL,
UNIQUE (event_id)
);

View file

@ -0,0 +1,38 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS room_tags(
user_id TEXT NOT NULL,
room_id TEXT NOT NULL,
tag TEXT NOT NULL, -- The name of the tag.
content TEXT NOT NULL, -- The JSON content of the tag.
CONSTRAINT room_tag_uniqueness UNIQUE (user_id, room_id, tag)
);
CREATE TABLE IF NOT EXISTS room_tags_revisions (
user_id TEXT NOT NULL,
room_id TEXT NOT NULL,
stream_id BIGINT NOT NULL, -- The current version of the room tags.
CONSTRAINT room_tag_revisions_uniqueness UNIQUE (user_id, room_id)
);
CREATE TABLE IF NOT EXISTS private_user_data_max_stream_id(
Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
stream_id BIGINT NOT NULL,
CHECK (Lock='X')
);
INSERT INTO private_user_data_max_stream_id (stream_id) VALUES (0);

View file

@ -15,22 +15,116 @@
from twisted.internet import defer from twisted.internet import defer
from _base import SQLBaseStore from .background_updates import BackgroundUpdateStore
from synapse.api.errors import SynapseError
from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from collections import namedtuple import logging
"""The result of a search.
Fields:
rank_map (dict): Mapping event_id -> rank
event_map (dict): Mapping event_id -> event
pagination_token (str): Pagination token
"""
SearchResult = namedtuple("SearchResult", ("rank_map", "event_map", "pagination_token"))
class SearchStore(SQLBaseStore): logger = logging.getLogger(__name__)
class SearchStore(BackgroundUpdateStore):
EVENT_SEARCH_UPDATE_NAME = "event_search"
def __init__(self, hs):
super(SearchStore, self).__init__(hs)
self.register_background_update_handler(
self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search
)
@defer.inlineCallbacks
def _background_reindex_search(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
INSERT_CLUMP_SIZE = 1000
TYPES = ["m.room.name", "m.room.message", "m.room.topic"]
def reindex_search_txn(txn):
sql = (
"SELECT stream_ordering, event_id FROM events"
" WHERE ? <= stream_ordering AND stream_ordering < ?"
" AND (%s)"
" ORDER BY stream_ordering DESC"
" LIMIT ?"
) % (" OR ".join("type = '%s'" % (t,) for t in TYPES),)
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
rows = txn.fetchall()
if not rows:
return 0
min_stream_id = rows[-1][0]
event_ids = [row[1] for row in rows]
events = self._get_events_txn(txn, event_ids)
event_search_rows = []
for event in events:
try:
event_id = event.event_id
room_id = event.room_id
content = event.content
if event.type == "m.room.message":
key = "content.body"
value = content["body"]
elif event.type == "m.room.topic":
key = "content.topic"
value = content["topic"]
elif event.type == "m.room.name":
key = "content.name"
value = content["name"]
except (KeyError, AttributeError):
# If the event is missing a necessary field then
# skip over it.
continue
event_search_rows.append((event_id, room_id, key, value))
if isinstance(self.database_engine, PostgresEngine):
sql = (
"INSERT INTO event_search (event_id, room_id, key, vector)"
" VALUES (?,?,?,to_tsvector('english', ?))"
)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
"INSERT INTO event_search (event_id, room_id, key, value)"
" VALUES (?,?,?,?)"
)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
for index in range(0, len(event_search_rows), INSERT_CLUMP_SIZE):
clump = event_search_rows[index:index + INSERT_CLUMP_SIZE]
txn.executemany(sql, clump)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
"max_stream_id_exclusive": min_stream_id,
"rows_inserted": rows_inserted + len(event_search_rows)
}
self._background_update_progress_txn(
txn, self.EVENT_SEARCH_UPDATE_NAME, progress
)
return len(event_search_rows)
result = yield self.runInteraction(
self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn
)
if not result:
yield self._end_background_update(self.EVENT_SEARCH_UPDATE_NAME)
defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def search_msgs(self, room_ids, search_term, keys): def search_msgs(self, room_ids, search_term, keys):
"""Performs a full text search over events with given keys. """Performs a full text search over events with given keys.
@ -42,7 +136,7 @@ class SearchStore(SQLBaseStore):
"content.body", "content.name", "content.topic" "content.body", "content.name", "content.topic"
Returns: Returns:
SearchResult list of dicts
""" """
clauses = [] clauses = []
args = [] args = []
@ -100,12 +194,114 @@ class SearchStore(SQLBaseStore):
for ev in events for ev in events
} }
defer.returnValue(SearchResult( defer.returnValue([
{ {
r["event_id"]: r["rank"] "event": event_map[r["event_id"]],
"rank": r["rank"],
}
for r in results for r in results
if r["event_id"] in event_map if r["event_id"] in event_map
}, ])
event_map,
None @defer.inlineCallbacks
)) def search_room(self, room_id, search_term, keys, limit, pagination_token=None):
"""Performs a full text search over events with given keys.
Args:
room_id (str): The room_id to search in
search_term (str): Search term to search for
keys (list): List of keys to search in, currently supports
"content.body", "content.name", "content.topic"
pagination_token (str): A pagination token previously returned
Returns:
list of dicts
"""
clauses = []
args = [search_term, room_id]
local_clauses = []
for key in keys:
local_clauses.append("key = ?")
args.append(key)
clauses.append(
"(%s)" % (" OR ".join(local_clauses),)
)
if pagination_token:
try:
topo, stream = pagination_token.split(",")
topo = int(topo)
stream = int(stream)
except:
raise SynapseError(400, "Invalid pagination token")
clauses.append(
"(topological_ordering < ?"
" OR (topological_ordering = ? AND stream_ordering < ?))"
)
args.extend([topo, topo, stream])
if isinstance(self.database_engine, PostgresEngine):
sql = (
"SELECT ts_rank_cd(vector, query) as rank,"
" topological_ordering, stream_ordering, room_id, event_id"
" FROM plainto_tsquery('english', ?) as query, event_search"
" NATURAL JOIN events"
" WHERE vector @@ query AND room_id = ?"
)
elif isinstance(self.database_engine, Sqlite3Engine):
# We use CROSS JOIN here to ensure we use the right indexes.
# https://sqlite.org/optoverview.html#crossjoin
#
# We want to use the full text search index on event_search to
# extract all possible matches first, then lookup those matches
# in the events table to get the topological ordering. We need
# to use the indexes in this order because sqlite refuses to
# MATCH unless it uses the full text search index
sql = (
"SELECT rank(matchinfo) as rank, room_id, event_id,"
" topological_ordering, stream_ordering"
" FROM (SELECT key, event_id, matchinfo(event_search) as matchinfo"
" FROM event_search"
" WHERE value MATCH ?"
" )"
" CROSS JOIN events USING (event_id)"
" WHERE room_id = ?"
)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
for clause in clauses:
sql += " AND " + clause
# We add an arbitrary limit here to ensure we don't try to pull the
# entire table from the database.
sql += " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?"
args.append(limit)
results = yield self._execute(
"search_rooms", self.cursor_to_dict, sql, *args
)
events = yield self._get_events([r["event_id"] for r in results])
event_map = {
ev.event_id: ev
for ev in events
}
defer.returnValue([
{
"event": event_map[r["event_id"]],
"rank": r["rank"],
"pagination_token": "%s,%s" % (
r["topological_ordering"], r["stream_ordering"]
),
}
for r in results
if r["event_id"] in event_map
])

View file

@ -237,6 +237,20 @@ class StateStore(SQLBaseStore):
defer.returnValue({event: event_to_state[event] for event in event_ids}) defer.returnValue({event: event_to_state[event] for event in event_ids})
@defer.inlineCallbacks
def get_state_for_event(self, event_id, types=None):
"""
Get the state dict corresponding to a particular event
:param str event_id: event whose state should be returned
:param list[(str, str)]|None types: List of (type, state_key) tuples
which are used to filter the state fetched. May be None, which
matches any key
:return: a deferred dict from (type, state_key) -> state_event
"""
state_map = yield self.get_state_for_events([event_id], types)
defer.returnValue(state_map[event_id])
@cached(num_args=2, lru=True, max_entries=10000) @cached(num_args=2, lru=True, max_entries=10000)
def _get_state_group_for_event(self, room_id, event_id): def _get_state_group_for_event(self, room_id, event_id):
return self._simple_select_one_onecol( return self._simple_select_one_onecol(

View file

@ -158,14 +158,40 @@ class StreamStore(SQLBaseStore):
defer.returnValue(results) defer.returnValue(results)
@log_function @log_function
def get_room_events_stream(self, user_id, from_key, to_key, room_id, def get_room_events_stream(
limit=0): self,
user_id,
from_key,
to_key,
limit=0,
is_guest=False,
room_ids=None
):
room_ids = room_ids or []
room_ids = [r for r in room_ids]
if is_guest:
current_room_membership_sql = (
"SELECT c.room_id FROM history_visibility AS h"
" INNER JOIN current_state_events AS c"
" ON h.event_id = c.event_id"
" WHERE c.room_id IN (%s) AND h.history_visibility = 'world_readable'" % (
",".join(map(lambda _: "?", room_ids))
)
)
current_room_membership_args = room_ids
else:
current_room_membership_sql = ( current_room_membership_sql = (
"SELECT m.room_id FROM room_memberships as m " "SELECT m.room_id FROM room_memberships as m "
" INNER JOIN current_state_events as c" " INNER JOIN current_state_events as c"
" ON m.event_id = c.event_id AND c.state_key = m.user_id" " ON m.event_id = c.event_id AND c.state_key = m.user_id"
" WHERE m.user_id = ? AND m.membership = 'join'" " WHERE m.user_id = ? AND m.membership = 'join'"
) )
current_room_membership_args = [user_id]
if room_ids:
current_room_membership_sql += " AND m.room_id in (%s)" % (
",".join(map(lambda _: "?", room_ids))
)
current_room_membership_args = [user_id] + room_ids
# We also want to get any membership events about that user, e.g. # We also want to get any membership events about that user, e.g.
# invites or leave notifications. # invites or leave notifications.
@ -174,6 +200,7 @@ class StreamStore(SQLBaseStore):
"INNER JOIN current_state_events as c ON m.event_id = c.event_id " "INNER JOIN current_state_events as c ON m.event_id = c.event_id "
"WHERE m.user_id = ? " "WHERE m.user_id = ? "
) )
membership_args = [user_id]
if limit: if limit:
limit = max(limit, MAX_STREAM_SIZE) limit = max(limit, MAX_STREAM_SIZE)
@ -200,7 +227,9 @@ class StreamStore(SQLBaseStore):
} }
def f(txn): def f(txn):
txn.execute(sql, (False, user_id, user_id, from_id.stream, to_id.stream,)) args = ([False] + current_room_membership_args + membership_args +
[from_id.stream, to_id.stream])
txn.execute(sql, args)
rows = self.cursor_to_dict(txn) rows = self.cursor_to_dict(txn)

216
synapse/storage/tags.py Normal file
View file

@ -0,0 +1,216 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached
from twisted.internet import defer
from .util.id_generators import StreamIdGenerator
import ujson as json
import logging
logger = logging.getLogger(__name__)
class TagsStore(SQLBaseStore):
def __init__(self, hs):
super(TagsStore, self).__init__(hs)
self._private_user_data_id_gen = StreamIdGenerator(
"private_user_data_max_stream_id", "stream_id"
)
def get_max_private_user_data_stream_id(self):
"""Get the current max stream id for the private user data stream
Returns:
A deferred int.
"""
return self._private_user_data_id_gen.get_max_token(self)
@cached()
def get_tags_for_user(self, user_id):
"""Get all the tags for a user.
Args:
user_id(str): The user to get the tags for.
Returns:
A deferred dict mapping from room_id strings to lists of tag
strings.
"""
deferred = self._simple_select_list(
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
)
@deferred.addCallback
def tags_by_room(rows):
tags_by_room = {}
for row in rows:
room_tags = tags_by_room.setdefault(row["room_id"], {})
room_tags[row["tag"]] = json.loads(row["content"])
return tags_by_room
return deferred
@defer.inlineCallbacks
def get_updated_tags(self, user_id, stream_id):
"""Get all the tags for the rooms where the tags have changed since the
given version
Args:
user_id(str): The user to get the tags for.
stream_id(int): The earliest update to get for the user.
Returns:
A deferred dict mapping from room_id strings to lists of tag
strings for all the rooms that changed since the stream_id token.
"""
def get_updated_tags_txn(txn):
sql = (
"SELECT room_id from room_tags_revisions"
" WHERE user_id = ? AND stream_id > ?"
)
txn.execute(sql, (user_id, stream_id))
room_ids = [row[0] for row in txn.fetchall()]
return room_ids
room_ids = yield self.runInteraction(
"get_updated_tags", get_updated_tags_txn
)
results = {}
if room_ids:
tags_by_room = yield self.get_tags_for_user(user_id)
for room_id in room_ids:
results[room_id] = tags_by_room.get(room_id, {})
defer.returnValue(results)
def get_tags_for_room(self, user_id, room_id):
"""Get all the tags for the given room
Args:
user_id(str): The user to get tags for
room_id(str): The room to get tags for
Returns:
A deferred list of string tags.
"""
return self._simple_select_list(
table="room_tags",
keyvalues={"user_id": user_id, "room_id": room_id},
retcols=("tag", "content"),
desc="get_tags_for_room",
).addCallback(lambda rows: {
row["tag"]: json.loads(row["content"]) for row in rows
})
@defer.inlineCallbacks
def add_tag_to_room(self, user_id, room_id, tag, content):
"""Add a tag to a room for a user.
Args:
user_id(str): The user to add a tag for.
room_id(str): The room to add a tag for.
tag(str): The tag name to add.
content(dict): A json object to associate with the tag.
Returns:
A deferred that completes once the tag has been added.
"""
content_json = json.dumps(content)
def add_tag_txn(txn, next_id):
self._simple_upsert_txn(
txn,
table="room_tags",
keyvalues={
"user_id": user_id,
"room_id": room_id,
"tag": tag,
},
values={
"content": content_json,
}
)
self._update_revision_txn(txn, user_id, room_id, next_id)
with (yield self._private_user_data_id_gen.get_next(self)) as next_id:
yield self.runInteraction("add_tag", add_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
result = yield self._private_user_data_id_gen.get_max_token(self)
defer.returnValue(result)
@defer.inlineCallbacks
def remove_tag_from_room(self, user_id, room_id, tag):
"""Remove a tag from a room for a user.
Returns:
A deferred that completes once the tag has been removed
"""
def remove_tag_txn(txn, next_id):
sql = (
"DELETE FROM room_tags "
" WHERE user_id = ? AND room_id = ? AND tag = ?"
)
txn.execute(sql, (user_id, room_id, tag))
self._update_revision_txn(txn, user_id, room_id, next_id)
with (yield self._private_user_data_id_gen.get_next(self)) as next_id:
yield self.runInteraction("remove_tag", remove_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
result = yield self._private_user_data_id_gen.get_max_token(self)
defer.returnValue(result)
def _update_revision_txn(self, txn, user_id, room_id, next_id):
"""Update the latest revision of the tags for the given user and room.
Args:
txn: The database cursor
user_id(str): The ID of the user.
room_id(str): The ID of the room.
next_id(int): The the revision to advance to.
"""
update_max_id_sql = (
"UPDATE private_user_data_max_stream_id"
" SET stream_id = ?"
" WHERE stream_id < ?"
)
txn.execute(update_max_id_sql, (next_id, next_id))
update_sql = (
"UPDATE room_tags_revisions"
" SET stream_id = ?"
" WHERE user_id = ?"
" AND room_id = ?"
)
txn.execute(update_sql, (next_id, user_id, room_id))
if txn.rowcount == 0:
insert_sql = (
"INSERT INTO room_tags_revisions (user_id, room_id, stream_id)"
" VALUES (?, ?, ?)"
)
try:
txn.execute(insert_sql, (user_id, room_id, next_id))
except self.database_engine.module.IntegrityError:
# Ignore insertion errors. It doesn't matter if the row wasn't
# inserted because if two updates happend concurrently the one
# with the higher stream_id will not be reported to a client
# unless the previous update has completed. It doesn't matter
# which stream_id ends up in the table, as long as it is higher
# than the id that the client has.
pass

View file

@ -59,7 +59,7 @@ class TransactionStore(SQLBaseStore):
allow_none=True, allow_none=True,
) )
if result and result.response_code: if result and result["response_code"]:
return result["response_code"], result["response_json"] return result["response_code"], result["response_json"]
else: else:
return None return None
@ -253,16 +253,6 @@ class TransactionStore(SQLBaseStore):
retry_interval (int) - how long until next retry in ms retry_interval (int) - how long until next retry in ms
""" """
# As this is the new value, we might as well prefill the cache
self.get_destination_retry_timings.prefill(
destination,
{
"destination": destination,
"retry_last_ts": retry_last_ts,
"retry_interval": retry_interval
},
)
# XXX: we could chose to not bother persisting this if our cache thinks # XXX: we could chose to not bother persisting this if our cache thinks
# this is a NOOP # this is a NOOP
return self.runInteraction( return self.runInteraction(
@ -275,25 +265,19 @@ class TransactionStore(SQLBaseStore):
def _set_destination_retry_timings(self, txn, destination, def _set_destination_retry_timings(self, txn, destination,
retry_last_ts, retry_interval): retry_last_ts, retry_interval):
query = ( txn.call_after(self.get_destination_retry_timings.invalidate, (destination,))
"UPDATE destinations"
" SET retry_last_ts = ?, retry_interval = ?"
" WHERE destination = ?"
)
txn.execute( self._simple_upsert_txn(
query,
(
retry_last_ts, retry_interval, destination,
)
)
if txn.rowcount == 0:
# destination wasn't already in table. Insert it.
self._simple_insert_txn(
txn, txn,
table="destinations", "destinations",
keyvalues={
"destination": destination,
},
values={ values={
"retry_last_ts": retry_last_ts,
"retry_interval": retry_interval,
},
insertion_values={
"destination": destination, "destination": destination,
"retry_last_ts": retry_last_ts, "retry_last_ts": retry_last_ts,
"retry_interval": retry_interval, "retry_interval": retry_interval,

View file

@ -21,6 +21,7 @@ from synapse.handlers.presence import PresenceEventSource
from synapse.handlers.room import RoomEventSource from synapse.handlers.room import RoomEventSource
from synapse.handlers.typing import TypingNotificationEventSource from synapse.handlers.typing import TypingNotificationEventSource
from synapse.handlers.receipts import ReceiptEventSource from synapse.handlers.receipts import ReceiptEventSource
from synapse.handlers.private_user_data import PrivateUserDataEventSource
class EventSources(object): class EventSources(object):
@ -29,6 +30,7 @@ class EventSources(object):
"presence": PresenceEventSource, "presence": PresenceEventSource,
"typing": TypingNotificationEventSource, "typing": TypingNotificationEventSource,
"receipt": ReceiptEventSource, "receipt": ReceiptEventSource,
"private_user_data": PrivateUserDataEventSource,
} }
def __init__(self, hs): def __init__(self, hs):
@ -52,5 +54,8 @@ class EventSources(object):
receipt_key=( receipt_key=(
yield self.sources["receipt"].get_current_key() yield self.sources["receipt"].get_current_key()
), ),
private_user_data_key=(
yield self.sources["private_user_data"].get_current_key()
),
) )
defer.returnValue(token) defer.returnValue(token)

View file

@ -98,10 +98,13 @@ class EventID(DomainSpecificString):
class StreamToken( class StreamToken(
namedtuple( namedtuple("Token", (
"Token", "room_key",
("room_key", "presence_key", "typing_key", "receipt_key") "presence_key",
) "typing_key",
"receipt_key",
"private_user_data_key",
))
): ):
_SEPARATOR = "_" _SEPARATOR = "_"
@ -109,7 +112,7 @@ class StreamToken(
def from_string(cls, string): def from_string(cls, string):
try: try:
keys = string.split(cls._SEPARATOR) keys = string.split(cls._SEPARATOR)
if len(keys) == len(cls._fields) - 1: while len(keys) < len(cls._fields):
# i.e. old token from before receipt_key # i.e. old token from before receipt_key
keys.append("0") keys.append("0")
return cls(*keys) return cls(*keys)
@ -128,13 +131,14 @@ class StreamToken(
else: else:
return int(self.room_key[1:].split("-")[-1]) return int(self.room_key[1:].split("-")[-1])
def is_after(self, other_token): def is_after(self, other):
"""Does this token contain events that the other doesn't?""" """Does this token contain events that the other doesn't?"""
return ( return (
(other_token.room_stream_id < self.room_stream_id) (other.room_stream_id < self.room_stream_id)
or (int(other_token.presence_key) < int(self.presence_key)) or (int(other.presence_key) < int(self.presence_key))
or (int(other_token.typing_key) < int(self.typing_key)) or (int(other.typing_key) < int(self.typing_key))
or (int(other_token.receipt_key) < int(self.receipt_key)) or (int(other.receipt_key) < int(self.receipt_key))
or (int(other.private_user_data_key) < int(self.private_user_data_key))
) )
def copy_and_advance(self, key, new_value): def copy_and_advance(self, key, new_value):

View file

@ -53,6 +53,14 @@ class Clock(object):
loop.stop() loop.stop()
def call_later(self, delay, callback, *args, **kwargs): def call_later(self, delay, callback, *args, **kwargs):
"""Call something later
Args:
delay(float): How long to wait in seconds.
callback(function): Function to call
*args: Postional arguments to pass to function.
**kwargs: Key arguments to pass to function.
"""
current_context = LoggingContext.current_context() current_context = LoggingContext.current_context()
def wrapped_callback(*args, **kwargs): def wrapped_callback(*args, **kwargs):

View file

@ -1,74 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
import logging
logger = logging.getLogger(__name__)
class Lock(object):
def __init__(self, deferred, key):
self._deferred = deferred
self.released = False
self.key = key
def release(self):
self.released = True
self._deferred.callback(None)
def __del__(self):
if not self.released:
logger.critical("Lock was destructed but never released!")
self.release()
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
logger.debug("Releasing lock for key=%r", self.key)
self.release()
class LockManager(object):
""" Utility class that allows us to lock based on a `key` """
def __init__(self):
self._lock_deferreds = {}
@defer.inlineCallbacks
def lock(self, key):
""" Allows us to block until it is our turn.
Args:
key (str)
Returns:
Lock
"""
new_deferred = defer.Deferred()
old_deferred = self._lock_deferreds.get(key)
self._lock_deferreds[key] = new_deferred
if old_deferred:
logger.debug("Queueing on lock for key=%r", key)
yield old_deferred
logger.debug("Obtained lock for key=%r", key)
else:
logger.debug("Entering uncontended lock for key=%r", key)
defer.returnValue(Lock(new_deferred, key))

View file

@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import CodeMessageException from synapse.api.errors import CodeMessageException
import logging import logging
import random
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -85,8 +86,9 @@ def get_retry_limiter(destination, clock, store, **kwargs):
class RetryDestinationLimiter(object): class RetryDestinationLimiter(object):
def __init__(self, destination, clock, store, retry_interval, def __init__(self, destination, clock, store, retry_interval,
min_retry_interval=5000, max_retry_interval=60 * 60 * 1000, min_retry_interval=10 * 60 * 1000,
multiplier_retry_interval=2,): max_retry_interval=24 * 60 * 60 * 1000,
multiplier_retry_interval=5,):
"""Marks the destination as "down" if an exception is thrown in the """Marks the destination as "down" if an exception is thrown in the
context, except for CodeMessageException with code < 500. context, except for CodeMessageException with code < 500.
@ -140,6 +142,7 @@ class RetryDestinationLimiter(object):
# We couldn't connect. # We couldn't connect.
if self.retry_interval: if self.retry_interval:
self.retry_interval *= self.multiplier_retry_interval self.retry_interval *= self.multiplier_retry_interval
self.retry_interval *= int(random.uniform(0.8, 1.4))
if self.retry_interval >= self.max_retry_interval: if self.retry_interval >= self.max_retry_interval:
self.retry_interval = self.max_retry_interval self.retry_interval = self.max_retry_interval

View file

@ -1,69 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import AuthError
INVITE_KEYS = {"id_server", "medium", "address", "display_name"}
JOIN_KEYS = {
"token",
"public_key",
"key_validity_url",
"sender",
"signed",
}
def has_invite_keys(content):
for key in INVITE_KEYS:
if key not in content:
return False
return True
def has_join_keys(content):
for key in JOIN_KEYS:
if key not in content:
return False
return True
def join_has_third_party_invite(content):
if "third_party_invite" not in content:
return False
return has_join_keys(content["third_party_invite"])
def extract_join_keys(src):
return {
key: value
for key, value in src.items()
if key in JOIN_KEYS
}
@defer.inlineCallbacks
def check_key_valid(http_client, event):
try:
response = yield http_client.get_json(
event.content["third_party_invite"]["key_validity_url"],
{"public_key": event.content["third_party_invite"]["public_key"]}
)
except Exception:
raise AuthError(502, "Third party certificate could not be checked")
if "valid" not in response or not response["valid"]:
raise AuthError(403, "Third party certificate was invalid")

View file

@ -51,7 +51,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = Mock(return_value=[""])
(user, _) = yield self.auth.get_user_by_req(request) (user, _, _) = yield self.auth.get_user_by_req(request)
self.assertEquals(user.to_string(), self.test_user) self.assertEquals(user.to_string(), self.test_user)
def test_get_user_by_req_user_bad_token(self): def test_get_user_by_req_user_bad_token(self):
@ -86,7 +86,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = Mock(return_value=[""])
(user, _) = yield self.auth.get_user_by_req(request) (user, _, _) = yield self.auth.get_user_by_req(request)
self.assertEquals(user.to_string(), self.test_user) self.assertEquals(user.to_string(), self.test_user)
def test_get_user_by_req_appservice_bad_token(self): def test_get_user_by_req_appservice_bad_token(self):
@ -121,7 +121,7 @@ class AuthTestCase(unittest.TestCase):
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
request.args["user_id"] = [masquerading_user_id] request.args["user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = Mock(return_value=[""])
(user, _) = yield self.auth.get_user_by_req(request) (user, _, _) = yield self.auth.get_user_by_req(request)
self.assertEquals(user.to_string(), masquerading_user_id) self.assertEquals(user.to_string(), masquerading_user_id)
def test_get_user_by_req_appservice_valid_token_bad_user_id(self): def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
@ -158,6 +158,25 @@ class AuthTestCase(unittest.TestCase):
user = user_info["user"] user = user_info["user"]
self.assertEqual(UserID.from_string(user_id), user) self.assertEqual(UserID.from_string(user_id), user)
@defer.inlineCallbacks
def test_get_guest_user_from_macaroon(self):
user_id = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
key=self.hs.config.macaroon_secret_key)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
macaroon.add_first_party_caveat("guest = true")
serialized = macaroon.serialize()
user_info = yield self.auth._get_user_from_macaroon(serialized)
user = user_info["user"]
is_guest = user_info["is_guest"]
self.assertEqual(UserID.from_string(user_id), user)
self.assertTrue(is_guest)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_user_from_macaroon_user_db_mismatch(self): def test_get_user_from_macaroon_user_db_mismatch(self):
self.store.get_user_by_access_token = Mock( self.store.get_user_by_access_token = Mock(

View file

@ -650,9 +650,30 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
{"presence": ONLINE} {"presence": ONLINE}
) )
# Apple sees self-reflection even without room_id
(events, _) = yield self.event_source.get_new_events(
user=self.u_apple,
from_key=0,
)
self.assertEquals(self.event_source.get_current_key(), 1)
self.assertEquals(events,
[
{"type": "m.presence",
"content": {
"user_id": "@apple:test",
"presence": ONLINE,
"last_active_ago": 0,
}},
],
msg="Presence event should be visible to self-reflection"
)
# Apple sees self-reflection # Apple sees self-reflection
(events, _) = yield self.event_source.get_new_events_for_user( (events, _) = yield self.event_source.get_new_events(
self.u_apple, 0, None user=self.u_apple,
from_key=0,
room_ids=[self.room_id],
) )
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
@ -684,8 +705,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
) )
# Banana sees it because of presence subscription # Banana sees it because of presence subscription
(events, _) = yield self.event_source.get_new_events_for_user( (events, _) = yield self.event_source.get_new_events(
self.u_banana, 0, None user=self.u_banana,
from_key=0,
room_ids=[self.room_id],
) )
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
@ -702,8 +725,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
) )
# Elderberry sees it because of same room # Elderberry sees it because of same room
(events, _) = yield self.event_source.get_new_events_for_user( (events, _) = yield self.event_source.get_new_events(
self.u_elderberry, 0, None user=self.u_elderberry,
from_key=0,
room_ids=[self.room_id],
) )
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
@ -720,8 +745,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
) )
# Durian is not in the room, should not see this event # Durian is not in the room, should not see this event
(events, _) = yield self.event_source.get_new_events_for_user( (events, _) = yield self.event_source.get_new_events(
self.u_durian, 0, None user=self.u_durian,
from_key=0,
room_ids=[],
) )
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
@ -767,8 +794,9 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
"accepted": True}, "accepted": True},
], presence) ], presence)
(events, _) = yield self.event_source.get_new_events_for_user( (events, _) = yield self.event_source.get_new_events(
self.u_apple, 1, None user=self.u_apple,
from_key=1,
) )
self.assertEquals(self.event_source.get_current_key(), 2) self.assertEquals(self.event_source.get_current_key(), 2)
@ -858,8 +886,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
) )
) )
(events, _) = yield self.event_source.get_new_events_for_user( (events, _) = yield self.event_source.get_new_events(
self.u_apple, 0, None user=self.u_apple,
from_key=0,
room_ids=[self.room_id],
) )
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
@ -905,8 +935,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
(events, _) = yield self.event_source.get_new_events_for_user( (events, _) = yield self.event_source.get_new_events(
self.u_apple, 0, None user=self.u_apple,
from_key=0,
room_ids=[self.room_id,]
) )
self.assertEquals(events, self.assertEquals(events,
[ [
@ -932,8 +964,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
self.assertEquals(self.event_source.get_current_key(), 2) self.assertEquals(self.event_source.get_current_key(), 2)
(events, _) = yield self.event_source.get_new_events_for_user( (events, _) = yield self.event_source.get_new_events(
self.u_apple, 0, None user=self.u_apple,
from_key=0,
room_ids=[self.room_id,]
) )
self.assertEquals(events, self.assertEquals(events,
[ [
@ -966,8 +1000,9 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
self.room_members.append(self.u_clementine) self.room_members.append(self.u_clementine)
(events, _) = yield self.event_source.get_new_events_for_user( (events, _) = yield self.event_source.get_new_events(
self.u_apple, 0, None user=self.u_apple,
from_key=0,
) )
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)

View file

@ -187,7 +187,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
]) ])
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) events = yield self.event_source.get_new_events(
room_ids=[self.room_id],
from_key=0,
)
self.assertEquals( self.assertEquals(
events[0], events[0],
[ [
@ -250,7 +253,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
]) ])
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) events = yield self.event_source.get_new_events(
room_ids=[self.room_id],
from_key=0
)
self.assertEquals( self.assertEquals(
events[0], events[0],
[ [
@ -306,7 +312,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
yield put_json.await_calls() yield put_json.await_calls()
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) events = yield self.event_source.get_new_events(
room_ids=[self.room_id],
from_key=0,
)
self.assertEquals( self.assertEquals(
events[0], events[0],
[ [
@ -337,7 +346,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.on_new_event.reset_mock() self.on_new_event.reset_mock()
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) events = yield self.event_source.get_new_events(
room_ids=[self.room_id],
from_key=0,
)
self.assertEquals( self.assertEquals(
events[0], events[0],
[ [
@ -356,7 +368,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
]) ])
self.assertEquals(self.event_source.get_current_key(), 2) self.assertEquals(self.event_source.get_current_key(), 2)
events = yield self.event_source.get_new_events_for_user(self.u_apple, 1, None) events = yield self.event_source.get_new_events(
room_ids=[self.room_id],
from_key=1,
)
self.assertEquals( self.assertEquals(
events[0], events[0],
[ [
@ -383,7 +398,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.on_new_event.reset_mock() self.on_new_event.reset_mock()
self.assertEquals(self.event_source.get_current_key(), 3) self.assertEquals(self.event_source.get_current_key(), 3)
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) events = yield self.event_source.get_new_events(
room_ids=[self.room_id],
from_key=0,
)
self.assertEquals( self.assertEquals(
events[0], events[0],
[ [

View file

@ -47,7 +47,14 @@ class NullSource(object):
def __init__(self, hs): def __init__(self, hs):
pass pass
def get_new_events_for_user(self, user, from_key, limit): def get_new_events(
self,
user,
from_key,
room_ids=None,
limit=None,
is_guest=None
):
return defer.succeed(([], from_key)) return defer.succeed(([], from_key))
def get_current_key(self, direction='f'): def get_current_key(self, direction='f'):
@ -86,10 +93,11 @@ class PresenceStateTestCase(unittest.TestCase):
return defer.succeed([]) return defer.succeed([])
self.datastore.get_presence_list = get_presence_list self.datastore.get_presence_list = get_presence_list
def _get_user_by_access_token(token=None): def _get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(myid), "user": UserID.from_string(myid),
"token_id": 1, "token_id": 1,
"is_guest": False,
} }
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
@ -173,10 +181,11 @@ class PresenceListTestCase(unittest.TestCase):
) )
self.datastore.has_presence_state = has_presence_state self.datastore.has_presence_state = has_presence_state
def _get_user_by_access_token(token=None): def _get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(myid), "user": UserID.from_string(myid),
"token_id": 1, "token_id": 1,
"is_guest": False,
} }
hs.handlers.room_member_handler = Mock( hs.handlers.room_member_handler = Mock(
@ -291,8 +300,8 @@ class PresenceEventStreamTestCase(unittest.TestCase):
hs.get_clock().time_msec.return_value = 1000000 hs.get_clock().time_msec.return_value = 1000000
def _get_user_by_req(req=None): def _get_user_by_req(req=None, allow_guest=False):
return (UserID.from_string(myid), "") return (UserID.from_string(myid), "", False)
hs.get_v1auth().get_user_by_req = _get_user_by_req hs.get_v1auth().get_user_by_req = _get_user_by_req
@ -312,6 +321,9 @@ class PresenceEventStreamTestCase(unittest.TestCase):
hs.handlers.room_member_handler.get_room_members = ( hs.handlers.room_member_handler.get_room_members = (
lambda r: self.room_members if r == "a-room" else [] lambda r: self.room_members if r == "a-room" else []
) )
hs.handlers.room_member_handler._filter_events_for_client = (
lambda user_id, events, **kwargs: events
)
self.mock_datastore = hs.get_datastore() self.mock_datastore = hs.get_datastore()
self.mock_datastore.get_app_service_by_token = Mock(return_value=None) self.mock_datastore.get_app_service_by_token = Mock(return_value=None)
@ -369,7 +381,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
# all be ours # all be ours
# I'll already get my own presence state change # I'll already get my own presence state change
self.assertEquals({"start": "0_1_0_0", "end": "0_1_0_0", "chunk": []}, self.assertEquals({"start": "0_1_0_0_0", "end": "0_1_0_0_0", "chunk": []},
response response
) )
@ -388,7 +400,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
"/events?from=s0_1_0&timeout=0", None) "/events?from=s0_1_0&timeout=0", None)
self.assertEquals(200, code) self.assertEquals(200, code)
self.assertEquals({"start": "s0_1_0_0", "end": "s0_2_0_0", "chunk": [ self.assertEquals({"start": "s0_1_0_0_0", "end": "s0_2_0_0_0", "chunk": [
{"type": "m.presence", {"type": "m.presence",
"content": { "content": {
"user_id": "@banana:test", "user_id": "@banana:test",

View file

@ -52,8 +52,8 @@ class ProfileTestCase(unittest.TestCase):
replication_layer=Mock(), replication_layer=Mock(),
) )
def _get_user_by_req(request=None): def _get_user_by_req(request=None, allow_guest=False):
return (UserID.from_string(myid), "") return (UserID.from_string(myid), "", False)
hs.get_v1auth().get_user_by_req = _get_user_by_req hs.get_v1auth().get_user_by_req = _get_user_by_req

View file

@ -54,10 +54,11 @@ class RoomPermissionsTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_access_token(token=None): def _get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
"is_guest": False,
} }
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
@ -439,10 +440,11 @@ class RoomsMemberListTestCase(RestTestCase):
self.auth_user_id = self.user_id self.auth_user_id = self.user_id
def _get_user_by_access_token(token=None): def _get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
"is_guest": False,
} }
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
@ -517,10 +519,11 @@ class RoomsCreateTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_access_token(token=None): def _get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
"is_guest": False,
} }
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
@ -608,10 +611,11 @@ class RoomTopicTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_access_token(token=None): def _get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
"is_guest": False,
} }
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
@ -713,10 +717,11 @@ class RoomMemberStateTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_access_token(token=None): def _get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
"is_guest": False,
} }
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
@ -838,10 +843,11 @@ class RoomMessagesTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_access_token(token=None): def _get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
"is_guest": False,
} }
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
@ -933,10 +939,11 @@ class RoomInitialSyncTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_access_token(token=None): def _get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
"is_guest": False,
} }
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
@ -987,3 +994,59 @@ class RoomInitialSyncTestCase(RestTestCase):
} }
self.assertTrue(self.user_id in presence_by_user) self.assertTrue(self.user_id in presence_by_user)
self.assertEquals("m.presence", presence_by_user[self.user_id]["type"]) self.assertEquals("m.presence", presence_by_user[self.user_id]["type"])
class RoomMessageListTestCase(RestTestCase):
""" Tests /rooms/$room_id/messages REST events. """
user_id = "@sid1:red"
@defer.inlineCallbacks
def setUp(self):
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
self.auth_user_id = self.user_id
hs = yield setup_test_homeserver(
"red",
http_client=None,
replication_layer=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]),
)
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
hs.get_handlers().federation_handler = Mock()
def _get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(self.auth_user_id),
"token_id": 1,
"is_guest": False,
}
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
hs.get_datastore().insert_client_ip = _insert_client_ip
synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)
self.room_id = yield self.create_room_as(self.user_id)
@defer.inlineCallbacks
def test_topo_token_is_accepted(self):
token = "t1-0_0_0_0_0"
(code, response) = yield self.mock_resource.trigger_get(
"/rooms/%s/messages?access_token=x&from=%s" %
(self.room_id, token))
self.assertEquals(200, code)
self.assertTrue("start" in response)
self.assertEquals(token, response['start'])
self.assertTrue("chunk" in response)
self.assertTrue("end" in response)
@defer.inlineCallbacks
def test_stream_token_is_rejected(self):
(code, response) = yield self.mock_resource.trigger_get(
"/rooms/%s/messages?access_token=x&from=s0_0_0_0" %
self.room_id)
self.assertEquals(400, code)

View file

@ -61,10 +61,11 @@ class RoomTypingTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_access_token(token=None): def _get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
"is_guest": False,
} }
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
@ -115,7 +116,10 @@ class RoomTypingTestCase(RestTestCase):
self.assertEquals(200, code) self.assertEquals(200, code)
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
events = yield self.event_source.get_new_events_for_user(self.user, 0, None) events = yield self.event_source.get_new_events(
from_key=0,
room_ids=[self.room_id],
)
self.assertEquals( self.assertEquals(
events[0], events[0],
[ [

View file

@ -43,10 +43,11 @@ class V2AlphaRestTestCase(unittest.TestCase):
resource_for_federation=self.mock_resource, resource_for_federation=self.mock_resource,
) )
def _get_user_by_access_token(token=None): def _get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.USER_ID), "user": UserID.from_string(self.USER_ID),
"token_id": 1, "token_id": 1,
"is_guest": False,
} }
hs.get_auth()._get_user_by_access_token = _get_user_by_access_token hs.get_auth()._get_user_by_access_token = _get_user_by_access_token

View file

@ -0,0 +1,76 @@
from tests import unittest
from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.types import UserID, RoomID, RoomAlias
from tests.utils import setup_test_homeserver
from mock import Mock
class BackgroundUpdateTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
hs = yield setup_test_homeserver()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.update_handler = Mock()
yield self.store.register_background_update_handler(
"test_update", self.update_handler
)
@defer.inlineCallbacks
def test_do_background_update(self):
desired_count = 1000;
duration_ms = 42;
@defer.inlineCallbacks
def update(progress, count):
self.clock.advance_time_msec(count * duration_ms)
progress = {"my_key": progress["my_key"] + 1}
yield self.store.runInteraction(
"update_progress",
self.store._background_update_progress_txn,
"test_update",
progress,
)
defer.returnValue(count)
self.update_handler.side_effect = update
yield self.store.start_background_update("test_update", {"my_key": 1})
self.update_handler.reset_mock()
result = yield self.store.do_background_update(
duration_ms * desired_count
)
self.assertIsNotNone(result)
self.update_handler.assert_called_once_with(
{"my_key": 1}, self.store.DEFAULT_BACKGROUND_BATCH_SIZE
)
@defer.inlineCallbacks
def update(progress, count):
yield self.store._end_background_update("test_update")
defer.returnValue(count)
self.update_handler.side_effect = update
self.update_handler.reset_mock()
result = yield self.store.do_background_update(
duration_ms * desired_count
)
self.assertIsNotNone(result)
self.update_handler.assert_called_once_with(
{"my_key": 2}, desired_count
)
self.update_handler.reset_mock()
result = yield self.store.do_background_update(
duration_ms * desired_count
)
self.assertIsNone(result)
self.assertFalse(self.update_handler.called)

View file

@ -120,7 +120,6 @@ class RedactionTestCase(unittest.TestCase):
self.u_alice.to_string(), self.u_alice.to_string(),
start, start,
end, end,
None, # Is currently ignored
) )
self.assertEqual(1, len(results)) self.assertEqual(1, len(results))
@ -149,7 +148,6 @@ class RedactionTestCase(unittest.TestCase):
self.u_alice.to_string(), self.u_alice.to_string(),
start, start,
end, end,
None, # Is currently ignored
) )
self.assertEqual(1, len(results)) self.assertEqual(1, len(results))
@ -199,7 +197,6 @@ class RedactionTestCase(unittest.TestCase):
self.u_alice.to_string(), self.u_alice.to_string(),
start, start,
end, end,
None, # Is currently ignored
) )
self.assertEqual(1, len(results)) self.assertEqual(1, len(results))
@ -228,7 +225,6 @@ class RedactionTestCase(unittest.TestCase):
self.u_alice.to_string(), self.u_alice.to_string(),
start, start,
end, end,
None, # Is currently ignored
) )
self.assertEqual(1, len(results)) self.assertEqual(1, len(results))

View file

@ -73,6 +73,8 @@ class RoomStoreTestCase(unittest.TestCase):
"room_id": self.room.to_string(), "room_id": self.room.to_string(),
"topic": None, "topic": None,
"aliases": [self.alias.to_string()], "aliases": [self.alias.to_string()],
"world_readable": False,
"guest_can_join": False,
}, rooms[0]) }, rooms[0])

View file

@ -68,7 +68,6 @@ class StreamStoreTestCase(unittest.TestCase):
self.u_bob.to_string(), self.u_bob.to_string(),
start, start,
end, end,
None, # Is currently ignored
) )
self.assertEqual(1, len(results)) self.assertEqual(1, len(results))
@ -105,7 +104,6 @@ class StreamStoreTestCase(unittest.TestCase):
self.u_alice.to_string(), self.u_alice.to_string(),
start, start,
end, end,
None, # Is currently ignored
) )
self.assertEqual(1, len(results)) self.assertEqual(1, len(results))
@ -147,7 +145,6 @@ class StreamStoreTestCase(unittest.TestCase):
self.u_bob.to_string(), self.u_bob.to_string(),
start, start,
end, end,
None, # Is currently ignored
) )
# We should not get the message, as it happened *after* bob left. # We should not get the message, as it happened *after* bob left.
@ -175,7 +172,6 @@ class StreamStoreTestCase(unittest.TestCase):
self.u_bob.to_string(), self.u_bob.to_string(),
start, start,
end, end,
None, # Is currently ignored
) )
# We should not get the message, as it happened *after* bob left. # We should not get the message, as it happened *after* bob left.

View file

@ -317,6 +317,99 @@ class StateTestCase(unittest.TestCase):
{e.event_id for e in context_store["E"].current_state.values()} {e.event_id for e in context_store["E"].current_state.values()}
) )
@defer.inlineCallbacks
def test_branch_have_perms_conflict(self):
userid1 = "@user_id:example.com"
userid2 = "@user_id2:example.com"
nodes = {
"A1": DictObj(
type=EventTypes.Create,
state_key="",
content={"creator": userid1},
depth=1,
),
"A2": DictObj(
type=EventTypes.Member,
state_key=userid1,
content={"membership": Membership.JOIN},
membership=Membership.JOIN,
),
"A3": DictObj(
type=EventTypes.Member,
state_key=userid2,
content={"membership": Membership.JOIN},
membership=Membership.JOIN,
),
"A4": DictObj(
type=EventTypes.PowerLevels,
state_key="",
content={
"events": {"m.room.name": 50},
"users": {userid1: 100,
userid2: 60},
},
),
"A5": DictObj(
type=EventTypes.Name,
state_key="",
),
"B": DictObj(
type=EventTypes.PowerLevels,
state_key="",
content={
"events": {"m.room.name": 50},
"users": {userid2: 30},
},
),
"C": DictObj(
type=EventTypes.Name,
state_key="",
sender=userid2,
),
"D": DictObj(
type=EventTypes.Message,
),
}
edges = {
"A2": ["A1"],
"A3": ["A2"],
"A4": ["A3"],
"A5": ["A4"],
"B": ["A5"],
"C": ["A5"],
"D": ["B", "C"]
}
self._add_depths(nodes, edges)
graph = Graph(nodes, edges)
store = StateGroupStore()
self.store.get_state_groups.side_effect = store.get_state_groups
context_store = {}
for event in graph.walk():
context = yield self.state.compute_event_context(event)
store.store_state_groups(event, context)
context_store[event.event_id] = context
self.assertSetEqual(
{"A1", "A2", "A3", "A5", "B"},
{e.event_id for e in context_store["D"].current_state.values()}
)
def _add_depths(self, nodes, edges):
def _get_depth(ev):
node = nodes[ev]
if 'depth' not in node:
prevs = edges[ev]
depth = max(_get_depth(prev) for prev in prevs) + 1
node['depth'] = depth
return node['depth']
for n in nodes:
_get_depth(n)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_annotate_with_old_message(self): def test_annotate_with_old_message(self):
event = create_event(type="test_message", name="event") event = create_event(type="test_message", name="event")

View file

@ -1,108 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from tests import unittest
from synapse.util.lockutils import LockManager
class LockManagerTestCase(unittest.TestCase):
def setUp(self):
self.lock_manager = LockManager()
@defer.inlineCallbacks
def test_one_lock(self):
key = "test"
deferred_lock1 = self.lock_manager.lock(key)
self.assertTrue(deferred_lock1.called)
lock1 = yield deferred_lock1
self.assertFalse(lock1.released)
lock1.release()
self.assertTrue(lock1.released)
@defer.inlineCallbacks
def test_concurrent_locks(self):
key = "test"
deferred_lock1 = self.lock_manager.lock(key)
deferred_lock2 = self.lock_manager.lock(key)
self.assertTrue(deferred_lock1.called)
self.assertFalse(deferred_lock2.called)
lock1 = yield deferred_lock1
self.assertFalse(lock1.released)
self.assertFalse(deferred_lock2.called)
lock1.release()
self.assertTrue(lock1.released)
self.assertTrue(deferred_lock2.called)
lock2 = yield deferred_lock2
lock2.release()
@defer.inlineCallbacks
def test_sequential_locks(self):
key = "test"
deferred_lock1 = self.lock_manager.lock(key)
self.assertTrue(deferred_lock1.called)
lock1 = yield deferred_lock1
self.assertFalse(lock1.released)
lock1.release()
self.assertTrue(lock1.released)
deferred_lock2 = self.lock_manager.lock(key)
self.assertTrue(deferred_lock2.called)
lock2 = yield deferred_lock2
self.assertFalse(lock2.released)
lock2.release()
self.assertTrue(lock2.released)
@defer.inlineCallbacks
def test_with_statement(self):
key = "test"
with (yield self.lock_manager.lock(key)) as lock:
self.assertFalse(lock.released)
self.assertTrue(lock.released)
@defer.inlineCallbacks
def test_two_with_statement(self):
key = "test"
with (yield self.lock_manager.lock(key)):
pass
with (yield self.lock_manager.lock(key)):
pass

View file

@ -243,6 +243,9 @@ class MockClock(object):
else: else:
self.timers.append(t) self.timers.append(t)
def advance_time_msec(self, ms):
self.advance_time(ms / 1000.)
class SQLiteMemoryDbPool(ConnectionPool, object): class SQLiteMemoryDbPool(ConnectionPool, object):
def __init__(self): def __init__(self):
@ -335,7 +338,7 @@ class MemoryDataStore(object):
] ]
def get_room_events_stream(self, user_id=None, from_key=None, to_key=None, def get_room_events_stream(self, user_id=None, from_key=None, to_key=None,
room_id=None, limit=0, with_feedback=False): limit=0, with_feedback=False):
return ([], from_key) # TODO return ([], from_key) # TODO
def get_joined_hosts_for_room(self, room_id): def get_joined_hosts_for_room(self, room_id):