Merge branch 'markjh/eventstream_presence' into markjh/v2_sync_api

This commit is contained in:
Mark Haines 2015-10-09 19:18:09 +01:00
commit af7b214476
11 changed files with 274 additions and 125 deletions

View file

@ -20,7 +20,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.types import UserID, EventID from synapse.types import RoomID, UserID, EventID
import logging import logging
import pymacaroons import pymacaroons
@ -80,6 +80,15 @@ class Auth(object):
"Room %r does not exist" % (event.room_id,) "Room %r does not exist" % (event.room_id,)
) )
creating_domain = RoomID.from_string(event.room_id).domain
originating_domain = UserID.from_string(event.sender).domain
if creating_domain != originating_domain:
if not self.can_federate(event, auth_events):
raise AuthError(
403,
"This room has been marked as unfederatable."
)
# FIXME: Temp hack # FIXME: Temp hack
if event.type == EventTypes.Aliases: if event.type == EventTypes.Aliases:
return True return True
@ -219,6 +228,11 @@ class Auth(object):
user_id, room_id, repr(member) user_id, room_id, repr(member)
)) ))
def can_federate(self, event, auth_events):
creation_event = auth_events.get((EventTypes.Create, ""))
return creation_event.content.get("m.federate", True) is True
@log_function @log_function
def is_membership_change_allowed(self, event, auth_events): def is_membership_change_allowed(self, event, auth_events):
membership = event.content["membership"] membership = event.content["membership"]
@ -234,6 +248,15 @@ class Auth(object):
target_user_id = event.state_key target_user_id = event.state_key
creating_domain = RoomID.from_string(event.room_id).domain
target_domain = UserID.from_string(target_user_id).domain
if creating_domain != target_domain:
if not self.can_federate(event, auth_events):
raise AuthError(
403,
"This room has been marked as unfederatable."
)
# get info about the caller # get info about the caller
key = (EventTypes.Member, event.user_id, ) key = (EventTypes.Member, event.user_id, )
caller = auth_events.get(key) caller = auth_events.get(key)

View file

@ -83,3 +83,4 @@ class RejectedReason(object):
class RoomCreationPreset(object): class RoomCreationPreset(object):
PRIVATE_CHAT = "private_chat" PRIVATE_CHAT = "private_chat"
PUBLIC_CHAT = "public_chat" PUBLIC_CHAT = "public_chat"
TRUSTED_PRIVATE_CHAT = "trusted_private_chat"

View file

@ -32,9 +32,9 @@ def start(configfile):
print "Starting ...", print "Starting ...",
args = SYNAPSE args = SYNAPSE
args.extend(["--daemonize", "-c", configfile]) args.extend(["--daemonize", "-c", configfile])
cwd = os.path.dirname(os.path.abspath(__file__))
try: try:
subprocess.check_call(args, cwd=cwd) subprocess.check_call(args)
print GREEN + "started" + NORMAL print GREEN + "started" + NORMAL
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
print ( print (

View file

@ -103,7 +103,10 @@ def format_event_raw(d):
def format_event_for_client_v1(d): def format_event_for_client_v1(d):
d["user_id"] = d.pop("sender", None) d["user_id"] = d.pop("sender", None)
move_keys = ("age", "redacted_because", "replaces_state", "prev_content") move_keys = (
"age", "redacted_because", "replaces_state", "prev_content",
"invite_room_state",
)
for key in move_keys: for key in move_keys:
if key in d["unsigned"]: if key in d["unsigned"]:
d[key] = d["unsigned"][key] d[key] = d["unsigned"][key]

View file

@ -123,24 +123,39 @@ class BaseHandler(object):
) )
) )
(event_stream_id, max_stream_id) = yield self.store.persist_event(
event, context=context
)
federation_handler = self.hs.get_handlers().federation_handler federation_handler = self.hs.get_handlers().federation_handler
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
if event.content["membership"] == Membership.INVITE: if event.content["membership"] == Membership.INVITE:
event.unsigned["invite_room_state"] = [
{
"type": e.type,
"state_key": e.state_key,
"content": e.content,
"sender": e.sender,
}
for k, e in context.current_state.items()
if e.type in (
EventTypes.JoinRules,
EventTypes.CanonicalAlias,
EventTypes.RoomAvatar,
EventTypes.Name,
)
]
invitee = UserID.from_string(event.state_key) invitee = UserID.from_string(event.state_key)
if not self.hs.is_mine(invitee): if not self.hs.is_mine(invitee):
# TODO: Can we add signature from remote server in a nicer # TODO: Can we add signature from remote server in a nicer
# way? If we have been invited by a remote server, we need # way? If we have been invited by a remote server, we need
# to get them to sign the event. # to get them to sign the event.
returned_invite = yield federation_handler.send_invite( returned_invite = yield federation_handler.send_invite(
invitee.domain, invitee.domain,
event, event,
) )
event.unsigned.pop("room_state", None)
# TODO: Make sure the signatures actually are correct. # TODO: Make sure the signatures actually are correct.
event.signatures.update( event.signatures.update(
returned_invite.signatures returned_invite.signatures
@ -161,6 +176,10 @@ class BaseHandler(object):
"You don't have permission to redact events" "You don't have permission to redact events"
) )
(event_stream_id, max_stream_id) = yield self.store.persist_event(
event, context=context
)
destinations = set(extra_destinations) destinations = set(extra_destinations)
for k, s in context.current_state.items(): for k, s in context.current_state.items():
try: try:
@ -189,6 +208,9 @@ class BaseHandler(object):
notify_d.addErrback(log_failure) notify_d.addErrback(log_failure)
# If invite, remove room_state from unsigned before sending.
event.unsigned.pop("invite_room_state", None)
federation_handler.handle_new_event( federation_handler.handle_new_event(
event, destinations=destinations, event, destinations=destinations,
) )

