0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-06 14:48:56 +01:00

Merge branch 'develop' into markjh/v2_sync_typing

Conflicts:
	synapse/handlers/sync.py
This commit is contained in:
Mark Haines 2015-10-21 15:48:34 +01:00
commit 5201c66108
27 changed files with 1027 additions and 214 deletions

View file

@ -14,7 +14,8 @@
# limitations under the License. # limitations under the License.
"""This module contains classes for authenticating the user.""" """This module contains classes for authenticating the user."""
from nacl.exceptions import BadSignatureError from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json, SignatureVerifyException
from twisted.internet import defer from twisted.internet import defer
@ -26,7 +27,6 @@ from synapse.util import third_party_invites
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
import logging import logging
import nacl.signing
import pymacaroons import pymacaroons
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -308,7 +308,11 @@ class Auth(object):
) )
if Membership.JOIN != membership: if Membership.JOIN != membership:
# JOIN is the only action you can perform if you're not in the room if (caller_invited
and Membership.LEAVE == membership
and target_user_id == event.user_id):
return True
if not caller_in_room: # caller isn't joined if not caller_in_room: # caller isn't joined
raise AuthError( raise AuthError(
403, 403,
@ -416,16 +420,23 @@ class Auth(object):
key_validity_url key_validity_url
) )
return False return False
for _, signature_block in join_third_party_invite["signatures"].items(): signed = join_third_party_invite["signed"]
if signed["mxid"] != event.user_id:
return False
if signed["token"] != token:
return False
for server, signature_block in signed["signatures"].items():
for key_name, encoded_signature in signature_block.items(): for key_name, encoded_signature in signature_block.items():
if not key_name.startswith("ed25519:"): if not key_name.startswith("ed25519:"):
return False return False
verify_key = nacl.signing.VerifyKey(decode_base64(public_key)) verify_key = decode_verify_key_bytes(
signature = decode_base64(encoded_signature) key_name,
verify_key.verify(token, signature) decode_base64(public_key)
)
verify_signed_json(signed, server, verify_key)
return True return True
return False return False
except (KeyError, BadSignatureError,): except (KeyError, SignatureVerifyException,):
return False return False
def _get_power_level_event(self, auth_events): def _get_power_level_event(self, auth_events):

View file

@ -66,7 +66,6 @@ def prune_event(event):
"users_default", "users_default",
"events", "events",
"events_default", "events_default",
"events_default",
"state_default", "state_default",
"ban", "ban",
"kick", "kick",

View file

@ -17,6 +17,7 @@
from twisted.internet import defer from twisted.internet import defer
from .federation_base import FederationBase from .federation_base import FederationBase
from synapse.api.constants import Membership
from .units import Edu from .units import Edu
from synapse.api.errors import ( from synapse.api.errors import (
@ -357,7 +358,34 @@ class FederationClient(FederationBase):
defer.returnValue(signed_auth) defer.returnValue(signed_auth)
@defer.inlineCallbacks @defer.inlineCallbacks
def make_join(self, destinations, room_id, user_id, content): def make_membership_event(self, destinations, room_id, user_id, membership, content):
"""
Creates an m.room.member event, with context, without participating in the room.
Does so by asking one of the already participating servers to create an
event with proper context.
Note that this does not append any events to any graphs.
Args:
destinations (str): Candidate homeservers which are probably
participating in the room.
room_id (str): The room in which the event will happen.
user_id (str): The user whose membership is being evented.
membership (str): The "membership" property of the event. Must be
one of "join" or "leave".
content (object): Any additional data to put into the content field
of the event.
Return:
A tuple of (origin (str), event (object)) where origin is the remote
homeserver which generated the event.
"""
valid_memberships = {Membership.JOIN, Membership.LEAVE}
if membership not in valid_memberships:
raise RuntimeError(
"make_membership_event called with membership='%s', must be one of %s" %
(membership, ",".join(valid_memberships))
)
for destination in destinations: for destination in destinations:
if destination == self.server_name: if destination == self.server_name:
continue continue
@ -368,13 +396,13 @@ class FederationClient(FederationBase):
content["third_party_invite"] content["third_party_invite"]
) )
try: try:
ret = yield self.transport_layer.make_join( ret = yield self.transport_layer.make_membership_event(
destination, room_id, user_id, args destination, room_id, user_id, membership, args
) )
pdu_dict = ret["event"] pdu_dict = ret["event"]
logger.debug("Got response to make_join: %s", pdu_dict) logger.debug("Got response to make_%s: %s", membership, pdu_dict)
defer.returnValue( defer.returnValue(
(destination, self.event_from_pdu_json(pdu_dict)) (destination, self.event_from_pdu_json(pdu_dict))
@ -384,8 +412,8 @@ class FederationClient(FederationBase):
raise raise
except Exception as e: except Exception as e:
logger.warn( logger.warn(
"Failed to make_join via %s: %s", "Failed to make_%s via %s: %s",
destination, e.message membership, destination, e.message
) )
raise RuntimeError("Failed to send to any server.") raise RuntimeError("Failed to send to any server.")
@ -491,6 +519,33 @@ class FederationClient(FederationBase):
defer.returnValue(pdu) defer.returnValue(pdu)
@defer.inlineCallbacks
def send_leave(self, destinations, pdu):
for destination in destinations:
if destination == self.server_name:
continue
try:
time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_leave(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
logger.debug("Got content: %s", content)
defer.returnValue(None)
except CodeMessageException:
raise
except Exception as e:
logger.exception(
"Failed to send_leave via %s: %s",
destination, e.message
)
raise RuntimeError("Failed to send to any server.")
@defer.inlineCallbacks @defer.inlineCallbacks
def query_auth(self, destination, room_id, event_id, local_auth): def query_auth(self, destination, room_id, event_id, local_auth):
""" """

View file

@ -267,6 +267,20 @@ class FederationServer(FederationBase):
], ],
})) }))
@defer.inlineCallbacks
def on_make_leave_request(self, room_id, user_id):
pdu = yield self.handler.on_make_leave_request(room_id, user_id)
time_now = self._clock.time_msec()
defer.returnValue({"event": pdu.get_pdu_json(time_now)})
@defer.inlineCallbacks
def on_send_leave_request(self, origin, content):
logger.debug("on_send_leave_request: content: %s", content)
pdu = self.event_from_pdu_json(content)
logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
yield self.handler.on_send_leave_request(origin, pdu)
defer.returnValue((200, {}))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_event_auth(self, origin, room_id, event_id): def on_event_auth(self, origin, room_id, event_id):
time_now = self._clock.time_msec() time_now = self._clock.time_msec()

