Merge branch 'develop' into sh-cas-auth-via-homeserver

This commit is contained in:
Steven Hammerton 2015-11-05 20:59:45 +00:00
commit fece2f5c77
45 changed files with 1117 additions and 324 deletions

View file

@ -15,8 +15,9 @@ recursive-include scripts *
recursive-include scripts-dev * recursive-include scripts-dev *
recursive-include tests *.py recursive-include tests *.py
recursive-include static *.css recursive-include synapse/static *.css
recursive-include static *.html recursive-include synapse/static *.gif
recursive-include static *.js recursive-include synapse/static *.html
recursive-include synapse/static *.js
prune demo/etc prune demo/etc

View file

@ -24,7 +24,6 @@ from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import RoomID, UserID, EventID from synapse.types import RoomID, UserID, EventID
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util import third_party_invites
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
import logging import logging
@ -318,6 +317,11 @@ class Auth(object):
} }
) )
if Membership.INVITE == membership and "third_party_invite" in event.content:
if not self._verify_third_party_invite(event, auth_events):
raise AuthError(403, "You are not invited to this room.")
return True
if Membership.JOIN != membership: if Membership.JOIN != membership:
if (caller_invited if (caller_invited
and Membership.LEAVE == membership and Membership.LEAVE == membership
@ -361,8 +365,7 @@ class Auth(object):
pass pass
elif join_rule == JoinRules.INVITE: elif join_rule == JoinRules.INVITE:
if not caller_in_room and not caller_invited: if not caller_in_room and not caller_invited:
if not self._verify_third_party_invite(event, auth_events): raise AuthError(403, "You are not invited to this room.")
raise AuthError(403, "You are not invited to this room.")
else: else:
# TODO (erikj): may_join list # TODO (erikj): may_join list
# TODO (erikj): private rooms # TODO (erikj): private rooms
@ -390,10 +393,10 @@ class Auth(object):
def _verify_third_party_invite(self, event, auth_events): def _verify_third_party_invite(self, event, auth_events):
""" """
Validates that the join event is authorized by a previous third-party invite. Validates that the invite event is authorized by a previous third-party invite.
Checks that the public key, and keyserver, match those in the invite, Checks that the public key, and keyserver, match those in the third party invite,
and that the join event has a signature issued using that public key. and that the invite event has a signature issued using that public key.
Args: Args:
event: The m.room.member join event being validated. event: The m.room.member join event being validated.
@ -404,35 +407,28 @@ class Auth(object):
True if the event fulfills the expectations of a previous third party True if the event fulfills the expectations of a previous third party
invite event. invite event.
""" """
if not third_party_invites.join_has_third_party_invite(event.content): if "third_party_invite" not in event.content:
return False return False
join_third_party_invite = event.content["third_party_invite"] if "signed" not in event.content["third_party_invite"]:
token = join_third_party_invite["token"] return False
signed = event.content["third_party_invite"]["signed"]
for key in {"mxid", "token"}:
if key not in signed:
return False
token = signed["token"]
invite_event = auth_events.get( invite_event = auth_events.get(
(EventTypes.ThirdPartyInvite, token,) (EventTypes.ThirdPartyInvite, token,)
) )
if not invite_event: if not invite_event:
logger.info("Failing 3pid invite because no invite found for token %s", token) return False
if event.user_id != invite_event.user_id:
return False return False
try: try:
public_key = join_third_party_invite["public_key"] public_key = invite_event.content["public_key"]
key_validity_url = join_third_party_invite["key_validity_url"] if signed["mxid"] != event.state_key:
if invite_event.content["public_key"] != public_key:
logger.info(
"Failing 3pid invite because public key invite: %s != join: %s",
invite_event.content["public_key"],
public_key
)
return False
if invite_event.content["key_validity_url"] != key_validity_url:
logger.info(
"Failing 3pid invite because key_validity_url invite: %s != join: %s",
invite_event.content["key_validity_url"],
key_validity_url
)
return False
signed = join_third_party_invite["signed"]
if signed["mxid"] != event.user_id:
return False return False
if signed["token"] != token: if signed["token"] != token:
return False return False
@ -445,6 +441,11 @@ class Auth(object):
decode_base64(public_key) decode_base64(public_key)
) )
verify_signed_json(signed, server, verify_key) verify_signed_json(signed, server, verify_key)
# We got the public key from the invite, so we know that the
# correct server signed the signed bundle.
# The caller is responsible for checking that the signing
# server has not revoked that public key.
return True return True
return False return False
except (KeyError, SignatureVerifyException,): except (KeyError, SignatureVerifyException,):
@ -751,17 +752,19 @@ class Auth(object):
if e_type == Membership.JOIN: if e_type == Membership.JOIN:
if member_event and not is_public: if member_event and not is_public:
auth_ids.append(member_event.event_id) auth_ids.append(member_event.event_id)
if third_party_invites.join_has_third_party_invite(event.content): else:
if member_event:
auth_ids.append(member_event.event_id)
if e_type == Membership.INVITE:
if "third_party_invite" in event.content:
key = ( key = (
EventTypes.ThirdPartyInvite, EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["token"] event.content["third_party_invite"]["token"]
) )
invite = current_state.get(key) third_party_invite = current_state.get(key)
if invite: if third_party_invite:
auth_ids.append(invite.event_id) auth_ids.append(third_party_invite.event_id)
else:
if member_event:
auth_ids.append(member_event.event_id)
elif member_event: elif member_event:
if member_event.content["membership"] == Membership.JOIN: if member_event.content["membership"] == Membership.JOIN:
auth_ids.append(member_event.event_id) auth_ids.append(member_event.event_id)

View file

@ -50,11 +50,11 @@ class Filtering(object):
# many definitions. # many definitions.
top_level_definitions = [ top_level_definitions = [
"public_user_data", "private_user_data", "server_data" "presence"
] ]
room_level_definitions = [ room_level_definitions = [
"state", "timeline", "ephemeral" "state", "timeline", "ephemeral", "private_user_data"
] ]
for key in top_level_definitions: for key in top_level_definitions:
@ -114,22 +114,6 @@ class Filtering(object):
if not isinstance(event_type, basestring): if not isinstance(event_type, basestring):
raise SynapseError(400, "Event type should be a string") raise SynapseError(400, "Event type should be a string")
if "format" in definition:
event_format = definition["format"]
if event_format not in ["federation", "events"]:
raise SynapseError(400, "Invalid format: %s" % (event_format,))
if "select" in definition:
event_select_list = definition["select"]
for select_key in event_select_list:
if select_key not in ["event_id", "origin_server_ts",
"thread_id", "content", "content.body"]:
raise SynapseError(400, "Bad select: %s" % (select_key,))
if ("bundle_updates" in definition and
type(definition["bundle_updates"]) != bool):
raise SynapseError(400, "Bad bundle_updates: expected bool.")
class FilterCollection(object): class FilterCollection(object):
def __init__(self, filter_json): def __init__(self, filter_json):

View file