View file

@ -46,6 +46,56 @@ class EventStreamHandler(BaseHandler):
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
@defer.inlineCallbacks
def started_stream(self, user):
"""Tells the presence handler that we have started an eventstream for
the user:
Args:
user (User): The user who started a stream.
Returns:
A deferred that completes once their presence has been updated.
"""
if user not in self._streams_per_user:
self._streams_per_user[user] = 0
if user in self._stop_timer_per_user:
try:
self.clock.cancel_call_later(
self._stop_timer_per_user.pop(user)
)
except:
logger.exception("Failed to cancel event timer")
else:
yield self.distributor.fire("started_user_eventstream", user)
self._streams_per_user[user] += 1
def stopped_stream(self, user):
"""If there are no streams for a user this starts a timer that will
notify the presence handler that we haven't got an event stream for
the user unless the user starts a new stream in 30 seconds.
Args:
user (User): The user who stopped a stream.
"""
self._streams_per_user[user] -= 1
if not self._streams_per_user[user]:
del self._streams_per_user[user]
# 30 seconds of grace to allow the client to reconnect again
# before we think they're gone
def _later():
logger.debug("_later stopped_user_eventstream %s", user)
self._stop_timer_per_user.pop(user, None)
return self.distributor.fire("stopped_user_eventstream", user)
logger.debug("Scheduling _later: for %s", user)
self._stop_timer_per_user[user] = (
self.clock.call_later(30, _later)
)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_stream(self, auth_user_id, pagin_config, timeout=0, def get_stream(self, auth_user_id, pagin_config, timeout=0,
@ -59,20 +109,7 @@ class EventStreamHandler(BaseHandler):
try: try:
if affect_presence: if affect_presence:
if auth_user not in self._streams_per_user: yield self.started_stream(auth_user)
self._streams_per_user[auth_user] = 0
if auth_user in self._stop_timer_per_user:
try:
self.clock.cancel_call_later(
self._stop_timer_per_user.pop(auth_user)
)
except:
logger.exception("Failed to cancel event timer")
else:
yield self.distributor.fire(
"started_user_eventstream", auth_user
)
self._streams_per_user[auth_user] += 1
rm_handler = self.hs.get_handlers().room_member_handler rm_handler = self.hs.get_handlers().room_member_handler
@ -114,27 +151,7 @@ class EventStreamHandler(BaseHandler):
finally: finally:
if affect_presence: if affect_presence:
self._streams_per_user[auth_user] -= 1 self.stopped_stream(auth_user)
if not self._streams_per_user[auth_user]:
del self._streams_per_user[auth_user]
# 10 seconds of grace to allow the client to reconnect again
# before we think they're gone
def _later():
logger.debug(
"_later stopped_user_eventstream %s", auth_user
)
self._stop_timer_per_user.pop(auth_user, None)
return self.distributor.fire(
"stopped_user_eventstream", auth_user
)
logger.debug("Scheduling _later: for %s", auth_user)
self._stop_timer_per_user[auth_user] = (
self.clock.call_later(30, _later)
)
class EventHandler(BaseHandler): class EventHandler(BaseHandler):

View file

@ -125,8 +125,20 @@ class FederationHandler(BaseHandler):
) )
if not is_in_room and not event.internal_metadata.is_outlier(): if not is_in_room and not event.internal_metadata.is_outlier():
logger.debug("Got event for room we're not in.") logger.debug("Got event for room we're not in.")
current_state = state
try:
event_stream_id, max_stream_id = yield self._persist_auth_tree(
auth_chain, state, event
)
except AuthError as e:
raise FederationError(
"ERROR",
e.code,
e.msg,
affected=event.event_id,
)
else:
event_ids = set() event_ids = set()
if state: if state:
event_ids |= {e.event_id for e in state} event_ids |= {e.event_id for e in state}
@ -150,7 +162,7 @@ class FederationHandler(BaseHandler):
auth_ids = [e_id for e_id, _ in e.auth_events] auth_ids = [e_id for e_id, _ in e.auth_events]
auth = { auth = {
(e.type, e.state_key): e for e in auth_chain (e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids if e.event_id in auth_ids or e.type == EventTypes.Create
} }
event_infos.append({ event_infos.append({
"event": e, "event": e,
@ -649,35 +661,8 @@ class FederationHandler(BaseHandler):
# FIXME # FIXME
pass pass
ev_infos = [] event_stream_id, max_stream_id = yield self._persist_auth_tree(
for e in itertools.chain(state, auth_chain): auth_chain, state, event
if e.event_id == event.event_id:
continue
e.internal_metadata.outlier = True
auth_ids = [e_id for e_id, _ in e.auth_events]
ev_infos.append({
"event": e,
"auth_events": {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids
}
})
yield self._handle_new_events(origin, ev_infos, outliers=True)
auth_ids = [e_id for e_id, _ in event.auth_events]
auth_events = {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids
}
_, event_stream_id, max_stream_id = yield self._handle_new_event(
origin,
new_event,
state=state,
current_state=state,
auth_events=auth_events,
) )
with PreserveLoggingContext(): with PreserveLoggingContext():
@ -1026,6 +1011,76 @@ class FederationHandler(BaseHandler):
is_new_state=(not outliers and not backfilled), is_new_state=(not outliers and not backfilled),
) )
@defer.inlineCallbacks
def _persist_auth_tree(self, auth_events, state, event):
"""Checks the auth chain is valid (and passes auth checks) for the
state and event. Then persists the auth chain and state atomically.
Persists the event seperately.
Returns:
2-tuple of (event_stream_id, max_stream_id) from the persist_event
call for `event`
"""
events_to_context = {}
for e in itertools.chain(auth_events, state):
ctx = yield self.state_handler.compute_event_context(
e, outlier=True,
)
events_to_context[e.event_id] = ctx
e.internal_metadata.outlier = True
event_map = {
e.event_id: e
for e in auth_events
}
create_event = None
for e in auth_events:
if (e.type, e.state_key) == (EventTypes.Create, ""):
create_event = e
break
for e in itertools.chain(auth_events, state, [event]):
auth_for_e = {
(event_map[e_id].type, event_map[e_id].state_key): event_map[e_id]
for e_id, _ in e.auth_events
}
if create_event:
auth_for_e[(EventTypes.Create, "")] = create_event
try:
self.auth.check(e, auth_events=auth_for_e)
except AuthError as err:
logger.warn(
"Rejecting %s because %s",
e.event_id, err.msg
)
if e == event:
raise
events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
yield self.store.persist_events(
[
(e, events_to_context[e.event_id])
for e in itertools.chain(auth_events, state)
],
is_new_state=False,
)
new_event_context = yield self.state_handler.compute_event_context(
event, old_state=state, outlier=False,
)
event_stream_id, max_stream_id = yield self.store.persist_event(
event, new_event_context,
backfilled=False,
is_new_state=True,
current_state=state,
)
defer.returnValue((event_stream_id, max_stream_id))
@defer.inlineCallbacks @defer.inlineCallbacks
def _prep_event(self, origin, event, state=None, backfilled=False, def _prep_event(self, origin, event, state=None, backfilled=False,
current_state=None, auth_events=None): current_state=None, auth_events=None):
@ -1166,7 +1221,7 @@ class FederationHandler(BaseHandler):
auth_ids = [e_id for e_id, _ in e.auth_events] auth_ids = [e_id for e_id, _ in e.auth_events]
auth = { auth = {
(e.type, e.state_key): e for e in remote_auth_chain (e.type, e.state_key): e for e in remote_auth_chain
if e.event_id in auth_ids if e.event_id in auth_ids or e.type == EventTypes.Create
} }
e.internal_metadata.outlier = True e.internal_metadata.outlier = True
@ -1284,6 +1339,7 @@ class FederationHandler(BaseHandler):
(e.type, e.state_key): e (e.type, e.state_key): e
for e in result["auth_chain"] for e in result["auth_chain"]
if e.event_id in auth_ids if e.event_id in auth_ids
or event.type == EventTypes.Create
} }
ev.internal_metadata.outlier = True ev.internal_metadata.outlier = True

View file

@ -383,8 +383,12 @@ class MessageHandler(BaseHandler):
} }
if event.membership == Membership.INVITE: if event.membership == Membership.INVITE:
time_now = self.clock.time_msec()
d["inviter"] = event.sender d["inviter"] = event.sender
invite_event = yield self.store.get_event(event.event_id)
d["invite"] = serialize_event(invite_event, time_now, as_client_event)
rooms_ret.append(d) rooms_ret.append(d)
if event.membership not in (Membership.JOIN, Membership.LEAVE): if event.membership not in (Membership.JOIN, Membership.LEAVE):