View file

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import Membership
from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
@ -160,8 +161,14 @@ class TransportLayerClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def make_join(self, destination, room_id, user_id, args={}): def make_membership_event(self, destination, room_id, user_id, membership, args={}):
path = PREFIX + "/make_join/%s/%s" % (room_id, user_id) valid_memberships = {Membership.JOIN, Membership.LEAVE}
if membership not in valid_memberships:
raise RuntimeError(
"make_membership_event called with membership='%s', must be one of %s" %
(membership, ",".join(valid_memberships))
)
path = PREFIX + "/make_%s/%s/%s" % (membership, room_id, user_id)
content = yield self.client.get_json( content = yield self.client.get_json(
destination=destination, destination=destination,
@ -185,6 +192,19 @@ class TransportLayerClient(object):
defer.returnValue(response) defer.returnValue(response)
@defer.inlineCallbacks
@log_function
def send_leave(self, destination, room_id, event_id, content):
path = PREFIX + "/send_leave/%s/%s" % (room_id, event_id)
response = yield self.client.put_json(
destination=destination,
path=path,
data=content,
)
defer.returnValue(response)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def send_invite(self, destination, room_id, event_id, content): def send_invite(self, destination, room_id, event_id, content):

View file

@ -296,6 +296,24 @@ class FederationMakeJoinServlet(BaseFederationServlet):
defer.returnValue((200, content)) defer.returnValue((200, content))
class FederationMakeLeaveServlet(BaseFederationServlet):
PATH = "/make_leave/([^/]*)/([^/]*)"
@defer.inlineCallbacks
def on_GET(self, origin, content, query, context, user_id):
content = yield self.handler.on_make_leave_request(context, user_id)
defer.returnValue((200, content))
class FederationSendLeaveServlet(BaseFederationServlet):
PATH = "/send_leave/([^/]*)/([^/]*)"
@defer.inlineCallbacks
def on_PUT(self, origin, content, query, room_id, txid):
content = yield self.handler.on_send_leave_request(origin, content)
defer.returnValue((200, content))
class FederationEventAuthServlet(BaseFederationServlet): class FederationEventAuthServlet(BaseFederationServlet):
PATH = "/event_auth/([^/]*)/([^/]*)" PATH = "/event_auth/([^/]*)/([^/]*)"
@ -385,8 +403,10 @@ SERVLET_CLASSES = (
FederationBackfillServlet, FederationBackfillServlet,
FederationQueryServlet, FederationQueryServlet,
FederationMakeJoinServlet, FederationMakeJoinServlet,
FederationMakeLeaveServlet,
FederationEventServlet, FederationEventServlet,
FederationSendJoinServlet, FederationSendJoinServlet,
FederationSendLeaveServlet,
FederationInviteServlet, FederationInviteServlet,
FederationQueryAuthServlet, FederationQueryAuthServlet,
FederationGetMissingEventsServlet, FederationGetMissingEventsServlet,

View file

@ -32,6 +32,7 @@ from .sync import SyncHandler
from .auth import AuthHandler from .auth import AuthHandler
from .identity import IdentityHandler from .identity import IdentityHandler
from .receipts import ReceiptsHandler from .receipts import ReceiptsHandler
from .search import SearchHandler
class Handlers(object): class Handlers(object):
@ -68,3 +69,4 @@ class Handlers(object):
self.sync_handler = SyncHandler(hs) self.sync_handler = SyncHandler(hs)
self.auth_handler = AuthHandler(hs) self.auth_handler = AuthHandler(hs)
self.identity_handler = IdentityHandler(hs) self.identity_handler = IdentityHandler(hs)
self.search_handler = SearchHandler(hs)

View file

@ -565,7 +565,7 @@ class FederationHandler(BaseHandler):
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def do_invite_join(self, target_hosts, room_id, joinee, content, snapshot): def do_invite_join(self, target_hosts, room_id, joinee, content):
""" Attempts to join the `joinee` to the room `room_id` via the """ Attempts to join the `joinee` to the room `room_id` via the
server `target_host`. server `target_host`.
@ -581,50 +581,19 @@ class FederationHandler(BaseHandler):
yield self.store.clean_room_for_join(room_id) yield self.store.clean_room_for_join(room_id)
origin, pdu = yield self.replication_layer.make_join( origin, event = yield self._make_and_verify_event(
target_hosts, target_hosts,
room_id, room_id,
joinee, joinee,
"join",
content content
) )
logger.debug("Got response to make_join: %s", pdu)
event = pdu
# We should assert some things.
# FIXME: Do this in a nicer way
assert(event.type == EventTypes.Member)
assert(event.user_id == joinee)
assert(event.state_key == joinee)
assert(event.room_id == room_id)
event.internal_metadata.outlier = False
self.room_queues[room_id] = [] self.room_queues[room_id] = []
builder = self.event_builder_factory.new(
unfreeze(event.get_pdu_json())
)
handled_events = set() handled_events = set()
try: try:
builder.event_id = self.event_builder_factory.create_event_id() new_event = self._sign_event(event)
builder.origin = self.hs.hostname
builder.content = content
if not hasattr(event, "signatures"):
builder.signatures = {}
add_hashes_and_signatures(
builder,
self.hs.hostname,
self.hs.config.signing_key[0],
)
new_event = builder.build()
# Try the host we successfully got a response to /make_join/ # Try the host we successfully got a response to /make_join/
# request first. # request first.
try: try:
@ -632,11 +601,7 @@ class FederationHandler(BaseHandler):
target_hosts.insert(0, origin) target_hosts.insert(0, origin)
except ValueError: except ValueError:
pass pass
ret = yield self.replication_layer.send_join(target_hosts, new_event)
ret = yield self.replication_layer.send_join(
target_hosts,
new_event
)
origin = ret["origin"] origin = ret["origin"]
state = ret["state"] state = ret["state"]
@ -700,7 +665,7 @@ class FederationHandler(BaseHandler):
@log_function @log_function
def on_make_join_request(self, room_id, user_id, query): def on_make_join_request(self, room_id, user_id, query):
""" We've received a /make_join/ request, so we create a partial """ We've received a /make_join/ request, so we create a partial
join event for the room and return that. We don *not* persist or join event for the room and return that. We do *not* persist or
process it until the other server has signed it and sent it back. process it until the other server has signed it and sent it back.
""" """
event_content = {"membership": Membership.JOIN} event_content = {"membership": Membership.JOIN}
@ -859,6 +824,168 @@ class FederationHandler(BaseHandler):
defer.returnValue(event) defer.returnValue(event)
@defer.inlineCallbacks
def do_remotely_reject_invite(self, target_hosts, room_id, user_id):
origin, event = yield self._make_and_verify_event(
target_hosts,
room_id,
user_id,
"leave",
{}
)
signed_event = self._sign_event(event)
# Try the host we successfully got a response to /make_join/
# request first.
try:
target_hosts.remove(origin)
target_hosts.insert(0, origin)
except ValueError:
pass
yield self.replication_layer.send_leave(
target_hosts,
signed_event
)
defer.returnValue(None)
@defer.inlineCallbacks
def _make_and_verify_event(self, target_hosts, room_id, user_id, membership, content):
origin, pdu = yield self.replication_layer.make_membership_event(
target_hosts,
room_id,
user_id,
membership,
content
)
logger.debug("Got response to make_%s: %s", membership, pdu)
event = pdu
# We should assert some things.
# FIXME: Do this in a nicer way
assert(event.type == EventTypes.Member)
assert(event.user_id == user_id)
assert(event.state_key == user_id)
assert(event.room_id == room_id)
defer.returnValue((origin, event))
def _sign_event(self, event):
event.internal_metadata.outlier = False
builder = self.event_builder_factory.new(
unfreeze(event.get_pdu_json())
)
builder.event_id = self.event_builder_factory.create_event_id()
builder.origin = self.hs.hostname
if not hasattr(event, "signatures"):
builder.signatures = {}
add_hashes_and_signatures(
builder,
self.hs.hostname,
self.hs.config.signing_key[0],
)
return builder.build()
@defer.inlineCallbacks
@log_function
def on_make_leave_request(self, room_id, user_id):
""" We've received a /make_leave/ request, so we create a partial
join event for the room and return that. We do *not* persist or
process it until the other server has signed it and sent it back.
"""
builder = self.event_builder_factory.new({
"type": EventTypes.Member,
"content": {"membership": Membership.LEAVE},
"room_id": room_id,
"sender": user_id,
"state_key": user_id,
})
event, context = yield self._create_new_client_event(
builder=builder,
)
self.auth.check(event, auth_events=context.current_state)
defer.returnValue(event)
@defer.inlineCallbacks
@log_function
def on_send_leave_request(self, origin, pdu):
""" We have received a leave event for a room. Fully process it."""
event = pdu
logger.debug(
"on_send_leave_request: Got event: %s, signatures: %s",
event.event_id,
event.signatures,
)
event.internal_metadata.outlier = False
context, event_stream_id, max_stream_id = yield self._handle_new_event(
origin, event
)
logger.debug(
"on_send_leave_request: After _handle_new_event: %s, sigs: %s",
event.event_id,
event.signatures,
)
extra_users = []
if event.type == EventTypes.Member:
target_user_id = event.state_key
target_user = UserID.from_string(target_user_id)
extra_users.append(target_user)
with PreserveLoggingContext():
d = self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, extra_users=extra_users
)
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
new_pdu = event
destinations = set()
for k, s in context.current_state.items():
try:
if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.LEAVE:
destinations.add(
UserID.from_string(s.state_key).domain
)
except:
logger.warn(
"Failed to get destination from event %s", s.event_id
)
destinations.discard(origin)
logger.debug(
"on_send_leave_request: Sending event: %s, signatures: %s",
event.event_id,
event.signatures,
)
self.replication_layer.send_pdu(new_pdu, destinations)
defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True): def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True):
yield run_on_reactor() yield run_on_reactor()

View file

@ -389,7 +389,22 @@ class RoomMemberHandler(BaseHandler):
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
yield self._do_join(event, context, do_auth=do_auth) yield self._do_join(event, context, do_auth=do_auth)
else: else:
# This is not a JOIN, so we can handle it normally. if event.membership == Membership.LEAVE:
is_host_in_room = yield self.is_host_in_room(room_id, context)
if not is_host_in_room:
# Rejecting an invite, rather than leaving a joined room
handler = self.hs.get_handlers().federation_handler
inviter = yield self.get_inviter(event)
if not inviter:
# return the same error as join_room_alias does
raise SynapseError(404, "No known servers")
yield handler.do_remotely_reject_invite(
[inviter.domain],
room_id,
event.user_id
)
defer.returnValue({"room_id": room_id})
return
# FIXME: This isn't idempotency. # FIXME: This isn't idempotency.
if prev_state and prev_state.membership == event.membership: if prev_state and prev_state.membership == event.membership:
@ -413,7 +428,7 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue({"room_id": room_id}) defer.returnValue({"room_id": room_id})
@defer.inlineCallbacks @defer.inlineCallbacks
def join_room_alias(self, joinee, room_alias, do_auth=True, content={}): def join_room_alias(self, joinee, room_alias, content={}):
directory_handler = self.hs.get_handlers().directory_handler directory_handler = self.hs.get_handlers().directory_handler
mapping = yield directory_handler.get_association(room_alias) mapping = yield directory_handler.get_association(room_alias)
@ -447,8 +462,6 @@ class RoomMemberHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_join(self, event, context, room_hosts=None, do_auth=True): def _do_join(self, event, context, room_hosts=None, do_auth=True):
joinee = UserID.from_string(event.state_key)
# room_id = RoomID.from_string(event.room_id, self.hs)
room_id = event.room_id room_id = event.room_id
# XXX: We don't do an auth check if we are doing an invite # XXX: We don't do an auth check if we are doing an invite
@ -456,48 +469,18 @@ class RoomMemberHandler(BaseHandler):
# that we are allowed to join when we decide whether or not we # that we are allowed to join when we decide whether or not we
# need to do the invite/join dance. # need to do the invite/join dance.
is_host_in_room = yield self.auth.check_host_in_room( is_host_in_room = yield self.is_host_in_room(room_id, context)
event.room_id,
self.hs.hostname
)
if not is_host_in_room:
# is *anyone* in the room?
room_member_keys = [
v for (k, v) in context.current_state.keys() if (
k == "m.room.member"
)
]
if len(room_member_keys) == 0:
# has the room been created so we can join it?
create_event = context.current_state.get(("m.room.create", ""))
if create_event:
is_host_in_room = True
if is_host_in_room: if is_host_in_room:
should_do_dance = False should_do_dance = False
elif room_hosts: # TODO: Shouldn't this be remote_room_host? elif room_hosts: # TODO: Shouldn't this be remote_room_host?
should_do_dance = True should_do_dance = True
else: else:
# TODO(markjh): get prev_state from snapshot inviter = yield self.get_inviter(event)
prev_state = yield self.store.get_room_member( if not inviter:
joinee.to_string(), room_id
)
if prev_state and prev_state.membership == Membership.INVITE:
inviter = UserID.from_string(prev_state.user_id)
should_do_dance = not self.hs.is_mine(inviter)
room_hosts = [inviter.domain]
elif "third_party_invite" in event.content:
if "sender" in event.content["third_party_invite"]:
inviter = UserID.from_string(
event.content["third_party_invite"]["sender"]
)
should_do_dance = not self.hs.is_mine(inviter)
room_hosts = [inviter.domain]
else:
# return the same error as join_room_alias does # return the same error as join_room_alias does
raise SynapseError(404, "No known servers") raise SynapseError(404, "No known servers")
should_do_dance = not self.hs.is_mine(inviter)
room_hosts = [inviter.domain]
if should_do_dance: if should_do_dance:
handler = self.hs.get_handlers().federation_handler handler = self.hs.get_handlers().federation_handler
@ -505,8 +488,7 @@ class RoomMemberHandler(BaseHandler):
room_hosts, room_hosts,
room_id, room_id,
event.user_id, event.user_id,
event.content, # FIXME To get a non-frozen dict event.content # FIXME To get a non-frozen dict
context
) )
else: else:
logger.debug("Doing normal join") logger.debug("Doing normal join")
@ -523,6 +505,44 @@ class RoomMemberHandler(BaseHandler):
"user_joined_room", user=user, room_id=room_id "user_joined_room", user=user, room_id=room_id
) )
@defer.inlineCallbacks
def get_inviter(self, event):
# TODO(markjh): get prev_state from snapshot
prev_state = yield self.store.get_room_member(
event.user_id, event.room_id
)
if prev_state and prev_state.membership == Membership.INVITE:
defer.returnValue(UserID.from_string(prev_state.user_id))
return
elif "third_party_invite" in event.content:
if "sender" in event.content["third_party_invite"]:
inviter = UserID.from_string(
event.content["third_party_invite"]["sender"]
)
defer.returnValue(inviter)
defer.returnValue(None)
@defer.inlineCallbacks
def is_host_in_room(self, room_id, context):
is_host_in_room = yield self.auth.check_host_in_room(
room_id,
self.hs.hostname
)
if not is_host_in_room:
# is *anyone* in the room?
room_member_keys = [
v for (k, v) in context.current_state.keys() if (
k == "m.room.member"
)
]
if len(room_member_keys) == 0:
# has the room been created so we can join it?
create_event = context.current_state.get(("m.room.create", ""))
if create_event:
is_host_in_room = True
defer.returnValue(is_host_in_room)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_joined_rooms_for_user(self, user): def get_joined_rooms_for_user(self, user):
"""Returns a list of roomids that the user has any of the given """Returns a list of roomids that the user has any of the given

View file

@ -0,0 +1,93 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.constants import Membership
from synapse.api.errors import SynapseError
from synapse.events.utils import serialize_event
import logging
logger = logging.getLogger(__name__)
class SearchHandler(BaseHandler):
def __init__(self, hs):
super(SearchHandler, self).__init__(hs)
@defer.inlineCallbacks
def search(self, user, content):
"""Performs a full text search for a user.
Args:
user (UserID)
content (dict): Search parameters
Returns:
dict to be returned to the client with results of search
"""
try:
search_term = content["search_categories"]["room_events"]["search_term"]
keys = content["search_categories"]["room_events"].get("keys", [
"content.body", "content.name", "content.topic",
])
except KeyError:
raise SynapseError(400, "Invalid search query")
# TODO: Search through left rooms too
rooms = yield self.store.get_rooms_for_user_where_membership_is(
user.to_string(),
membership_list=[Membership.JOIN],
# membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban],
)
room_ids = set(r.room_id for r in rooms)
# TODO: Apply room filter to rooms list
rank_map, event_map = yield self.store.search_msgs(room_ids, search_term, keys)
allowed_events = yield self._filter_events_for_client(
user.to_string(), event_map.values()
)
# TODO: Filter allowed_events
# TODO: Add a limit
time_now = self.clock.time_msec()
results = {
e.event_id: {
"rank": rank_map[e.event_id],
"result": serialize_event(e, time_now)
}
for e in allowed_events
}
logger.info("Found %d results", len(results))
defer.returnValue({
"search_categories": {
"room_events": {
"results": results,
"count": len(results)
}
}
})

View file

@ -61,18 +61,37 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
return bool(self.timeline or self.state or self.ephemeral) return bool(self.timeline or self.state or self.ephemeral)
class ArchivedSyncResult(collections.namedtuple("JoinedSyncResult", [
"room_id",
"timeline",
"state",
])):
__slots__ = []
def __nonzero__(self):
"""Make the result appear empty if there are no updates. This is used
to tell if room needs to be part of the sync result.
"""
return bool(self.timeline or self.state)
class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [ class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [
"room_id", "room_id",
"invite", "invite",
])): ])):
__slots__ = [] __slots__ = []
def __nonzero__(self):
"""Invited rooms should always be reported to the client"""
return True
class SyncResult(collections.namedtuple("SyncResult", [ class SyncResult(collections.namedtuple("SyncResult", [
"next_batch", # Token for the next sync "next_batch", # Token for the next sync
"presence", # List of presence events for the user. "presence", # List of presence events for the user.
"joined", # JoinedSyncResult for each joined room. "joined", # JoinedSyncResult for each joined room.
"invited", # InvitedSyncResult for each invited room. "invited", # InvitedSyncResult for each invited room.
"archived", # ArchivedSyncResult for each archived room.
])): ])):
__slots__ = [] __slots__ = []
@ -160,11 +179,17 @@ class SyncHandler(BaseHandler):
) )
room_list = yield self.store.get_rooms_for_user_where_membership_is( room_list = yield self.store.get_rooms_for_user_where_membership_is(
user_id=sync_config.user.to_string(), user_id=sync_config.user.to_string(),
membership_list=[Membership.INVITE, Membership.JOIN] membership_list=(
Membership.INVITE,
Membership.JOIN,
Membership.LEAVE,
Membership.BAN
)
) )
joined = [] joined = []
invited = [] invited = []
archived = []
for event in room_list: for event in room_list:
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
room_sync = yield self.initial_sync_for_joined_room( room_sync = yield self.initial_sync_for_joined_room(
@ -177,11 +202,23 @@ class SyncHandler(BaseHandler):
room_id=event.room_id, room_id=event.room_id,
invite=invite, invite=invite,
)) ))
elif event.membership in (Membership.LEAVE, Membership.BAN):
leave_token = now_token.copy_and_replace(
"room_key", "s%d" % (event.stream_ordering,)
)
room_sync = yield self.initial_sync_for_archived_room(
sync_config=sync_config,
room_id=event.room_id,
leave_event_id=event.event_id,
leave_token=leave_token,
)
archived.append(room_sync)
defer.returnValue(SyncResult( defer.returnValue(SyncResult(
presence=presence, presence=presence,
joined=joined, joined=joined,
invited=invited, invited=invited,
archived=archived,
next_batch=now_token, next_batch=now_token,
)) ))
@ -240,6 +277,28 @@ class SyncHandler(BaseHandler):
defer.returnValue((now_token, typing_by_room)) defer.returnValue((now_token, typing_by_room))
@defer.inlineCallbacks
def initial_sync_for_archived_room(self, room_id, sync_config,
leave_event_id, leave_token):
"""Sync a room for a client which is starting without any state
Returns:
A Deferred JoinedSyncResult.
"""
batch = yield self.load_filtered_recents(
room_id, sync_config, leave_token,
)
leave_state = yield self.store.get_state_for_events(
[leave_event_id], None
)
defer.returnValue(ArchivedSyncResult(
room_id=room_id,
timeline=batch,
state=leave_state[leave_event_id].values(),
))
@defer.inlineCallbacks @defer.inlineCallbacks
def incremental_sync_with_gap(self, sync_config, since_token): def incremental_sync_with_gap(self, sync_config, since_token):
""" Get the incremental delta needed to bring the client up to """ Get the incremental delta needed to bring the client up to
@ -284,18 +343,22 @@ class SyncHandler(BaseHandler):
) )
joined = [] joined = []
archived = []
if len(room_events) <= timeline_limit: if len(room_events) <= timeline_limit:
# There is no gap in any of the rooms. Therefore we can just # There is no gap in any of the rooms. Therefore we can just
# partition the new events by room and return them. # partition the new events by room and return them.
invite_events = [] invite_events = []
leave_events = []
events_by_room_id = {} events_by_room_id = {}
for event in room_events: for event in room_events:
events_by_room_id.setdefault(event.room_id, []).append(event) events_by_room_id.setdefault(event.room_id, []).append(event)
if event.room_id not in joined_room_ids: if event.room_id not in joined_room_ids:
if (event.type == EventTypes.Member if (event.type == EventTypes.Member
and event.membership == Membership.INVITE
and event.state_key == sync_config.user.to_string()): and event.state_key == sync_config.user.to_string()):
if event.membership == Membership.INVITE:
invite_events.append(event) invite_events.append(event)
elif event.membership in (Membership.LEAVE, Membership.BAN):
leave_events.append(event)
for room_id in joined_room_ids: for room_id in joined_room_ids:
recents = events_by_room_id.get(room_id, []) recents = events_by_room_id.get(room_id, [])
@ -323,11 +386,16 @@ class SyncHandler(BaseHandler):
) )
if room_sync: if room_sync:
joined.append(room_sync) joined.append(room_sync)
else: else:
invite_events = yield self.store.get_invites_for_user( invite_events = yield self.store.get_invites_for_user(
sync_config.user.to_string() sync_config.user.to_string()
) )
leave_events = yield self.store.get_leave_and_ban_events_for_user(
sync_config.user.to_string()
)
for room_id in joined_room_ids: for room_id in joined_room_ids:
room_sync = yield self.incremental_sync_with_gap_for_room( room_sync = yield self.incremental_sync_with_gap_for_room(
room_id, sync_config, since_token, now_token, room_id, sync_config, since_token, now_token,
@ -336,6 +404,12 @@ class SyncHandler(BaseHandler):
if room_sync: if room_sync:
joined.append(room_sync) joined.append(room_sync)
for leave_event in leave_events:
room_sync = yield self.incremental_sync_for_archived_room(
sync_config, leave_event, since_token
)
archived.append(room_sync)
invited = [ invited = [
InvitedSyncResult(room_id=event.room_id, invite=event) InvitedSyncResult(room_id=event.room_id, invite=event)
for event in invite_events for event in invite_events
@ -345,6 +419,7 @@ class SyncHandler(BaseHandler):
presence=presence, presence=presence,
joined=joined, joined=joined,
invited=invited, invited=invited,
archived=archived,
next_batch=now_token, next_batch=now_token,
)) ))
@ -443,6 +518,55 @@ class SyncHandler(BaseHandler):
defer.returnValue(room_sync) defer.returnValue(room_sync)
@defer.inlineCallbacks
def incremental_sync_for_archived_room(self, sync_config, leave_event,
since_token):
""" Get the incremental delta needed to bring the client up to date for
the archived room.
Returns:
A Deferred ArchivedSyncResult
"""
stream_token = yield self.store.get_stream_token_for_event(
leave_event.event_id
)
leave_token = since_token.copy_and_replace("room_key", stream_token)
batch = yield self.load_filtered_recents(
leave_event.room_id, sync_config, leave_token, since_token,
)
logging.debug("Recents %r", batch)
# TODO(mjark): This seems racy since this isn't being passed a
# token to indicate what point in the stream this is
leave_state = yield self.store.get_state_for_events(
[leave_event.event_id], None
)
state_events_at_leave = leave_state[leave_event.event_id].values()
state_at_previous_sync = yield self.get_state_at_previous_sync(
leave_event.room_id, since_token=since_token
)
state_events_delta = yield self.compute_state_delta(
since_token=since_token,
previous_state=state_at_previous_sync,
current_state=state_events_at_leave,
)
room_sync = ArchivedSyncResult(
room_id=leave_event.room_id,
timeline=batch,
state=state_events_delta,
)
logging.debug("Room sync: %r", room_sync)
defer.returnValue(room_sync)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_at_previous_sync(self, room_id, since_token): def get_state_at_previous_sync(self, room_id, since_token):
""" Get the room state at the previous sync the client made. """ Get the room state at the previous sync the client made.

View file

@ -101,6 +101,8 @@ class LoginRestServlet(ClientV1RestServlet):
user_id = yield self.hs.get_datastore().get_user_id_by_threepid( user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
login_submission['medium'], login_submission['address'] login_submission['medium'], login_submission['address']
) )
if not user_id:
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
else: else:
user_id = login_submission['user'] user_id = login_submission['user']
@ -192,36 +194,6 @@ class LoginRestServlet(ClientV1RestServlet):
return (user, attributes) return (user, attributes)
class LoginFallbackRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/fallback$")
def on_GET(self, request):
# TODO(kegan): This should be returning some HTML which is capable of
# hitting LoginRestServlet
return (200, {})
class PasswordResetRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/reset")
@defer.inlineCallbacks
def on_POST(self, request):
reset_info = _parse_json(request)
try:
email = reset_info["email"]
user_id = reset_info["user_id"]
handler = self.handlers.login_handler
yield handler.reset_password(user_id, email)
# purposefully give no feedback to avoid people hammering different
# combinations.
defer.returnValue((200, {}))
except KeyError:
raise SynapseError(
400,
"Missing keys. Requires 'email' and 'user_id'."
)
class SAML2RestServlet(ClientV1RestServlet): class SAML2RestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/saml2") PATTERN = client_path_pattern("/login/saml2")

View file

@ -555,6 +555,22 @@ class RoomTypingRestServlet(ClientV1RestServlet):
defer.returnValue((200, {})) defer.returnValue((200, {}))
class SearchRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern(
"/search$"
)
@defer.inlineCallbacks
def on_POST(self, request):
auth_user, _ = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
results = yield self.handlers.search_handler.search(auth_user, content)
defer.returnValue((200, results))
def _parse_json(request): def _parse_json(request):
try: try:
content = json.loads(request.content.read()) content = json.loads(request.content.read())
@ -611,3 +627,4 @@ def register_servlets(hs, http_server):
RoomInitialSyncRestServlet(hs).register(http_server) RoomInitialSyncRestServlet(hs).register(http_server)
RoomRedactEventRestServlet(hs).register(http_server) RoomRedactEventRestServlet(hs).register(http_server)
RoomTypingRestServlet(hs).register(http_server) RoomTypingRestServlet(hs).register(http_server)
SearchRestServlet(hs).register(http_server)

View file

@ -136,6 +136,10 @@ class SyncRestServlet(RestServlet):
sync_result.invited, filter, time_now, token_id sync_result.invited, filter, time_now, token_id
) )
archived = self.encode_archived(
sync_result.archived, filter, time_now, token_id
)
response_content = { response_content = {
"presence": self.encode_presence( "presence": self.encode_presence(
sync_result.presence, filter, time_now sync_result.presence, filter, time_now
@ -143,7 +147,7 @@ class SyncRestServlet(RestServlet):
"rooms": { "rooms": {
"joined": joined, "joined": joined,
"invited": invited, "invited": invited,
"archived": {}, "archived": archived,
}, },
"next_batch": sync_result.next_batch.to_string(), "next_batch": sync_result.next_batch.to_string(),
} }
@ -182,14 +186,20 @@ class SyncRestServlet(RestServlet):
return invited return invited
def encode_archived(self, rooms, filter, time_now, token_id):
joined = {}
for room in rooms:
joined[room.room_id] = self.encode_room(
room, filter, time_now, token_id, joined=False
)
return joined
@staticmethod @staticmethod
def encode_room(room, filter, time_now, token_id): def encode_room(room, filter, time_now, token_id, joined=True):
event_map = {} event_map = {}
state_events = filter.filter_room_state(room.state) state_events = filter.filter_room_state(room.state)
timeline_events = filter.filter_room_timeline(room.timeline.events)
ephemeral_events = filter.filter_room_ephemeral(room.ephemeral)
state_event_ids = [] state_event_ids = []
timeline_event_ids = []
for event in state_events: for event in state_events:
# TODO(mjark): Respect formatting requirements in the filter. # TODO(mjark): Respect formatting requirements in the filter.
event_map[event.event_id] = serialize_event( event_map[event.event_id] = serialize_event(
@ -198,6 +208,8 @@ class SyncRestServlet(RestServlet):
) )
state_event_ids.append(event.event_id) state_event_ids.append(event.event_id)
timeline_events = filter.filter_room_timeline(room.timeline.events)
timeline_event_ids = []
for event in timeline_events: for event in timeline_events:
# TODO(mjark): Respect formatting requirements in the filter. # TODO(mjark): Respect formatting requirements in the filter.
event_map[event.event_id] = serialize_event( event_map[event.event_id] = serialize_event(
@ -205,6 +217,7 @@ class SyncRestServlet(RestServlet):
event_format=format_event_for_client_v2_without_event_id, event_format=format_event_for_client_v2_without_event_id,
) )
timeline_event_ids.append(event.event_id) timeline_event_ids.append(event.event_id)
result = { result = {
"event_map": event_map, "event_map": event_map,
"timeline": { "timeline": {
@ -213,8 +226,12 @@ class SyncRestServlet(RestServlet):
"limited": room.timeline.limited, "limited": room.timeline.limited,
}, },
"state": {"events": state_event_ids}, "state": {"events": state_event_ids},
"ephemeral": {"events": ephemeral_events},
} }
if joined:
ephemeral_events = filter.filter_room_ephemeral(room.ephemeral)
result["ephemeral"] = {"events": ephemeral_events}
return result return result

View file

@ -40,6 +40,7 @@ from .filtering import FilteringStore
from .end_to_end_keys import EndToEndKeyStore from .end_to_end_keys import EndToEndKeyStore
from .receipts import ReceiptsStore from .receipts import ReceiptsStore
from .search import SearchStore
import logging import logging
@ -69,6 +70,7 @@ class DataStore(RoomMemberStore, RoomStore,
EventsStore, EventsStore,
ReceiptsStore, ReceiptsStore,
EndToEndKeyStore, EndToEndKeyStore,
SearchStore,
): ):
def __init__(self, hs): def __init__(self, hs):

View file

@ -519,7 +519,7 @@ class SQLBaseStore(object):
allow_none=False, allow_none=False,
desc="_simple_select_one_onecol"): desc="_simple_select_one_onecol"):
"""Executes a SELECT query on the named table, which is expected to """Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it." return a single row, returning a single column from it.
Args: Args:
table : string giving the table name table : string giving the table name

View file

@ -307,6 +307,8 @@ class EventsStore(SQLBaseStore):
self._store_room_name_txn(txn, event) self._store_room_name_txn(txn, event)
elif event.type == EventTypes.Topic: elif event.type == EventTypes.Topic:
self._store_room_topic_txn(txn, event) self._store_room_topic_txn(txn, event)
elif event.type == EventTypes.Message:
self._store_room_message_txn(txn, event)
elif event.type == EventTypes.Redaction: elif event.type == EventTypes.Redaction:
self._store_redaction(txn, event) self._store_redaction(txn, event)

View file

@ -19,6 +19,7 @@ from synapse.api.errors import StoreError
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cachedInlineCallbacks
from .engines import PostgresEngine, Sqlite3Engine
import collections import collections
import logging import logging
@ -175,6 +176,10 @@ class RoomStore(SQLBaseStore):
}, },
) )
self._store_event_search_txn(
txn, event, "content.topic", event.content["topic"]
)
def _store_room_name_txn(self, txn, event): def _store_room_name_txn(self, txn, event):
if hasattr(event, "content") and "name" in event.content: if hasattr(event, "content") and "name" in event.content:
self._simple_insert_txn( self._simple_insert_txn(
@ -187,6 +192,33 @@ class RoomStore(SQLBaseStore):
} }
) )
self._store_event_search_txn(
txn, event, "content.name", event.content["name"]
)
def _store_room_message_txn(self, txn, event):
if hasattr(event, "content") and "body" in event.content:
self._store_event_search_txn(
txn, event, "content.body", event.content["body"]
)
def _store_event_search_txn(self, txn, event, key, value):
if isinstance(self.database_engine, PostgresEngine):
sql = (
"INSERT INTO event_search (event_id, room_id, key, vector)"
" VALUES (?,?,?,to_tsvector('english', ?))"
)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
"INSERT INTO event_search (event_id, room_id, key, value)"
" VALUES (?,?,?,?)"
)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
txn.execute(sql, (event.event_id, event.room_id, key, value,))
@cachedInlineCallbacks() @cachedInlineCallbacks()
def get_room_name_and_aliases(self, room_id): def get_room_name_and_aliases(self, room_id):
def f(txn): def f(txn):

View file

@ -124,6 +124,19 @@ class RoomMemberStore(SQLBaseStore):
invites.event_id for invite in invites invites.event_id for invite in invites
])) ]))
def get_leave_and_ban_events_for_user(self, user_id):
""" Get all the leave events for a user
Args:
user_id (str): The user ID.
Returns:
A deferred list of event objects.
"""
return self.get_rooms_for_user_where_membership_is(
user_id, (Membership.LEAVE, Membership.BAN)
).addCallback(lambda leaves: self._get_events([
leave.event_id for leave in leaves
]))
def get_rooms_for_user_where_membership_is(self, user_id, membership_list): def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
""" Get all the rooms for this user where the membership for this user """ Get all the rooms for this user where the membership for this user
matches one in the membership list. matches one in the membership list.

View file

@ -0,0 +1,123 @@
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from synapse.storage.prepare_database import get_statements
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
import ujson
logger = logging.getLogger(__name__)
POSTGRES_SQL = """
CREATE TABLE event_search (
event_id TEXT,
room_id TEXT,
key TEXT,
vector tsvector
);
INSERT INTO event_search SELECT
event_id, room_id, 'content.body',
to_tsvector('english', json::json->'content'->>'body')
FROM events NATURAL JOIN event_json WHERE type = 'm.room.message';
INSERT INTO event_search SELECT
event_id, room_id, 'content.name',
to_tsvector('english', json::json->'content'->>'name')
FROM events NATURAL JOIN event_json WHERE type = 'm.room.name';
INSERT INTO event_search SELECT
event_id, room_id, 'content.topic',
to_tsvector('english', json::json->'content'->>'topic')
FROM events NATURAL JOIN event_json WHERE type = 'm.room.topic';
CREATE INDEX event_search_fts_idx ON event_search USING gin(vector);
CREATE INDEX event_search_ev_idx ON event_search(event_id);
CREATE INDEX event_search_ev_ridx ON event_search(room_id);
"""
SQLITE_TABLE = (
"CREATE VIRTUAL TABLE event_search USING fts3 ( event_id, room_id, key, value)"
)
def run_upgrade(cur, database_engine, *args, **kwargs):
if isinstance(database_engine, PostgresEngine):
run_postgres_upgrade(cur)
return
if isinstance(database_engine, Sqlite3Engine):
run_sqlite_upgrade(cur)
return
def run_postgres_upgrade(cur):
for statement in get_statements(POSTGRES_SQL.splitlines()):
cur.execute(statement)
def run_sqlite_upgrade(cur):
cur.execute(SQLITE_TABLE)
rowid = -1
while True:
cur.execute(
"SELECT rowid, json FROM event_json"
" WHERE rowid > ?"
" ORDER BY rowid ASC LIMIT 100",
(rowid,)
)
res = cur.fetchall()
if not res:
break
events = [
ujson.loads(js)
for _, js in res
]
rowid = max(rid for rid, _ in res)
rows = []
for ev in events:
if ev["type"] == "m.room.message":
rows.append((
ev["event_id"], ev["room_id"], "content.body",
ev["content"]["body"]
))
if ev["type"] == "m.room.name":
rows.append((
ev["event_id"], ev["room_id"], "content.name",
ev["content"]["name"]
))
if ev["type"] == "m.room.topic":
rows.append((
ev["event_id"], ev["room_id"], "content.topic",
ev["content"]["topic"]
))
if rows:
logger.info(rows)
cur.executemany(
"INSERT INTO event_search (event_id, room_id, key, value)"
" VALUES (?,?,?,?)",
rows
)

93
synapse/storage/search.py Normal file
View file

@ -0,0 +1,93 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from _base import SQLBaseStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
class SearchStore(SQLBaseStore):
@defer.inlineCallbacks
def search_msgs(self, room_ids, search_term, keys):
"""Performs a full text search over events with given keys.
Args:
room_ids (list): List of room ids to search in
search_term (str): Search term to search for
keys (list): List of keys to search in, currently supports
"content.body", "content.name", "content.topic"
Returns:
2-tuple of (dict event_id -> rank, dict event_id -> event)
"""
clauses = []
args = []
clauses.append(
"room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)
)
args.extend(room_ids)
local_clauses = []
for key in keys:
local_clauses.append("key = ?")
args.append(key)
clauses.append(
"(%s)" % (" OR ".join(local_clauses),)
)
if isinstance(self.database_engine, PostgresEngine):
sql = (
"SELECT ts_rank_cd(vector, query) AS rank, event_id"
" FROM plainto_tsquery('english', ?) as query, event_search"
" WHERE vector @@ query"
)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
"SELECT 0 as rank, event_id FROM event_search"
" WHERE value MATCH ?"
)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
for clause in clauses:
sql += " AND " + clause
# We add an arbitrary limit here to ensure we don't try to pull the
# entire table from the database.
sql += " ORDER BY rank DESC LIMIT 500"
results = yield self._execute(
"search_msgs", self.cursor_to_dict, sql, *([search_term] + args)
)
events = yield self._get_events([r["event_id"] for r in results])
event_map = {
ev.event_id: ev
for ev in events
}
defer.returnValue((
{
r["event_id"]: r["rank"]
for r in results
if r["event_id"] in event_map
},
event_map
))

View file

@ -214,7 +214,6 @@ class StateStore(SQLBaseStore):
that are in the `types` list. that are in the `types` list.
Args: Args:
room_id (str)
event_ids (list) event_ids (list)
types (list): List of (type, state_key) tuples which are used to types (list): List of (type, state_key) tuples which are used to
filter the state fetched. `state_key` may be None, which matches filter the state fetched. `state_key` may be None, which matches

View file

@ -1,71 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" This module allows you to send out emails.
"""
import email.utils
import smtplib
import twisted.python.log
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
import logging
logger = logging.getLogger(__name__)
class EmailException(Exception):
pass
def send_email(smtp_server, from_addr, to_addr, subject, body):
"""Sends an email.
Args:
smtp_server(str): The SMTP server to use.
from_addr(str): The address to send from.
to_addr(str): The address to send to.
subject(str): The subject of the email.
body(str): The plain text body of the email.
Raises:
EmailException if there was a problem sending the mail.
"""
if not smtp_server or not from_addr or not to_addr:
raise EmailException("Need SMTP server, from and to addresses. Check"
" the config to set these.")
msg = MIMEMultipart('alternative')
msg['Subject'] = subject
msg['From'] = from_addr
msg['To'] = to_addr
plain_part = MIMEText(body)
msg.attach(plain_part)
raw_from = email.utils.parseaddr(from_addr)[1]
raw_to = email.utils.parseaddr(to_addr)[1]
if not raw_from or not raw_to:
raise EmailException("Couldn't parse from/to address.")
logger.info("Sending email to %s on server %s with subject %s",
to_addr, smtp_server, subject)
try:
smtp = smtplib.SMTP(smtp_server)
smtp.sendmail(raw_from, raw_to, msg.as_string())
smtp.quit()
except Exception as origException:
twisted.python.log.err()
ese = EmailException()
ese.cause = origException
raise ese

View file

@ -23,8 +23,8 @@ JOIN_KEYS = {
"token", "token",
"public_key", "public_key",
"key_validity_url", "key_validity_url",
"signatures",
"sender", "sender",
"signed",
} }

15
tests/crypto/__init__.py Normal file
View file

@ -0,0 +1,15 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,114 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests import unittest
from synapse.events.builder import EventBuilder
from synapse.crypto.event_signing import add_hashes_and_signatures
from unpaddedbase64 import decode_base64
import nacl.signing
# Perform these tests using given secret key so we get entirely deterministic
# signatures output that we can test against.
SIGNING_KEY_SEED = decode_base64(
"YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1"
)
KEY_ALG = "ed25519"
KEY_VER = 1
KEY_NAME = "%s:%d" % (KEY_ALG, KEY_VER)
HOSTNAME = "domain"
class EventSigningTestCase(unittest.TestCase):
def setUp(self):
self.signing_key = nacl.signing.SigningKey(SIGNING_KEY_SEED)
self.signing_key.alg = KEY_ALG
self.signing_key.version = KEY_VER
def test_sign_minimal(self):
builder = EventBuilder(
{
'event_id': "$0:domain",
'origin': "domain",
'origin_server_ts': 1000000,
'signatures': {},
'type': "X",
'unsigned': {'age_ts': 1000000},
},
)
add_hashes_and_signatures(builder, HOSTNAME, self.signing_key)
event = builder.build()
self.assertTrue(hasattr(event, 'hashes'))
self.assertIn('sha256', event.hashes)
self.assertEquals(
event.hashes['sha256'],
"6tJjLpXtggfke8UxFhAKg82QVkJzvKOVOOSjUDK4ZSI",
)
self.assertTrue(hasattr(event, 'signatures'))
self.assertIn(HOSTNAME, event.signatures)
self.assertIn(KEY_NAME, event.signatures["domain"])
self.assertEquals(
event.signatures[HOSTNAME][KEY_NAME],
"2Wptgo4CwmLo/Y8B8qinxApKaCkBG2fjTWB7AbP5Uy+"
"aIbygsSdLOFzvdDjww8zUVKCmI02eP9xtyJxc/cLiBA",
)
def test_sign_message(self):
builder = EventBuilder(
{
'content': {
'body': "Here is the message content",
},
'event_id': "$0:domain",
'origin': "domain",
'origin_server_ts': 1000000,
'type': "m.room.message",
'room_id': "!r:domain",
'sender': "@u:domain",
'signatures': {},
'unsigned': {'age_ts': 1000000},
}
)
add_hashes_and_signatures(builder, HOSTNAME, self.signing_key)
event = builder.build()
self.assertTrue(hasattr(event, 'hashes'))
self.assertIn('sha256', event.hashes)
self.assertEquals(
event.hashes['sha256'],
"onLKD1bGljeBWQhWZ1kaP9SorVmRQNdN5aM2JYU2n/g",
)
self.assertTrue(hasattr(event, 'signatures'))
self.assertIn(HOSTNAME, event.signatures)
self.assertIn(KEY_NAME, event.signatures["domain"])
self.assertEquals(
event.signatures[HOSTNAME][KEY_NAME],
"Wm+VzmOUOz08Ds+0NTWb1d4CZrVsJSikkeRxh6aCcUw"
"u6pNC78FunoD7KNWzqFn241eYHYMGCA5McEiVPdhzBA"
)

View file

@ -277,10 +277,10 @@ class RoomPermissionsTestCase(RestTestCase):
expect_code=403) expect_code=403)
# set [invite/join/left] of self, set [invite/join/left] of other, # set [invite/join/left] of self, set [invite/join/left] of other,
# expect all 403s # expect all 404s because room doesn't exist on any server
for usr in [self.user_id, self.rmcreator_id]: for usr in [self.user_id, self.rmcreator_id]:
yield self.join(room=room, user=usr, expect_code=404) yield self.join(room=room, user=usr, expect_code=404)
yield self.leave(room=room, user=usr, expect_code=403) yield self.leave(room=room, user=usr, expect_code=404)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_membership_private_room_perms(self): def test_membership_private_room_perms(self):