@ -132,7 +132,9 @@ class SynapseHomeServer(HomeServer):
def build_resource_for_static_content(self): def build_resource_for_static_content(self):
# This is old and should go away: not going to bother adding gzip # This is old and should go away: not going to bother adding gzip
return File("static") return File(
os.path.join(os.path.dirname(synapse.__file__), "static")
)
def build_resource_for_content_repo(self): def build_resource_for_content_repo(self):
return ContentRepoResource( return ContentRepoResource(

View file

@ -26,7 +26,6 @@ from synapse.api.errors import (
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util import third_party_invites
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
import synapse.metrics import synapse.metrics
@ -358,7 +357,7 @@ class FederationClient(FederationBase):
defer.returnValue(signed_auth) defer.returnValue(signed_auth)
@defer.inlineCallbacks @defer.inlineCallbacks
def make_membership_event(self, destinations, room_id, user_id, membership, content): def make_membership_event(self, destinations, room_id, user_id, membership):
""" """
Creates an m.room.member event, with context, without participating in the room. Creates an m.room.member event, with context, without participating in the room.
@ -390,14 +389,9 @@ class FederationClient(FederationBase):
if destination == self.server_name: if destination == self.server_name:
continue continue
args = {}
if third_party_invites.join_has_third_party_invite(content):
args = third_party_invites.extract_join_keys(
content["third_party_invite"]
)
try: try:
ret = yield self.transport_layer.make_membership_event( ret = yield self.transport_layer.make_membership_event(
destination, room_id, user_id, membership, args destination, room_id, user_id, membership
) )
pdu_dict = ret["event"] pdu_dict = ret["event"]
@ -704,3 +698,26 @@ class FederationClient(FederationBase):
event.internal_metadata.outlier = outlier event.internal_metadata.outlier = outlier
return event return event
@defer.inlineCallbacks
def forward_third_party_invite(self, destinations, room_id, event_dict):
for destination in destinations:
if destination == self.server_name:
continue
try:
yield self.transport_layer.exchange_third_party_invite(
destination=destination,
room_id=room_id,
event_dict=event_dict,
)
defer.returnValue(None)
except CodeMessageException:
raise
except Exception as e:
logger.exception(
"Failed to send_third_party_invite via %s: %s",
destination, e.message
)
raise RuntimeError("Failed to send to any server.")

View file

@ -23,12 +23,10 @@ from synapse.util.logutils import log_function
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
import synapse.metrics import synapse.metrics
from synapse.api.errors import FederationError, SynapseError, Codes from synapse.api.errors import FederationError, SynapseError
from synapse.crypto.event_signing import compute_event_signature from synapse.crypto.event_signing import compute_event_signature
from synapse.util import third_party_invites
import simplejson as json import simplejson as json
import logging import logging
@ -230,19 +228,8 @@ class FederationServer(FederationBase):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def on_make_join_request(self, room_id, user_id, query): def on_make_join_request(self, room_id, user_id):
threepid_details = {} pdu = yield self.handler.on_make_join_request(room_id, user_id)
if third_party_invites.has_join_keys(query):
for k in third_party_invites.JOIN_KEYS:
if not isinstance(query[k], list) or len(query[k]) != 1:
raise FederationError(
"FATAL",
Codes.MISSING_PARAM,
"key %s value %s" % (k, query[k],),
None
)
threepid_details[k] = query[k][0]
pdu = yield self.handler.on_make_join_request(room_id, user_id, threepid_details)
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
defer.returnValue({"event": pdu.get_pdu_json(time_now)}) defer.returnValue({"event": pdu.get_pdu_json(time_now)})
@ -556,3 +543,15 @@ class FederationServer(FederationBase):
event.internal_metadata.outlier = outlier event.internal_metadata.outlier = outlier
return event return event
@defer.inlineCallbacks
def exchange_third_party_invite(self, invite):
ret = yield self.handler.exchange_third_party_invite(invite)
defer.returnValue(ret)
@defer.inlineCallbacks
def on_exchange_third_party_invite_request(self, origin, room_id, event_dict):
ret = yield self.handler.on_exchange_third_party_invite_request(
origin, room_id, event_dict
)
defer.returnValue(ret)

View file

@ -161,7 +161,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def make_membership_event(self, destination, room_id, user_id, membership, args={}): def make_membership_event(self, destination, room_id, user_id, membership):
valid_memberships = {Membership.JOIN, Membership.LEAVE} valid_memberships = {Membership.JOIN, Membership.LEAVE}
if membership not in valid_memberships: if membership not in valid_memberships:
raise RuntimeError( raise RuntimeError(
@ -173,7 +173,6 @@ class TransportLayerClient(object):
content = yield self.client.get_json( content = yield self.client.get_json(
destination=destination, destination=destination,
path=path, path=path,
args=args,
retry_on_dns_fail=True, retry_on_dns_fail=True,
) )
@ -218,6 +217,19 @@ class TransportLayerClient(object):
defer.returnValue(response) defer.returnValue(response)
@defer.inlineCallbacks
@log_function
def exchange_third_party_invite(self, destination, room_id, event_dict):
path = PREFIX + "/exchange_third_party_invite/%s" % (room_id,)
response = yield self.client.put_json(
destination=destination,
path=path,
data=event_dict,
)
defer.returnValue(response)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_event_auth(self, destination, room_id, event_id): def get_event_auth(self, destination, room_id, event_id):

View file

@ -292,7 +292,7 @@ class FederationMakeJoinServlet(BaseFederationServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, origin, content, query, context, user_id): def on_GET(self, origin, content, query, context, user_id):
content = yield self.handler.on_make_join_request(context, user_id, query) content = yield self.handler.on_make_join_request(context, user_id)
defer.returnValue((200, content)) defer.returnValue((200, content))
@ -343,6 +343,17 @@ class FederationInviteServlet(BaseFederationServlet):
defer.returnValue((200, content)) defer.returnValue((200, content))
class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
PATH = "/exchange_third_party_invite/([^/]*)"
@defer.inlineCallbacks
def on_PUT(self, origin, content, query, room_id):
content = yield self.handler.on_exchange_third_party_invite_request(
origin, room_id, content
)
defer.returnValue((200, content))
class FederationClientKeysQueryServlet(BaseFederationServlet): class FederationClientKeysQueryServlet(BaseFederationServlet):
PATH = "/user/keys/query" PATH = "/user/keys/query"
@ -396,6 +407,30 @@ class FederationGetMissingEventsServlet(BaseFederationServlet):
defer.returnValue((200, content)) defer.returnValue((200, content))
class On3pidBindServlet(BaseFederationServlet):
PATH = "/3pid/onbind"
@defer.inlineCallbacks
def on_POST(self, request):
content_bytes = request.content.read()
content = json.loads(content_bytes)
if "invites" in content:
last_exception = None
for invite in content["invites"]:
try:
yield self.handler.exchange_third_party_invite(invite)
except Exception as e:
last_exception = e
if last_exception:
raise last_exception
defer.returnValue((200, {}))
# Avoid doing remote HS authorization checks which are done by default by
# BaseFederationServlet.
def _wrap(self, code):
return code
SERVLET_CLASSES = ( SERVLET_CLASSES = (
FederationPullServlet, FederationPullServlet,
FederationEventServlet, FederationEventServlet,
@ -413,4 +448,6 @@ SERVLET_CLASSES = (
FederationEventAuthServlet, FederationEventAuthServlet,
FederationClientKeysQueryServlet, FederationClientKeysQueryServlet,
FederationClientKeysClaimServlet, FederationClientKeysClaimServlet,
FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet,
) )

View file

@ -21,7 +21,6 @@ from synapse.api.constants import Membership, EventTypes
from synapse.types import UserID, RoomAlias from synapse.types import UserID, RoomAlias
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.util import third_party_invites
import logging import logging
@ -47,7 +46,8 @@ class BaseHandler(object):
self.event_builder_factory = hs.get_event_builder_factory() self.event_builder_factory = hs.get_event_builder_factory()
@defer.inlineCallbacks @defer.inlineCallbacks
def _filter_events_for_client(self, user_id, events, is_guest=False): def _filter_events_for_client(self, user_id, events, is_guest=False,
require_all_visible_for_guests=True):
# Assumes that user has at some point joined the room if not is_guest. # Assumes that user has at some point joined the room if not is_guest.
def allowed(event, membership, visibility): def allowed(event, membership, visibility):
@ -100,7 +100,9 @@ class BaseHandler(object):
if should_include: if should_include:
events_to_return.append(event) events_to_return.append(event)
if is_guest and len(events_to_return) < len(events): if (require_all_visible_for_guests
and is_guest
and len(events_to_return) < len(events)):
# This indicates that some events in the requested range were not # This indicates that some events in the requested range were not
# visible to guest users. To be safe, we reject the entire request, # visible to guest users. To be safe, we reject the entire request,
# so that we don't have to worry about interpreting visibility # so that we don't have to worry about interpreting visibility
@ -189,16 +191,6 @@ class BaseHandler(object):
) )
) )
if (
event.type == EventTypes.Member and
event.content["membership"] == Membership.JOIN and
third_party_invites.join_has_third_party_invite(event.content)
):
yield third_party_invites.check_key_valid(
self.hs.get_simple_http_client(),
event
)
federation_handler = self.hs.get_handlers().federation_handler federation_handler = self.hs.get_handlers().federation_handler
if event.type == EventTypes.Member: if event.type == EventTypes.Member:

View file

@ -100,7 +100,7 @@ class EventStreamHandler(BaseHandler):
@log_function @log_function
def get_stream(self, auth_user_id, pagin_config, timeout=0, def get_stream(self, auth_user_id, pagin_config, timeout=0,
as_client_event=True, affect_presence=True, as_client_event=True, affect_presence=True,
only_room_events=False): only_room_events=False, room_id=None, is_guest=False):
"""Fetches the events stream for a given user. """Fetches the events stream for a given user.
If `only_room_events` is `True` only room events will be returned. If `only_room_events` is `True` only room events will be returned.
@ -119,9 +119,15 @@ class EventStreamHandler(BaseHandler):
# thundering herds on restart. # thundering herds on restart.
timeout = random.randint(int(timeout*0.9), int(timeout*1.1)) timeout = random.randint(int(timeout*0.9), int(timeout*1.1))
if is_guest:
yield self.distributor.fire(
"user_joined_room", user=auth_user, room_id=room_id
)
events, tokens = yield self.notifier.get_events_for( events, tokens = yield self.notifier.get_events_for(
auth_user, pagin_config, timeout, auth_user, pagin_config, timeout,
only_room_events=only_room_events only_room_events=only_room_events,
is_guest=is_guest, guest_room_id=room_id
) )
time_now = self.clock.time_msec() time_now = self.clock.time_msec()

View file

@ -21,6 +21,7 @@ from synapse.api.errors import (
AuthError, FederationError, StoreError, CodeMessageException, SynapseError, AuthError, FederationError, StoreError, CodeMessageException, SynapseError,
) )
from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.events.validator import EventValidator
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
@ -39,7 +40,6 @@ from twisted.internet import defer
import itertools import itertools
import logging import logging
from synapse.util import third_party_invites
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -58,6 +58,8 @@ class FederationHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(FederationHandler, self).__init__(hs) super(FederationHandler, self).__init__(hs)
self.hs = hs
self.distributor.observe( self.distributor.observe(
"user_joined_room", "user_joined_room",
self._on_user_joined self._on_user_joined
@ -68,7 +70,6 @@ class FederationHandler(BaseHandler):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.replication_layer = hs.get_replication_layer() self.replication_layer = hs.get_replication_layer()
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
# self.auth_handler = gs.get_auth_handler()
self.server_name = hs.hostname self.server_name = hs.hostname
self.keyring = hs.get_keyring() self.keyring = hs.get_keyring()
@ -563,7 +564,7 @@ class FederationHandler(BaseHandler):
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def do_invite_join(self, target_hosts, room_id, joinee, content): def do_invite_join(self, target_hosts, room_id, joinee):
""" Attempts to join the `joinee` to the room `room_id` via the """ Attempts to join the `joinee` to the room `room_id` via the
server `target_host`. server `target_host`.
@ -583,8 +584,7 @@ class FederationHandler(BaseHandler):
target_hosts, target_hosts,
room_id, room_id,
joinee, joinee,
"join", "join"
content
) )
self.room_queues[room_id] = [] self.room_queues[room_id] = []
@ -661,16 +661,12 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_make_join_request(self, room_id, user_id, query): def on_make_join_request(self, room_id, user_id):
""" We've received a /make_join/ request, so we create a partial """ We've received a /make_join/ request, so we create a partial
join event for the room and return that. We do *not* persist or join event for the room and return that. We do *not* persist or
process it until the other server has signed it and sent it back. process it until the other server has signed it and sent it back.
""" """
event_content = {"membership": Membership.JOIN} event_content = {"membership": Membership.JOIN}
if third_party_invites.has_join_keys(query):
event_content["third_party_invite"] = (
third_party_invites.extract_join_keys(query)
)
builder = self.event_builder_factory.new({ builder = self.event_builder_factory.new({
"type": EventTypes.Member, "type": EventTypes.Member,
@ -686,9 +682,6 @@ class FederationHandler(BaseHandler):
self.auth.check(event, auth_events=context.current_state) self.auth.check(event, auth_events=context.current_state)
if third_party_invites.join_has_third_party_invite(event.content):
third_party_invites.check_key_valid(self.hs.get_simple_http_client(), event)
defer.returnValue(event) defer.returnValue(event)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -828,8 +821,7 @@ class FederationHandler(BaseHandler):
target_hosts, target_hosts,
room_id, room_id,
user_id, user_id,
"leave", "leave"
{}
) )
signed_event = self._sign_event(event) signed_event = self._sign_event(event)
@ -848,13 +840,12 @@ class FederationHandler(BaseHandler):
defer.returnValue(None) defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def _make_and_verify_event(self, target_hosts, room_id, user_id, membership, content): def _make_and_verify_event(self, target_hosts, room_id, user_id, membership):
origin, pdu = yield self.replication_layer.make_membership_event( origin, pdu = yield self.replication_layer.make_membership_event(
target_hosts, target_hosts,
room_id, room_id,
user_id, user_id,
membership, membership
content
) )
logger.debug("Got response to make_%s: %s", membership, pdu) logger.debug("Got response to make_%s: %s", membership, pdu)
@ -1647,3 +1638,75 @@ class FederationHandler(BaseHandler):
}, },
"missing": [e.event_id for e in missing_locals], "missing": [e.event_id for e in missing_locals],
}) })
@defer.inlineCallbacks
@log_function
def exchange_third_party_invite(self, invite):
sender = invite["sender"]
room_id = invite["room_id"]
event_dict = {
"type": EventTypes.Member,
"content": {
"membership": Membership.INVITE,
"third_party_invite": invite,
},
"room_id": room_id,
"sender": sender,
"state_key": invite["mxid"],
}
if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
builder = self.event_builder_factory.new(event_dict)
EventValidator().validate_new(builder)
event, context = yield self._create_new_client_event(builder=builder)
self.auth.check(event, context.current_state)
yield self._validate_keyserver(event, auth_events=context.current_state)
member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.change_membership(event, context)
else:
destinations = set([x.split(":", 1)[-1] for x in (sender, room_id)])
yield self.replication_layer.forward_third_party_invite(
destinations,
room_id,
event_dict,
)
@defer.inlineCallbacks
@log_function
def on_exchange_third_party_invite_request(self, origin, room_id, event_dict):
builder = self.event_builder_factory.new(event_dict)
event, context = yield self._create_new_client_event(
builder=builder,
)
self.auth.check(event, auth_events=context.current_state)
yield self._validate_keyserver(event, auth_events=context.current_state)
returned_invite = yield self.send_invite(origin, event)
# TODO: Make sure the signatures actually are correct.
event.signatures.update(returned_invite.signatures)
member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.change_membership(event, context)
@defer.inlineCallbacks
def _validate_keyserver(self, event, auth_events):
token = event.content["third_party_invite"]["signed"]["token"]
invite_event = auth_events.get(
(EventTypes.ThirdPartyInvite, token,)
)
try:
response = yield self.hs.get_simple_http_client().get_json(
invite_event.content["key_validity_url"],
{"public_key": invite_event.content["public_key"]}
)
except Exception:
raise SynapseError(
502,
"Third party certificate could not be checked"
)
if "valid" not in response or not response["valid"]:
raise AuthError(403, "Third party certificate was invalid")

View file

@ -16,7 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError, AuthError, Codes
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
@ -229,7 +229,7 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_data(self, user_id=None, room_id=None, def get_room_data(self, user_id=None, room_id=None,
event_type=None, state_key=""): event_type=None, state_key="", is_guest=False):
""" Get data from a room. """ Get data from a room.
Args: Args:
@ -239,23 +239,42 @@ class MessageHandler(BaseHandler):
Raises: Raises:
SynapseError if something went wrong. SynapseError if something went wrong.
""" """
member_event = yield self.auth.check_user_was_in_room(room_id, user_id) membership, membership_event_id = yield self._check_in_room_or_world_readable(
room_id, user_id, is_guest
)
if member_event.membership == Membership.JOIN: if membership == Membership.JOIN:
data = yield self.state_handler.get_current_state( data = yield self.state_handler.get_current_state(
room_id, event_type, state_key room_id, event_type, state_key
) )
elif member_event.membership == Membership.LEAVE: elif membership == Membership.LEAVE:
key = (event_type, state_key) key = (event_type, state_key)
room_state = yield self.store.get_state_for_events( room_state = yield self.store.get_state_for_events(
[member_event.event_id], [key] [membership_event_id], [key]
) )
data = room_state[member_event.event_id].get(key) data = room_state[membership_event_id].get(key)
defer.returnValue(data) defer.returnValue(data)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_events(self, user_id, room_id): def _check_in_room_or_world_readable(self, room_id, user_id, is_guest):
if is_guest:
visibility = yield self.state_handler.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, ""
)
if visibility.content["history_visibility"] == "world_readable":
defer.returnValue((Membership.JOIN, None))
return
else:
raise AuthError(
403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
)
else:
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
defer.returnValue((member_event.membership, member_event.event_id))
@defer.inlineCallbacks
def get_state_events(self, user_id, room_id, is_guest=False):
"""Retrieve all state events for a given room. If the user is """Retrieve all state events for a given room. If the user is
joined to the room then return the current state. If the user has joined to the room then return the current state. If the user has
left the room return the state events from when they left. left the room return the state events from when they left.
@ -266,15 +285,17 @@ class MessageHandler(BaseHandler):
Returns: Returns:
A list of dicts representing state events. [{}, {}, {}] A list of dicts representing state events. [{}, {}, {}]
""" """
member_event = yield self.auth.check_user_was_in_room(room_id, user_id) membership, membership_event_id = yield self._check_in_room_or_world_readable(
room_id, user_id, is_guest
)
if member_event.membership == Membership.JOIN: if membership == Membership.JOIN:
room_state = yield self.state_handler.get_current_state(room_id) room_state = yield self.state_handler.get_current_state(room_id)
elif member_event.membership == Membership.LEAVE: elif membership == Membership.LEAVE:
room_state = yield self.store.get_state_for_events( room_state = yield self.store.get_state_for_events(
[member_event.event_id], None [membership_event_id], None
) )
room_state = room_state[member_event.event_id] room_state = room_state[membership_event_id]
now = self.clock.time_msec() now = self.clock.time_msec()
defer.returnValue( defer.returnValue(

View file

@ -1142,8 +1142,9 @@ class PresenceEventSource(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_new_events_for_user(self, user, from_key, limit): def get_new_events(self, user, from_key, room_ids=None, **kwargs):
from_key = int(from_key) from_key = int(from_key)
room_ids = room_ids or []
presence = self.hs.get_handlers().presence_handler presence = self.hs.get_handlers().presence_handler
cachemap = presence._user_cachemap cachemap = presence._user_cachemap
@ -1161,7 +1162,6 @@ class PresenceEventSource(object):
user_ids_to_check |= set( user_ids_to_check |= set(
UserID.from_string(p["observed_user_id"]) for p in presence_list UserID.from_string(p["observed_user_id"]) for p in presence_list
) )
room_ids = yield presence.get_joined_rooms_for_user(user)
for room_id in set(room_ids) & set(presence._room_serials): for room_id in set(room_ids) & set(presence._room_serials):
if presence._room_serials[room_id] > from_key: if presence._room_serials[room_id] > from_key:
joined = yield presence.get_joined_users_for_room_id(room_id) joined = yield presence.get_joined_users_for_room_id(room_id)

View file

@ -24,7 +24,7 @@ class PrivateUserDataEventSource(object):
return self.store.get_max_private_user_data_stream_id() return self.store.get_max_private_user_data_stream_id()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_new_events_for_user(self, user, from_key, limit): def get_new_events(self, user, from_key, **kwargs):
user_id = user.to_string() user_id = user.to_string()
last_stream_id = from_key last_stream_id = from_key

View file

@ -164,17 +164,15 @@ class ReceiptEventSource(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_new_events_for_user(self, user, from_key, limit): def get_new_events(self, from_key, room_ids, **kwargs):
from_key = int(from_key) from_key = int(from_key)
to_key = yield self.get_current_key() to_key = yield self.get_current_key()
if from_key == to_key: if from_key == to_key:
defer.returnValue(([], to_key)) defer.returnValue(([], to_key))
rooms = yield self.store.get_rooms_for_user(user.to_string())
rooms = [room.room_id for room in rooms]
events = yield self.store.get_linearized_receipts_for_rooms( events = yield self.store.get_linearized_receipts_for_rooms(
rooms, room_ids,
from_key=from_key, from_key=from_key,
to_key=to_key, to_key=to_key,
) )

View file

@ -38,6 +38,8 @@ import string
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
id_server_scheme = "https://"
class RoomCreationHandler(BaseHandler): class RoomCreationHandler(BaseHandler):
@ -488,8 +490,7 @@ class RoomMemberHandler(BaseHandler):
yield handler.do_invite_join( yield handler.do_invite_join(
room_hosts, room_hosts,
room_id, room_id,
event.user_id, event.user_id
event.content # FIXME To get a non-frozen dict
) )
else: else:
logger.debug("Doing normal join") logger.debug("Doing normal join")
@ -632,7 +633,7 @@ class RoomMemberHandler(BaseHandler):
""" """
try: try:
data = yield self.hs.get_simple_http_client().get_json( data = yield self.hs.get_simple_http_client().get_json(
"https://%s/_matrix/identity/api/v1/lookup" % (id_server,), "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,),
{ {
"medium": medium, "medium": medium,
"address": address, "address": address,
@ -655,8 +656,8 @@ class RoomMemberHandler(BaseHandler):
raise AuthError(401, "No signature from server %s" % (server_hostname,)) raise AuthError(401, "No signature from server %s" % (server_hostname,))
for key_name, signature in data["signatures"][server_hostname].items(): for key_name, signature in data["signatures"][server_hostname].items():
key_data = yield self.hs.get_simple_http_client().get_json( key_data = yield self.hs.get_simple_http_client().get_json(
"https://%s/_matrix/identity/api/v1/pubkey/%s" % "%s%s/_matrix/identity/api/v1/pubkey/%s" %
(server_hostname, key_name,), (id_server_scheme, server_hostname, key_name,),
) )
if "public_key" not in key_data: if "public_key" not in key_data:
raise AuthError(401, "No public key named %s from %s" % raise AuthError(401, "No public key named %s from %s" %
@ -709,7 +710,9 @@ class RoomMemberHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _ask_id_server_for_third_party_invite( def _ask_id_server_for_third_party_invite(
self, id_server, medium, address, room_id, sender): self, id_server, medium, address, room_id, sender):
is_url = "https://%s/_matrix/identity/api/v1/store-invite" % (id_server,) is_url = "%s%s/_matrix/identity/api/v1/store-invite" % (
id_server_scheme, id_server,
)
data = yield self.hs.get_simple_http_client().post_urlencoded_get_json( data = yield self.hs.get_simple_http_client().post_urlencoded_get_json(
is_url, is_url,
{ {
@ -722,8 +725,8 @@ class RoomMemberHandler(BaseHandler):
# TODO: Check for success # TODO: Check for success
token = data["token"] token = data["token"]
public_key = data["public_key"] public_key = data["public_key"]
key_validity_url = "https://%s/_matrix/identity/api/v1/pubkey/isvalid" % ( key_validity_url = "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % (
id_server, id_server_scheme, id_server,
) )
defer.returnValue((token, public_key, key_validity_url)) defer.returnValue((token, public_key, key_validity_url))
@ -807,7 +810,14 @@ class RoomEventSource(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_new_events_for_user(self, user, from_key, limit): def get_new_events(
self,
user,
from_key,
limit,
room_ids,
is_guest,
):
# We just ignore the key for now. # We just ignore the key for now.
to_key = yield self.get_current_key() to_key = yield self.get_current_key()
@ -828,6 +838,8 @@ class RoomEventSource(object):
from_key=from_key, from_key=from_key,
to_key=to_key, to_key=to_key,
limit=limit, limit=limit,
room_ids=room_ids,
is_guest=is_guest,
) )
defer.returnValue((events, end_key)) defer.returnValue((events, end_key))

View file

@ -22,6 +22,8 @@ from synapse.api.filtering import Filter
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from unpaddedbase64 import decode_base64, encode_base64
import logging import logging
@ -34,27 +36,59 @@ class SearchHandler(BaseHandler):
super(SearchHandler, self).__init__(hs) super(SearchHandler, self).__init__(hs)
@defer.inlineCallbacks @defer.inlineCallbacks
def search(self, user, content): def search(self, user, content, batch=None):
"""Performs a full text search for a user. """Performs a full text search for a user.
Args: Args:
user (UserID) user (UserID)
content (dict): Search parameters content (dict): Search parameters
batch (str): The next_batch parameter. Used for pagination.
Returns: Returns:
dict to be returned to the client with results of search dict to be returned to the client with results of search
""" """
batch_group = None
batch_group_key = None
batch_token = None
if batch:
try:
b = decode_base64(batch)
batch_group, batch_group_key, batch_token = b.split("\n")
assert batch_group is not None
assert batch_group_key is not None
assert batch_token is not None
except:
raise SynapseError(400, "Invalid batch")
try: try:
search_term = content["search_categories"]["room_events"]["search_term"] room_cat = content["search_categories"]["room_events"]
keys = content["search_categories"]["room_events"].get("keys", [
# The actual thing to query in FTS
search_term = room_cat["search_term"]
# Which "keys" to search over in FTS query
keys = room_cat.get("keys", [
"content.body", "content.name", "content.topic", "content.body", "content.name", "content.topic",
]) ])
filter_dict = content["search_categories"]["room_events"].get("filter", {})
event_context = content["search_categories"]["room_events"].get( # Filter to apply to results
filter_dict = room_cat.get("filter", {})
# What to order results by (impacts whether pagination can be doen)
order_by = room_cat.get("order_by", "rank")
# Include context around each event?
event_context = room_cat.get(
"event_context", None "event_context", None
) )
# Group results together? May allow clients to paginate within a
# group
group_by = room_cat.get("groupings", {}).get("group_by", {})
group_keys = [g["key"] for g in group_by]
if event_context is not None: if event_context is not None:
before_limit = int(event_context.get( before_limit = int(event_context.get(
"before_limit", 5 "before_limit", 5
@ -65,6 +99,15 @@ class SearchHandler(BaseHandler):
except KeyError: except KeyError:
raise SynapseError(400, "Invalid search query") raise SynapseError(400, "Invalid search query")
if order_by not in ("rank", "recent"):
raise SynapseError(400, "Invalid order by: %r" % (order_by,))
if set(group_keys) - {"room_id", "sender"}:
raise SynapseError(
400,
"Invalid group by keys: %r" % (set(group_keys) - {"room_id", "sender"},)
)
search_filter = Filter(filter_dict) search_filter = Filter(filter_dict)
# TODO: Search through left rooms too # TODO: Search through left rooms too
@ -77,19 +120,130 @@ class SearchHandler(BaseHandler):
room_ids = search_filter.filter_rooms(room_ids) room_ids = search_filter.filter_rooms(room_ids)
rank_map, event_map, _ = yield self.store.search_msgs( if batch_group == "room_id":
room_ids, search_term, keys room_ids.intersection_update({batch_group_key})
)
filtered_events = search_filter.filter(event_map.values()) rank_map = {} # event_id -> rank of event
allowed_events = []
room_groups = {} # Holds result of grouping by room, if applicable
sender_group = {} # Holds result of grouping by sender, if applicable
allowed_events = yield self._filter_events_for_client( # Holds the next_batch for the entire result set if one of those exists
user.to_string(), filtered_events global_next_batch = None
)
allowed_events.sort(key=lambda e: -rank_map[e.event_id]) if order_by == "rank":
allowed_events = allowed_events[:search_filter.limit()] results = yield self.store.search_msgs(
room_ids, search_term, keys
)
results_map = {r["event"].event_id: r for r in results}
rank_map.update({r["event"].event_id: r["rank"] for r in results})
filtered_events = search_filter.filter([r["event"] for r in results])
events = yield self._filter_events_for_client(
user.to_string(), filtered_events
)
events.sort(key=lambda e: -rank_map[e.event_id])
allowed_events = events[:search_filter.limit()]
for e in allowed_events:
rm = room_groups.setdefault(e.room_id, {
"results": [],
"order": rank_map[e.event_id],
})
rm["results"].append(e.event_id)
s = sender_group.setdefault(e.sender, {
"results": [],
"order": rank_map[e.event_id],
})
s["results"].append(e.event_id)
elif order_by == "recent":
# In this case we specifically loop through each room as the given
# limit applies to each room, rather than a global list.
# This is not necessarilly a good idea.
for room_id in room_ids:
room_events = []
if batch_group == "room_id" and batch_group_key == room_id:
pagination_token = batch_token
else:
pagination_token = None
i = 0
# We keep looping and we keep filtering until we reach the limit
# or we run out of things.
# But only go around 5 times since otherwise synapse will be sad.
while len(room_events) < search_filter.limit() and i < 5:
i += 1
results = yield self.store.search_room(
room_id, search_term, keys, search_filter.limit() * 2,
pagination_token=pagination_token,
)
results_map = {r["event"].event_id: r for r in results}
rank_map.update({r["event"].event_id: r["rank"] for r in results})
filtered_events = search_filter.filter([
r["event"] for r in results
])
events = yield self._filter_events_for_client(
user.to_string(), filtered_events
)
room_events.extend(events)
room_events = room_events[:search_filter.limit()]
if len(results) < search_filter.limit() * 2:
pagination_token = None
break
else:
pagination_token = results[-1]["pagination_token"]
if room_events:
res = results_map[room_events[-1].event_id]
pagination_token = res["pagination_token"]
group = room_groups.setdefault(room_id, {})
if pagination_token:
next_batch = encode_base64("%s\n%s\n%s" % (
"room_id", room_id, pagination_token
))
group["next_batch"] = next_batch
if batch_token:
global_next_batch = next_batch
group["results"] = [e.event_id for e in room_events]
group["order"] = max(
e.origin_server_ts/1000 for e in room_events
if hasattr(e, "origin_server_ts")
)
allowed_events.extend(room_events)
# Normalize the group orders
if room_groups:
if len(room_groups) > 1:
mx = max(g["order"] for g in room_groups.values())
mn = min(g["order"] for g in room_groups.values())
for g in room_groups.values():
g["order"] = (g["order"] - mn) * 1.0 / (mx - mn)
else:
room_groups.values()[0]["order"] = 1
else:
# We should never get here due to the guard earlier.
raise NotImplementedError()
# If client has asked for "context" for each event (i.e. some surrounding
# events and state), fetch that
if event_context is not None: if event_context is not None:
now_token = yield self.hs.get_event_sources().get_current_token() now_token = yield self.hs.get_event_sources().get_current_token()
@ -144,11 +298,22 @@ class SearchHandler(BaseHandler):
logger.info("Found %d results", len(results)) logger.info("Found %d results", len(results))
rooms_cat_res = {
"results": results,
"count": len(results)
}
if room_groups and "room_id" in group_keys:
rooms_cat_res.setdefault("groups", {})["room_id"] = room_groups
if sender_group and "sender" in group_keys:
rooms_cat_res.setdefault("groups", {})["sender"] = sender_group
if global_next_batch:
rooms_cat_res["next_batch"] = global_next_batch
defer.returnValue({ defer.returnValue({
"search_categories": { "search_categories": {
"room_events": { "room_events": rooms_cat_res
"results": results,
"count": len(results)
}
} }
}) })

View file

@ -295,11 +295,16 @@ class SyncHandler(BaseHandler):
typing_key = since_token.typing_key if since_token else "0" typing_key = since_token.typing_key if since_token else "0"
rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string())
room_ids = [room.room_id for room in rooms]
typing_source = self.event_sources.sources["typing"] typing_source = self.event_sources.sources["typing"]
typing, typing_key = yield typing_source.get_new_events_for_user( typing, typing_key = yield typing_source.get_new_events(
user=sync_config.user, user=sync_config.user,
from_key=typing_key, from_key=typing_key,
limit=sync_config.filter.ephemeral_limit(), limit=sync_config.filter.ephemeral_limit(),
room_ids=room_ids,
is_guest=False,
) )
now_token = now_token.copy_and_replace("typing_key", typing_key) now_token = now_token.copy_and_replace("typing_key", typing_key)
@ -312,10 +317,13 @@ class SyncHandler(BaseHandler):
receipt_key = since_token.receipt_key if since_token else "0" receipt_key = since_token.receipt_key if since_token else "0"
receipt_source = self.event_sources.sources["receipt"] receipt_source = self.event_sources.sources["receipt"]
receipts, receipt_key = yield receipt_source.get_new_events_for_user( receipts, receipt_key = yield receipt_source.get_new_events(
user=sync_config.user, user=sync_config.user,
from_key=receipt_key, from_key=receipt_key,
limit=sync_config.filter.ephemeral_limit(), limit=sync_config.filter.ephemeral_limit(),
room_ids=room_ids,
# /sync doesn't support guest access, they can't get to this point in code
is_guest=False,
) )
now_token = now_token.copy_and_replace("receipt_key", receipt_key) now_token = now_token.copy_and_replace("receipt_key", receipt_key)
@ -360,11 +368,17 @@ class SyncHandler(BaseHandler):
""" """
now_token = yield self.event_sources.get_current_token() now_token = yield self.event_sources.get_current_token()
rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string())
room_ids = [room.room_id for room in rooms]
presence_source = self.event_sources.sources["presence"] presence_source = self.event_sources.sources["presence"]
presence, presence_key = yield presence_source.get_new_events_for_user( presence, presence_key = yield presence_source.get_new_events(
user=sync_config.user, user=sync_config.user,
from_key=since_token.presence_key, from_key=since_token.presence_key,
limit=sync_config.filter.presence_limit(), limit=sync_config.filter.presence_limit(),
room_ids=room_ids,
# /sync doesn't support guest access, they can't get to this point in code
is_guest=False,
) )
now_token = now_token.copy_and_replace("presence_key", presence_key) now_token = now_token.copy_and_replace("presence_key", presence_key)

View file

@ -246,17 +246,12 @@ class TypingNotificationEventSource(object):
}, },
} }
@defer.inlineCallbacks def get_new_events(self, from_key, room_ids, **kwargs):
def get_new_events_for_user(self, user, from_key, limit):
from_key = int(from_key) from_key = int(from_key)
handler = self.handler() handler = self.handler()
joined_room_ids = (
yield self.room_member_handler().get_joined_rooms_for_user(user)
)
events = [] events = []
for room_id in joined_room_ids: for room_id in room_ids:
if room_id not in handler._room_serials: if room_id not in handler._room_serials:
continue continue
if handler._room_serials[room_id] <= from_key: if handler._room_serials[room_id] <= from_key:
@ -264,7 +259,7 @@ class TypingNotificationEventSource(object):
events.append(self._make_event_for(room_id)) events.append(self._make_event_for(room_id))
defer.returnValue((events, handler._latest_room_serial)) return events, handler._latest_room_serial
def get_current_key(self): def get_current_key(self):
return self.handler()._latest_room_serial return self.handler()._latest_room_serial

View file

@ -35,6 +35,7 @@ from signedjson.sign import sign_json
import simplejson as json import simplejson as json
import logging import logging
import random
import sys import sys
import urllib import urllib
import urlparse import urlparse
@ -55,6 +56,9 @@ incoming_responses_counter = metrics.register_counter(
) )
MAX_RETRIES = 4
class MatrixFederationEndpointFactory(object): class MatrixFederationEndpointFactory(object):
def __init__(self, hs): def __init__(self, hs):
self.tls_server_context_factory = hs.tls_server_context_factory self.tls_server_context_factory = hs.tls_server_context_factory
@ -119,7 +123,7 @@ class MatrixFederationHttpClient(object):
# XXX: Would be much nicer to retry only at the transaction-layer # XXX: Would be much nicer to retry only at the transaction-layer
# (once we have reliable transactions in place) # (once we have reliable transactions in place)
retries_left = 5 retries_left = MAX_RETRIES
http_url_bytes = urlparse.urlunparse( http_url_bytes = urlparse.urlunparse(
("", "", path_bytes, param_bytes, query_bytes, "") ("", "", path_bytes, param_bytes, query_bytes, "")
@ -180,7 +184,9 @@ class MatrixFederationHttpClient(object):
) )
if retries_left and not timeout: if retries_left and not timeout:
yield sleep(2 ** (5 - retries_left)) delay = 5 ** (MAX_RETRIES + 1 - retries_left)
delay *= random.uniform(0.8, 1.4)
yield sleep(delay)
retries_left -= 1 retries_left -= 1
else: else:
raise raise

View file

@ -269,7 +269,7 @@ class Notifier(object):
logger.exception("Failed to notify listener") logger.exception("Failed to notify listener")
@defer.inlineCallbacks @defer.inlineCallbacks
def wait_for_events(self, user, timeout, callback, def wait_for_events(self, user, timeout, callback, room_ids=None,
from_token=StreamToken("s0", "0", "0", "0", "0")): from_token=StreamToken("s0", "0", "0", "0", "0")):
"""Wait until the callback returns a non empty response or the """Wait until the callback returns a non empty response or the
timeout fires. timeout fires.
@ -279,11 +279,12 @@ class Notifier(object):
if user_stream is None: if user_stream is None:
appservice = yield self.store.get_app_service_by_user_id(user) appservice = yield self.store.get_app_service_by_user_id(user)
current_token = yield self.event_sources.get_current_token() current_token = yield self.event_sources.get_current_token()
rooms = yield self.store.get_rooms_for_user(user) if room_ids is None:
rooms = [room.room_id for room in rooms] rooms = yield self.store.get_rooms_for_user(user)
room_ids = [room.room_id for room in rooms]
user_stream = _NotifierUserStream( user_stream = _NotifierUserStream(
user=user, user=user,
rooms=rooms, rooms=room_ids,
appservice=appservice, appservice=appservice,
current_token=current_token, current_token=current_token,
time_now_ms=self.clock.time_msec(), time_now_ms=self.clock.time_msec(),
@ -329,7 +330,8 @@ class Notifier(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_events_for(self, user, pagination_config, timeout, def get_events_for(self, user, pagination_config, timeout,
only_room_events=False): only_room_events=False,
is_guest=False, guest_room_id=None):
""" For the given user and rooms, return any new events for them. If """ For the given user and rooms, return any new events for them. If
there are no new events wait for up to `timeout` milliseconds for any there are no new events wait for up to `timeout` milliseconds for any
new events to happen before returning. new events to happen before returning.
@ -342,6 +344,16 @@ class Notifier(object):
limit = pagination_config.limit limit = pagination_config.limit
room_ids = []
if is_guest:
# TODO(daniel): Deal with non-room events too
only_room_events = True
if guest_room_id:
room_ids = [guest_room_id]
else:
rooms = yield self.store.get_rooms_for_user(user.to_string())
room_ids = [room.room_id for room in rooms]
@defer.inlineCallbacks @defer.inlineCallbacks
def check_for_updates(before_token, after_token): def check_for_updates(before_token, after_token):
if not after_token.is_after(before_token): if not after_token.is_after(before_token):
@ -357,9 +369,23 @@ class Notifier(object):
continue continue
if only_room_events and name != "room": if only_room_events and name != "room":
continue continue
new_events, new_key = yield source.get_new_events_for_user( new_events, new_key = yield source.get_new_events(
user, getattr(from_token, keyname), limit, user=user,
from_key=getattr(from_token, keyname),
limit=limit,
is_guest=is_guest,
room_ids=room_ids,
) )
if is_guest:
room_member_handler = self.hs.get_handlers().room_member_handler
new_events = yield room_member_handler._filter_events_for_client(
user.to_string(),
new_events,
is_guest=is_guest,
require_all_visible_for_guests=False
)
events.extend(new_events) events.extend(new_events)
end_token = end_token.copy_and_replace(keyname, new_key) end_token = end_token.copy_and_replace(keyname, new_key)
@ -369,7 +395,7 @@ class Notifier(object):
defer.returnValue(None) defer.returnValue(None)
result = yield self.wait_for_events( result = yield self.wait_for_events(
user, timeout, check_for_updates, from_token=from_token user, timeout, check_for_updates, room_ids=room_ids, from_token=from_token
) )
if result is None: if result is None:

View file

@ -34,7 +34,15 @@ class EventStreamRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
auth_user, _, _ = yield self.auth.get_user_by_req(request) auth_user, _, is_guest = yield self.auth.get_user_by_req(
request,
allow_guest=True
)
room_id = None
if is_guest:
if "room_id" not in request.args:
raise SynapseError(400, "Guest users must specify room_id param")
room_id = request.args["room_id"][0]
try: try:
handler = self.handlers.event_stream_handler handler = self.handlers.event_stream_handler
pagin_config = PaginationConfig.from_request(request) pagin_config = PaginationConfig.from_request(request)
@ -49,7 +57,8 @@ class EventStreamRestServlet(ClientV1RestServlet):
chunk = yield handler.get_stream( chunk = yield handler.get_stream(
auth_user.to_string(), pagin_config, timeout=timeout, auth_user.to_string(), pagin_config, timeout=timeout,
as_client_event=as_client_event as_client_event=as_client_event, affect_presence=(not is_guest),
room_id=room_id, is_guest=is_guest
) )
except: except:
logger.exception("Event stream failed") logger.exception("Event stream failed")

View file

@ -26,7 +26,6 @@ from synapse.events.utils import serialize_event
import simplejson as json import simplejson as json
import logging import logging
import urllib import urllib
from synapse.util import third_party_invites
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -125,7 +124,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id, event_type, state_key): def on_GET(self, request, room_id, event_type, state_key):
user, _, _ = yield self.auth.get_user_by_req(request) user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
data = yield msg_handler.get_room_data( data = yield msg_handler.get_room_data(
@ -133,6 +132,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
room_id=room_id, room_id=room_id,
event_type=event_type, event_type=event_type,
state_key=state_key, state_key=state_key,
is_guest=is_guest,
) )
if not data: if not data:
@ -348,12 +348,13 @@ class RoomStateRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
user, _, _ = yield self.auth.get_user_by_req(request) user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
handler = self.handlers.message_handler handler = self.handlers.message_handler
# Get all the current state for this room # Get all the current state for this room
events = yield handler.get_state_events( events = yield handler.get_state_events(
room_id=room_id, room_id=room_id,
user_id=user.to_string(), user_id=user.to_string(),
is_guest=is_guest,
) )
defer.returnValue((200, events)) defer.returnValue((200, events))
@ -451,7 +452,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
# target user is you unless it is an invite # target user is you unless it is an invite
state_key = user.to_string() state_key = user.to_string()
if membership_action == "invite" and third_party_invites.has_invite_keys(content): if membership_action == "invite" and self._has_3pid_invite_keys(content):
yield self.handlers.room_member_handler.do_3pid_invite( yield self.handlers.room_member_handler.do_3pid_invite(
room_id, room_id,
user, user,
@ -478,19 +479,10 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
event_content = {
"membership": unicode(membership_action),
}
if membership_action == "join" and third_party_invites.has_join_keys(content):
event_content["third_party_invite"] = (
third_party_invites.extract_join_keys(content)
)
yield msg_handler.create_and_send_event( yield msg_handler.create_and_send_event(
{ {
"type": EventTypes.Member, "type": EventTypes.Member,
"content": event_content, "content": {"membership": unicode(membership_action)},
"room_id": room_id, "room_id": room_id,
"sender": user.to_string(), "sender": user.to_string(),
"state_key": state_key, "state_key": state_key,
@ -501,6 +493,12 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
defer.returnValue((200, {})) defer.returnValue((200, {}))
def _has_3pid_invite_keys(self, content):
for key in {"id_server", "medium", "address", "display_name"}:
if key not in content:
return False
return True
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_id, membership_action, txn_id): def on_PUT(self, request, room_id, membership_action, txn_id):
try: try:
@ -602,7 +600,8 @@ class SearchRestServlet(ClientV1RestServlet):
content = _parse_json(request) content = _parse_json(request)
results = yield self.handlers.search_handler.search(auth_user, content) batch = request.args.get("next_batch", [None])[0]
results = yield self.handlers.search_handler.search(auth_user, content, batch)
defer.returnValue((200, results)) defer.returnValue((200, results))

View file

@ -0,0 +1,50 @@
<html>
<head>
<title> Login </title>
<meta name='viewport' content='width=device-width, initial-scale=1, user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
<link rel="stylesheet" href="style.css">
<script src="js/jquery-2.1.3.min.js"></script>
<script src="js/login.js"></script>
</head>
<body onload="matrixLogin.onLoad()">
<center>
<br/>
<h1>Log in with one of the following methods</h1>
<span id="feedback" style="color: #f00"></span>
<br/>
<br/>
<div id="loading">
<img src="spinner.gif" />
</div>
<div id="cas_flow" class="login_flow" style="display:none"
onclick="gotoCas(); return false;">
CAS Authentication: <button id="cas_button" style="margin: 10px">Log in</button>
</div>
<br/>
<form id="password_form" class="login_flow" style="display:none"
onsubmit="matrixLogin.password_login(); return false;">
<div>
Password Authentication:<br/>
<div style="text-align: center">
<input id="user_id" size="32" type="text" placeholder="Matrix ID (e.g. bob)" autocapitalize="off" autocorrect="off" />
<br/>
<input id="password" size="32" type="password" placeholder="Password"/>
<br/>
<button type="submit" style="margin: 10px">Log in</button>
</div>
</div>
</form>
<div id="no_login_types" type="button" class="login_flow" style="display:none">
Log in currently unavailable.
</div>
</center>
</body>
</html>

View file

@ -0,0 +1,167 @@
window.matrixLogin = {
endpoint: location.origin + "/_matrix/client/api/v1/login",
serverAcceptsPassword: false,
serverAcceptsCas: false
};
var submitPassword = function(user, pwd) {
console.log("Logging in with password...");
var data = {
type: "m.login.password",
user: user,
password: pwd,
};
$.post(matrixLogin.endpoint, JSON.stringify(data), function(response) {
show_login();
matrixLogin.onLogin(response);
}).error(errorFunc);
};
var submitCas = function(ticket, service) {
console.log("Logging in with cas...");
var data = {
type: "m.login.cas",
ticket: ticket,
service: service,
};
$.post(matrixLogin.endpoint, JSON.stringify(data), function(response) {
show_login();
matrixLogin.onLogin(response);
}).error(errorFunc);
};
var errorFunc = function(err) {
show_login();
if (err.responseJSON && err.responseJSON.error) {
setFeedbackString(err.responseJSON.error + " (" + err.responseJSON.errcode + ")");
}
else {
setFeedbackString("Request failed: " + err.status);
}
};
var getCasURL = function(cb) {
$.get(matrixLogin.endpoint + "/cas", function(response) {
var cas_url = response.serverUrl;
cb(cas_url);
}).error(errorFunc);
};
var gotoCas = function() {
getCasURL(function(cas_url) {
var this_page = window.location.origin + window.location.pathname;
var redirect_url = cas_url + "/login?service=" + encodeURIComponent(this_page);
window.location.replace(redirect_url);
});
}
var setFeedbackString = function(text) {
$("#feedback").text(text);
};
var show_login = function() {
$("#loading").hide();
if (matrixLogin.serverAcceptsPassword) {
$("#password_form").show();
}
if (matrixLogin.serverAcceptsCas) {
$("#cas_flow").show();
}
if (!matrixLogin.serverAcceptsPassword && !matrixLogin.serverAcceptsCas) {
$("#no_login_types").show();
}
};
var show_spinner = function() {
$("#password_form").hide();
$("#cas_flow").hide();
$("#no_login_types").hide();
$("#loading").show();
};
var fetch_info = function(cb) {
$.get(matrixLogin.endpoint, function(response) {
var serverAcceptsPassword = false;
var serverAcceptsCas = false;
for (var i=0; i<response.flows.length; i++) {
var flow = response.flows[i];
if ("m.login.cas" === flow.type) {
matrixLogin.serverAcceptsCas = true;
console.log("Server accepts CAS");
}
if ("m.login.password" === flow.type) {
matrixLogin.serverAcceptsPassword = true;
console.log("Server accepts password");
}
}
cb();
}).error(errorFunc);
}
matrixLogin.onLoad = function() {
fetch_info(function() {
if (!try_cas()) {
show_login();
}
});
};
matrixLogin.password_login = function() {
var user = $("#user_id").val();
var pwd = $("#password").val();
setFeedbackString("");
show_spinner();
submitPassword(user, pwd);
};
matrixLogin.onLogin = function(response) {
// clobber this function
console.log("onLogin - This function should be replaced to proceed.");
console.log(response);
};
var parseQsFromUrl = function(query) {
var result = {};
query.split("&").forEach(function(part) {
var item = part.split("=");
var key = item[0];
var val = item[1];
if (val) {
val = decodeURIComponent(val);
}
result[key] = val
});
return result;
};
var try_cas = function() {
var pos = window.location.href.indexOf("?");
if (pos == -1) {
return false;
}
var qs = parseQsFromUrl(window.location.href.substr(pos+1));
var ticket = qs.ticket;
if (!ticket) {
return false;
}
submitCas(ticket, location.origin);
return true;
};

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.8 KiB

View file

@ -0,0 +1,57 @@
html {
height: 100%;
}
body {
height: 100%;
font-family: "Myriad Pro", "Myriad", Helvetica, Arial, sans-serif;
font-size: 12pt;
margin: 0px;
}
h1 {
font-size: 20pt;
}
a:link { color: #666; }
a:visited { color: #666; }
a:hover { color: #000; }
a:active { color: #000; }
input {
width: 90%
}
textarea, input {
font-family: inherit;
font-size: inherit;
margin: 5px;
}
.smallPrint {
color: #888;
font-size: 9pt ! important;
font-style: italic ! important;
}
.g-recaptcha div {
margin: auto;
}
.login_flow {
text-align: left;
padding: 10px;
margin-bottom: 40px;
display: inline-block;
-webkit-border-radius: 10px;
-moz-border-radius: 10px;
border-radius: 10px;
-webkit-box-shadow: 0px 0px 20px 0px rgba(0,0,0,0.15);
-moz-box-shadow: 0px 0px 20px 0px rgba(0,0,0,0.15);
box-shadow: 0px 0px 20px 0px rgba(0,0,0,0.15);
background-color: #f8f8f8;
border: 1px #ccc solid;
}

File diff suppressed because one or more lines are too long

View file

@ -311,6 +311,8 @@ class EventsStore(SQLBaseStore):
self._store_room_message_txn(txn, event) self._store_room_message_txn(txn, event)
elif event.type == EventTypes.Redaction: elif event.type == EventTypes.Redaction:
self._store_redaction(txn, event) self._store_redaction(txn, event)
elif event.type == EventTypes.RoomHistoryVisibility:
self._store_history_visibility_txn(txn, event)
self._store_room_members_txn( self._store_room_members_txn(
txn, txn,

View file

@ -202,6 +202,19 @@ class RoomStore(SQLBaseStore):
txn, event, "content.body", event.content["body"] txn, event, "content.body", event.content["body"]
) )
def _store_history_visibility_txn(self, txn, event):
if hasattr(event, "content") and "history_visibility" in event.content:
sql = (
"INSERT INTO history_visibility"
" (event_id, room_id, history_visibility)"
" VALUES (?, ?, ?)"
)
txn.execute(sql, (
event.event_id,
event.room_id,
event.content["history_visibility"]
))
def _store_event_search_txn(self, txn, event, key, value): def _store_event_search_txn(self, txn, event, key, value):
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
sql = ( sql = (

View file

@ -0,0 +1,26 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* This is a manual index of history_visibility content of state events,
* so that we can join on them in SELECT statements.
*/
CREATE TABLE IF NOT EXISTS history_visibility(
id INTEGER PRIMARY KEY,
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
history_visibility TEXT NOT NULL,
UNIQUE (event_id)
);

View file

@ -16,18 +16,13 @@
from twisted.internet import defer from twisted.internet import defer
from _base import SQLBaseStore from _base import SQLBaseStore
from synapse.api.errors import SynapseError
from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from collections import namedtuple import logging
"""The result of a search.
Fields: logger = logging.getLogger(__name__)
rank_map (dict): Mapping event_id -> rank
event_map (dict): Mapping event_id -> event
pagination_token (str): Pagination token
"""
SearchResult = namedtuple("SearchResult", ("rank_map", "event_map", "pagination_token"))
class SearchStore(SQLBaseStore): class SearchStore(SQLBaseStore):
@ -42,7 +37,7 @@ class SearchStore(SQLBaseStore):
"content.body", "content.name", "content.topic" "content.body", "content.name", "content.topic"
Returns: Returns:
SearchResult list of dicts
""" """
clauses = [] clauses = []
args = [] args = []
@ -100,12 +95,103 @@ class SearchStore(SQLBaseStore):
for ev in events for ev in events
} }
defer.returnValue(SearchResult( defer.returnValue([
{ {
r["event_id"]: r["rank"] "event": event_map[r["event_id"]],
for r in results "rank": r["rank"],
if r["event_id"] in event_map }
}, for r in results
event_map, if r["event_id"] in event_map
None ])
))
@defer.inlineCallbacks
def search_room(self, room_id, search_term, keys, limit, pagination_token=None):
"""Performs a full text search over events with given keys.
Args:
room_id (str): The room_id to search in
search_term (str): Search term to search for
keys (list): List of keys to search in, currently supports
"content.body", "content.name", "content.topic"
pagination_token (str): A pagination token previously returned
Returns:
list of dicts
"""
clauses = []
args = [search_term, room_id]
local_clauses = []
for key in keys:
local_clauses.append("key = ?")
args.append(key)
clauses.append(
"(%s)" % (" OR ".join(local_clauses),)
)
if pagination_token:
try:
topo, stream = pagination_token.split(",")
topo = int(topo)
stream = int(stream)
except:
raise SynapseError(400, "Invalid pagination token")
clauses.append(
"(topological_ordering < ?"
" OR (topological_ordering = ? AND stream_ordering < ?))"
)
args.extend([topo, topo, stream])
if isinstance(self.database_engine, PostgresEngine):
sql = (
"SELECT ts_rank_cd(vector, query) as rank,"
" topological_ordering, stream_ordering, room_id, event_id"
" FROM plainto_tsquery('english', ?) as query, event_search"
" NATURAL JOIN events"
" WHERE vector @@ query AND room_id = ?"
)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
"SELECT rank(matchinfo(event_search)) as rank, room_id, event_id"
" topological_ordering, stream_ordering"
" FROM event_search"
" NATURAL JOIN events"
" WHERE value MATCH ? AND room_id = ?"
)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
for clause in clauses:
sql += " AND " + clause
# We add an arbitrary limit here to ensure we don't try to pull the
# entire table from the database.
sql += " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?"
args.append(limit)
results = yield self._execute(
"search_rooms", self.cursor_to_dict, sql, *args
)
events = yield self._get_events([r["event_id"] for r in results])
event_map = {
ev.event_id: ev
for ev in events
}
defer.returnValue([
{
"event": event_map[r["event_id"]],
"rank": r["rank"],
"pagination_token": "%s,%s" % (
r["topological_ordering"], r["stream_ordering"]
),
}
for r in results
if r["event_id"] in event_map
])

View file

@ -158,13 +158,40 @@ class StreamStore(SQLBaseStore):
defer.returnValue(results) defer.returnValue(results)
@log_function @log_function
def get_room_events_stream(self, user_id, from_key, to_key, limit=0): def get_room_events_stream(
current_room_membership_sql = ( self,
"SELECT m.room_id FROM room_memberships as m " user_id,
" INNER JOIN current_state_events as c" from_key,
" ON m.event_id = c.event_id AND c.state_key = m.user_id" to_key,
" WHERE m.user_id = ? AND m.membership = 'join'" limit=0,
) is_guest=False,
room_ids=None
):
room_ids = room_ids or []
room_ids = [r for r in room_ids]
if is_guest:
current_room_membership_sql = (
"SELECT c.room_id FROM history_visibility AS h"
" INNER JOIN current_state_events AS c"
" ON h.event_id = c.event_id"
" WHERE c.room_id IN (%s) AND h.history_visibility = 'world_readable'" % (
",".join(map(lambda _: "?", room_ids))
)
)
current_room_membership_args = room_ids
else:
current_room_membership_sql = (
"SELECT m.room_id FROM room_memberships as m "
" INNER JOIN current_state_events as c"
" ON m.event_id = c.event_id AND c.state_key = m.user_id"
" WHERE m.user_id = ? AND m.membership = 'join'"
)
current_room_membership_args = [user_id]
if room_ids:
current_room_membership_sql += " AND m.room_id in (%s)" % (
",".join(map(lambda _: "?", room_ids))
)
current_room_membership_args = [user_id] + room_ids
# We also want to get any membership events about that user, e.g. # We also want to get any membership events about that user, e.g.
# invites or leave notifications. # invites or leave notifications.
@ -173,6 +200,7 @@ class StreamStore(SQLBaseStore):
"INNER JOIN current_state_events as c ON m.event_id = c.event_id " "INNER JOIN current_state_events as c ON m.event_id = c.event_id "
"WHERE m.user_id = ? " "WHERE m.user_id = ? "
) )
membership_args = [user_id]
if limit: if limit:
limit = max(limit, MAX_STREAM_SIZE) limit = max(limit, MAX_STREAM_SIZE)
@ -199,7 +227,9 @@ class StreamStore(SQLBaseStore):
} }
def f(txn): def f(txn):
txn.execute(sql, (False, user_id, user_id, from_id.stream, to_id.stream,)) args = ([False] + current_room_membership_args + membership_args +
[from_id.stream, to_id.stream])
txn.execute(sql, args)
rows = self.cursor_to_dict(txn) rows = self.cursor_to_dict(txn)

View file

@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import CodeMessageException from synapse.api.errors import CodeMessageException
import logging import logging
import random
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -85,8 +86,9 @@ def get_retry_limiter(destination, clock, store, **kwargs):
class RetryDestinationLimiter(object): class RetryDestinationLimiter(object):
def __init__(self, destination, clock, store, retry_interval, def __init__(self, destination, clock, store, retry_interval,
min_retry_interval=5000, max_retry_interval=60 * 60 * 1000, min_retry_interval=10 * 60 * 1000,
multiplier_retry_interval=2,): max_retry_interval=24 * 60 * 60 * 1000,
multiplier_retry_interval=5,):
"""Marks the destination as "down" if an exception is thrown in the """Marks the destination as "down" if an exception is thrown in the
context, except for CodeMessageException with code < 500. context, except for CodeMessageException with code < 500.
@ -140,6 +142,7 @@ class RetryDestinationLimiter(object):
# We couldn't connect. # We couldn't connect.
if self.retry_interval: if self.retry_interval:
self.retry_interval *= self.multiplier_retry_interval self.retry_interval *= self.multiplier_retry_interval
self.retry_interval *= int(random.uniform(0.8, 1.4))
if self.retry_interval >= self.max_retry_interval: if self.retry_interval >= self.max_retry_interval:
self.retry_interval = self.max_retry_interval self.retry_interval = self.max_retry_interval

View file

@ -1,69 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import AuthError
INVITE_KEYS = {"id_server", "medium", "address", "display_name"}
JOIN_KEYS = {
"token",
"public_key",
"key_validity_url",
"sender",
"signed",
}
def has_invite_keys(content):
for key in INVITE_KEYS:
if key not in content:
return False
return True
def has_join_keys(content):
for key in JOIN_KEYS:
if key not in content:
return False
return True
def join_has_third_party_invite(content):
if "third_party_invite" not in content:
return False
return has_join_keys(content["third_party_invite"])
def extract_join_keys(src):
return {
key: value
for key, value in src.items()
if key in JOIN_KEYS
}
@defer.inlineCallbacks
def check_key_valid(http_client, event):
try:
response = yield http_client.get_json(
event.content["third_party_invite"]["key_validity_url"],
{"public_key": event.content["third_party_invite"]["public_key"]}
)
except Exception:
raise AuthError(502, "Third party certificate could not be checked")
if "valid" not in response or not response["valid"]:
raise AuthError(403, "Third party certificate was invalid")

View file

@ -650,9 +650,30 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
{"presence": ONLINE} {"presence": ONLINE}
) )
# Apple sees self-reflection even without room_id
(events, _) = yield self.event_source.get_new_events(
user=self.u_apple,
from_key=0,
)
self.assertEquals(self.event_source.get_current_key(), 1)
self.assertEquals(events,
[
{"type": "m.presence",
"content": {
"user_id": "@apple:test",
"presence": ONLINE,
"last_active_ago": 0,
}},
],
msg="Presence event should be visible to self-reflection"
)
# Apple sees self-reflection # Apple sees self-reflection
(events, _) = yield self.event_source.get_new_events_for_user( (events, _) = yield self.event_source.get_new_events(
self.u_apple, 0, None user=self.u_apple,
from_key=0,
room_ids=[self.room_id],
) )
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
@ -684,8 +705,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
) )
# Banana sees it because of presence subscription # Banana sees it because of presence subscription
(events, _) = yield self.event_source.get_new_events_for_user( (events, _) = yield self.event_source.get_new_events(
self.u_banana, 0, None user=self.u_banana,
from_key=0,
room_ids=[self.room_id],
) )
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
@ -702,8 +725,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
) )
# Elderberry sees it because of same room # Elderberry sees it because of same room
(events, _) = yield self.event_source.get_new_events_for_user( (events, _) = yield self.event_source.get_new_events(
self.u_elderberry, 0, None user=self.u_elderberry,
from_key=0,
room_ids=[self.room_id],
) )
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
@ -720,8 +745,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
) )
# Durian is not in the room, should not see this event # Durian is not in the room, should not see this event
(events, _) = yield self.event_source.get_new_events_for_user( (events, _) = yield self.event_source.get_new_events(
self.u_durian, 0, None user=self.u_durian,
from_key=0,
room_ids=[],
) )
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
@ -767,8 +794,9 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
"accepted": True}, "accepted": True},
], presence) ], presence)
(events, _) = yield self.event_source.get_new_events_for_user( (events, _) = yield self.event_source.get_new_events(
self.u_apple, 1, None user=self.u_apple,
from_key=1,
) )
self.assertEquals(self.event_source.get_current_key(), 2) self.assertEquals(self.event_source.get_current_key(), 2)
@ -858,8 +886,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
) )
) )
(events, _) = yield self.event_source.get_new_events_for_user( (events, _) = yield self.event_source.get_new_events(
self.u_apple, 0, None user=self.u_apple,
from_key=0,
room_ids=[self.room_id],
) )
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
@ -905,8 +935,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
(events, _) = yield self.event_source.get_new_events_for_user( (events, _) = yield self.event_source.get_new_events(
self.u_apple, 0, None user=self.u_apple,
from_key=0,
room_ids=[self.room_id,]
) )
self.assertEquals(events, self.assertEquals(events,
[ [
@ -932,8 +964,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
self.assertEquals(self.event_source.get_current_key(), 2) self.assertEquals(self.event_source.get_current_key(), 2)
(events, _) = yield self.event_source.get_new_events_for_user( (events, _) = yield self.event_source.get_new_events(
self.u_apple, 0, None user=self.u_apple,
from_key=0,
room_ids=[self.room_id,]
) )
self.assertEquals(events, self.assertEquals(events,
[ [
@ -966,8 +1000,9 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
self.room_members.append(self.u_clementine) self.room_members.append(self.u_clementine)
(events, _) = yield self.event_source.get_new_events_for_user( (events, _) = yield self.event_source.get_new_events(
self.u_apple, 0, None user=self.u_apple,
from_key=0,
) )
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)

View file

@ -187,7 +187,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
]) ])
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) events = yield self.event_source.get_new_events(
room_ids=[self.room_id],
from_key=0,
)
self.assertEquals( self.assertEquals(
events[0], events[0],
[ [
@ -250,7 +253,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
]) ])
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) events = yield self.event_source.get_new_events(
room_ids=[self.room_id],
from_key=0
)
self.assertEquals( self.assertEquals(
events[0], events[0],
[ [
@ -306,7 +312,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
yield put_json.await_calls() yield put_json.await_calls()
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) events = yield self.event_source.get_new_events(
room_ids=[self.room_id],
from_key=0,
)
self.assertEquals( self.assertEquals(
events[0], events[0],
[ [
@ -337,7 +346,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.on_new_event.reset_mock() self.on_new_event.reset_mock()
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) events = yield self.event_source.get_new_events(
room_ids=[self.room_id],
from_key=0,
)
self.assertEquals( self.assertEquals(
events[0], events[0],
[ [
@ -356,7 +368,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
]) ])
self.assertEquals(self.event_source.get_current_key(), 2) self.assertEquals(self.event_source.get_current_key(), 2)
events = yield self.event_source.get_new_events_for_user(self.u_apple, 1, None) events = yield self.event_source.get_new_events(
room_ids=[self.room_id],
from_key=1,
)
self.assertEquals( self.assertEquals(
events[0], events[0],
[ [
@ -383,7 +398,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.on_new_event.reset_mock() self.on_new_event.reset_mock()
self.assertEquals(self.event_source.get_current_key(), 3) self.assertEquals(self.event_source.get_current_key(), 3)
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) events = yield self.event_source.get_new_events(
room_ids=[self.room_id],
from_key=0,
)
self.assertEquals( self.assertEquals(
events[0], events[0],
[ [

View file

@ -47,7 +47,14 @@ class NullSource(object):
def __init__(self, hs): def __init__(self, hs):
pass pass
def get_new_events_for_user(self, user, from_key, limit): def get_new_events(
self,
user,
from_key,
room_ids=None,
limit=None,
is_guest=None
):
return defer.succeed(([], from_key)) return defer.succeed(([], from_key))
def get_current_key(self, direction='f'): def get_current_key(self, direction='f'):

View file

@ -116,7 +116,10 @@ class RoomTypingTestCase(RestTestCase):
self.assertEquals(200, code) self.assertEquals(200, code)
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
events = yield self.event_source.get_new_events_for_user(self.user, 0, None) events = yield self.event_source.get_new_events(
from_key=0,
room_ids=[self.room_id],
)
self.assertEquals( self.assertEquals(
events[0], events[0],
[ [