View file

@ -41,6 +41,11 @@ class RoomCreationHandler(BaseHandler):
"history_visibility": "shared", "history_visibility": "shared",
"original_invitees_have_ops": False, "original_invitees_have_ops": False,
}, },
RoomCreationPreset.TRUSTED_PRIVATE_CHAT: {
"join_rules": JoinRules.INVITE,
"history_visibility": "shared",
"original_invitees_have_ops": True,
},
RoomCreationPreset.PUBLIC_CHAT: { RoomCreationPreset.PUBLIC_CHAT: {
"join_rules": JoinRules.PUBLIC, "join_rules": JoinRules.PUBLIC,
"history_visibility": "shared", "history_visibility": "shared",
@ -149,12 +154,16 @@ class RoomCreationHandler(BaseHandler):
for val in raw_initial_state: for val in raw_initial_state:
initial_state[(val["type"], val.get("state_key", ""))] = val["content"] initial_state[(val["type"], val.get("state_key", ""))] = val["content"]
creation_content = config.get("creation_content", {})
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
creation_events = self._create_events_for_new_room( creation_events = self._create_events_for_new_room(
user, room_id, user, room_id,
preset_config=preset_config, preset_config=preset_config,
invite_list=invite_list, invite_list=invite_list,
initial_state=initial_state, initial_state=initial_state,
creation_content=creation_content,
room_alias=room_alias,
) )
msg_handler = self.hs.get_handlers().message_handler msg_handler = self.hs.get_handlers().message_handler
@ -202,7 +211,8 @@ class RoomCreationHandler(BaseHandler):
defer.returnValue(result) defer.returnValue(result)
def _create_events_for_new_room(self, creator, room_id, preset_config, def _create_events_for_new_room(self, creator, room_id, preset_config,
invite_list, initial_state): invite_list, initial_state, creation_content,
room_alias):
config = RoomCreationHandler.PRESETS_DICT[preset_config] config = RoomCreationHandler.PRESETS_DICT[preset_config]
creator_id = creator.to_string() creator_id = creator.to_string()
@ -224,9 +234,10 @@ class RoomCreationHandler(BaseHandler):
return e return e
creation_content.update({"creator": creator.to_string()})
creation_event = create( creation_event = create(
etype=EventTypes.Create, etype=EventTypes.Create,
content={"creator": creator.to_string()}, content=creation_content,
) )
join_event = create( join_event = create(
@ -271,6 +282,14 @@ class RoomCreationHandler(BaseHandler):
returned_events.append(power_levels_event) returned_events.append(power_levels_event)
if room_alias and (EventTypes.CanonicalAlias, '') not in initial_state:
room_alias_event = create(
etype=EventTypes.CanonicalAlias,
content={"alias": room_alias.to_string()},
)
returned_events.append(room_alias_event)
if (EventTypes.JoinRules, '') not in initial_state: if (EventTypes.JoinRules, '') not in initial_state:
join_rules_event = create( join_rules_event = create(
etype=EventTypes.JoinRules, etype=EventTypes.JoinRules,

View file

@ -15,6 +15,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from ._base import client_v2_pattern from ._base import client_v2_pattern
@ -41,6 +42,9 @@ class ReceiptRestServlet(RestServlet):
def on_POST(self, request, room_id, receipt_type, event_id): def on_POST(self, request, room_id, receipt_type, event_id):
user, _ = yield self.auth.get_user_by_req(request) user, _ = yield self.auth.get_user_by_req(request)
if receipt_type != "m.read":
raise SynapseError(400, "Receipt type must be 'm.read'")
yield self.receipts_handler.received_client_receipt( yield self.receipts_handler.received_client_receipt(
room_id, room_id,
receipt_type, receipt_type,

View file

@ -35,7 +35,7 @@ def create_event(name=None, type=None, state_key=None, depth=2, event_id=None,
if not event_id: if not event_id:
_next_event_id += 1 _next_event_id += 1
event_id = str(_next_event_id) event_id = "$%s:test" % (_next_event_id,)
if not name: if not name:
if state_key is not None: if state_key is not None: