0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-17 07:21:37 +01:00

Merge branch 'develop' into paul/tiny-fixes

This commit is contained in:
Paul "LeoNerd" Evans 2015-11-03 16:08:48 +00:00
commit 8a0407c7e6
46 changed files with 1812 additions and 298 deletions

View file

@ -1,3 +1,20 @@
Changes in synapse v0.10.1-rc1 (2015-10-15)
===========================================
* Add support for CAS, thanks to Steven Hammerton (PR #295, #296)
* Add support for using macaroons for ``access_token`` (PR #256, #229)
* Add support for ``m.room.canonical_alias`` (PR #287)
* Add support for viewing the history of rooms that they have left. (PR #276,
#294)
* Add support for refresh tokens (PR #240)
* Add flag on creation which disables federation of the room (PR #279)
* Add some room state to invites. (PR #275)
* Atomically persist events when joining a room over federation (PR #283)
* Change default history visibility for private rooms (PR #271)
* Allow users to redact their own sent events (PR #262)
* Use tox for tests (PR #247)
* Split up syutil into separate libraries (PR #243)
Changes in synapse v0.10.0-r2 (2015-09-16) Changes in synapse v0.10.0-r2 (2015-09-16)
========================================== ==========================================

View file

@ -38,8 +38,12 @@ for port in 8080 8081 8082; do
perl -p -i -e 's/^enable_registration:.*/enable_registration: true/g' $DIR/etc/$port.config perl -p -i -e 's/^enable_registration:.*/enable_registration: true/g' $DIR/etc/$port.config
echo "full_twisted_stacktraces: true" >> $DIR/etc/$port.config if ! grep -F "full_twisted_stacktraces" -q $DIR/etc/$port.config; then
echo "report_stats: false" >> $DIR/etc/$port.config echo "full_twisted_stacktraces: true" >> $DIR/etc/$port.config
fi
if ! grep -F "report_stats" -q $DIR/etc/$port.config ; then
echo "report_stats: false" >> $DIR/etc/$port.config
fi
python -m synapse.app.homeserver \ python -m synapse.app.homeserver \
--config-path "$DIR/etc/$port.config" \ --config-path "$DIR/etc/$port.config" \

View file

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.10.0-r2" __version__ = "0.10.1-rc1"

View file

@ -14,19 +14,20 @@
# limitations under the License. # limitations under the License.
"""This module contains classes for authenticating the user.""" """This module contains classes for authenticating the user."""
from nacl.exceptions import BadSignatureError from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json, SignatureVerifyException
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import RoomID, UserID, EventID from synapse.types import RoomID, UserID, EventID
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util import third_party_invites from synapse.util import third_party_invites
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
import logging import logging
import nacl.signing
import pymacaroons import pymacaroons
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -64,6 +65,8 @@ class Auth(object):
Returns: Returns:
True if the auth checks pass. True if the auth checks pass.
""" """
self.check_size_limits(event)
try: try:
if not hasattr(event, "room_id"): if not hasattr(event, "room_id"):
raise AuthError(500, "Event has no room_id: %s" % event) raise AuthError(500, "Event has no room_id: %s" % event)
@ -131,6 +134,23 @@ class Auth(object):
logger.info("Denying! %s", event) logger.info("Denying! %s", event)
raise raise
def check_size_limits(self, event):
def too_big(field):
raise EventSizeError("%s too large" % (field,))
if len(event.user_id) > 255:
too_big("user_id")
if len(event.room_id) > 255:
too_big("room_id")
if event.is_state() and len(event.state_key) > 255:
too_big("state_key")
if len(event.type) > 255:
too_big("type")
if len(event.event_id) > 255:
too_big("event_id")
if len(encode_canonical_json(event.get_pdu_json())) > 65536:
too_big("event")
@defer.inlineCallbacks @defer.inlineCallbacks
def check_joined_room(self, room_id, user_id, current_state=None): def check_joined_room(self, room_id, user_id, current_state=None):
"""Check if the user is currently joined in the room """Check if the user is currently joined in the room
@ -308,7 +328,11 @@ class Auth(object):
) )
if Membership.JOIN != membership: if Membership.JOIN != membership:
# JOIN is the only action you can perform if you're not in the room if (caller_invited
and Membership.LEAVE == membership
and target_user_id == event.user_id):
return True
if not caller_in_room: # caller isn't joined if not caller_in_room: # caller isn't joined
raise AuthError( raise AuthError(
403, 403,
@ -416,16 +440,23 @@ class Auth(object):
key_validity_url key_validity_url
) )
return False return False
for _, signature_block in join_third_party_invite["signatures"].items(): signed = join_third_party_invite["signed"]
if signed["mxid"] != event.user_id:
return False
if signed["token"] != token:
return False
for server, signature_block in signed["signatures"].items():
for key_name, encoded_signature in signature_block.items(): for key_name, encoded_signature in signature_block.items():
if not key_name.startswith("ed25519:"): if not key_name.startswith("ed25519:"):
return False return False
verify_key = nacl.signing.VerifyKey(decode_base64(public_key)) verify_key = decode_verify_key_bytes(
signature = decode_base64(encoded_signature) key_name,
verify_key.verify(token, signature) decode_base64(public_key)
)
verify_signed_json(signed, server, verify_key)
return True return True
return False return False
except (KeyError, BadSignatureError,): except (KeyError, SignatureVerifyException,):
return False return False
def _get_power_level_event(self, auth_events): def _get_power_level_event(self, auth_events):

View file

@ -119,6 +119,15 @@ class AuthError(SynapseError):
super(AuthError, self).__init__(*args, **kwargs) super(AuthError, self).__init__(*args, **kwargs)
class EventSizeError(SynapseError):
"""An error raised when an event is too big."""
def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.TOO_LARGE
super(EventSizeError, self).__init__(413, *args, **kwargs)
class EventStreamError(SynapseError): class EventStreamError(SynapseError):
"""An error raised when there a problem with the event stream.""" """An error raised when there a problem with the event stream."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):

View file

@ -24,7 +24,7 @@ class Filtering(object):
def get_user_filter(self, user_localpart, filter_id): def get_user_filter(self, user_localpart, filter_id):
result = self.store.get_user_filter(user_localpart, filter_id) result = self.store.get_user_filter(user_localpart, filter_id)
result.addCallback(Filter) result.addCallback(FilterCollection)
return result return result
def add_user_filter(self, user_localpart, user_filter): def add_user_filter(self, user_localpart, user_filter):
@ -131,125 +131,126 @@ class Filtering(object):
raise SynapseError(400, "Bad bundle_updates: expected bool.") raise SynapseError(400, "Bad bundle_updates: expected bool.")
class FilterCollection(object):
def __init__(self, filter_json):
self.filter_json = filter_json
self.room_timeline_filter = Filter(
self.filter_json.get("room", {}).get("timeline", {})
)
self.room_state_filter = Filter(
self.filter_json.get("room", {}).get("state", {})
)
self.room_ephemeral_filter = Filter(
self.filter_json.get("room", {}).get("ephemeral", {})
)
self.presence_filter = Filter(
self.filter_json.get("presence", {})
)
def timeline_limit(self):
return self.room_timeline_filter.limit()
def presence_limit(self):
return self.presence_filter.limit()
def ephemeral_limit(self):
return self.room_ephemeral_filter.limit()
def filter_presence(self, events):
return self.presence_filter.filter(events)
def filter_room_state(self, events):
return self.room_state_filter.filter(events)
def filter_room_timeline(self, events):
return self.room_timeline_filter.filter(events)
def filter_room_ephemeral(self, events):
return self.room_ephemeral_filter.filter(events)
class Filter(object): class Filter(object):
def __init__(self, filter_json): def __init__(self, filter_json):
self.filter_json = filter_json self.filter_json = filter_json
def timeline_limit(self): def check(self, event):
return self.filter_json.get("room", {}).get("timeline", {}).get("limit", 10) """Checks whether the filter matches the given event.
def presence_limit(self):
return self.filter_json.get("presence", {}).get("limit", 10)
def ephemeral_limit(self):
return self.filter_json.get("room", {}).get("ephemeral", {}).get("limit", 10)
def filter_presence(self, events):
return self._filter_on_key(events, ["presence"])
def filter_room_state(self, events):
return self._filter_on_key(events, ["room", "state"])
def filter_room_timeline(self, events):
return self._filter_on_key(events, ["room", "timeline"])
def filter_room_ephemeral(self, events):
return self._filter_on_key(events, ["room", "ephemeral"])
def _filter_on_key(self, events, keys):
filter_json = self.filter_json
if not filter_json:
return events
try:
# extract the right definition from the filter
definition = filter_json
for key in keys:
definition = definition[key]
return self._filter_with_definition(events, definition)
except KeyError:
# return all events if definition isn't specified.
return events
def _filter_with_definition(self, events, definition):
return [e for e in events if self._passes_definition(definition, e)]
def _passes_definition(self, definition, event):
"""Check if the event passes the filter definition
Args:
definition(dict): The filter definition to check against
event(dict or Event): The event to check
Returns: Returns:
True if the event passes the filter in the definition bool: True if the event matches
""" """
if type(event) is dict: if isinstance(event, dict):
room_id = event.get("room_id") return self.check_fields(
sender = event.get("sender") event.get("room_id", None),
event_type = event["type"] event.get("sender", None),
event.get("type", None),
)
else: else:
room_id = getattr(event, "room_id", None) return self.check_fields(
sender = getattr(event, "sender", None) getattr(event, "room_id", None),
event_type = event.type getattr(event, "sender", None),
return self._event_passes_definition( event.type,
definition, room_id, sender, event_type )
)
def _event_passes_definition(self, definition, room_id, sender, def check_fields(self, room_id, sender, event_type):
event_type): """Checks whether the filter matches the given event fields.
"""Check if the event passes through the given definition.
Args:
definition(dict): The definition to check against.
room_id(str): The id of the room this event is in or None.
sender(str): The sender of the event
event_type(str): The type of the event.
Returns: Returns:
True if the event passes through the filter. bool: True if the event fields match
""" """
# Algorithm notes: literal_keys = {
# For each key in the definition, check the event meets the criteria: "rooms": lambda v: room_id == v,
# * For types: Literal match or prefix match (if ends with wildcard) "senders": lambda v: sender == v,
# * For senders/rooms: Literal match only "types": lambda v: _matches_wildcard(event_type, v)
# * "not_" checks take presedence (e.g. if "m.*" is in both 'types' }
# and 'not_types' then it is treated as only being in 'not_types')
# room checks for name, match_func in literal_keys.items():
if room_id is not None: not_name = "not_%s" % (name,)
allow_rooms = definition.get("rooms", None) disallowed_values = self.filter_json.get(not_name, [])
reject_rooms = definition.get("not_rooms", None) if any(map(match_func, disallowed_values)):
if reject_rooms and room_id in reject_rooms:
return False
if allow_rooms and room_id not in allow_rooms:
return False return False
# sender checks allowed_values = self.filter_json.get(name, None)
if sender is not None: if allowed_values is not None:
allow_senders = definition.get("senders", None) if not any(map(match_func, allowed_values)):
reject_senders = definition.get("not_senders", None)
if reject_senders and sender in reject_senders:
return False
if allow_senders and sender not in allow_senders:
return False
# type checks
if "not_types" in definition:
for def_type in definition["not_types"]:
if self._event_matches_type(event_type, def_type):
return False return False
if "types" in definition:
included = False
for def_type in definition["types"]:
if self._event_matches_type(event_type, def_type):
included = True
break
if not included:
return False
return True return True
def _event_matches_type(self, event_type, def_type): def filter_rooms(self, room_ids):
if def_type.endswith("*"): """Apply the 'rooms' filter to a given list of rooms.
type_prefix = def_type[:-1]
return event_type.startswith(type_prefix) Args:
else: room_ids (list): A list of room_ids.
return event_type == def_type
Returns:
list: A list of room_ids that match the filter
"""
room_ids = set(room_ids)
disallowed_rooms = set(self.filter_json.get("not_rooms", []))
room_ids -= disallowed_rooms
allowed_rooms = self.filter_json.get("rooms", None)
if allowed_rooms is not None:
room_ids &= set(allowed_rooms)
return room_ids
def filter(self, events):
return filter(self.check, events)
def limit(self):
return self.filter_json.get("limit", 10)
def _matches_wildcard(actual_value, filter_value):
if filter_value.endswith("*"):
type_prefix = filter_value[:-1]
return actual_value.startswith(type_prefix)
else:
return actual_value == filter_value

View file

@ -224,8 +224,8 @@ class _Recoverer(object):
self.clock.call_later((2 ** self.backoff_counter), self.retry) self.clock.call_later((2 ** self.backoff_counter), self.retry)
def _backoff(self): def _backoff(self):
# cap the backoff to be around 18h => (2^16) = 65536 secs # cap the backoff to be around 8.5min => (2^9) = 512 secs
if self.backoff_counter < 16: if self.backoff_counter < 9:
self.backoff_counter += 1 self.backoff_counter += 1
self.recover() self.recover()

View file

@ -25,7 +25,7 @@ class CasConfig(Config):
def read_config(self, config): def read_config(self, config):
cas_config = config.get("cas_config", None) cas_config = config.get("cas_config", None)
if cas_config: if cas_config:
self.cas_enabled = True self.cas_enabled = cas_config.get("enabled", True)
self.cas_server_url = cas_config["server_url"] self.cas_server_url = cas_config["server_url"]
self.cas_required_attributes = cas_config.get("required_attributes", {}) self.cas_required_attributes = cas_config.get("required_attributes", {})
else: else:
@ -37,6 +37,7 @@ class CasConfig(Config):
return """ return """
# Enable CAS for registration and login. # Enable CAS for registration and login.
#cas_config: #cas_config:
# enabled: true
# server_url: "https://cas-server.com" # server_url: "https://cas-server.com"
# #required_attributes: # #required_attributes:
# # name: value # # name: value

View file

@ -27,12 +27,14 @@ from .appservice import AppServiceConfig
from .key import KeyConfig from .key import KeyConfig
from .saml2 import SAML2Config from .saml2 import SAML2Config
from .cas import CasConfig from .cas import CasConfig
from .password import PasswordConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
VoipConfig, RegistrationConfig, MetricsConfig, VoipConfig, RegistrationConfig, MetricsConfig,
AppServiceConfig, KeyConfig, SAML2Config, CasConfig): AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
PasswordConfig,):
pass pass

View file

@ -0,0 +1,32 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class PasswordConfig(Config):
"""Password login configuration
"""
def read_config(self, config):
password_config = config.get("password_config", {})
self.password_enabled = password_config.get("enabled", True)
def default_config(self, config_dir_path, server_name, **kwargs):
return """
# Enable password for login.
password_config:
enabled: true
"""

View file

@ -33,7 +33,7 @@ class SAML2Config(Config):
def read_config(self, config): def read_config(self, config):
saml2_config = config.get("saml2_config", None) saml2_config = config.get("saml2_config", None)
if saml2_config: if saml2_config:
self.saml2_enabled = True self.saml2_enabled = saml2_config.get("enabled", True)
self.saml2_config_path = saml2_config["config_path"] self.saml2_config_path = saml2_config["config_path"]
self.saml2_idp_redirect_url = saml2_config["idp_redirect_url"] self.saml2_idp_redirect_url = saml2_config["idp_redirect_url"]
else: else:
@ -49,6 +49,7 @@ class SAML2Config(Config):
# the user back to /login/saml2 with proper info. # the user back to /login/saml2 with proper info.
# See pysaml2 docs for format of config. # See pysaml2 docs for format of config.
#saml2_config: #saml2_config:
# enabled: true
# config_path: "%s/sp_conf.py" # config_path: "%s/sp_conf.py"
# idp_redirect_url: "http://%s/idp" # idp_redirect_url: "http://%s/idp"
""" % (config_dir_path, server_name) """ % (config_dir_path, server_name)

View file

@ -154,7 +154,8 @@ def serialize_event(e, time_now_ms, as_client_event=True,
if "redacted_because" in e.unsigned: if "redacted_because" in e.unsigned:
d["unsigned"]["redacted_because"] = serialize_event( d["unsigned"]["redacted_because"] = serialize_event(
e.unsigned["redacted_because"], time_now_ms e.unsigned["redacted_because"], time_now_ms,
event_format=event_format
) )
if token_id is not None: if token_id is not None:

View file

@ -17,6 +17,7 @@
from twisted.internet import defer from twisted.internet import defer
from .federation_base import FederationBase from .federation_base import FederationBase
from synapse.api.constants import Membership
from .units import Edu from .units import Edu
from synapse.api.errors import ( from synapse.api.errors import (
@ -357,7 +358,34 @@ class FederationClient(FederationBase):
defer.returnValue(signed_auth) defer.returnValue(signed_auth)
@defer.inlineCallbacks @defer.inlineCallbacks
def make_join(self, destinations, room_id, user_id, content): def make_membership_event(self, destinations, room_id, user_id, membership, content):
"""
Creates an m.room.member event, with context, without participating in the room.
Does so by asking one of the already participating servers to create an
event with proper context.
Note that this does not append any events to any graphs.
Args:
destinations (str): Candidate homeservers which are probably
participating in the room.
room_id (str): The room in which the event will happen.
user_id (str): The user whose membership is being evented.
membership (str): The "membership" property of the event. Must be
one of "join" or "leave".
content (object): Any additional data to put into the content field
of the event.
Return:
A tuple of (origin (str), event (object)) where origin is the remote
homeserver which generated the event.
"""
valid_memberships = {Membership.JOIN, Membership.LEAVE}
if membership not in valid_memberships:
raise RuntimeError(
"make_membership_event called with membership='%s', must be one of %s" %
(membership, ",".join(valid_memberships))
)
for destination in destinations: for destination in destinations:
if destination == self.server_name: if destination == self.server_name:
continue continue
@ -368,13 +396,13 @@ class FederationClient(FederationBase):
content["third_party_invite"] content["third_party_invite"]
) )
try: try:
ret = yield self.transport_layer.make_join( ret = yield self.transport_layer.make_membership_event(
destination, room_id, user_id, args destination, room_id, user_id, membership, args
) )
pdu_dict = ret["event"] pdu_dict = ret["event"]
logger.debug("Got response to make_join: %s", pdu_dict) logger.debug("Got response to make_%s: %s", membership, pdu_dict)
defer.returnValue( defer.returnValue(
(destination, self.event_from_pdu_json(pdu_dict)) (destination, self.event_from_pdu_json(pdu_dict))
@ -384,8 +412,8 @@ class FederationClient(FederationBase):
raise raise
except Exception as e: except Exception as e:
logger.warn( logger.warn(
"Failed to make_join via %s: %s", "Failed to make_%s via %s: %s",
destination, e.message membership, destination, e.message
) )
raise RuntimeError("Failed to send to any server.") raise RuntimeError("Failed to send to any server.")
@ -491,6 +519,33 @@ class FederationClient(FederationBase):
defer.returnValue(pdu) defer.returnValue(pdu)
@defer.inlineCallbacks
def send_leave(self, destinations, pdu):
for destination in destinations:
if destination == self.server_name:
continue
try:
time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_leave(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
logger.debug("Got content: %s", content)
defer.returnValue(None)
except CodeMessageException:
raise
except Exception as e:
logger.exception(
"Failed to send_leave via %s: %s",
destination, e.message
)
raise RuntimeError("Failed to send to any server.")
@defer.inlineCallbacks @defer.inlineCallbacks
def query_auth(self, destination, room_id, event_id, local_auth): def query_auth(self, destination, room_id, event_id, local_auth):
""" """

View file

@ -267,6 +267,20 @@ class FederationServer(FederationBase):
], ],
})) }))
@defer.inlineCallbacks
def on_make_leave_request(self, room_id, user_id):
pdu = yield self.handler.on_make_leave_request(room_id, user_id)
time_now = self._clock.time_msec()
defer.returnValue({"event": pdu.get_pdu_json(time_now)})
@defer.inlineCallbacks
def on_send_leave_request(self, origin, content):
logger.debug("on_send_leave_request: content: %s", content)
pdu = self.event_from_pdu_json(content)
logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
yield self.handler.on_send_leave_request(origin, pdu)
defer.returnValue((200, {}))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_event_auth(self, origin, room_id, event_id): def on_event_auth(self, origin, room_id, event_id):
time_now = self._clock.time_msec() time_now = self._clock.time_msec()

View file

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import Membership
from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
@ -160,8 +161,14 @@ class TransportLayerClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def make_join(self, destination, room_id, user_id, args={}): def make_membership_event(self, destination, room_id, user_id, membership, args={}):
path = PREFIX + "/make_join/%s/%s" % (room_id, user_id) valid_memberships = {Membership.JOIN, Membership.LEAVE}
if membership not in valid_memberships:
raise RuntimeError(
"make_membership_event called with membership='%s', must be one of %s" %
(membership, ",".join(valid_memberships))
)
path = PREFIX + "/make_%s/%s/%s" % (membership, room_id, user_id)
content = yield self.client.get_json( content = yield self.client.get_json(
destination=destination, destination=destination,
@ -185,6 +192,19 @@ class TransportLayerClient(object):
defer.returnValue(response) defer.returnValue(response)
@defer.inlineCallbacks
@log_function
def send_leave(self, destination, room_id, event_id, content):
path = PREFIX + "/send_leave/%s/%s" % (room_id, event_id)
response = yield self.client.put_json(
destination=destination,
path=path,
data=content,
)
defer.returnValue(response)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def send_invite(self, destination, room_id, event_id, content): def send_invite(self, destination, room_id, event_id, content):

View file

@ -296,6 +296,24 @@ class FederationMakeJoinServlet(BaseFederationServlet):
defer.returnValue((200, content)) defer.returnValue((200, content))
class FederationMakeLeaveServlet(BaseFederationServlet):
PATH = "/make_leave/([^/]*)/([^/]*)"
@defer.inlineCallbacks
def on_GET(self, origin, content, query, context, user_id):
content = yield self.handler.on_make_leave_request(context, user_id)
defer.returnValue((200, content))
class FederationSendLeaveServlet(BaseFederationServlet):
PATH = "/send_leave/([^/]*)/([^/]*)"
@defer.inlineCallbacks
def on_PUT(self, origin, content, query, room_id, txid):
content = yield self.handler.on_send_leave_request(origin, content)
defer.returnValue((200, content))
class FederationEventAuthServlet(BaseFederationServlet): class FederationEventAuthServlet(BaseFederationServlet):
PATH = "/event_auth/([^/]*)/([^/]*)" PATH = "/event_auth/([^/]*)/([^/]*)"
@ -385,8 +403,10 @@ SERVLET_CLASSES = (
FederationBackfillServlet, FederationBackfillServlet,
FederationQueryServlet, FederationQueryServlet,
FederationMakeJoinServlet, FederationMakeJoinServlet,
FederationMakeLeaveServlet,
FederationEventServlet, FederationEventServlet,
FederationSendJoinServlet, FederationSendJoinServlet,
FederationSendLeaveServlet,
FederationInviteServlet, FederationInviteServlet,
FederationQueryAuthServlet, FederationQueryAuthServlet,
FederationGetMissingEventsServlet, FederationGetMissingEventsServlet,

View file

@ -17,7 +17,7 @@ from synapse.appservice.scheduler import AppServiceScheduler
from synapse.appservice.api import ApplicationServiceApi from synapse.appservice.api import ApplicationServiceApi
from .register import RegistrationHandler from .register import RegistrationHandler
from .room import ( from .room import (
RoomCreationHandler, RoomMemberHandler, RoomListHandler RoomCreationHandler, RoomMemberHandler, RoomListHandler, RoomContextHandler,
) )
from .message import MessageHandler from .message import MessageHandler
from .events import EventStreamHandler, EventHandler from .events import EventStreamHandler, EventHandler
@ -32,6 +32,7 @@ from .sync import SyncHandler
from .auth import AuthHandler from .auth import AuthHandler
from .identity import IdentityHandler from .identity import IdentityHandler
from .receipts import ReceiptsHandler from .receipts import ReceiptsHandler
from .search import SearchHandler
class Handlers(object): class Handlers(object):
@ -68,3 +69,5 @@ class Handlers(object):
self.sync_handler = SyncHandler(hs) self.sync_handler = SyncHandler(hs)
self.auth_handler = AuthHandler(hs) self.auth_handler = AuthHandler(hs)
self.identity_handler = IdentityHandler(hs) self.identity_handler = IdentityHandler(hs)
self.search_handler = SearchHandler(hs)
self.room_context_handler = RoomContextHandler(hs)

View file

@ -565,7 +565,7 @@ class FederationHandler(BaseHandler):
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def do_invite_join(self, target_hosts, room_id, joinee, content, snapshot): def do_invite_join(self, target_hosts, room_id, joinee, content):
""" Attempts to join the `joinee` to the room `room_id` via the """ Attempts to join the `joinee` to the room `room_id` via the
server `target_host`. server `target_host`.
@ -581,50 +581,19 @@ class FederationHandler(BaseHandler):
yield self.store.clean_room_for_join(room_id) yield self.store.clean_room_for_join(room_id)
origin, pdu = yield self.replication_layer.make_join( origin, event = yield self._make_and_verify_event(
target_hosts, target_hosts,
room_id, room_id,
joinee, joinee,
"join",
content content
) )
logger.debug("Got response to make_join: %s", pdu)
event = pdu
# We should assert some things.
# FIXME: Do this in a nicer way
assert(event.type == EventTypes.Member)
assert(event.user_id == joinee)
assert(event.state_key == joinee)
assert(event.room_id == room_id)
event.internal_metadata.outlier = False
self.room_queues[room_id] = [] self.room_queues[room_id] = []
builder = self.event_builder_factory.new(
unfreeze(event.get_pdu_json())
)
handled_events = set() handled_events = set()
try: try:
builder.event_id = self.event_builder_factory.create_event_id() new_event = self._sign_event(event)
builder.origin = self.hs.hostname
builder.content = content
if not hasattr(event, "signatures"):
builder.signatures = {}
add_hashes_and_signatures(
builder,
self.hs.hostname,
self.hs.config.signing_key[0],
)
new_event = builder.build()
# Try the host we successfully got a response to /make_join/ # Try the host we successfully got a response to /make_join/
# request first. # request first.
try: try:
@ -632,11 +601,7 @@ class FederationHandler(BaseHandler):
target_hosts.insert(0, origin) target_hosts.insert(0, origin)
except ValueError: except ValueError:
pass pass
ret = yield self.replication_layer.send_join(target_hosts, new_event)
ret = yield self.replication_layer.send_join(
target_hosts,
new_event
)
origin = ret["origin"] origin = ret["origin"]
state = ret["state"] state = ret["state"]
@ -700,7 +665,7 @@ class FederationHandler(BaseHandler):
@log_function @log_function
def on_make_join_request(self, room_id, user_id, query): def on_make_join_request(self, room_id, user_id, query):
""" We've received a /make_join/ request, so we create a partial """ We've received a /make_join/ request, so we create a partial
join event for the room and return that. We don *not* persist or join event for the room and return that. We do *not* persist or
process it until the other server has signed it and sent it back. process it until the other server has signed it and sent it back.
""" """
event_content = {"membership": Membership.JOIN} event_content = {"membership": Membership.JOIN}
@ -859,6 +824,168 @@ class FederationHandler(BaseHandler):
defer.returnValue(event) defer.returnValue(event)
@defer.inlineCallbacks
def do_remotely_reject_invite(self, target_hosts, room_id, user_id):
origin, event = yield self._make_and_verify_event(
target_hosts,
room_id,
user_id,
"leave",
{}
)
signed_event = self._sign_event(event)
# Try the host we successfully got a response to /make_join/
# request first.
try:
target_hosts.remove(origin)
target_hosts.insert(0, origin)
except ValueError:
pass
yield self.replication_layer.send_leave(
target_hosts,
signed_event
)
defer.returnValue(None)
@defer.inlineCallbacks
def _make_and_verify_event(self, target_hosts, room_id, user_id, membership, content):
origin, pdu = yield self.replication_layer.make_membership_event(
target_hosts,
room_id,
user_id,
membership,
content
)
logger.debug("Got response to make_%s: %s", membership, pdu)
event = pdu
# We should assert some things.
# FIXME: Do this in a nicer way
assert(event.type == EventTypes.Member)
assert(event.user_id == user_id)
assert(event.state_key == user_id)
assert(event.room_id == room_id)
defer.returnValue((origin, event))
def _sign_event(self, event):
event.internal_metadata.outlier = False
builder = self.event_builder_factory.new(
unfreeze(event.get_pdu_json())
)
builder.event_id = self.event_builder_factory.create_event_id()
builder.origin = self.hs.hostname
if not hasattr(event, "signatures"):
builder.signatures = {}
add_hashes_and_signatures(
builder,
self.hs.hostname,
self.hs.config.signing_key[0],
)
return builder.build()
@defer.inlineCallbacks
@log_function
def on_make_leave_request(self, room_id, user_id):
""" We've received a /make_leave/ request, so we create a partial
join event for the room and return that. We do *not* persist or
process it until the other server has signed it and sent it back.
"""
builder = self.event_builder_factory.new({
"type": EventTypes.Member,
"content": {"membership": Membership.LEAVE},
"room_id": room_id,
"sender": user_id,
"state_key": user_id,
})
event, context = yield self._create_new_client_event(
builder=builder,
)
self.auth.check(event, auth_events=context.current_state)
defer.returnValue(event)
@defer.inlineCallbacks
@log_function
def on_send_leave_request(self, origin, pdu):
""" We have received a leave event for a room. Fully process it."""
event = pdu
logger.debug(
"on_send_leave_request: Got event: %s, signatures: %s",
event.event_id,
event.signatures,
)
event.internal_metadata.outlier = False
context, event_stream_id, max_stream_id = yield self._handle_new_event(
origin, event
)
logger.debug(
"on_send_leave_request: After _handle_new_event: %s, sigs: %s",
event.event_id,
event.signatures,
)
extra_users = []
if event.type == EventTypes.Member:
target_user_id = event.state_key
target_user = UserID.from_string(target_user_id)
extra_users.append(target_user)
with PreserveLoggingContext():
d = self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, extra_users=extra_users
)
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
new_pdu = event
destinations = set()
for k, s in context.current_state.items():
try:
if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.LEAVE:
destinations.add(
UserID.from_string(s.state_key).domain
)
except:
logger.warn(
"Failed to get destination from event %s", s.event_id
)
destinations.discard(origin)
logger.debug(
"on_send_leave_request: Sending event: %s, signatures: %s",
event.event_id,
event.signatures,
)
self.replication_layer.send_pdu(new_pdu, destinations)
defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True): def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True):
yield run_on_reactor() yield run_on_reactor()

View file

@ -156,13 +156,7 @@ class ReceiptsHandler(BaseHandler):
if not result: if not result:
defer.returnValue([]) defer.returnValue([])
event = { defer.returnValue(result)
"type": "m.receipt",
"room_id": room_id,
"content": result,
}
defer.returnValue([event])
class ReceiptEventSource(object): class ReceiptEventSource(object):

View file

@ -33,6 +33,7 @@ from collections import OrderedDict
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
import logging import logging
import math
import string import string
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -389,7 +390,22 @@ class RoomMemberHandler(BaseHandler):
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
yield self._do_join(event, context, do_auth=do_auth) yield self._do_join(event, context, do_auth=do_auth)
else: else:
# This is not a JOIN, so we can handle it normally. if event.membership == Membership.LEAVE:
is_host_in_room = yield self.is_host_in_room(room_id, context)
if not is_host_in_room:
# Rejecting an invite, rather than leaving a joined room
handler = self.hs.get_handlers().federation_handler
inviter = yield self.get_inviter(event)
if not inviter:
# return the same error as join_room_alias does
raise SynapseError(404, "No known servers")
yield handler.do_remotely_reject_invite(
[inviter.domain],
room_id,
event.user_id
)
defer.returnValue({"room_id": room_id})
return
# FIXME: This isn't idempotency. # FIXME: This isn't idempotency.
if prev_state and prev_state.membership == event.membership: if prev_state and prev_state.membership == event.membership:
@ -413,7 +429,7 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue({"room_id": room_id}) defer.returnValue({"room_id": room_id})
@defer.inlineCallbacks @defer.inlineCallbacks
def join_room_alias(self, joinee, room_alias, do_auth=True, content={}): def join_room_alias(self, joinee, room_alias, content={}):
directory_handler = self.hs.get_handlers().directory_handler directory_handler = self.hs.get_handlers().directory_handler
mapping = yield directory_handler.get_association(room_alias) mapping = yield directory_handler.get_association(room_alias)
@ -447,8 +463,6 @@ class RoomMemberHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_join(self, event, context, room_hosts=None, do_auth=True): def _do_join(self, event, context, room_hosts=None, do_auth=True):
joinee = UserID.from_string(event.state_key)
# room_id = RoomID.from_string(event.room_id, self.hs)
room_id = event.room_id room_id = event.room_id
# XXX: We don't do an auth check if we are doing an invite # XXX: We don't do an auth check if we are doing an invite
@ -456,48 +470,18 @@ class RoomMemberHandler(BaseHandler):
# that we are allowed to join when we decide whether or not we # that we are allowed to join when we decide whether or not we
# need to do the invite/join dance. # need to do the invite/join dance.
is_host_in_room = yield self.auth.check_host_in_room( is_host_in_room = yield self.is_host_in_room(room_id, context)
event.room_id,
self.hs.hostname
)
if not is_host_in_room:
# is *anyone* in the room?
room_member_keys = [
v for (k, v) in context.current_state.keys() if (
k == "m.room.member"
)
]
if len(room_member_keys) == 0:
# has the room been created so we can join it?
create_event = context.current_state.get(("m.room.create", ""))
if create_event:
is_host_in_room = True
if is_host_in_room: if is_host_in_room:
should_do_dance = False should_do_dance = False
elif room_hosts: # TODO: Shouldn't this be remote_room_host? elif room_hosts: # TODO: Shouldn't this be remote_room_host?
should_do_dance = True should_do_dance = True
else: else:
# TODO(markjh): get prev_state from snapshot inviter = yield self.get_inviter(event)
prev_state = yield self.store.get_room_member( if not inviter:
joinee.to_string(), room_id
)
if prev_state and prev_state.membership == Membership.INVITE:
inviter = UserID.from_string(prev_state.user_id)
should_do_dance = not self.hs.is_mine(inviter)
room_hosts = [inviter.domain]
elif "third_party_invite" in event.content:
if "sender" in event.content["third_party_invite"]:
inviter = UserID.from_string(
event.content["third_party_invite"]["sender"]
)
should_do_dance = not self.hs.is_mine(inviter)
room_hosts = [inviter.domain]
else:
# return the same error as join_room_alias does # return the same error as join_room_alias does
raise SynapseError(404, "No known servers") raise SynapseError(404, "No known servers")
should_do_dance = not self.hs.is_mine(inviter)
room_hosts = [inviter.domain]
if should_do_dance: if should_do_dance:
handler = self.hs.get_handlers().federation_handler handler = self.hs.get_handlers().federation_handler
@ -505,8 +489,7 @@ class RoomMemberHandler(BaseHandler):
room_hosts, room_hosts,
room_id, room_id,
event.user_id, event.user_id,
event.content, # FIXME To get a non-frozen dict event.content # FIXME To get a non-frozen dict
context
) )
else: else:
logger.debug("Doing normal join") logger.debug("Doing normal join")
@ -523,6 +506,44 @@ class RoomMemberHandler(BaseHandler):
"user_joined_room", user=user, room_id=room_id "user_joined_room", user=user, room_id=room_id
) )
@defer.inlineCallbacks
def get_inviter(self, event):
# TODO(markjh): get prev_state from snapshot
prev_state = yield self.store.get_room_member(
event.user_id, event.room_id
)
if prev_state and prev_state.membership == Membership.INVITE:
defer.returnValue(UserID.from_string(prev_state.user_id))
return
elif "third_party_invite" in event.content:
if "sender" in event.content["third_party_invite"]:
inviter = UserID.from_string(
event.content["third_party_invite"]["sender"]
)
defer.returnValue(inviter)
defer.returnValue(None)
@defer.inlineCallbacks
def is_host_in_room(self, room_id, context):
is_host_in_room = yield self.auth.check_host_in_room(
room_id,
self.hs.hostname
)
if not is_host_in_room:
# is *anyone* in the room?
room_member_keys = [
v for (k, v) in context.current_state.keys() if (
k == "m.room.member"
)
]
if len(room_member_keys) == 0:
# has the room been created so we can join it?
create_event = context.current_state.get(("m.room.create", ""))
if create_event:
is_host_in_room = True
defer.returnValue(is_host_in_room)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_joined_rooms_for_user(self, user): def get_joined_rooms_for_user(self, user):
"""Returns a list of roomids that the user has any of the given """Returns a list of roomids that the user has any of the given
@ -727,6 +748,60 @@ class RoomListHandler(BaseHandler):
defer.returnValue({"start": "START", "end": "END", "chunk": chunk}) defer.returnValue({"start": "START", "end": "END", "chunk": chunk})
class RoomContextHandler(BaseHandler):
@defer.inlineCallbacks
def get_event_context(self, user, room_id, event_id, limit):
"""Retrieves events, pagination tokens and state around a given event
in a room.
Args:
user (UserID)
room_id (str)
event_id (str)
limit (int): The maximum number of events to return in total
(excluding state).
Returns:
dict
"""
before_limit = math.floor(limit/2.)
after_limit = limit - before_limit
now_token = yield self.hs.get_event_sources().get_current_token()
results = yield self.store.get_events_around(
room_id, event_id, before_limit, after_limit
)
results["events_before"] = yield self._filter_events_for_client(
user.to_string(), results["events_before"]
)
results["events_after"] = yield self._filter_events_for_client(
user.to_string(), results["events_after"]
)
if results["events_after"]:
last_event_id = results["events_after"][-1].event_id
else:
last_event_id = event_id
state = yield self.store.get_state_for_events(
[last_event_id], None
)
results["state"] = state[last_event_id].values()
results["start"] = now_token.copy_and_replace(
"room_key", results["start"]
).to_string()
results["end"] = now_token.copy_and_replace(
"room_key", results["end"]
).to_string()
defer.returnValue(results)
class RoomEventSource(object): class RoomEventSource(object):
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()

154
synapse/handlers/search.py Normal file
View file

@ -0,0 +1,154 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.constants import Membership
from synapse.api.filtering import Filter
from synapse.api.errors import SynapseError
from synapse.events.utils import serialize_event
import logging
logger = logging.getLogger(__name__)
class SearchHandler(BaseHandler):
def __init__(self, hs):
super(SearchHandler, self).__init__(hs)
@defer.inlineCallbacks
def search(self, user, content):
"""Performs a full text search for a user.
Args:
user (UserID)
content (dict): Search parameters
Returns:
dict to be returned to the client with results of search
"""
try:
search_term = content["search_categories"]["room_events"]["search_term"]
keys = content["search_categories"]["room_events"].get("keys", [
"content.body", "content.name", "content.topic",
])
filter_dict = content["search_categories"]["room_events"].get("filter", {})
event_context = content["search_categories"]["room_events"].get(
"event_context", None
)
if event_context is not None:
before_limit = int(event_context.get(
"before_limit", 5
))
after_limit = int(event_context.get(
"after_limit", 5
))
except KeyError:
raise SynapseError(400, "Invalid search query")
search_filter = Filter(filter_dict)
# TODO: Search through left rooms too
rooms = yield self.store.get_rooms_for_user_where_membership_is(
user.to_string(),
membership_list=[Membership.JOIN],
# membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban],
)
room_ids = set(r.room_id for r in rooms)
room_ids = search_filter.filter_rooms(room_ids)
rank_map, event_map, _ = yield self.store.search_msgs(
room_ids, search_term, keys
)
filtered_events = search_filter.filter(event_map.values())
allowed_events = yield self._filter_events_for_client(
user.to_string(), filtered_events
)
allowed_events.sort(key=lambda e: -rank_map[e.event_id])
allowed_events = allowed_events[:search_filter.limit()]
if event_context is not None:
now_token = yield self.hs.get_event_sources().get_current_token()
contexts = {}
for event in allowed_events:
res = yield self.store.get_events_around(
event.room_id, event.event_id, before_limit, after_limit
)
res["events_before"] = yield self._filter_events_for_client(
user.to_string(), res["events_before"]
)
res["events_after"] = yield self._filter_events_for_client(
user.to_string(), res["events_after"]
)
res["start"] = now_token.copy_and_replace(
"room_key", res["start"]
).to_string()
res["end"] = now_token.copy_and_replace(
"room_key", res["end"]
).to_string()
contexts[event.event_id] = res
else:
contexts = {}
# TODO: Add a limit
time_now = self.clock.time_msec()
for context in contexts.values():
context["events_before"] = [
serialize_event(e, time_now)
for e in context["events_before"]
]
context["events_after"] = [
serialize_event(e, time_now)
for e in context["events_after"]
]
results = {
e.event_id: {
"rank": rank_map[e.event_id],
"result": serialize_event(e, time_now),
"context": contexts.get(e.event_id, {}),
}
for e in allowed_events
}
logger.info("Found %d results", len(results))
defer.returnValue({
"search_categories": {
"room_events": {
"results": results,
"count": len(results)
}
}
})

View file

@ -61,18 +61,37 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
return bool(self.timeline or self.state or self.ephemeral) return bool(self.timeline or self.state or self.ephemeral)
class ArchivedSyncResult(collections.namedtuple("JoinedSyncResult", [
"room_id",
"timeline",
"state",
])):
__slots__ = []
def __nonzero__(self):
"""Make the result appear empty if there are no updates. This is used
to tell if room needs to be part of the sync result.
"""
return bool(self.timeline or self.state)
class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [ class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [
"room_id", "room_id",
"invite", "invite",
])): ])):
__slots__ = [] __slots__ = []
def __nonzero__(self):
"""Invited rooms should always be reported to the client"""
return True
class SyncResult(collections.namedtuple("SyncResult", [ class SyncResult(collections.namedtuple("SyncResult", [
"next_batch", # Token for the next sync "next_batch", # Token for the next sync
"presence", # List of presence events for the user. "presence", # List of presence events for the user.
"joined", # JoinedSyncResult for each joined room. "joined", # JoinedSyncResult for each joined room.
"invited", # InvitedSyncResult for each invited room. "invited", # InvitedSyncResult for each invited room.
"archived", # ArchivedSyncResult for each archived room.
])): ])):
__slots__ = [] __slots__ = []
@ -94,15 +113,20 @@ class SyncHandler(BaseHandler):
self.clock = hs.get_clock() self.clock = hs.get_clock()
@defer.inlineCallbacks @defer.inlineCallbacks
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0): def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
full_state=False):
"""Get the sync for a client if we have new data for it now. Otherwise """Get the sync for a client if we have new data for it now. Otherwise
wait for new data to arrive on the server. If the timeout expires, then wait for new data to arrive on the server. If the timeout expires, then
return an empty sync result. return an empty sync result.
Returns: Returns:
A Deferred SyncResult. A Deferred SyncResult.
""" """
if timeout == 0 or since_token is None:
result = yield self.current_sync_for_user(sync_config, since_token) if timeout == 0 or since_token is None or full_state:
# we are going to return immediately, so don't bother calling
# notifier.wait_for_events.
result = yield self.current_sync_for_user(sync_config, since_token,
full_state=full_state)
defer.returnValue(result) defer.returnValue(result)
else: else:
def current_sync_callback(before_token, after_token): def current_sync_callback(before_token, after_token):
@ -127,24 +151,33 @@ class SyncHandler(BaseHandler):
) )
defer.returnValue(result) defer.returnValue(result)
def current_sync_for_user(self, sync_config, since_token=None): def current_sync_for_user(self, sync_config, since_token=None,
full_state=False):
"""Get the sync for client needed to match what the server has now. """Get the sync for client needed to match what the server has now.
Returns: Returns:
A Deferred SyncResult. A Deferred SyncResult.
""" """
if since_token is None: if since_token is None or full_state:
return self.initial_sync(sync_config) return self.full_state_sync(sync_config, since_token)
else: else:
return self.incremental_sync_with_gap(sync_config, since_token) return self.incremental_sync_with_gap(sync_config, since_token)
@defer.inlineCallbacks @defer.inlineCallbacks
def initial_sync(self, sync_config): def full_state_sync(self, sync_config, timeline_since_token):
"""Get a sync for a client which is starting without any state """Get a sync for a client which is starting without any state.
If a 'message_since_token' is given, only timeline events which have
happened since that token will be returned.
Returns: Returns:
A Deferred SyncResult. A Deferred SyncResult.
""" """
now_token = yield self.event_sources.get_current_token() now_token = yield self.event_sources.get_current_token()
now_token, typing_by_room = yield self.typing_by_room(
sync_config, now_token
)
presence_stream = self.event_sources.sources["presence"] presence_stream = self.event_sources.sources["presence"]
# TODO (mjark): This looks wrong, shouldn't we be getting the presence # TODO (mjark): This looks wrong, shouldn't we be getting the presence
# UP to the present rather than after the present? # UP to the present rather than after the present?
@ -156,15 +189,25 @@ class SyncHandler(BaseHandler):
) )
room_list = yield self.store.get_rooms_for_user_where_membership_is( room_list = yield self.store.get_rooms_for_user_where_membership_is(
user_id=sync_config.user.to_string(), user_id=sync_config.user.to_string(),
membership_list=[Membership.INVITE, Membership.JOIN] membership_list=(
Membership.INVITE,
Membership.JOIN,
Membership.LEAVE,
Membership.BAN
)
) )
joined = [] joined = []
invited = [] invited = []
archived = []
for event in room_list: for event in room_list:
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
room_sync = yield self.initial_sync_for_joined_room( room_sync = yield self.full_state_sync_for_joined_room(
event.room_id, sync_config, now_token, room_id=event.room_id,
sync_config=sync_config,
now_token=now_token,
timeline_since_token=timeline_since_token,
typing_by_room=typing_by_room
) )
joined.append(room_sync) joined.append(room_sync)
elif event.membership == Membership.INVITE: elif event.membership == Membership.INVITE:
@ -173,23 +216,38 @@ class SyncHandler(BaseHandler):
room_id=event.room_id, room_id=event.room_id,
invite=invite, invite=invite,
)) ))
elif event.membership in (Membership.LEAVE, Membership.BAN):
leave_token = now_token.copy_and_replace(
"room_key", "s%d" % (event.stream_ordering,)
)
room_sync = yield self.full_state_sync_for_archived_room(
sync_config=sync_config,
room_id=event.room_id,
leave_event_id=event.event_id,
leave_token=leave_token,
timeline_since_token=timeline_since_token,
)
archived.append(room_sync)
defer.returnValue(SyncResult( defer.returnValue(SyncResult(
presence=presence, presence=presence,
joined=joined, joined=joined,
invited=invited, invited=invited,
archived=archived,
next_batch=now_token, next_batch=now_token,
)) ))
@defer.inlineCallbacks @defer.inlineCallbacks
def initial_sync_for_joined_room(self, room_id, sync_config, now_token): def full_state_sync_for_joined_room(self, room_id, sync_config,
now_token, timeline_since_token,
typing_by_room):
"""Sync a room for a client which is starting without any state """Sync a room for a client which is starting without any state
Returns: Returns:
A Deferred JoinedSyncResult. A Deferred JoinedSyncResult.
""" """
batch = yield self.load_filtered_recents( batch = yield self.load_filtered_recents(
room_id, sync_config, now_token, room_id, sync_config, now_token, since_token=timeline_since_token
) )
current_state = yield self.state_handler.get_current_state( current_state = yield self.state_handler.get_current_state(
@ -201,7 +259,61 @@ class SyncHandler(BaseHandler):
room_id=room_id, room_id=room_id,
timeline=batch, timeline=batch,
state=current_state_events, state=current_state_events,
ephemeral=[], ephemeral=typing_by_room.get(room_id, []),
))
@defer.inlineCallbacks
def typing_by_room(self, sync_config, now_token, since_token=None):
"""Get the typing events for each room the user is in
Args:
sync_config (SyncConfig): The flags, filters and user for the sync.
now_token (StreamToken): Where the server is currently up to.
since_token (StreamToken): Where the server was when the client
last synced.
Returns:
A tuple of the now StreamToken, updated to reflect the which typing
events are included, and a dict mapping from room_id to a list of
typing events for that room.
"""
typing_key = since_token.typing_key if since_token else "0"
typing_source = self.event_sources.sources["typing"]
typing, typing_key = yield typing_source.get_new_events_for_user(
user=sync_config.user,
from_key=typing_key,
limit=sync_config.filter.ephemeral_limit(),
)
now_token = now_token.copy_and_replace("typing_key", typing_key)
typing_by_room = {event["room_id"]: [event] for event in typing}
for event in typing:
event.pop("room_id")
logger.debug("Typing %r", typing_by_room)
defer.returnValue((now_token, typing_by_room))
@defer.inlineCallbacks
def full_state_sync_for_archived_room(self, room_id, sync_config,
leave_event_id, leave_token,
timeline_since_token):
"""Sync a room for a client which is starting without any state
Returns:
A Deferred JoinedSyncResult.
"""
batch = yield self.load_filtered_recents(
room_id, sync_config, leave_token, since_token=timeline_since_token
)
leave_state = yield self.store.get_state_for_events(
[leave_event_id], None
)
defer.returnValue(ArchivedSyncResult(
room_id=room_id,
timeline=batch,
state=leave_state[leave_event_id].values(),
)) ))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -221,18 +333,9 @@ class SyncHandler(BaseHandler):
) )
now_token = now_token.copy_and_replace("presence_key", presence_key) now_token = now_token.copy_and_replace("presence_key", presence_key)
typing_source = self.event_sources.sources["typing"] now_token, typing_by_room = yield self.typing_by_room(
typing, typing_key = yield typing_source.get_new_events_for_user( sync_config, now_token, since_token
user=sync_config.user,
from_key=since_token.typing_key,
limit=sync_config.filter.ephemeral_limit(),
) )
now_token = now_token.copy_and_replace("typing_key", typing_key)
typing_by_room = {event["room_id"]: [event] for event in typing}
for event in typing:
event.pop("room_id")
logger.debug("Typing %r", typing_by_room)
rm_handler = self.hs.get_handlers().room_member_handler rm_handler = self.hs.get_handlers().room_member_handler
app_service = yield self.store.get_app_service_by_user_id( app_service = yield self.store.get_app_service_by_user_id(
@ -257,18 +360,22 @@ class SyncHandler(BaseHandler):
) )
joined = [] joined = []
archived = []
if len(room_events) <= timeline_limit: if len(room_events) <= timeline_limit:
# There is no gap in any of the rooms. Therefore we can just # There is no gap in any of the rooms. Therefore we can just
# partition the new events by room and return them. # partition the new events by room and return them.
invite_events = [] invite_events = []
leave_events = []
events_by_room_id = {} events_by_room_id = {}
for event in room_events: for event in room_events:
events_by_room_id.setdefault(event.room_id, []).append(event) events_by_room_id.setdefault(event.room_id, []).append(event)
if event.room_id not in joined_room_ids: if event.room_id not in joined_room_ids:
if (event.type == EventTypes.Member if (event.type == EventTypes.Member
and event.membership == Membership.INVITE
and event.state_key == sync_config.user.to_string()): and event.state_key == sync_config.user.to_string()):
invite_events.append(event) if event.membership == Membership.INVITE:
invite_events.append(event)
elif event.membership in (Membership.LEAVE, Membership.BAN):
leave_events.append(event)
for room_id in joined_room_ids: for room_id in joined_room_ids:
recents = events_by_room_id.get(room_id, []) recents = events_by_room_id.get(room_id, [])
@ -280,7 +387,7 @@ class SyncHandler(BaseHandler):
else: else:
prev_batch = now_token prev_batch = now_token
state = yield self.check_joined_room( state, limited = yield self.check_joined_room(
sync_config, room_id, state sync_config, room_id, state
) )
@ -289,18 +396,23 @@ class SyncHandler(BaseHandler):
timeline=TimelineBatch( timeline=TimelineBatch(
events=recents, events=recents,
prev_batch=prev_batch, prev_batch=prev_batch,
limited=False, limited=limited,
), ),
state=state, state=state,
ephemeral=typing_by_room.get(room_id, []) ephemeral=typing_by_room.get(room_id, [])
) )
if room_sync: if room_sync:
joined.append(room_sync) joined.append(room_sync)
else: else:
invite_events = yield self.store.get_invites_for_user( invite_events = yield self.store.get_invites_for_user(
sync_config.user.to_string() sync_config.user.to_string()
) )
leave_events = yield self.store.get_leave_and_ban_events_for_user(
sync_config.user.to_string()
)
for room_id in joined_room_ids: for room_id in joined_room_ids:
room_sync = yield self.incremental_sync_with_gap_for_room( room_sync = yield self.incremental_sync_with_gap_for_room(
room_id, sync_config, since_token, now_token, room_id, sync_config, since_token, now_token,
@ -309,6 +421,12 @@ class SyncHandler(BaseHandler):
if room_sync: if room_sync:
joined.append(room_sync) joined.append(room_sync)
for leave_event in leave_events:
room_sync = yield self.incremental_sync_for_archived_room(
sync_config, leave_event, since_token
)
archived.append(room_sync)
invited = [ invited = [
InvitedSyncResult(room_id=event.room_id, invite=event) InvitedSyncResult(room_id=event.room_id, invite=event)
for event in invite_events for event in invite_events
@ -318,6 +436,7 @@ class SyncHandler(BaseHandler):
presence=presence, presence=presence,
joined=joined, joined=joined,
invited=invited, invited=invited,
archived=archived,
next_batch=now_token, next_batch=now_token,
)) ))
@ -401,7 +520,7 @@ class SyncHandler(BaseHandler):
current_state=current_state_events, current_state=current_state_events,
) )
state_events_delta = yield self.check_joined_room( state_events_delta, _ = yield self.check_joined_room(
sync_config, room_id, state_events_delta sync_config, room_id, state_events_delta
) )
@ -416,6 +535,55 @@ class SyncHandler(BaseHandler):
defer.returnValue(room_sync) defer.returnValue(room_sync)
@defer.inlineCallbacks
def incremental_sync_for_archived_room(self, sync_config, leave_event,
since_token):
""" Get the incremental delta needed to bring the client up to date for
the archived room.
Returns:
A Deferred ArchivedSyncResult
"""
stream_token = yield self.store.get_stream_token_for_event(
leave_event.event_id
)
leave_token = since_token.copy_and_replace("room_key", stream_token)
batch = yield self.load_filtered_recents(
leave_event.room_id, sync_config, leave_token, since_token,
)
logging.debug("Recents %r", batch)
# TODO(mjark): This seems racy since this isn't being passed a
# token to indicate what point in the stream this is
leave_state = yield self.store.get_state_for_events(
[leave_event.event_id], None
)
state_events_at_leave = leave_state[leave_event.event_id].values()
state_at_previous_sync = yield self.get_state_at_previous_sync(
leave_event.room_id, since_token=since_token
)
state_events_delta = yield self.compute_state_delta(
since_token=since_token,
previous_state=state_at_previous_sync,
current_state=state_events_at_leave,
)
room_sync = ArchivedSyncResult(
room_id=leave_event.room_id,
timeline=batch,
state=state_events_delta,
)
logging.debug("Room sync: %r", room_sync)
defer.returnValue(room_sync)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_at_previous_sync(self, room_id, since_token): def get_state_at_previous_sync(self, room_id, since_token):
""" Get the room state at the previous sync the client made. """ Get the room state at the previous sync the client made.
@ -459,6 +627,7 @@ class SyncHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def check_joined_room(self, sync_config, room_id, state_delta): def check_joined_room(self, sync_config, room_id, state_delta):
joined = False joined = False
limited = False
for event in state_delta: for event in state_delta:
if ( if (
event.type == EventTypes.Member event.type == EventTypes.Member
@ -470,5 +639,6 @@ class SyncHandler(BaseHandler):
if joined: if joined:
res = yield self.state_handler.get_current_state(room_id) res = yield self.state_handler.get_current_state(room_id)
state_delta = res.values() state_delta = res.values()
limited = True
defer.returnValue(state_delta) defer.returnValue((state_delta, limited))

View file

@ -43,6 +43,7 @@ class LoginRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(LoginRestServlet, self).__init__(hs) super(LoginRestServlet, self).__init__(hs)
self.idp_redirect_url = hs.config.saml2_idp_redirect_url self.idp_redirect_url = hs.config.saml2_idp_redirect_url
self.password_enabled = hs.config.password_enabled
self.saml2_enabled = hs.config.saml2_enabled self.saml2_enabled = hs.config.saml2_enabled
self.cas_enabled = hs.config.cas_enabled self.cas_enabled = hs.config.cas_enabled
self.cas_server_url = hs.config.cas_server_url self.cas_server_url = hs.config.cas_server_url
@ -50,11 +51,13 @@ class LoginRestServlet(ClientV1RestServlet):
self.servername = hs.config.server_name self.servername = hs.config.server_name
def on_GET(self, request): def on_GET(self, request):
flows = [{"type": LoginRestServlet.PASS_TYPE}] flows = []
if self.saml2_enabled: if self.saml2_enabled:
flows.append({"type": LoginRestServlet.SAML2_TYPE}) flows.append({"type": LoginRestServlet.SAML2_TYPE})
if self.cas_enabled: if self.cas_enabled:
flows.append({"type": LoginRestServlet.CAS_TYPE}) flows.append({"type": LoginRestServlet.CAS_TYPE})
if self.password_enabled:
flows.append({"type": LoginRestServlet.PASS_TYPE})
return (200, {"flows": flows}) return (200, {"flows": flows})
def on_OPTIONS(self, request): def on_OPTIONS(self, request):
@ -65,6 +68,9 @@ class LoginRestServlet(ClientV1RestServlet):
login_submission = _parse_json(request) login_submission = _parse_json(request)
try: try:
if login_submission["type"] == LoginRestServlet.PASS_TYPE: if login_submission["type"] == LoginRestServlet.PASS_TYPE:
if not self.password_enabled:
raise SynapseError(400, "Password login has been disabled.")
result = yield self.do_password_login(login_submission) result = yield self.do_password_login(login_submission)
defer.returnValue(result) defer.returnValue(result)
elif self.saml2_enabled and (login_submission["type"] == elif self.saml2_enabled and (login_submission["type"] ==
@ -101,6 +107,8 @@ class LoginRestServlet(ClientV1RestServlet):
user_id = yield self.hs.get_datastore().get_user_id_by_threepid( user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
login_submission['medium'], login_submission['address'] login_submission['medium'], login_submission['address']
) )
if not user_id:
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
else: else:
user_id = login_submission['user'] user_id = login_submission['user']

View file

@ -397,6 +397,41 @@ class RoomTriggerBackfill(ClientV1RestServlet):
defer.returnValue((200, res)) defer.returnValue((200, res))
class RoomEventContext(ClientV1RestServlet):
PATTERN = client_path_pattern(
"/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$"
)
def __init__(self, hs):
super(RoomEventContext, self).__init__(hs)
self.clock = hs.get_clock()
@defer.inlineCallbacks
def on_GET(self, request, room_id, event_id):
user, _ = yield self.auth.get_user_by_req(request)
limit = int(request.args.get("limit", [10])[0])
results = yield self.handlers.room_context_handler.get_event_context(
user, room_id, event_id, limit,
)
time_now = self.clock.time_msec()
results["events_before"] = [
serialize_event(event, time_now) for event in results["events_before"]
]
results["events_after"] = [
serialize_event(event, time_now) for event in results["events_after"]
]
results["state"] = [
serialize_event(event, time_now) for event in results["state"]
]
logger.info("Responding with %r", results)
defer.returnValue((200, results))
# TODO: Needs unit testing # TODO: Needs unit testing
class RoomMembershipRestServlet(ClientV1RestServlet): class RoomMembershipRestServlet(ClientV1RestServlet):
@ -555,6 +590,22 @@ class RoomTypingRestServlet(ClientV1RestServlet):
defer.returnValue((200, {})) defer.returnValue((200, {}))
class SearchRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern(
"/search$"
)
@defer.inlineCallbacks
def on_POST(self, request):
auth_user, _ = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
results = yield self.handlers.search_handler.search(auth_user, content)
defer.returnValue((200, results))
def _parse_json(request): def _parse_json(request):
try: try:
content = json.loads(request.content.read()) content = json.loads(request.content.read())
@ -611,3 +662,5 @@ def register_servlets(hs, http_server):
RoomInitialSyncRestServlet(hs).register(http_server) RoomInitialSyncRestServlet(hs).register(http_server)
RoomRedactEventRestServlet(hs).register(http_server) RoomRedactEventRestServlet(hs).register(http_server)
RoomTypingRestServlet(hs).register(http_server) RoomTypingRestServlet(hs).register(http_server)
SearchRestServlet(hs).register(http_server)
RoomEventContext(hs).register(http_server)

View file

@ -16,14 +16,14 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, parse_string, parse_integer RestServlet, parse_string, parse_integer, parse_boolean
) )
from synapse.handlers.sync import SyncConfig from synapse.handlers.sync import SyncConfig
from synapse.types import StreamToken from synapse.types import StreamToken
from synapse.events.utils import ( from synapse.events.utils import (
serialize_event, format_event_for_client_v2_without_event_id, serialize_event, format_event_for_client_v2_without_event_id,
) )
from synapse.api.filtering import Filter from synapse.api.filtering import FilterCollection
from ._base import client_v2_pattern from ._base import client_v2_pattern
import copy import copy
@ -90,6 +90,7 @@ class SyncRestServlet(RestServlet):
allowed_values=self.ALLOWED_PRESENCE allowed_values=self.ALLOWED_PRESENCE
) )
filter_id = parse_string(request, "filter", default=None) filter_id = parse_string(request, "filter", default=None)
full_state = parse_boolean(request, "full_state", default=False)
logger.info( logger.info(
"/sync: user=%r, timeout=%r, since=%r," "/sync: user=%r, timeout=%r, since=%r,"
@ -103,7 +104,7 @@ class SyncRestServlet(RestServlet):
user.localpart, filter_id user.localpart, filter_id
) )
except: except:
filter = Filter({}) filter = FilterCollection({})
sync_config = SyncConfig( sync_config = SyncConfig(
user=user, user=user,
@ -120,7 +121,8 @@ class SyncRestServlet(RestServlet):
try: try:
sync_result = yield self.sync_handler.wait_for_sync_for_user( sync_result = yield self.sync_handler.wait_for_sync_for_user(
sync_config, since_token=since_token, timeout=timeout sync_config, since_token=since_token, timeout=timeout,
full_state=full_state
) )
finally: finally:
if set_presence == "online": if set_presence == "online":
@ -136,6 +138,10 @@ class SyncRestServlet(RestServlet):
sync_result.invited, filter, time_now, token_id sync_result.invited, filter, time_now, token_id
) )
archived = self.encode_archived(
sync_result.archived, filter, time_now, token_id
)
response_content = { response_content = {
"presence": self.encode_presence( "presence": self.encode_presence(
sync_result.presence, filter, time_now sync_result.presence, filter, time_now
@ -143,7 +149,7 @@ class SyncRestServlet(RestServlet):
"rooms": { "rooms": {
"joined": joined, "joined": joined,
"invited": invited, "invited": invited,
"archived": {}, "archived": archived,
}, },
"next_batch": sync_result.next_batch.to_string(), "next_batch": sync_result.next_batch.to_string(),
} }
@ -182,14 +188,20 @@ class SyncRestServlet(RestServlet):
return invited return invited
def encode_archived(self, rooms, filter, time_now, token_id):
joined = {}
for room in rooms:
joined[room.room_id] = self.encode_room(
room, filter, time_now, token_id, joined=False
)
return joined
@staticmethod @staticmethod
def encode_room(room, filter, time_now, token_id): def encode_room(room, filter, time_now, token_id, joined=True):
event_map = {} event_map = {}
state_events = filter.filter_room_state(room.state) state_events = filter.filter_room_state(room.state)
timeline_events = filter.filter_room_timeline(room.timeline.events)
ephemeral_events = filter.filter_room_ephemeral(room.ephemeral)
state_event_ids = [] state_event_ids = []
timeline_event_ids = []
for event in state_events: for event in state_events:
# TODO(mjark): Respect formatting requirements in the filter. # TODO(mjark): Respect formatting requirements in the filter.
event_map[event.event_id] = serialize_event( event_map[event.event_id] = serialize_event(
@ -198,6 +210,8 @@ class SyncRestServlet(RestServlet):
) )
state_event_ids.append(event.event_id) state_event_ids.append(event.event_id)
timeline_events = filter.filter_room_timeline(room.timeline.events)
timeline_event_ids = []
for event in timeline_events: for event in timeline_events:
# TODO(mjark): Respect formatting requirements in the filter. # TODO(mjark): Respect formatting requirements in the filter.
event_map[event.event_id] = serialize_event( event_map[event.event_id] = serialize_event(
@ -205,6 +219,7 @@ class SyncRestServlet(RestServlet):
event_format=format_event_for_client_v2_without_event_id, event_format=format_event_for_client_v2_without_event_id,
) )
timeline_event_ids.append(event.event_id) timeline_event_ids.append(event.event_id)
result = { result = {
"event_map": event_map, "event_map": event_map,
"timeline": { "timeline": {
@ -213,8 +228,12 @@ class SyncRestServlet(RestServlet):
"limited": room.timeline.limited, "limited": room.timeline.limited,
}, },
"state": {"events": state_event_ids}, "state": {"events": state_event_ids},
"ephemeral": {"events": ephemeral_events},
} }
if joined:
ephemeral_events = filter.filter_room_ephemeral(room.ephemeral)
result["ephemeral"] = {"events": ephemeral_events}
return result return result

View file

@ -40,6 +40,7 @@ from .filtering import FilteringStore
from .end_to_end_keys import EndToEndKeyStore from .end_to_end_keys import EndToEndKeyStore
from .receipts import ReceiptsStore from .receipts import ReceiptsStore
from .search import SearchStore
import logging import logging
@ -69,6 +70,7 @@ class DataStore(RoomMemberStore, RoomStore,
EventsStore, EventsStore,
ReceiptsStore, ReceiptsStore,
EndToEndKeyStore, EndToEndKeyStore,
SearchStore,
): ):
def __init__(self, hs): def __init__(self, hs):

View file

@ -519,7 +519,7 @@ class SQLBaseStore(object):
allow_none=False, allow_none=False,
desc="_simple_select_one_onecol"): desc="_simple_select_one_onecol"):
"""Executes a SELECT query on the named table, which is expected to """Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it." return a single row, returning a single column from it.
Args: Args:
table : string giving the table name table : string giving the table name

View file

@ -17,6 +17,8 @@ from synapse.storage.prepare_database import (
prepare_database, prepare_sqlite3_database prepare_database, prepare_sqlite3_database
) )
import struct
class Sqlite3Engine(object): class Sqlite3Engine(object):
single_threaded = True single_threaded = True
@ -32,6 +34,7 @@ class Sqlite3Engine(object):
def on_new_connection(self, db_conn): def on_new_connection(self, db_conn):
self.prepare_database(db_conn) self.prepare_database(db_conn)
db_conn.create_function("rank", 1, _rank)
def prepare_database(self, db_conn): def prepare_database(self, db_conn):
prepare_sqlite3_database(db_conn) prepare_sqlite3_database(db_conn)
@ -45,3 +48,27 @@ class Sqlite3Engine(object):
def lock_table(self, txn, table): def lock_table(self, txn, table):
return return
# Following functions taken from: https://github.com/coleifer/peewee
def _parse_match_info(buf):
bufsize = len(buf)
return [struct.unpack('@I', buf[i:i+4])[0] for i in range(0, bufsize, 4)]
def _rank(raw_match_info):
"""Handle match_info called w/default args 'pcx' - based on the example rank
function http://sqlite.org/fts3.html#appendix_a
"""
match_info = _parse_match_info(raw_match_info)
score = 0.0
p, c = match_info[:2]
for phrase_num in range(p):
phrase_info_idx = 2 + (phrase_num * c * 3)
for col_num in range(c):
col_idx = phrase_info_idx + (col_num * 3)
x1, x2 = match_info[col_idx:col_idx + 2]
if x1 > 0:
score += float(x1) / x2
return score

View file

@ -307,6 +307,8 @@ class EventsStore(SQLBaseStore):
self._store_room_name_txn(txn, event) self._store_room_name_txn(txn, event)
elif event.type == EventTypes.Topic: elif event.type == EventTypes.Topic:
self._store_room_topic_txn(txn, event) self._store_room_topic_txn(txn, event)
elif event.type == EventTypes.Message:
self._store_room_message_txn(txn, event)
elif event.type == EventTypes.Redaction: elif event.type == EventTypes.Redaction:
self._store_redaction(txn, event) self._store_redaction(txn, event)

View file

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 24 SCHEMA_VERSION = 25
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))

View file

@ -19,6 +19,7 @@ from synapse.api.errors import StoreError
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cachedInlineCallbacks
from .engines import PostgresEngine, Sqlite3Engine
import collections import collections
import logging import logging
@ -175,6 +176,10 @@ class RoomStore(SQLBaseStore):
}, },
) )
self._store_event_search_txn(
txn, event, "content.topic", event.content["topic"]
)
def _store_room_name_txn(self, txn, event): def _store_room_name_txn(self, txn, event):
if hasattr(event, "content") and "name" in event.content: if hasattr(event, "content") and "name" in event.content:
self._simple_insert_txn( self._simple_insert_txn(
@ -187,6 +192,33 @@ class RoomStore(SQLBaseStore):
} }
) )
self._store_event_search_txn(
txn, event, "content.name", event.content["name"]
)
def _store_room_message_txn(self, txn, event):
if hasattr(event, "content") and "body" in event.content:
self._store_event_search_txn(
txn, event, "content.body", event.content["body"]
)
def _store_event_search_txn(self, txn, event, key, value):
if isinstance(self.database_engine, PostgresEngine):
sql = (
"INSERT INTO event_search (event_id, room_id, key, vector)"
" VALUES (?,?,?,to_tsvector('english', ?))"
)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
"INSERT INTO event_search (event_id, room_id, key, value)"
" VALUES (?,?,?,?)"
)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
txn.execute(sql, (event.event_id, event.room_id, key, value,))
@cachedInlineCallbacks() @cachedInlineCallbacks()
def get_room_name_and_aliases(self, room_id): def get_room_name_and_aliases(self, room_id):
def f(txn): def f(txn):

View file

@ -124,6 +124,19 @@ class RoomMemberStore(SQLBaseStore):
invites.event_id for invite in invites invites.event_id for invite in invites
])) ]))
def get_leave_and_ban_events_for_user(self, user_id):
""" Get all the leave events for a user
Args:
user_id (str): The user ID.
Returns:
A deferred list of event objects.
"""
return self.get_rooms_for_user_where_membership_is(
user_id, (Membership.LEAVE, Membership.BAN)
).addCallback(lambda leaves: self._get_events([
leave.event_id for leave in leaves
]))
def get_rooms_for_user_where_membership_is(self, user_id, membership_list): def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
""" Get all the rooms for this user where the membership for this user """ Get all the rooms for this user where the membership for this user
matches one in the membership list. matches one in the membership list.

View file

@ -0,0 +1,127 @@
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from synapse.storage.prepare_database import get_statements
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
import ujson
logger = logging.getLogger(__name__)
POSTGRES_SQL = """
CREATE TABLE IF NOT EXISTS event_search (
event_id TEXT,
room_id TEXT,
sender TEXT,
key TEXT,
vector tsvector
);
INSERT INTO event_search SELECT
event_id, room_id, json::json->>'sender', 'content.body',
to_tsvector('english', json::json->'content'->>'body')
FROM events NATURAL JOIN event_json WHERE type = 'm.room.message';
INSERT INTO event_search SELECT
event_id, room_id, json::json->>'sender', 'content.name',
to_tsvector('english', json::json->'content'->>'name')
FROM events NATURAL JOIN event_json WHERE type = 'm.room.name';
INSERT INTO event_search SELECT
event_id, room_id, json::json->>'sender', 'content.topic',
to_tsvector('english', json::json->'content'->>'topic')
FROM events NATURAL JOIN event_json WHERE type = 'm.room.topic';
CREATE INDEX event_search_fts_idx ON event_search USING gin(vector);
CREATE INDEX event_search_ev_idx ON event_search(event_id);
CREATE INDEX event_search_ev_ridx ON event_search(room_id);
"""
SQLITE_TABLE = (
"CREATE VIRTUAL TABLE IF NOT EXISTS event_search"
" USING fts4 ( event_id, room_id, sender, key, value )"
)
def run_upgrade(cur, database_engine, *args, **kwargs):
if isinstance(database_engine, PostgresEngine):
run_postgres_upgrade(cur)
return
if isinstance(database_engine, Sqlite3Engine):
run_sqlite_upgrade(cur)
return
def run_postgres_upgrade(cur):
for statement in get_statements(POSTGRES_SQL.splitlines()):
cur.execute(statement)
def run_sqlite_upgrade(cur):
cur.execute(SQLITE_TABLE)
rowid = -1
while True:
cur.execute(
"SELECT rowid, json FROM event_json"
" WHERE rowid > ?"
" ORDER BY rowid ASC LIMIT 100",
(rowid,)
)
res = cur.fetchall()
if not res:
break
events = [
ujson.loads(js)
for _, js in res
]
rowid = max(rid for rid, _ in res)
rows = []
for ev in events:
content = ev.get("content", {})
body = content.get("body", None)
name = content.get("name", None)
topic = content.get("topic", None)
sender = ev.get("sender", None)
if ev["type"] == "m.room.message" and body:
rows.append((
ev["event_id"], ev["room_id"], sender, "content.body", body
))
if ev["type"] == "m.room.name" and name:
rows.append((
ev["event_id"], ev["room_id"], sender, "content.name", name
))
if ev["type"] == "m.room.topic" and topic:
rows.append((
ev["event_id"], ev["room_id"], sender, "content.topic", topic
))
if rows:
logger.info(rows)
cur.executemany(
"INSERT INTO event_search (event_id, room_id, sender, key, value)"
" VALUES (?,?,?,?,?)",
rows
)

111
synapse/storage/search.py Normal file
View file

@ -0,0 +1,111 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from _base import SQLBaseStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from collections import namedtuple
"""The result of a search.
Fields:
rank_map (dict): Mapping event_id -> rank
event_map (dict): Mapping event_id -> event
pagination_token (str): Pagination token
"""
SearchResult = namedtuple("SearchResult", ("rank_map", "event_map", "pagination_token"))
class SearchStore(SQLBaseStore):
@defer.inlineCallbacks
def search_msgs(self, room_ids, search_term, keys):
"""Performs a full text search over events with given keys.
Args:
room_ids (list): List of room ids to search in
search_term (str): Search term to search for
keys (list): List of keys to search in, currently supports
"content.body", "content.name", "content.topic"
Returns:
SearchResult
"""
clauses = []
args = []
# Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless.
if len(room_ids) < 500:
clauses.append(
"room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)
)
args.extend(room_ids)
local_clauses = []
for key in keys:
local_clauses.append("key = ?")
args.append(key)
clauses.append(
"(%s)" % (" OR ".join(local_clauses),)
)
if isinstance(self.database_engine, PostgresEngine):
sql = (
"SELECT ts_rank_cd(vector, query) AS rank, room_id, event_id"
" FROM plainto_tsquery('english', ?) as query, event_search"
" WHERE vector @@ query"
)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
"SELECT rank(matchinfo(event_search)) as rank, room_id, event_id"
" FROM event_search"
" WHERE value MATCH ?"
)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
for clause in clauses:
sql += " AND " + clause
# We add an arbitrary limit here to ensure we don't try to pull the
# entire table from the database.
sql += " ORDER BY rank DESC LIMIT 500"
results = yield self._execute(
"search_msgs", self.cursor_to_dict, sql, *([search_term] + args)
)
results = filter(lambda row: row["room_id"] in room_ids, results)
events = yield self._get_events([r["event_id"] for r in results])
event_map = {
ev.event_id: ev
for ev in events
}
defer.returnValue(SearchResult(
{
r["event_id"]: r["rank"]
for r in results
if r["event_id"] in event_map
},
event_map,
None
))

View file

@ -214,7 +214,6 @@ class StateStore(SQLBaseStore):
that are in the `types` list. that are in the `types` list.
Args: Args:
room_id (str)
event_ids (list) event_ids (list)
types (list): List of (type, state_key) tuples which are used to types (list): List of (type, state_key) tuples which are used to
filter the state fetched. `state_key` may be None, which matches filter the state fetched. `state_key` may be None, which matches

View file

@ -23,7 +23,7 @@ paginate bacwards.
This is implemented by keeping two ordering columns: stream_ordering and This is implemented by keeping two ordering columns: stream_ordering and
topological_ordering. Stream ordering is basically insertion/received order topological_ordering. Stream ordering is basically insertion/received order
(except for events from backfill requests). The topolgical_ordering is a (except for events from backfill requests). The topological_ordering is a
weak ordering of events based on the pdu graph. weak ordering of events based on the pdu graph.
This means that we have to have two different types of tokens, depending on This means that we have to have two different types of tokens, depending on
@ -436,3 +436,138 @@ class StreamStore(SQLBaseStore):
internal = event.internal_metadata internal = event.internal_metadata
internal.before = str(RoomStreamToken(topo, stream - 1)) internal.before = str(RoomStreamToken(topo, stream - 1))
internal.after = str(RoomStreamToken(topo, stream)) internal.after = str(RoomStreamToken(topo, stream))
@defer.inlineCallbacks
def get_events_around(self, room_id, event_id, before_limit, after_limit):
"""Retrieve events and pagination tokens around a given event in a
room.
Args:
room_id (str)
event_id (str)
before_limit (int)
after_limit (int)
Returns:
dict
"""
results = yield self.runInteraction(
"get_events_around", self._get_events_around_txn,
room_id, event_id, before_limit, after_limit
)
events_before = yield self._get_events(
[e for e in results["before"]["event_ids"]],
get_prev_content=True
)
events_after = yield self._get_events(
[e for e in results["after"]["event_ids"]],
get_prev_content=True
)
defer.returnValue({
"events_before": events_before,
"events_after": events_after,
"start": results["before"]["token"],
"end": results["after"]["token"],
})
def _get_events_around_txn(self, txn, room_id, event_id, before_limit, after_limit):
"""Retrieves event_ids and pagination tokens around a given event in a
room.
Args:
room_id (str)
event_id (str)
before_limit (int)
after_limit (int)
Returns:
dict
"""
results = self._simple_select_one_txn(
txn,
"events",
keyvalues={
"event_id": event_id,
"room_id": room_id,
},
retcols=["stream_ordering", "topological_ordering"],
)
stream_ordering = results["stream_ordering"]
topological_ordering = results["topological_ordering"]
query_before = (
"SELECT topological_ordering, stream_ordering, event_id FROM events"
" WHERE room_id = ? AND (topological_ordering < ?"
" OR (topological_ordering = ? AND stream_ordering < ?))"
" ORDER BY topological_ordering DESC, stream_ordering DESC"
" LIMIT ?"
)
query_after = (
"SELECT topological_ordering, stream_ordering, event_id FROM events"
" WHERE room_id = ? AND (topological_ordering > ?"
" OR (topological_ordering = ? AND stream_ordering > ?))"
" ORDER BY topological_ordering ASC, stream_ordering ASC"
" LIMIT ?"
)
txn.execute(
query_before,
(
room_id, topological_ordering, topological_ordering,
stream_ordering, before_limit,
)
)
rows = self.cursor_to_dict(txn)
events_before = [r["event_id"] for r in rows]
if rows:
start_token = str(RoomStreamToken(
rows[0]["topological_ordering"],
rows[0]["stream_ordering"] - 1,
))
else:
start_token = str(RoomStreamToken(
topological_ordering,
stream_ordering - 1,
))
txn.execute(
query_after,
(
room_id, topological_ordering, topological_ordering,
stream_ordering, after_limit,
)
)
rows = self.cursor_to_dict(txn)
events_after = [r["event_id"] for r in rows]
if rows:
end_token = str(RoomStreamToken(
rows[-1]["topological_ordering"],
rows[-1]["stream_ordering"],
))
else:
end_token = str(RoomStreamToken(
topological_ordering,
stream_ordering,
))
return {
"before": {
"event_ids": events_before,
"token": start_token,
},
"after": {
"event_ids": events_after,
"token": end_token,
},
}

View file

@ -47,7 +47,7 @@ class DomainSpecificString(
@classmethod @classmethod
def from_string(cls, s): def from_string(cls, s):
"""Parse the string given by 's' into a structure object.""" """Parse the string given by 's' into a structure object."""
if s[0] != cls.SIGIL: if len(s) < 1 or s[0] != cls.SIGIL:
raise SynapseError(400, "Expected %s string to start with '%s'" % ( raise SynapseError(400, "Expected %s string to start with '%s'" % (
cls.__name__, cls.SIGIL, cls.__name__, cls.SIGIL,
)) ))

View file

@ -23,8 +23,8 @@ JOIN_KEYS = {
"token", "token",
"public_key", "public_key",
"key_validity_url", "key_validity_url",
"signatures",
"sender", "sender",
"signed",
} }

View file

@ -23,10 +23,17 @@ from tests.utils import (
) )
from synapse.types import UserID from synapse.types import UserID
from synapse.api.filtering import Filter from synapse.api.filtering import FilterCollection, Filter
user_localpart = "test_user" user_localpart = "test_user"
MockEvent = namedtuple("MockEvent", "sender type room_id") # MockEvent = namedtuple("MockEvent", "sender type room_id")
def MockEvent(**kwargs):
ev = NonCallableMock(spec_set=kwargs.keys())
ev.configure_mock(**kwargs)
return ev
class FilteringTestCase(unittest.TestCase): class FilteringTestCase(unittest.TestCase):
@ -44,7 +51,6 @@ class FilteringTestCase(unittest.TestCase):
) )
self.filtering = hs.get_filtering() self.filtering = hs.get_filtering()
self.filter = Filter({})
self.datastore = hs.get_datastore() self.datastore = hs.get_datastore()
@ -57,8 +63,9 @@ class FilteringTestCase(unittest.TestCase):
type="m.room.message", type="m.room.message",
room_id="!foo:bar" room_id="!foo:bar"
) )
self.assertTrue( self.assertTrue(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_types_works_with_wildcards(self): def test_definition_types_works_with_wildcards(self):
@ -71,7 +78,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar" room_id="!foo:bar"
) )
self.assertTrue( self.assertTrue(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_types_works_with_unknowns(self): def test_definition_types_works_with_unknowns(self):
@ -84,7 +91,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar" room_id="!foo:bar"
) )
self.assertFalse( self.assertFalse(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_not_types_works_with_literals(self): def test_definition_not_types_works_with_literals(self):
@ -97,7 +104,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar" room_id="!foo:bar"
) )
self.assertFalse( self.assertFalse(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_not_types_works_with_wildcards(self): def test_definition_not_types_works_with_wildcards(self):
@ -110,7 +117,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar" room_id="!foo:bar"
) )
self.assertFalse( self.assertFalse(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_not_types_works_with_unknowns(self): def test_definition_not_types_works_with_unknowns(self):
@ -123,7 +130,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar" room_id="!foo:bar"
) )
self.assertTrue( self.assertTrue(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_not_types_takes_priority_over_types(self): def test_definition_not_types_takes_priority_over_types(self):
@ -137,7 +144,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar" room_id="!foo:bar"
) )
self.assertFalse( self.assertFalse(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_senders_works_with_literals(self): def test_definition_senders_works_with_literals(self):
@ -150,7 +157,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar" room_id="!foo:bar"
) )
self.assertTrue( self.assertTrue(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_senders_works_with_unknowns(self): def test_definition_senders_works_with_unknowns(self):
@ -163,7 +170,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar" room_id="!foo:bar"
) )
self.assertFalse( self.assertFalse(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_not_senders_works_with_literals(self): def test_definition_not_senders_works_with_literals(self):
@ -176,7 +183,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar" room_id="!foo:bar"
) )
self.assertFalse( self.assertFalse(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_not_senders_works_with_unknowns(self): def test_definition_not_senders_works_with_unknowns(self):
@ -189,7 +196,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar" room_id="!foo:bar"
) )
self.assertTrue( self.assertTrue(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_not_senders_takes_priority_over_senders(self): def test_definition_not_senders_takes_priority_over_senders(self):
@ -203,7 +210,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar" room_id="!foo:bar"
) )
self.assertFalse( self.assertFalse(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_rooms_works_with_literals(self): def test_definition_rooms_works_with_literals(self):
@ -216,7 +223,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!secretbase:unknown" room_id="!secretbase:unknown"
) )
self.assertTrue( self.assertTrue(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_rooms_works_with_unknowns(self): def test_definition_rooms_works_with_unknowns(self):
@ -229,7 +236,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!anothersecretbase:unknown" room_id="!anothersecretbase:unknown"
) )
self.assertFalse( self.assertFalse(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_not_rooms_works_with_literals(self): def test_definition_not_rooms_works_with_literals(self):
@ -242,7 +249,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!anothersecretbase:unknown" room_id="!anothersecretbase:unknown"
) )
self.assertFalse( self.assertFalse(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_not_rooms_works_with_unknowns(self): def test_definition_not_rooms_works_with_unknowns(self):
@ -255,7 +262,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!anothersecretbase:unknown" room_id="!anothersecretbase:unknown"
) )
self.assertTrue( self.assertTrue(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_not_rooms_takes_priority_over_rooms(self): def test_definition_not_rooms_takes_priority_over_rooms(self):
@ -269,7 +276,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!secretbase:unknown" room_id="!secretbase:unknown"
) )
self.assertFalse( self.assertFalse(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_combined_event(self): def test_definition_combined_event(self):
@ -287,7 +294,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!stage:unknown" # yup room_id="!stage:unknown" # yup
) )
self.assertTrue( self.assertTrue(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_combined_event_bad_sender(self): def test_definition_combined_event_bad_sender(self):
@ -305,7 +312,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!stage:unknown" # yup room_id="!stage:unknown" # yup
) )
self.assertFalse( self.assertFalse(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_combined_event_bad_room(self): def test_definition_combined_event_bad_room(self):
@ -323,7 +330,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!piggyshouse:muppets" # nope room_id="!piggyshouse:muppets" # nope
) )
self.assertFalse( self.assertFalse(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_combined_event_bad_type(self): def test_definition_combined_event_bad_type(self):
@ -341,7 +348,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!stage:unknown" # yup room_id="!stage:unknown" # yup
) )
self.assertFalse( self.assertFalse(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -359,7 +366,6 @@ class FilteringTestCase(unittest.TestCase):
event = MockEvent( event = MockEvent(
sender="@foo:bar", sender="@foo:bar",
type="m.profile", type="m.profile",
room_id="!foo:bar"
) )
events = [event] events = [event]
@ -386,7 +392,6 @@ class FilteringTestCase(unittest.TestCase):
event = MockEvent( event = MockEvent(
sender="@foo:bar", sender="@foo:bar",
type="custom.avatar.3d.crazy", type="custom.avatar.3d.crazy",
room_id="!foo:bar"
) )
events = [event] events = [event]

15
tests/crypto/__init__.py Normal file
View file

@ -0,0 +1,15 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,114 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests import unittest
from synapse.events.builder import EventBuilder
from synapse.crypto.event_signing import add_hashes_and_signatures
from unpaddedbase64 import decode_base64
import nacl.signing
# Perform these tests using given secret key so we get entirely deterministic
# signatures output that we can test against.
SIGNING_KEY_SEED = decode_base64(
"YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1"
)
KEY_ALG = "ed25519"
KEY_VER = 1
KEY_NAME = "%s:%d" % (KEY_ALG, KEY_VER)
HOSTNAME = "domain"
class EventSigningTestCase(unittest.TestCase):
def setUp(self):
self.signing_key = nacl.signing.SigningKey(SIGNING_KEY_SEED)
self.signing_key.alg = KEY_ALG
self.signing_key.version = KEY_VER
def test_sign_minimal(self):
builder = EventBuilder(
{
'event_id': "$0:domain",
'origin': "domain",
'origin_server_ts': 1000000,
'signatures': {},
'type': "X",
'unsigned': {'age_ts': 1000000},
},
)
add_hashes_and_signatures(builder, HOSTNAME, self.signing_key)
event = builder.build()
self.assertTrue(hasattr(event, 'hashes'))
self.assertIn('sha256', event.hashes)
self.assertEquals(
event.hashes['sha256'],
"6tJjLpXtggfke8UxFhAKg82QVkJzvKOVOOSjUDK4ZSI",
)
self.assertTrue(hasattr(event, 'signatures'))
self.assertIn(HOSTNAME, event.signatures)
self.assertIn(KEY_NAME, event.signatures["domain"])
self.assertEquals(
event.signatures[HOSTNAME][KEY_NAME],
"2Wptgo4CwmLo/Y8B8qinxApKaCkBG2fjTWB7AbP5Uy+"
"aIbygsSdLOFzvdDjww8zUVKCmI02eP9xtyJxc/cLiBA",
)
def test_sign_message(self):
builder = EventBuilder(
{
'content': {
'body': "Here is the message content",
},
'event_id': "$0:domain",
'origin': "domain",
'origin_server_ts': 1000000,
'type': "m.room.message",
'room_id': "!r:domain",
'sender': "@u:domain",
'signatures': {},
'unsigned': {'age_ts': 1000000},
}
)
add_hashes_and_signatures(builder, HOSTNAME, self.signing_key)
event = builder.build()
self.assertTrue(hasattr(event, 'hashes'))
self.assertIn('sha256', event.hashes)
self.assertEquals(
event.hashes['sha256'],
"onLKD1bGljeBWQhWZ1kaP9SorVmRQNdN5aM2JYU2n/g",
)
self.assertTrue(hasattr(event, 'signatures'))
self.assertIn(HOSTNAME, event.signatures)
self.assertIn(KEY_NAME, event.signatures["domain"])
self.assertEquals(
event.signatures[HOSTNAME][KEY_NAME],
"Wm+VzmOUOz08Ds+0NTWb1d4CZrVsJSikkeRxh6aCcUw"
"u6pNC78FunoD7KNWzqFn241eYHYMGCA5McEiVPdhzBA"
)

0
tests/events/__init__.py Normal file
View file

115
tests/events/test_utils.py Normal file
View file

@ -0,0 +1,115 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .. import unittest
from synapse.events import FrozenEvent
from synapse.events.utils import prune_event
class PruneEventTestCase(unittest.TestCase):
""" Asserts that a new event constructed with `evdict` will look like
`matchdict` when it is redacted. """
def run_test(self, evdict, matchdict):
self.assertEquals(
prune_event(FrozenEvent(evdict)).get_dict(),
matchdict
)
def test_minimal(self):
self.run_test(
{'type': 'A'},
{
'type': 'A',
'content': {},
'signatures': {},
'unsigned': {},
}
)
def test_basic_keys(self):
self.run_test(
{
'type': 'A',
'room_id': '!1:domain',
'sender': '@2:domain',
'event_id': '$3:domain',
'origin': 'domain',
},
{
'type': 'A',
'room_id': '!1:domain',
'sender': '@2:domain',
'event_id': '$3:domain',
'origin': 'domain',
'content': {},
'signatures': {},
'unsigned': {},
}
)
def test_unsigned_age_ts(self):
self.run_test(
{
'type': 'B',
'unsigned': {'age_ts': 20},
},
{
'type': 'B',
'content': {},
'signatures': {},
'unsigned': {'age_ts': 20},
}
)
self.run_test(
{
'type': 'B',
'unsigned': {'other_key': 'here'},
},
{
'type': 'B',
'content': {},
'signatures': {},
'unsigned': {},
}
)
def test_content(self):
self.run_test(
{
'type': 'C',
'content': {'things': 'here'},
},
{
'type': 'C',
'content': {},
'signatures': {},
'unsigned': {},
}
)
self.run_test(
{
'type': 'm.room.create',
'content': {'creator': '@2:domain', 'other_field': 'here'},
},
{
'type': 'm.room.create',
'content': {'creator': '@2:domain'},
'signatures': {},
'unsigned': {},
}
)

View file

@ -277,10 +277,10 @@ class RoomPermissionsTestCase(RestTestCase):
expect_code=403) expect_code=403)
# set [invite/join/left] of self, set [invite/join/left] of other, # set [invite/join/left] of self, set [invite/join/left] of other,
# expect all 403s # expect all 404s because room doesn't exist on any server
for usr in [self.user_id, self.rmcreator_id]: for usr in [self.user_id, self.rmcreator_id]:
yield self.join(room=room, user=usr, expect_code=404) yield self.join(room=room, user=usr, expect_code=404)
yield self.leave(room=room, user=usr, expect_code=403) yield self.leave(room=room, user=usr, expect_code=404)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_membership_private_room_perms(self): def test_membership_private_room_perms(self):

View file

@ -15,13 +15,14 @@
from tests import unittest from tests import unittest
from synapse.api.errors import SynapseError
from synapse.server import BaseHomeServer from synapse.server import BaseHomeServer
from synapse.types import UserID, RoomAlias from synapse.types import UserID, RoomAlias
mock_homeserver = BaseHomeServer(hostname="my.domain") mock_homeserver = BaseHomeServer(hostname="my.domain")
class UserIDTestCase(unittest.TestCase):
class UserIDTestCase(unittest.TestCase):
def test_parse(self): def test_parse(self):
user = UserID.from_string("@1234abcd:my.domain") user = UserID.from_string("@1234abcd:my.domain")
@ -29,6 +30,11 @@ class UserIDTestCase(unittest.TestCase):
self.assertEquals("my.domain", user.domain) self.assertEquals("my.domain", user.domain)
self.assertEquals(True, mock_homeserver.is_mine(user)) self.assertEquals(True, mock_homeserver.is_mine(user))
def test_pase_empty(self):
with self.assertRaises(SynapseError):
UserID.from_string("")
def test_build(self): def test_build(self):
user = UserID("5678efgh", "my.domain") user = UserID("5678efgh", "my.domain")
@ -44,7 +50,6 @@ class UserIDTestCase(unittest.TestCase):
class RoomAliasTestCase(unittest.TestCase): class RoomAliasTestCase(unittest.TestCase):
def test_parse(self): def test_parse(self):
room = RoomAlias.from_string("#channel:my.domain") room = RoomAlias.from_string("#channel:my.domain")

View file

@ -19,6 +19,7 @@ commands =
check-manifest check-manifest
[testenv:pep8] [testenv:pep8]
skip_install = True
basepython = python2.7 basepython = python2.7
deps = deps =
flake8 flake8