0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-17 04:33:50 +01:00

Merge branch 'develop' into rav/guest_access_after_room_join

This commit is contained in:
Richard van der Hoff 2016-02-19 12:00:16 +00:00
commit 05aee12652
46 changed files with 2059 additions and 3732 deletions

View file

@ -32,7 +32,6 @@ class PresenceState(object):
OFFLINE = u"offline" OFFLINE = u"offline"
UNAVAILABLE = u"unavailable" UNAVAILABLE = u"unavailable"
ONLINE = u"online" ONLINE = u"online"
FREE_FOR_CHAT = u"free_for_chat"
class JoinRules(object): class JoinRules(object):

View file

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import LimitExceededError, SynapseError, AuthError from synapse.api.errors import LimitExceededError, SynapseError, AuthError
from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.api.constants import Membership, EventTypes from synapse.api.constants import Membership, EventTypes
from synapse.types import UserID, RoomAlias from synapse.types import UserID, RoomAlias, Requester
from synapse.push.action_generator import ActionGenerator from synapse.push.action_generator import ActionGenerator
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
@ -177,7 +177,7 @@ class BaseHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_new_client_event(self, builder): def _create_new_client_event(self, builder):
latest_ret = yield self.store.get_latest_events_in_room( latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room(
builder.room_id, builder.room_id,
) )
@ -186,7 +186,10 @@ class BaseHandler(object):
else: else:
depth = 1 depth = 1
prev_events = [(e, h) for e, h, _ in latest_ret] prev_events = [
(event_id, prev_hashes)
for event_id, prev_hashes, _ in latest_ret
]
builder.prev_events = prev_events builder.prev_events = prev_events
builder.depth = depth builder.depth = depth
@ -195,6 +198,31 @@ class BaseHandler(object):
context = yield state_handler.compute_event_context(builder) context = yield state_handler.compute_event_context(builder)
# If we've received an invite over federation, there are no latest
# events in the room, because we don't know enough about the graph
# fragment we received to treat it like a graph, so the above returned
# no relevant events. It may have returned some events (if we have
# joined and left the room), but not useful ones, like the invite. So we
# forcibly set our context to the invite we received over federation.
if (
not self.is_host_in_room(context.current_state) and
builder.type == EventTypes.Member
):
prev_member_event = yield self.store.get_room_member(
builder.sender, builder.room_id
)
if prev_member_event:
builder.prev_events = (
prev_member_event.event_id,
prev_member_event.prev_events
)
context = yield state_handler.compute_event_context(
builder,
old_state=(prev_member_event,),
outlier=True
)
if builder.is_state(): if builder.is_state():
builder.prev_state = yield self.store.add_event_hashes( builder.prev_state = yield self.store.add_event_hashes(
context.prev_state_events context.prev_state_events
@ -217,10 +245,33 @@ class BaseHandler(object):
(event, context,) (event, context,)
) )
def is_host_in_room(self, current_state):
room_members = [
(state_key, event.membership)
for ((event_type, state_key), event) in current_state.items()
if event_type == EventTypes.Member
]
if len(room_members) == 0:
# Have we just created the room, and is this about to be the very
# first member event?
create_event = current_state.get(("m.room.create", ""))
if create_event:
return True
for (state_key, membership) in room_members:
if (
UserID.from_string(state_key).domain == self.hs.hostname
and membership == Membership.JOIN
):
return True
return False
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_new_client_event(self, event, context, extra_users=[]): def handle_new_client_event(self, event, context, ratelimit=True, extra_users=[]):
# We now need to go and hit out to wherever we need to hit out to. # We now need to go and hit out to wherever we need to hit out to.
if ratelimit:
self.ratelimit(event.sender)
self.auth.check(event, auth_events=context.current_state) self.auth.check(event, auth_events=context.current_state)
yield self.maybe_kick_guest_users(event, context.current_state.values()) yield self.maybe_kick_guest_users(event, context.current_state.values())
@ -346,7 +397,8 @@ class BaseHandler(object):
if member_event.type != EventTypes.Member: if member_event.type != EventTypes.Member:
continue continue
if not self.hs.is_mine(UserID.from_string(member_event.state_key)): target_user = UserID.from_string(member_event.state_key)
if not self.hs.is_mine(target_user):
continue continue
if member_event.content["membership"] not in { if member_event.content["membership"] not in {
@ -368,18 +420,13 @@ class BaseHandler(object):
# and having homeservers have their own users leave keeps more # and having homeservers have their own users leave keeps more
# of that decision-making and control local to the guest-having # of that decision-making and control local to the guest-having
# homeserver. # homeserver.
message_handler = self.hs.get_handlers().message_handler requester = Requester(target_user, "", True)
yield message_handler.create_and_send_event( handler = self.hs.get_handlers().room_member_handler
{ yield handler.update_membership(
"type": EventTypes.Member, requester,
"state_key": member_event.state_key, target_user,
"content": { member_event.room_id,
"membership": Membership.LEAVE, "leave",
"kind": "guest"
},
"room_id": member_event.room_id,
"sender": member_event.state_key
},
ratelimit=False, ratelimit=False,
) )
except Exception as e: except Exception as e:

View file

@ -216,7 +216,7 @@ class DirectoryHandler(BaseHandler):
aliases = yield self.store.get_aliases_for_room(room_id) aliases = yield self.store.get_aliases_for_room(room_id)
msg_handler = self.hs.get_handlers().message_handler msg_handler = self.hs.get_handlers().message_handler
yield msg_handler.create_and_send_event({ yield msg_handler.create_and_send_nonmember_event({
"type": EventTypes.Aliases, "type": EventTypes.Aliases,
"state_key": self.hs.hostname, "state_key": self.hs.hostname,
"room_id": room_id, "room_id": room_id,

View file

@ -19,6 +19,8 @@ from synapse.util.logutils import log_function
from synapse.types import UserID from synapse.types import UserID
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.util.logcontext import preserve_context_over_fn from synapse.util.logcontext import preserve_context_over_fn
from synapse.api.constants import Membership, EventTypes
from synapse.events import EventBase
from ._base import BaseHandler from ._base import BaseHandler
@ -126,11 +128,12 @@ class EventStreamHandler(BaseHandler):
If `only_keys` is not None, events from keys will be sent down. If `only_keys` is not None, events from keys will be sent down.
""" """
auth_user = UserID.from_string(auth_user_id) auth_user = UserID.from_string(auth_user_id)
presence_handler = self.hs.get_handlers().presence_handler
try: context = yield presence_handler.user_syncing(
if affect_presence: auth_user_id, affect_presence=affect_presence,
yield self.started_stream(auth_user) )
with context:
if timeout: if timeout:
# If they've set a timeout set a minimum limit. # If they've set a timeout set a minimum limit.
timeout = max(timeout, 500) timeout = max(timeout, 500)
@ -145,6 +148,34 @@ class EventStreamHandler(BaseHandler):
is_guest=is_guest, explicit_room_id=room_id is_guest=is_guest, explicit_room_id=room_id
) )
# When the user joins a new room, or another user joins a currently
# joined room, we need to send down presence for those users.
to_add = []
for event in events:
if not isinstance(event, EventBase):
continue
if event.type == EventTypes.Member:
if event.membership != Membership.JOIN:
continue
# Send down presence.
if event.state_key == auth_user_id:
# Send down presence for everyone in the room.
users = yield self.store.get_users_in_room(event.room_id)
states = yield presence_handler.get_states(
users,
as_event=True,
)
to_add.extend(states)
else:
ev = yield presence_handler.get_state(
UserID.from_string(event.state_key),
as_event=True,
)
to_add.append(ev)
events.extend(to_add)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
chunks = [ chunks = [
@ -159,10 +190,6 @@ class EventStreamHandler(BaseHandler):
defer.returnValue(chunk) defer.returnValue(chunk)
finally:
if affect_presence:
self.stopped_stream(auth_user)
class EventHandler(BaseHandler): class EventHandler(BaseHandler):

View file

@ -1658,7 +1658,7 @@ class FederationHandler(BaseHandler):
self.auth.check(event, context.current_state) self.auth.check(event, context.current_state)
yield self._validate_keyserver(event, auth_events=context.current_state) yield self._validate_keyserver(event, auth_events=context.current_state)
member_handler = self.hs.get_handlers().room_member_handler member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.send_membership_event(event, context) yield member_handler.send_membership_event(event, context, from_client=False)
else: else:
destinations = set([x.split(":", 1)[-1] for x in (sender, room_id)]) destinations = set([x.split(":", 1)[-1] for x in (sender, room_id)])
yield self.replication_layer.forward_third_party_invite( yield self.replication_layer.forward_third_party_invite(
@ -1687,7 +1687,7 @@ class FederationHandler(BaseHandler):
# TODO: Make sure the signatures actually are correct. # TODO: Make sure the signatures actually are correct.
event.signatures.update(returned_invite.signatures) event.signatures.update(returned_invite.signatures)
member_handler = self.hs.get_handlers().room_member_handler member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.send_membership_event(event, context) yield member_handler.send_membership_event(event, context, from_client=False)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_display_name_to_third_party_invite(self, event_dict, event, context): def add_display_name_to_third_party_invite(self, event_dict, event, context):

View file

@ -16,12 +16,11 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.caches.snapshot_cache import SnapshotCache from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.types import UserID, RoomStreamToken, StreamToken from synapse.types import UserID, RoomStreamToken, StreamToken
@ -216,7 +215,7 @@ class MessageHandler(BaseHandler):
defer.returnValue((event, context)) defer.returnValue((event, context))
@defer.inlineCallbacks @defer.inlineCallbacks
def send_event(self, event, context, ratelimit=True, is_guest=False): def send_nonmember_event(self, event, context, ratelimit=True):
""" """
Persists and notifies local clients and federation of an event. Persists and notifies local clients and federation of an event.
@ -226,55 +225,68 @@ class MessageHandler(BaseHandler):
ratelimit (bool): Whether to rate limit this send. ratelimit (bool): Whether to rate limit this send.
is_guest (bool): Whether the sender is a guest. is_guest (bool): Whether the sender is a guest.
""" """
if event.type == EventTypes.Member:
raise SynapseError(
500,
"Tried to send member event through non-member codepath"
)
user = UserID.from_string(event.sender) user = UserID.from_string(event.sender)
assert self.hs.is_mine(user), "User must be our own: %s" % (user,) assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if ratelimit:
self.ratelimit(event.sender)
if event.is_state(): if event.is_state():
prev_state = context.current_state.get((event.type, event.state_key)) prev_state = self.deduplicate_state_event(event, context)
if prev_state and event.user_id == prev_state.user_id: if prev_state is not None:
prev_content = encode_canonical_json(prev_state.content) defer.returnValue(prev_state)
next_content = encode_canonical_json(event.content)
if prev_content == next_content:
# Duplicate suppression for state updates with same sender
# and content.
defer.returnValue(prev_state)
if event.type == EventTypes.Member: yield self.handle_new_client_event(
member_handler = self.hs.get_handlers().room_member_handler event=event,
yield member_handler.send_membership_event(event, context, is_guest=is_guest) context=context,
else: ratelimit=ratelimit,
yield self.handle_new_client_event( )
event=event,
context=context,
)
if event.type == EventTypes.Message: if event.type == EventTypes.Message:
presence = self.hs.get_handlers().presence_handler presence = self.hs.get_handlers().presence_handler
with PreserveLoggingContext(): yield presence.bump_presence_active_time(user)
presence.bump_presence_active_time(user)
def deduplicate_state_event(self, event, context):
"""
Checks whether event is in the latest resolved state in context.
If so, returns the version of the event in context.
Otherwise, returns None.
"""
prev_event = context.current_state.get((event.type, event.state_key))
if prev_event and event.user_id == prev_event.user_id:
prev_content = encode_canonical_json(prev_event.content)
next_content = encode_canonical_json(event.content)
if prev_content == next_content:
return prev_event
return None
@defer.inlineCallbacks @defer.inlineCallbacks
def create_and_send_event(self, event_dict, ratelimit=True, def create_and_send_nonmember_event(
token_id=None, txn_id=None, is_guest=False): self,
event_dict,
ratelimit=True,
token_id=None,
txn_id=None
):
""" """
Creates an event, then sends it. Creates an event, then sends it.
See self.create_event and self.send_event. See self.create_event and self.send_nonmember_event.
""" """
event, context = yield self.create_event( event, context = yield self.create_event(
event_dict, event_dict,
token_id=token_id, token_id=token_id,
txn_id=txn_id txn_id=txn_id
) )
yield self.send_event( yield self.send_nonmember_event(
event, event,
context, context,
ratelimit=ratelimit, ratelimit=ratelimit,
is_guest=is_guest
) )
defer.returnValue(event) defer.returnValue(event)
@ -660,10 +672,6 @@ class MessageHandler(BaseHandler):
room_id=room_id, room_id=room_id,
) )
# TODO(paul): I wish I was called with user objects not user_id
# strings...
auth_user = UserID.from_string(user_id)
# TODO: These concurrently # TODO: These concurrently
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
state = [ state = [
@ -688,13 +696,11 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_presence(): def get_presence():
states = yield presence_handler.get_states( states = yield presence_handler.get_states(
target_users=[UserID.from_string(m.user_id) for m in room_members], [m.user_id for m in room_members],
auth_user=auth_user,
as_event=True, as_event=True,
check_auth=False,
) )
defer.returnValue(states.values()) defer.returnValue(states)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_receipts(): def get_receipts():

File diff suppressed because it is too large Load diff

View file

@ -16,8 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError, CodeMessageException from synapse.api.errors import SynapseError, AuthError, CodeMessageException
from synapse.api.constants import EventTypes, Membership from synapse.types import UserID, Requester
from synapse.types import UserID
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from ._base import BaseHandler from ._base import BaseHandler
@ -49,6 +48,9 @@ class ProfileHandler(BaseHandler):
distributor = hs.get_distributor() distributor = hs.get_distributor()
self.distributor = distributor self.distributor = distributor
distributor.declare("collect_presencelike_data")
distributor.declare("changed_presencelike_data")
distributor.observe("registered_user", self.registered_user) distributor.observe("registered_user", self.registered_user)
distributor.observe( distributor.observe(
@ -208,21 +210,18 @@ class ProfileHandler(BaseHandler):
) )
for j in joins: for j in joins:
content = { handler = self.hs.get_handlers().room_member_handler
"membership": Membership.JOIN,
}
yield collect_presencelike_data(self.distributor, user, content)
msg_handler = self.hs.get_handlers().message_handler
try: try:
yield msg_handler.create_and_send_event({ # Assume the user isn't a guest because we don't let guests set
"type": EventTypes.Member, # profile or avatar data.
"room_id": j.room_id, requester = Requester(user, "", False)
"state_key": user.to_string(), yield handler.update_membership(
"content": content, requester,
"sender": user.to_string() user,
}, ratelimit=False) j.room_id,
"join", # We treat a profile update like a join.
ratelimit=False, # Try to hide that these events aren't atomic.
)
except Exception as e: except Exception as e:
logger.warn( logger.warn(
"Failed to update join event for room %s - %s", "Failed to update join event for room %s - %s",

View file

@ -24,7 +24,6 @@ from synapse.api.constants import (
) )
from synapse.api.errors import AuthError, StoreError, SynapseError, Codes from synapse.api.errors import AuthError, StoreError, SynapseError, Codes
from synapse.util import stringutils, unwrapFirstError from synapse.util import stringutils, unwrapFirstError
from synapse.util.async import run_on_reactor
from synapse.util.logcontext import preserve_context_over_fn from synapse.util.logcontext import preserve_context_over_fn
from signedjson.sign import verify_signed_json from signedjson.sign import verify_signed_json
@ -42,10 +41,6 @@ logger = logging.getLogger(__name__)
id_server_scheme = "https://" id_server_scheme = "https://"
def collect_presencelike_data(distributor, user, content):
return distributor.fire("collect_presencelike_data", user, content)
def user_left_room(distributor, user, room_id): def user_left_room(distributor, user, room_id):
return preserve_context_over_fn( return preserve_context_over_fn(
distributor.fire, distributor.fire,
@ -173,9 +168,14 @@ class RoomCreationHandler(BaseHandler):
creation_content = config.get("creation_content", {}) creation_content = config.get("creation_content", {})
user = UserID.from_string(user_id) msg_handler = self.hs.get_handlers().message_handler
creation_events = self._create_events_for_new_room( room_member_handler = self.hs.get_handlers().room_member_handler
user, room_id,
yield self._send_events_for_new_room(
requester,
room_id,
msg_handler,
room_member_handler,
preset_config=preset_config, preset_config=preset_config,
invite_list=invite_list, invite_list=invite_list,
initial_state=initial_state, initial_state=initial_state,
@ -183,14 +183,9 @@ class RoomCreationHandler(BaseHandler):
room_alias=room_alias, room_alias=room_alias,
) )
msg_handler = self.hs.get_handlers().message_handler
for event in creation_events:
yield msg_handler.create_and_send_event(event, ratelimit=False)
if "name" in config: if "name" in config:
name = config["name"] name = config["name"]
yield msg_handler.create_and_send_event({ yield msg_handler.create_and_send_nonmember_event({
"type": EventTypes.Name, "type": EventTypes.Name,
"room_id": room_id, "room_id": room_id,
"sender": user_id, "sender": user_id,
@ -200,7 +195,7 @@ class RoomCreationHandler(BaseHandler):
if "topic" in config: if "topic" in config:
topic = config["topic"] topic = config["topic"]
yield msg_handler.create_and_send_event({ yield msg_handler.create_and_send_nonmember_event({
"type": EventTypes.Topic, "type": EventTypes.Topic,
"room_id": room_id, "room_id": room_id,
"sender": user_id, "sender": user_id,
@ -209,13 +204,13 @@ class RoomCreationHandler(BaseHandler):
}, ratelimit=False) }, ratelimit=False)
for invitee in invite_list: for invitee in invite_list:
yield msg_handler.create_and_send_event({ room_member_handler.update_membership(
"type": EventTypes.Member, requester,
"state_key": invitee, UserID.from_string(invitee),
"room_id": room_id, room_id,
"sender": user_id, "invite",
"content": {"membership": Membership.INVITE}, ratelimit=False,
}, ratelimit=False) )
for invite_3pid in invite_3pid_list: for invite_3pid in invite_3pid_list:
id_server = invite_3pid["id_server"] id_server = invite_3pid["id_server"]
@ -223,11 +218,11 @@ class RoomCreationHandler(BaseHandler):
medium = invite_3pid["medium"] medium = invite_3pid["medium"]
yield self.hs.get_handlers().room_member_handler.do_3pid_invite( yield self.hs.get_handlers().room_member_handler.do_3pid_invite(
room_id, room_id,
user, requester.user,
medium, medium,
address, address,
id_server, id_server,
token_id=None, requester,
txn_id=None, txn_id=None,
) )
@ -241,19 +236,19 @@ class RoomCreationHandler(BaseHandler):
defer.returnValue(result) defer.returnValue(result)
def _create_events_for_new_room(self, creator, room_id, preset_config, @defer.inlineCallbacks
invite_list, initial_state, creation_content, def _send_events_for_new_room(
room_alias): self,
config = RoomCreationHandler.PRESETS_DICT[preset_config] creator, # A Requester object.
room_id,
creator_id = creator.to_string() msg_handler,
room_member_handler,
event_keys = { preset_config,
"room_id": room_id, invite_list,
"sender": creator_id, initial_state,
"state_key": "", creation_content,
} room_alias
):
def create(etype, content, **kwargs): def create(etype, content, **kwargs):
e = { e = {
"type": etype, "type": etype,
@ -265,26 +260,39 @@ class RoomCreationHandler(BaseHandler):
return e return e
creation_content.update({"creator": creator.to_string()}) @defer.inlineCallbacks
creation_event = create( def send(etype, content, **kwargs):
event = create(etype, content, **kwargs)
yield msg_handler.create_and_send_nonmember_event(event, ratelimit=False)
config = RoomCreationHandler.PRESETS_DICT[preset_config]
creator_id = creator.user.to_string()
event_keys = {
"room_id": room_id,
"sender": creator_id,
"state_key": "",
}
creation_content.update({"creator": creator_id})
yield send(
etype=EventTypes.Create, etype=EventTypes.Create,
content=creation_content, content=creation_content,
) )
join_event = create( yield room_member_handler.update_membership(
etype=EventTypes.Member, creator,
state_key=creator_id, creator.user,
content={ room_id,
"membership": Membership.JOIN, "join",
}, ratelimit=False,
) )
returned_events = [creation_event, join_event]
if (EventTypes.PowerLevels, '') not in initial_state: if (EventTypes.PowerLevels, '') not in initial_state:
power_level_content = { power_level_content = {
"users": { "users": {
creator.to_string(): 100, creator_id: 100,
}, },
"users_default": 0, "users_default": 0,
"events": { "events": {
@ -306,45 +314,35 @@ class RoomCreationHandler(BaseHandler):
for invitee in invite_list: for invitee in invite_list:
power_level_content["users"][invitee] = 100 power_level_content["users"][invitee] = 100
power_levels_event = create( yield send(
etype=EventTypes.PowerLevels, etype=EventTypes.PowerLevels,
content=power_level_content, content=power_level_content,
) )
returned_events.append(power_levels_event)
if room_alias and (EventTypes.CanonicalAlias, '') not in initial_state: if room_alias and (EventTypes.CanonicalAlias, '') not in initial_state:
room_alias_event = create( yield send(
etype=EventTypes.CanonicalAlias, etype=EventTypes.CanonicalAlias,
content={"alias": room_alias.to_string()}, content={"alias": room_alias.to_string()},
) )
returned_events.append(room_alias_event)
if (EventTypes.JoinRules, '') not in initial_state: if (EventTypes.JoinRules, '') not in initial_state:
join_rules_event = create( yield send(
etype=EventTypes.JoinRules, etype=EventTypes.JoinRules,
content={"join_rule": config["join_rules"]}, content={"join_rule": config["join_rules"]},
) )
returned_events.append(join_rules_event)
if (EventTypes.RoomHistoryVisibility, '') not in initial_state: if (EventTypes.RoomHistoryVisibility, '') not in initial_state:
history_event = create( yield send(
etype=EventTypes.RoomHistoryVisibility, etype=EventTypes.RoomHistoryVisibility,
content={"history_visibility": config["history_visibility"]} content={"history_visibility": config["history_visibility"]}
) )
returned_events.append(history_event)
for (etype, state_key), content in initial_state.items(): for (etype, state_key), content in initial_state.items():
returned_events.append(create( yield send(
etype=etype, etype=etype,
state_key=state_key, state_key=state_key,
content=content, content=content,
)) )
return returned_events
class RoomMemberHandler(BaseHandler): class RoomMemberHandler(BaseHandler):
@ -392,7 +390,16 @@ class RoomMemberHandler(BaseHandler):
remotedomains.add(member.domain) remotedomains.add(member.domain)
@defer.inlineCallbacks @defer.inlineCallbacks
def update_membership(self, requester, target, room_id, action, txn_id=None): def update_membership(
self,
requester,
target,
room_id,
action,
txn_id=None,
remote_room_hosts=None,
ratelimit=True,
):
effective_membership_state = action effective_membership_state = action
if action in ["kick", "unban"]: if action in ["kick", "unban"]:
effective_membership_state = "leave" effective_membership_state = "leave"
@ -401,7 +408,7 @@ class RoomMemberHandler(BaseHandler):
msg_handler = self.hs.get_handlers().message_handler msg_handler = self.hs.get_handlers().message_handler
content = {"membership": unicode(effective_membership_state)} content = {"membership": effective_membership_state}
if requester.is_guest: if requester.is_guest:
content["kind"] = "guest" content["kind"] = "guest"
@ -412,6 +419,9 @@ class RoomMemberHandler(BaseHandler):
"room_id": room_id, "room_id": room_id,
"sender": requester.user.to_string(), "sender": requester.user.to_string(),
"state_key": target.to_string(), "state_key": target.to_string(),
# For backwards compatibility:
"membership": effective_membership_state,
}, },
token_id=requester.access_token_id, token_id=requester.access_token_id,
txn_id=txn_id, txn_id=txn_id,
@ -432,90 +442,165 @@ class RoomMemberHandler(BaseHandler):
errcode=Codes.BAD_STATE errcode=Codes.BAD_STATE
) )
yield msg_handler.send_event( member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.send_membership_event(
event, event,
context, context,
ratelimit=True, is_guest=requester.is_guest,
is_guest=requester.is_guest ratelimit=ratelimit,
remote_room_hosts=remote_room_hosts,
from_client=True,
) )
if action == "forget": if action == "forget":
yield self.forget(requester.user, room_id) yield self.forget(requester.user, room_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def send_membership_event(self, event, context, is_guest=False): def send_membership_event(
""" Change the membership status of a user in a room. self,
event,
context,
is_guest=False,
remote_room_hosts=None,
ratelimit=True,
from_client=True,
):
"""
Change the membership status of a user in a room.
Args: Args:
event (SynapseEvent): The membership event event (SynapseEvent): The membership event.
context: The context of the event.
is_guest (bool): Whether the sender is a guest.
room_hosts ([str]): Homeservers which are likely to already be in
the room, and could be danced with in order to join this
homeserver for the first time.
ratelimit (bool): Whether to rate limit this request.
from_client (bool): Whether this request is the result of a local
client request (rather than over federation). If so, we will
perform extra checks, like that this homeserver can act as this
client.
Raises: Raises:
SynapseError if there was a problem changing the membership. SynapseError if there was a problem changing the membership.
""" """
target_user_id = event.state_key target_user = UserID.from_string(event.state_key)
room_id = event.room_id
prev_state = context.current_state.get( if from_client:
(EventTypes.Member, target_user_id), sender = UserID.from_string(event.sender)
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
message_handler = self.hs.get_handlers().message_handler
prev_event = message_handler.deduplicate_state_event(event, context)
if prev_event is not None:
return
action = "send"
if event.membership == Membership.JOIN:
if is_guest and not self._can_guest_join(context.current_state):
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
do_remote_join_dance, remote_room_hosts = self._should_do_dance(
context,
(self.get_inviter(event.state_key, context.current_state)),
remote_room_hosts,
)
if do_remote_join_dance:
action = "remote_join"
elif event.membership == Membership.LEAVE:
is_host_in_room = self.is_host_in_room(context.current_state)
if not is_host_in_room:
action = "remote_reject"
federation_handler = self.hs.get_handlers().federation_handler
if action == "remote_join":
if len(remote_room_hosts) == 0:
raise SynapseError(404, "No known servers")
# We don't do an auth check if we are doing an invite
# join dance for now, since we're kinda implicitly checking
# that we are allowed to join when we decide whether or not we
# need to do the invite/join dance.
yield federation_handler.do_invite_join(
remote_room_hosts,
event.room_id,
event.user_id,
event.content,
)
elif action == "remote_reject":
inviter = self.get_inviter(target_user.to_string(), context.current_state)
if not inviter:
raise SynapseError(404, "No known servers")
yield federation_handler.do_remotely_reject_invite(
[inviter.domain],
room_id,
event.user_id
)
else:
yield self.handle_new_client_event(
event,
context,
extra_users=[target_user],
ratelimit=ratelimit,
)
prev_member_event = context.current_state.get(
(EventTypes.Member, target_user.to_string()),
None None
) )
room_id = event.room_id
# If we're trying to join a room then we have to do this differently
# if this HS is not currently in the room, i.e. we have to do the
# invite/join dance.
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
if is_guest: if not prev_member_event or prev_member_event.membership != Membership.JOIN:
guest_access = context.current_state.get( # Only fire user_joined_room if the user has acutally joined the
(EventTypes.GuestAccess, ""), # room. Don't bother if the user is just changing their profile
None # info.
) yield user_joined_room(self.distributor, target_user, room_id)
is_guest_access_allowed = ( elif event.membership == Membership.LEAVE:
guest_access if prev_member_event and prev_member_event.membership == Membership.JOIN:
and guest_access.content user_left_room(self.distributor, target_user, room_id)
and "guest_access" in guest_access.content
and guest_access.content["guest_access"] == "can_join"
)
if not is_guest_access_allowed:
raise AuthError(403, "Guest access not allowed")
yield self._do_join(event, context) def _can_guest_join(self, current_state):
else: """
if event.membership == Membership.LEAVE: Returns whether a guest can join a room based on its current state.
is_host_in_room = yield self.is_host_in_room(room_id, context) """
if not is_host_in_room: guest_access = current_state.get((EventTypes.GuestAccess, ""), None)
# Rejecting an invite, rather than leaving a joined room return (
handler = self.hs.get_handlers().federation_handler guest_access
inviter = yield self.get_inviter(event) and guest_access.content
if not inviter: and "guest_access" in guest_access.content
# return the same error as join_room_alias does and guest_access.content["guest_access"] == "can_join"
raise SynapseError(404, "No known servers") )
yield handler.do_remotely_reject_invite(
[inviter.domain],
room_id,
event.user_id
)
defer.returnValue({"room_id": room_id})
return
# FIXME: This isn't idempotency. def _should_do_dance(self, context, inviter, room_hosts=None):
if prev_state and prev_state.membership == event.membership: # TODO: Shouldn't this be remote_room_host?
# double same action, treat this event as a NOOP. room_hosts = room_hosts or []
defer.returnValue({})
return
yield self._do_local_membership_update( is_host_in_room = self.is_host_in_room(context.current_state)
event, if is_host_in_room:
context=context, return False, room_hosts
)
if prev_state and prev_state.membership == Membership.JOIN: if inviter and not self.hs.is_mine(inviter):
user = UserID.from_string(event.user_id) room_hosts.append(inviter.domain)
user_left_room(self.distributor, user, event.room_id)
defer.returnValue({"room_id": room_id}) return True, room_hosts
@defer.inlineCallbacks @defer.inlineCallbacks
def join_room_alias(self, joinee, room_alias, content={}): def lookup_room_alias(self, room_alias):
"""
Get the room ID associated with a room alias.
Args:
room_alias (RoomAlias): The alias to look up.
Returns:
A tuple of:
The room ID as a RoomID object.
Hosts likely to be participating in the room ([str]).
Raises:
SynapseError if room alias could not be found.
"""
directory_handler = self.hs.get_handlers().directory_handler directory_handler = self.hs.get_handlers().directory_handler
mapping = yield directory_handler.get_association(room_alias) mapping = yield directory_handler.get_association(room_alias)
@ -523,111 +608,15 @@ class RoomMemberHandler(BaseHandler):
raise SynapseError(404, "No such room alias") raise SynapseError(404, "No such room alias")
room_id = mapping["room_id"] room_id = mapping["room_id"]
hosts = mapping["servers"] servers = mapping["servers"]
if not hosts:
raise SynapseError(404, "No known servers")
# If event doesn't include a display name, add one. defer.returnValue((RoomID.from_string(room_id), servers))
yield collect_presencelike_data(self.distributor, joinee, content)
content.update({"membership": Membership.JOIN})
builder = self.event_builder_factory.new({
"type": EventTypes.Member,
"state_key": joinee.to_string(),
"room_id": room_id,
"sender": joinee.to_string(),
"membership": Membership.JOIN,
"content": content,
})
event, context = yield self._create_new_client_event(builder)
yield self._do_join(event, context, room_hosts=hosts)
defer.returnValue({"room_id": room_id})
@defer.inlineCallbacks
def _do_join(self, event, context, room_hosts=None):
room_id = event.room_id
# XXX: We don't do an auth check if we are doing an invite
# join dance for now, since we're kinda implicitly checking
# that we are allowed to join when we decide whether or not we
# need to do the invite/join dance.
is_host_in_room = yield self.is_host_in_room(room_id, context)
if is_host_in_room:
should_do_dance = False
elif room_hosts: # TODO: Shouldn't this be remote_room_host?
should_do_dance = True
else:
inviter = yield self.get_inviter(event)
if not inviter:
# return the same error as join_room_alias does
raise SynapseError(404, "No known servers")
should_do_dance = not self.hs.is_mine(inviter)
room_hosts = [inviter.domain]
if should_do_dance:
handler = self.hs.get_handlers().federation_handler
yield handler.do_invite_join(
room_hosts,
room_id,
event.user_id,
event.content,
)
else:
logger.debug("Doing normal join")
yield self._do_local_membership_update(
event,
context=context,
)
prev_state = context.current_state.get((event.type, event.state_key))
if not prev_state or prev_state.membership != Membership.JOIN:
# Only fire user_joined_room if the user has acutally joined the
# room. Don't bother if the user is just changing their profile
# info.
user = UserID.from_string(event.user_id)
yield user_joined_room(self.distributor, user, room_id)
@defer.inlineCallbacks
def get_inviter(self, event):
# TODO(markjh): get prev_state from snapshot
prev_state = yield self.store.get_room_member(
event.user_id, event.room_id
)
def get_inviter(self, user_id, current_state):
prev_state = current_state.get((EventTypes.Member, user_id))
if prev_state and prev_state.membership == Membership.INVITE: if prev_state and prev_state.membership == Membership.INVITE:
defer.returnValue(UserID.from_string(prev_state.user_id)) return UserID.from_string(prev_state.user_id)
return return None
elif "third_party_invite" in event.content:
if "sender" in event.content["third_party_invite"]:
inviter = UserID.from_string(
event.content["third_party_invite"]["sender"]
)
defer.returnValue(inviter)
defer.returnValue(None)
@defer.inlineCallbacks
def is_host_in_room(self, room_id, context):
is_host_in_room = yield self.auth.check_host_in_room(
room_id,
self.hs.hostname
)
if not is_host_in_room:
# is *anyone* in the room?
room_member_keys = [
v for (k, v) in context.current_state.keys() if (
k == "m.room.member"
)
]
if len(room_member_keys) == 0:
# has the room been created so we can join it?
create_event = context.current_state.get(("m.room.create", ""))
if create_event:
is_host_in_room = True
defer.returnValue(is_host_in_room)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_joined_rooms_for_user(self, user): def get_joined_rooms_for_user(self, user):
@ -644,18 +633,6 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue(room_ids) defer.returnValue(room_ids)
@defer.inlineCallbacks
def _do_local_membership_update(self, event, context):
yield run_on_reactor()
target_user = UserID.from_string(event.state_key)
yield self.handle_new_client_event(
event,
context,
extra_users=[target_user],
)
@defer.inlineCallbacks @defer.inlineCallbacks
def do_3pid_invite( def do_3pid_invite(
self, self,
@ -664,7 +641,7 @@ class RoomMemberHandler(BaseHandler):
medium, medium,
address, address,
id_server, id_server,
token_id, requester,
txn_id txn_id
): ):
invitee = yield self._lookup_3pid( invitee = yield self._lookup_3pid(
@ -672,19 +649,12 @@ class RoomMemberHandler(BaseHandler):
) )
if invitee: if invitee:
# make sure it looks like a user ID; it'll throw if it's invalid. handler = self.hs.get_handlers().room_member_handler
UserID.from_string(invitee) yield handler.update_membership(
yield self.hs.get_handlers().message_handler.create_and_send_event( requester,
{ UserID.from_string(invitee),
"type": EventTypes.Member, room_id,
"content": { "invite",
"membership": unicode("invite")
},
"room_id": room_id,
"sender": inviter.to_string(),
"state_key": invitee,
},
token_id=token_id,
txn_id=txn_id, txn_id=txn_id,
) )
else: else:
@ -694,7 +664,7 @@ class RoomMemberHandler(BaseHandler):
address, address,
room_id, room_id,
inviter, inviter,
token_id, requester.access_token_id,
txn_id=txn_id txn_id=txn_id
) )
@ -805,7 +775,7 @@ class RoomMemberHandler(BaseHandler):
) )
) )
msg_handler = self.hs.get_handlers().message_handler msg_handler = self.hs.get_handlers().message_handler
yield msg_handler.create_and_send_event( yield msg_handler.create_and_send_nonmember_event(
{ {
"type": EventTypes.ThirdPartyInvite, "type": EventTypes.ThirdPartyInvite,
"content": { "content": {

View file

@ -582,6 +582,28 @@ class SyncHandler(BaseHandler):
if room_sync: if room_sync:
joined.append(room_sync) joined.append(room_sync)
# For each newly joined room, we want to send down presence of
# existing users.
presence_handler = self.hs.get_handlers().presence_handler
extra_presence_users = set()
for room_id in newly_joined_rooms:
users = yield self.store.get_users_in_room(event.room_id)
extra_presence_users.update(users)
# For each new member, send down presence.
for joined_sync in joined:
it = itertools.chain(joined_sync.timeline.events, joined_sync.state.values())
for event in it:
if event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
extra_presence_users.add(event.state_key)
states = yield presence_handler.get_states(
[u for u in extra_presence_users if u != user_id],
as_event=True,
)
presence.extend(states)
account_data_for_user = sync_config.filter_collection.filter_account_data( account_data_for_user = sync_config.filter_collection.filter_account_data(
self.account_data_for_user(account_data) self.account_data_for_user(account_data)
) )
@ -821,15 +843,17 @@ class SyncHandler(BaseHandler):
# TODO(mjark) Check for new redactions in the state events. # TODO(mjark) Check for new redactions in the state events.
with Measure(self.clock, "compute_state_delta"): with Measure(self.clock, "compute_state_delta"):
current_state = yield self.get_state_at(
room_id, stream_position=now_token
)
if full_state: if full_state:
if batch: if batch:
state = yield self.store.get_state_for_event( state = yield self.store.get_state_for_event(
batch.events[0].event_id batch.events[0].event_id
) )
else: else:
state = yield self.get_state_at( state = current_state
room_id, stream_position=now_token
)
timeline_state = { timeline_state = {
(event.type, event.state_key): event (event.type, event.state_key): event
@ -840,6 +864,7 @@ class SyncHandler(BaseHandler):
timeline_contains=timeline_state, timeline_contains=timeline_state,
timeline_start=state, timeline_start=state,
previous={}, previous={},
current=current_state,
) )
elif batch.limited: elif batch.limited:
state_at_previous_sync = yield self.get_state_at( state_at_previous_sync = yield self.get_state_at(
@ -859,6 +884,7 @@ class SyncHandler(BaseHandler):
timeline_contains=timeline_state, timeline_contains=timeline_state,
timeline_start=state_at_timeline_start, timeline_start=state_at_timeline_start,
previous=state_at_previous_sync, previous=state_at_previous_sync,
current=current_state,
) )
else: else:
state = {} state = {}
@ -918,7 +944,7 @@ def _action_has_highlight(actions):
return False return False
def _calculate_state(timeline_contains, timeline_start, previous): def _calculate_state(timeline_contains, timeline_start, previous, current):
"""Works out what state to include in a sync response. """Works out what state to include in a sync response.
Args: Args:
@ -926,6 +952,7 @@ def _calculate_state(timeline_contains, timeline_start, previous):
timeline_start (dict): state at the start of the timeline timeline_start (dict): state at the start of the timeline
previous (dict): state at the end of the previous sync (or empty dict previous (dict): state at the end of the previous sync (or empty dict
if this is an initial sync) if this is an initial sync)
current (dict): state at the end of the timeline
Returns: Returns:
dict dict
@ -936,14 +963,16 @@ def _calculate_state(timeline_contains, timeline_start, previous):
timeline_contains.values(), timeline_contains.values(),
previous.values(), previous.values(),
timeline_start.values(), timeline_start.values(),
current.values(),
) )
} }
c_ids = set(e.event_id for e in current.values())
tc_ids = set(e.event_id for e in timeline_contains.values()) tc_ids = set(e.event_id for e in timeline_contains.values())
p_ids = set(e.event_id for e in previous.values()) p_ids = set(e.event_id for e in previous.values())
ts_ids = set(e.event_id for e in timeline_start.values()) ts_ids = set(e.event_id for e in timeline_start.values())
state_ids = (ts_ids - p_ids) - tc_ids state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids
evs = (event_id_to_state[e] for e in state_ids) evs = (event_id_to_state[e] for e in state_ids)
return { return {

View file

@ -47,14 +47,13 @@ class Pusher(object):
MAX_BACKOFF = 60 * 60 * 1000 MAX_BACKOFF = 60 * 60 * 1000
GIVE_UP_AFTER = 24 * 60 * 60 * 1000 GIVE_UP_AFTER = 24 * 60 * 60 * 1000
def __init__(self, _hs, profile_tag, user_id, app_id, def __init__(self, _hs, user_id, app_id,
app_display_name, device_display_name, pushkey, pushkey_ts, app_display_name, device_display_name, pushkey, pushkey_ts,
data, last_token, last_success, failing_since): data, last_token, last_success, failing_since):
self.hs = _hs self.hs = _hs
self.evStreamHandler = self.hs.get_handlers().event_stream_handler self.evStreamHandler = self.hs.get_handlers().event_stream_handler
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()
self.profile_tag = profile_tag
self.user_id = user_id self.user_id = user_id
self.app_id = app_id self.app_id = app_id
self.app_display_name = app_display_name self.app_display_name = app_display_name
@ -186,8 +185,8 @@ class Pusher(object):
processed = False processed = False
rule_evaluator = yield \ rule_evaluator = yield \
push_rule_evaluator.evaluator_for_user_id_and_profile_tag( push_rule_evaluator.evaluator_for_user_id(
self.user_id, self.profile_tag, single_event['room_id'], self.store self.user_id, single_event['room_id'], self.store
) )
actions = yield rule_evaluator.actions_for_event(single_event) actions = yield rule_evaluator.actions_for_event(single_event)

View file

@ -44,5 +44,5 @@ class ActionGenerator:
) )
context.push_actions = [ context.push_actions = [
(uid, None, actions) for uid, actions in actions_by_user.items() (uid, actions) for uid, actions in actions_by_user.items()
] ]

View file

@ -152,7 +152,7 @@ def _condition_checker(evaluator, conditions, uid, display_name, cache):
elif res is True: elif res is True:
continue continue
res = evaluator.matches(cond, uid, display_name, None) res = evaluator.matches(cond, uid, display_name)
if _id: if _id:
cache[_id] = bool(res) cache[_id] = bool(res)

View file

@ -23,12 +23,11 @@ logger = logging.getLogger(__name__)
class HttpPusher(Pusher): class HttpPusher(Pusher):
def __init__(self, _hs, profile_tag, user_id, app_id, def __init__(self, _hs, user_id, app_id,
app_display_name, device_display_name, pushkey, pushkey_ts, app_display_name, device_display_name, pushkey, pushkey_ts,
data, last_token, last_success, failing_since): data, last_token, last_success, failing_since):
super(HttpPusher, self).__init__( super(HttpPusher, self).__init__(
_hs, _hs,
profile_tag,
user_id, user_id,
app_id, app_id,
app_display_name, app_display_name,

View file

@ -33,7 +33,7 @@ INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
@defer.inlineCallbacks @defer.inlineCallbacks
def evaluator_for_user_id_and_profile_tag(user_id, profile_tag, room_id, store): def evaluator_for_user_id(user_id, room_id, store):
rawrules = yield store.get_push_rules_for_user(user_id) rawrules = yield store.get_push_rules_for_user(user_id)
enabled_map = yield store.get_push_rules_enabled_for_user(user_id) enabled_map = yield store.get_push_rules_enabled_for_user(user_id)
our_member_event = yield store.get_current_state( our_member_event = yield store.get_current_state(
@ -43,7 +43,7 @@ def evaluator_for_user_id_and_profile_tag(user_id, profile_tag, room_id, store):
) )
defer.returnValue(PushRuleEvaluator( defer.returnValue(PushRuleEvaluator(
user_id, profile_tag, rawrules, enabled_map, user_id, rawrules, enabled_map,
room_id, our_member_event, store room_id, our_member_event, store
)) ))
@ -77,10 +77,9 @@ def _room_member_count(ev, condition, room_member_count):
class PushRuleEvaluator: class PushRuleEvaluator:
DEFAULT_ACTIONS = [] DEFAULT_ACTIONS = []
def __init__(self, user_id, profile_tag, raw_rules, enabled_map, room_id, def __init__(self, user_id, raw_rules, enabled_map, room_id,
our_member_event, store): our_member_event, store):
self.user_id = user_id self.user_id = user_id
self.profile_tag = profile_tag
self.room_id = room_id self.room_id = room_id
self.our_member_event = our_member_event self.our_member_event = our_member_event
self.store = store self.store = store
@ -152,7 +151,7 @@ class PushRuleEvaluator:
matches = True matches = True
for c in conditions: for c in conditions:
matches = evaluator.matches( matches = evaluator.matches(
c, self.user_id, my_display_name, self.profile_tag c, self.user_id, my_display_name
) )
if not matches: if not matches:
break break
@ -189,13 +188,9 @@ class PushRuleEvaluatorForEvent(object):
# Maps strings of e.g. 'content.body' -> event["content"]["body"] # Maps strings of e.g. 'content.body' -> event["content"]["body"]
self._value_cache = _flatten_dict(event) self._value_cache = _flatten_dict(event)
def matches(self, condition, user_id, display_name, profile_tag): def matches(self, condition, user_id, display_name):
if condition['kind'] == 'event_match': if condition['kind'] == 'event_match':
return self._event_match(condition, user_id) return self._event_match(condition, user_id)
elif condition['kind'] == 'device':
if 'profile_tag' not in condition:
return True
return condition['profile_tag'] == profile_tag
elif condition['kind'] == 'contains_display_name': elif condition['kind'] == 'contains_display_name':
return self._contains_display_name(display_name) return self._contains_display_name(display_name)
elif condition['kind'] == 'room_member_count': elif condition['kind'] == 'room_member_count':

View file

@ -29,6 +29,7 @@ class PusherPool:
def __init__(self, _hs): def __init__(self, _hs):
self.hs = _hs self.hs = _hs
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
self.pushers = {} self.pushers = {}
self.last_pusher_started = -1 self.last_pusher_started = -1
@ -38,8 +39,11 @@ class PusherPool:
self._start_pushers(pushers) self._start_pushers(pushers)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_pusher(self, user_id, access_token, profile_tag, kind, app_id, def add_pusher(self, user_id, access_token, kind, app_id,
app_display_name, device_display_name, pushkey, lang, data): app_display_name, device_display_name, pushkey, lang, data,
profile_tag=""):
time_now_msec = self.clock.time_msec()
# we try to create the pusher just to validate the config: it # we try to create the pusher just to validate the config: it
# will then get pulled out of the database, # will then get pulled out of the database,
# recreated, added and started: this means we have only one # recreated, added and started: this means we have only one
@ -47,23 +51,31 @@ class PusherPool:
self._create_pusher({ self._create_pusher({
"user_name": user_id, "user_name": user_id,
"kind": kind, "kind": kind,
"profile_tag": profile_tag,
"app_id": app_id, "app_id": app_id,
"app_display_name": app_display_name, "app_display_name": app_display_name,
"device_display_name": device_display_name, "device_display_name": device_display_name,
"pushkey": pushkey, "pushkey": pushkey,
"ts": self.hs.get_clock().time_msec(), "ts": time_now_msec,
"lang": lang, "lang": lang,
"data": data, "data": data,
"last_token": None, "last_token": None,
"last_success": None, "last_success": None,
"failing_since": None "failing_since": None
}) })
yield self._add_pusher_to_store( yield self.store.add_pusher(
user_id, access_token, profile_tag, kind, app_id, user_id=user_id,
app_display_name, device_display_name, access_token=access_token,
pushkey, lang, data kind=kind,
app_id=app_id,
app_display_name=app_display_name,
device_display_name=device_display_name,
pushkey=pushkey,
pushkey_ts=time_now_msec,
lang=lang,
data=data,
profile_tag=profile_tag,
) )
yield self._refresh_pusher(app_id, pushkey, user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def remove_pushers_by_app_id_and_pushkey_not_user(self, app_id, pushkey, def remove_pushers_by_app_id_and_pushkey_not_user(self, app_id, pushkey,
@ -94,30 +106,10 @@ class PusherPool:
) )
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks
def _add_pusher_to_store(self, user_id, access_token, profile_tag, kind,
app_id, app_display_name, device_display_name,
pushkey, lang, data):
yield self.store.add_pusher(
user_id=user_id,
access_token=access_token,
profile_tag=profile_tag,
kind=kind,
app_id=app_id,
app_display_name=app_display_name,
device_display_name=device_display_name,
pushkey=pushkey,
pushkey_ts=self.hs.get_clock().time_msec(),
lang=lang,
data=data,
)
yield self._refresh_pusher(app_id, pushkey, user_id)
def _create_pusher(self, pusherdict): def _create_pusher(self, pusherdict):
if pusherdict['kind'] == 'http': if pusherdict['kind'] == 'http':
return HttpPusher( return HttpPusher(
self.hs, self.hs,
profile_tag=pusherdict['profile_tag'],
user_id=pusherdict['user_name'], user_id=pusherdict['user_name'],
app_id=pusherdict['app_id'], app_id=pusherdict['app_id'],
app_display_name=pusherdict['app_display_name'], app_display_name=pusherdict['app_display_name'],

View file

@ -17,7 +17,7 @@
""" """
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError, AuthError
from synapse.types import UserID from synapse.types import UserID
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
@ -35,8 +35,15 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
state = yield self.handlers.presence_handler.get_state( if requester.user != user:
target_user=user, auth_user=requester.user) allowed = yield self.handlers.presence_handler.is_visible(
observed_user=user, observer_user=requester.user,
)
if not allowed:
raise AuthError(403, "You are not allowed to see their presence.")
state = yield self.handlers.presence_handler.get_state(target_user=user)
defer.returnValue((200, state)) defer.returnValue((200, state))
@ -45,6 +52,9 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
if requester.user != user:
raise AuthError(403, "Can only set your own presence state")
state = {} state = {}
try: try:
content = json.loads(request.content.read()) content = json.loads(request.content.read())
@ -63,8 +73,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
except: except:
raise SynapseError(400, "Unable to parse state") raise SynapseError(400, "Unable to parse state")
yield self.handlers.presence_handler.set_state( yield self.handlers.presence_handler.set_state(user, state)
target_user=user, auth_user=requester.user, state=state)
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -87,11 +96,8 @@ class PresenceListRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Cannot get another user's presence list") raise SynapseError(400, "Cannot get another user's presence list")
presence = yield self.handlers.presence_handler.get_presence_list( presence = yield self.handlers.presence_handler.get_presence_list(
observer_user=user, accepted=True) observer_user=user, accepted=True
)
for p in presence:
observed_user = p.pop("observed_user")
p["user_id"] = observed_user.to_string()
defer.returnValue((200, presence)) defer.returnValue((200, presence))

View file

@ -60,7 +60,6 @@ class PushRuleRestServlet(ClientV1RestServlet):
spec['template'], spec['template'],
spec['rule_id'], spec['rule_id'],
content, content,
device=spec['device'] if 'device' in spec else None
) )
except InvalidRuleException as e: except InvalidRuleException as e:
raise SynapseError(400, e.message) raise SynapseError(400, e.message)
@ -153,23 +152,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
elif pattern_type == "user_localpart": elif pattern_type == "user_localpart":
c["pattern"] = user.localpart c["pattern"] = user.localpart
if r['priority_class'] > PRIORITY_CLASS_MAP['override']: rulearray = rules['global'][template_name]
# per-device rule
profile_tag = _profile_tag_from_conditions(r["conditions"])
r = _strip_device_condition(r)
if not profile_tag:
continue
if profile_tag not in rules['device']:
rules['device'][profile_tag] = {}
rules['device'][profile_tag] = (
_add_empty_priority_class_arrays(
rules['device'][profile_tag]
)
)
rulearray = rules['device'][profile_tag][template_name]
else:
rulearray = rules['global'][template_name]
template_rule = _rule_to_template(r) template_rule = _rule_to_template(r)
if template_rule: if template_rule:
@ -195,24 +178,6 @@ class PushRuleRestServlet(ClientV1RestServlet):
path = path[1:] path = path[1:]
result = _filter_ruleset_with_path(rules['global'], path) result = _filter_ruleset_with_path(rules['global'], path)
defer.returnValue((200, result)) defer.returnValue((200, result))
elif path[0] == 'device':
path = path[1:]
if path == []:
raise UnrecognizedRequestError(
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
)
if path[0] == '':
defer.returnValue((200, rules['device']))
profile_tag = path[0]
path = path[1:]
if profile_tag not in rules['device']:
ret = {}
ret = _add_empty_priority_class_arrays(ret)
defer.returnValue((200, ret))
ruleset = rules['device'][profile_tag]
result = _filter_ruleset_with_path(ruleset, path)
defer.returnValue((200, result))
else: else:
raise UnrecognizedRequestError() raise UnrecognizedRequestError()
@ -252,16 +217,9 @@ def _rule_spec_from_path(path):
scope = path[1] scope = path[1]
path = path[2:] path = path[2:]
if scope not in ['global', 'device']: if scope != 'global':
raise UnrecognizedRequestError() raise UnrecognizedRequestError()
device = None
if scope == 'device':
if len(path) == 0:
raise UnrecognizedRequestError()
device = path[0]
path = path[1:]
if len(path) == 0: if len(path) == 0:
raise UnrecognizedRequestError() raise UnrecognizedRequestError()
@ -278,8 +236,6 @@ def _rule_spec_from_path(path):
'template': template, 'template': template,
'rule_id': rule_id 'rule_id': rule_id
} }
if device:
spec['profile_tag'] = device
path = path[1:] path = path[1:]
@ -289,7 +245,7 @@ def _rule_spec_from_path(path):
return spec return spec
def _rule_tuple_from_request_object(rule_template, rule_id, req_obj, device=None): def _rule_tuple_from_request_object(rule_template, rule_id, req_obj):
if rule_template in ['override', 'underride']: if rule_template in ['override', 'underride']:
if 'conditions' not in req_obj: if 'conditions' not in req_obj:
raise InvalidRuleException("Missing 'conditions'") raise InvalidRuleException("Missing 'conditions'")
@ -322,12 +278,6 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj, device=None
else: else:
raise InvalidRuleException("Unknown rule template: %s" % (rule_template,)) raise InvalidRuleException("Unknown rule template: %s" % (rule_template,))
if device:
conditions.append({
'kind': 'device',
'profile_tag': device
})
if 'actions' not in req_obj: if 'actions' not in req_obj:
raise InvalidRuleException("No actions found") raise InvalidRuleException("No actions found")
actions = req_obj['actions'] actions = req_obj['actions']
@ -349,17 +299,6 @@ def _add_empty_priority_class_arrays(d):
return d return d
def _profile_tag_from_conditions(conditions):
"""
Given a list of conditions, return the profile tag of the
device rule if there is one
"""
for c in conditions:
if c['kind'] == 'device':
return c['profile_tag']
return None
def _filter_ruleset_with_path(ruleset, path): def _filter_ruleset_with_path(ruleset, path):
if path == []: if path == []:
raise UnrecognizedRequestError( raise UnrecognizedRequestError(
@ -403,19 +342,11 @@ def _priority_class_from_spec(spec):
raise InvalidRuleException("Unknown template: %s" % (spec['template'])) raise InvalidRuleException("Unknown template: %s" % (spec['template']))
pc = PRIORITY_CLASS_MAP[spec['template']] pc = PRIORITY_CLASS_MAP[spec['template']]
if spec['scope'] == 'device':
pc += len(PRIORITY_CLASS_MAP)
return pc return pc
def _priority_class_to_template_name(pc): def _priority_class_to_template_name(pc):
if pc > PRIORITY_CLASS_MAP['override']: return PRIORITY_CLASS_INVERSE_MAP[pc]
# per-device
prio_class_index = pc - len(PRIORITY_CLASS_MAP)
return PRIORITY_CLASS_INVERSE_MAP[prio_class_index]
else:
return PRIORITY_CLASS_INVERSE_MAP[pc]
def _rule_to_template(rule): def _rule_to_template(rule):
@ -445,23 +376,12 @@ def _rule_to_template(rule):
return templaterule return templaterule
def _strip_device_condition(rule):
for i, c in enumerate(rule['conditions']):
if c['kind'] == 'device':
del rule['conditions'][i]
return rule
def _namespaced_rule_id_from_spec(spec): def _namespaced_rule_id_from_spec(spec):
return _namespaced_rule_id(spec, spec['rule_id']) return _namespaced_rule_id(spec, spec['rule_id'])
def _namespaced_rule_id(spec, rule_id): def _namespaced_rule_id(spec, rule_id):
if spec['scope'] == 'global': return "global/%s/%s" % (spec['template'], rule_id)
scope = 'global'
else:
scope = 'device/%s' % (spec['profile_tag'])
return "%s/%s/%s" % (scope, spec['template'], rule_id)
def _rule_id_from_namespaced(in_rule_id): def _rule_id_from_namespaced(in_rule_id):

View file

@ -45,7 +45,7 @@ class PusherRestServlet(ClientV1RestServlet):
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
reqd = ['profile_tag', 'kind', 'app_id', 'app_display_name', reqd = ['kind', 'app_id', 'app_display_name',
'device_display_name', 'pushkey', 'lang', 'data'] 'device_display_name', 'pushkey', 'lang', 'data']
missing = [] missing = []
for i in reqd: for i in reqd:
@ -73,14 +73,14 @@ class PusherRestServlet(ClientV1RestServlet):
yield pusher_pool.add_pusher( yield pusher_pool.add_pusher(
user_id=user.to_string(), user_id=user.to_string(),
access_token=requester.access_token_id, access_token=requester.access_token_id,
profile_tag=content['profile_tag'],
kind=content['kind'], kind=content['kind'],
app_id=content['app_id'], app_id=content['app_id'],
app_display_name=content['app_display_name'], app_display_name=content['app_display_name'],
device_display_name=content['device_display_name'], device_display_name=content['device_display_name'],
pushkey=content['pushkey'], pushkey=content['pushkey'],
lang=content['lang'], lang=content['lang'],
data=content['data'] data=content['data'],
profile_tag=content.get('profile_tag', ""),
) )
except PusherConfigException as pce: except PusherConfigException as pce:
raise SynapseError(400, "Config Error: " + pce.message, raise SynapseError(400, "Config Error: " + pce.message,

View file

@ -150,10 +150,21 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
event_dict["state_key"] = state_key event_dict["state_key"] = state_key
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
event = yield msg_handler.create_and_send_event( event, context = yield msg_handler.create_event(
event_dict, token_id=requester.access_token_id, txn_id=txn_id, event_dict,
token_id=requester.access_token_id,
txn_id=txn_id,
) )
if event_type == EventTypes.Member:
yield self.handlers.room_member_handler.send_membership_event(
event,
context,
is_guest=requester.is_guest,
)
else:
yield msg_handler.send_nonmember_event(event, context)
defer.returnValue((200, {"event_id": event.event_id})) defer.returnValue((200, {"event_id": event.event_id}))
@ -171,7 +182,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
content = _parse_json(request) content = _parse_json(request)
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
event = yield msg_handler.create_and_send_event( event = yield msg_handler.create_and_send_nonmember_event(
{ {
"type": event_type, "type": event_type,
"content": content, "content": content,
@ -217,46 +228,29 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
allow_guest=True, allow_guest=True,
) )
# the identifier could be a room alias or a room id. Try one then the if RoomID.is_valid(room_identifier):
# other if it fails to parse, without swallowing other valid room_id = room_identifier
# SynapseErrors. remote_room_hosts = None
elif RoomAlias.is_valid(room_identifier):
identifier = None
is_room_alias = False
try:
identifier = RoomAlias.from_string(room_identifier)
is_room_alias = True
except SynapseError:
identifier = RoomID.from_string(room_identifier)
# TODO: Support for specifying the home server to join with?
if is_room_alias:
handler = self.handlers.room_member_handler handler = self.handlers.room_member_handler
ret_dict = yield handler.join_room_alias( room_alias = RoomAlias.from_string(room_identifier)
requester.user, room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias)
identifier, room_id = room_id.to_string()
) else:
defer.returnValue((200, ret_dict)) raise SynapseError(400, "%s was not legal room ID or room alias" % (
else: # room id room_identifier,
msg_handler = self.handlers.message_handler ))
content = {"membership": Membership.JOIN}
if requester.is_guest:
content["kind"] = "guest"
yield msg_handler.create_and_send_event(
{
"type": EventTypes.Member,
"content": content,
"room_id": identifier.to_string(),
"sender": requester.user.to_string(),
"state_key": requester.user.to_string(),
},
token_id=requester.access_token_id,
txn_id=txn_id,
is_guest=requester.is_guest,
)
defer.returnValue((200, {"room_id": identifier.to_string()})) yield self.handlers.room_member_handler.update_membership(
requester=requester,
target=requester.user,
room_id=room_id,
action="join",
txn_id=txn_id,
remote_room_hosts=remote_room_hosts,
)
defer.returnValue((200, {"room_id": room_id}))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_identifier, txn_id): def on_PUT(self, request, room_identifier, txn_id):
@ -304,18 +298,6 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
if event["type"] != EventTypes.Member: if event["type"] != EventTypes.Member:
continue continue
chunk.append(event) chunk.append(event)
# FIXME: should probably be state_key here, not user_id
target_user = UserID.from_string(event["user_id"])
# Presence is an optional cache; don't fail if we can't fetch it
try:
presence_handler = self.handlers.presence_handler
presence_state = yield presence_handler.get_state(
target_user=target_user,
auth_user=requester.user,
)
event["content"].update(presence_state)
except:
pass
defer.returnValue((200, { defer.returnValue((200, {
"chunk": chunk "chunk": chunk
@ -451,7 +433,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
content["medium"], content["medium"],
content["address"], content["address"],
content["id_server"], content["id_server"],
requester.access_token_id, requester,
txn_id txn_id
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -507,7 +489,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
content = _parse_json(request) content = _parse_json(request)
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
event = yield msg_handler.create_and_send_event( event = yield msg_handler.create_and_send_nonmember_event(
{ {
"type": EventTypes.Redaction, "type": EventTypes.Redaction,
"content": content, "content": content,
@ -541,6 +523,10 @@ class RoomTypingRestServlet(ClientV1RestServlet):
"/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$" "/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$"
) )
def __init__(self, hs):
super(RoomTypingRestServlet, self).__init__(hs)
self.presence_handler = hs.get_handlers().presence_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_id, user_id): def on_PUT(self, request, room_id, user_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
@ -552,6 +538,8 @@ class RoomTypingRestServlet(ClientV1RestServlet):
typing_handler = self.handlers.typing_notification_handler typing_handler = self.handlers.typing_notification_handler
yield self.presence_handler.bump_presence_active_time(requester.user)
if content["typing"]: if content["typing"]:
yield typing_handler.started_typing( yield typing_handler.started_typing(
target_user=target_user, target_user=target_user,

View file

@ -37,6 +37,7 @@ class ReceiptRestServlet(RestServlet):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.receipts_handler = hs.get_handlers().receipts_handler self.receipts_handler = hs.get_handlers().receipts_handler
self.presence_handler = hs.get_handlers().presence_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, receipt_type, event_id): def on_POST(self, request, room_id, receipt_type, event_id):
@ -45,6 +46,8 @@ class ReceiptRestServlet(RestServlet):
if receipt_type != "m.read": if receipt_type != "m.read":
raise SynapseError(400, "Receipt type must be 'm.read'") raise SynapseError(400, "Receipt type must be 'm.read'")
yield self.presence_handler.bump_presence_active_time(requester.user)
yield self.receipts_handler.received_client_receipt( yield self.receipts_handler.received_client_receipt(
room_id, room_id,
receipt_type, receipt_type,

View file

@ -25,6 +25,7 @@ from synapse.events.utils import (
) )
from synapse.api.filtering import FilterCollection, DEFAULT_FILTER_COLLECTION from synapse.api.filtering import FilterCollection, DEFAULT_FILTER_COLLECTION
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.api.constants import PresenceState
from ._base import client_v2_patterns from ._base import client_v2_patterns
import copy import copy
@ -82,6 +83,7 @@ class SyncRestServlet(RestServlet):
self.sync_handler = hs.get_handlers().sync_handler self.sync_handler = hs.get_handlers().sync_handler
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.filtering = hs.get_filtering() self.filtering = hs.get_filtering()
self.presence_handler = hs.get_handlers().presence_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
@ -139,17 +141,19 @@ class SyncRestServlet(RestServlet):
else: else:
since_token = None since_token = None
if set_presence == "online": affect_presence = set_presence != PresenceState.OFFLINE
yield self.event_stream_handler.started_stream(user)
try: if affect_presence:
yield self.presence_handler.set_state(user, {"presence": set_presence})
context = yield self.presence_handler.user_syncing(
user.to_string(), affect_presence=affect_presence,
)
with context:
sync_result = yield self.sync_handler.wait_for_sync_for_user( sync_result = yield self.sync_handler.wait_for_sync_for_user(
sync_config, since_token=since_token, timeout=timeout, sync_config, since_token=since_token, timeout=timeout,
full_state=full_state full_state=full_state
) )
finally:
if set_presence == "online":
self.event_stream_handler.stopped_stream(user)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()

View file

@ -20,7 +20,7 @@ from .appservice import (
from ._base import Cache from ._base import Cache
from .directory import DirectoryStore from .directory import DirectoryStore
from .events import EventsStore from .events import EventsStore
from .presence import PresenceStore from .presence import PresenceStore, UserPresenceState
from .profile import ProfileStore from .profile import ProfileStore
from .registration import RegistrationStore from .registration import RegistrationStore
from .room import RoomStore from .room import RoomStore
@ -47,6 +47,7 @@ from .account_data import AccountDataStore
from util.id_generators import IdGenerator, StreamIdGenerator from util.id_generators import IdGenerator, StreamIdGenerator
from synapse.api.constants import PresenceState
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -110,6 +111,9 @@ class DataStore(RoomMemberStore, RoomStore,
self._account_data_id_gen = StreamIdGenerator( self._account_data_id_gen = StreamIdGenerator(
db_conn, "account_data_max_stream_id", "stream_id" db_conn, "account_data_max_stream_id", "stream_id"
) )
self._presence_id_gen = StreamIdGenerator(
db_conn, "presence_stream", "stream_id"
)
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self) self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
self._state_groups_id_gen = IdGenerator("state_groups", "id", self) self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
@ -119,7 +123,7 @@ class DataStore(RoomMemberStore, RoomStore,
self._push_rule_id_gen = IdGenerator("push_rules", "id", self) self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self) self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
events_max = self._stream_id_gen.get_max_token(None) events_max = self._stream_id_gen.get_max_token()
event_cache_prefill, min_event_val = self._get_cache_dict( event_cache_prefill, min_event_val = self._get_cache_dict(
db_conn, "events", db_conn, "events",
entity_column="room_id", entity_column="room_id",
@ -135,13 +139,31 @@ class DataStore(RoomMemberStore, RoomStore,
"MembershipStreamChangeCache", events_max, "MembershipStreamChangeCache", events_max,
) )
account_max = self._account_data_id_gen.get_max_token(None) account_max = self._account_data_id_gen.get_max_token()
self._account_data_stream_cache = StreamChangeCache( self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max, "AccountDataAndTagsChangeCache", account_max,
) )
self.__presence_on_startup = self._get_active_presence(db_conn)
presence_cache_prefill, min_presence_val = self._get_cache_dict(
db_conn, "presence_stream",
entity_column="user_id",
stream_column="stream_id",
max_value=self._presence_id_gen.get_max_token(),
)
self.presence_stream_cache = StreamChangeCache(
"PresenceStreamChangeCache", min_presence_val,
prefilled_cache=presence_cache_prefill
)
super(DataStore, self).__init__(hs) super(DataStore, self).__init__(hs)
def take_presence_startup_info(self):
active_on_startup = self.__presence_on_startup
self.__presence_on_startup = None
return active_on_startup
def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value): def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value):
# Fetch a mapping of room_id -> max stream position for "recent" rooms. # Fetch a mapping of room_id -> max stream position for "recent" rooms.
# It doesn't really matter how many we get, the StreamChangeCache will # It doesn't really matter how many we get, the StreamChangeCache will
@ -161,6 +183,7 @@ class DataStore(RoomMemberStore, RoomStore,
txn = db_conn.cursor() txn = db_conn.cursor()
txn.execute(sql, (int(max_value),)) txn.execute(sql, (int(max_value),))
rows = txn.fetchall() rows = txn.fetchall()
txn.close()
cache = { cache = {
row[0]: int(row[1]) row[0]: int(row[1])
@ -174,6 +197,28 @@ class DataStore(RoomMemberStore, RoomStore,
return cache, min_val return cache, min_val
def _get_active_presence(self, db_conn):
"""Fetch non-offline presence from the database so that we can register
the appropriate time outs.
"""
sql = (
"SELECT user_id, state, last_active_ts, last_federation_update_ts,"
" last_user_sync_ts, status_msg, currently_active FROM presence_stream"
" WHERE state != ?"
)
sql = self.database_engine.convert_param_style(sql)
txn = db_conn.cursor()
txn.execute(sql, (PresenceState.OFFLINE,))
rows = self.cursor_to_dict(txn)
txn.close()
for row in rows:
row["currently_active"] = bool(row["currently_active"])
return [UserPresenceState(**row) for row in rows]
@defer.inlineCallbacks @defer.inlineCallbacks
def insert_client_ip(self, user, access_token, ip, user_agent): def insert_client_ip(self, user, access_token, ip, user_agent):
now = int(self._clock.time_msec()) now = int(self._clock.time_msec())

View file

@ -168,7 +168,7 @@ class AccountDataStore(SQLBaseStore):
"add_room_account_data", add_account_data_txn, next_id "add_room_account_data", add_account_data_txn, next_id
) )
result = yield self._account_data_id_gen.get_max_token(self) result = yield self._account_data_id_gen.get_max_token()
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -207,7 +207,7 @@ class AccountDataStore(SQLBaseStore):
"add_user_account_data", add_account_data_txn, next_id "add_user_account_data", add_account_data_txn, next_id
) )
result = yield self._account_data_id_gen.get_max_token(self) result = yield self._account_data_id_gen.get_max_token()
defer.returnValue(result) defer.returnValue(result)
def _update_max_stream_id(self, txn, next_id): def _update_max_stream_id(self, txn, next_id):

View file

@ -114,10 +114,10 @@ class EventFederationStore(SQLBaseStore):
retcol="event_id", retcol="event_id",
) )
def get_latest_events_in_room(self, room_id): def get_latest_event_ids_and_hashes_in_room(self, room_id):
return self.runInteraction( return self.runInteraction(
"get_latest_events_in_room", "get_latest_event_ids_and_hashes_in_room",
self._get_latest_events_in_room, self._get_latest_event_ids_and_hashes_in_room,
room_id, room_id,
) )
@ -132,7 +132,7 @@ class EventFederationStore(SQLBaseStore):
desc="get_latest_event_ids_in_room", desc="get_latest_event_ids_in_room",
) )
def _get_latest_events_in_room(self, txn, room_id): def _get_latest_event_ids_and_hashes_in_room(self, txn, room_id):
sql = ( sql = (
"SELECT e.event_id, e.depth FROM events as e " "SELECT e.event_id, e.depth FROM events as e "
"INNER JOIN event_forward_extremities as f " "INNER JOIN event_forward_extremities as f "

View file

@ -27,15 +27,14 @@ class EventPushActionsStore(SQLBaseStore):
def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples): def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples):
""" """
:param event: the event set actions for :param event: the event set actions for
:param tuples: list of tuples of (user_id, profile_tag, actions) :param tuples: list of tuples of (user_id, actions)
""" """
values = [] values = []
for uid, profile_tag, actions in tuples: for uid, actions in tuples:
values.append({ values.append({
'room_id': event.room_id, 'room_id': event.room_id,
'event_id': event.event_id, 'event_id': event.event_id,
'user_id': uid, 'user_id': uid,
'profile_tag': profile_tag,
'actions': json.dumps(actions), 'actions': json.dumps(actions),
'stream_ordering': event.internal_metadata.stream_ordering, 'stream_ordering': event.internal_metadata.stream_ordering,
'topological_ordering': event.depth, 'topological_ordering': event.depth,
@ -43,7 +42,7 @@ class EventPushActionsStore(SQLBaseStore):
'highlight': 1 if _action_has_highlight(actions) else 0, 'highlight': 1 if _action_has_highlight(actions) else 0,
}) })
for uid, _, __ in tuples: for uid, __ in tuples:
txn.call_after( txn.call_after(
self.get_unread_event_push_actions_by_room_for_user.invalidate_many, self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
(event.room_id, uid) (event.room_id, uid)

View file

@ -131,7 +131,7 @@ class EventsStore(SQLBaseStore):
except _RollbackButIsFineException: except _RollbackButIsFineException:
pass pass
max_persisted_id = yield self._stream_id_gen.get_max_token(self) max_persisted_id = yield self._stream_id_gen.get_max_token()
defer.returnValue((stream_ordering, max_persisted_id)) defer.returnValue((stream_ordering, max_persisted_id))
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 29 SCHEMA_VERSION = 30
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))

View file

@ -14,73 +14,129 @@
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList from synapse.api.constants import PresenceState
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from collections import namedtuple
from twisted.internet import defer from twisted.internet import defer
class UserPresenceState(namedtuple("UserPresenceState",
("user_id", "state", "last_active_ts",
"last_federation_update_ts", "last_user_sync_ts",
"status_msg", "currently_active"))):
"""Represents the current presence state of the user.
user_id (str)
last_active (int): Time in msec that the user last interacted with server.
last_federation_update (int): Time in msec since either a) we sent a presence
update to other servers or b) we received a presence update, depending
on if is a local user or not.
last_user_sync (int): Time in msec that the user last *completed* a sync
(or event stream).
status_msg (str): User set status message.
"""
def copy_and_replace(self, **kwargs):
return self._replace(**kwargs)
@classmethod
def default(cls, user_id):
"""Returns a default presence state.
"""
return cls(
user_id=user_id,
state=PresenceState.OFFLINE,
last_active_ts=0,
last_federation_update_ts=0,
last_user_sync_ts=0,
status_msg=None,
currently_active=False,
)
class PresenceStore(SQLBaseStore): class PresenceStore(SQLBaseStore):
def create_presence(self, user_localpart): @defer.inlineCallbacks
res = self._simple_insert( def update_presence(self, presence_states):
table="presence", stream_id_manager = yield self._presence_id_gen.get_next(self)
values={"user_id": user_localpart}, with stream_id_manager as stream_id:
desc="create_presence", yield self.runInteraction(
"update_presence",
self._update_presence_txn, stream_id, presence_states,
)
defer.returnValue((stream_id, self._presence_id_gen.get_max_token()))
def _update_presence_txn(self, txn, stream_id, presence_states):
for state in presence_states:
txn.call_after(
self.presence_stream_cache.entity_has_changed,
state.user_id, stream_id,
)
# Actually insert new rows
self._simple_insert_many_txn(
txn,
table="presence_stream",
values=[
{
"stream_id": stream_id,
"user_id": state.user_id,
"state": state.state,
"last_active_ts": state.last_active_ts,
"last_federation_update_ts": state.last_federation_update_ts,
"last_user_sync_ts": state.last_user_sync_ts,
"status_msg": state.status_msg,
"currently_active": state.currently_active,
}
for state in presence_states
],
) )
self.get_presence_state.invalidate((user_localpart,)) # Delete old rows to stop database from getting really big
return res sql = (
"DELETE FROM presence_stream WHERE"
def has_presence_state(self, user_localpart): " stream_id < ?"
return self._simple_select_one( " AND user_id IN (%s)"
table="presence",
keyvalues={"user_id": user_localpart},
retcols=["user_id"],
allow_none=True,
desc="has_presence_state",
) )
@cached(max_entries=2000) batches = (
def get_presence_state(self, user_localpart): presence_states[i:i + 50]
return self._simple_select_one( for i in xrange(0, len(presence_states), 50)
table="presence",
keyvalues={"user_id": user_localpart},
retcols=["state", "status_msg", "mtime"],
desc="get_presence_state",
) )
for states in batches:
@cachedList(get_presence_state.cache, list_name="user_localparts", args = [stream_id]
inlineCallbacks=True) args.extend(s.user_id for s in states)
def get_presence_states(self, user_localparts): txn.execute(
rows = yield self._simple_select_many_batch( sql % (",".join("?" for _ in states),),
table="presence", args
column="user_id", )
iterable=user_localparts,
retcols=("user_id", "state", "status_msg", "mtime",),
desc="get_presence_states",
)
defer.returnValue({
row["user_id"]: {
"state": row["state"],
"status_msg": row["status_msg"],
"mtime": row["mtime"],
}
for row in rows
})
@defer.inlineCallbacks @defer.inlineCallbacks
def set_presence_state(self, user_localpart, new_state): def get_presence_for_users(self, user_ids):
res = yield self._simple_update_one( rows = yield self._simple_select_many_batch(
table="presence", table="presence_stream",
keyvalues={"user_id": user_localpart}, column="user_id",
updatevalues={"state": new_state["state"], iterable=user_ids,
"status_msg": new_state["status_msg"], keyvalues={},
"mtime": self._clock.time_msec()}, retcols=(
desc="set_presence_state", "user_id",
"state",
"last_active_ts",
"last_federation_update_ts",
"last_user_sync_ts",
"status_msg",
"currently_active",
),
) )
self.get_presence_state.invalidate((user_localpart,)) for row in rows:
defer.returnValue(res) row["currently_active"] = bool(row["currently_active"])
defer.returnValue([UserPresenceState(**row) for row in rows])
def get_current_presence_token(self):
return self._presence_id_gen.get_max_token()
def allow_presence_visible(self, observed_localpart, observer_userid): def allow_presence_visible(self, observed_localpart, observer_userid):
return self._simple_insert( return self._simple_insert(
@ -128,6 +184,7 @@ class PresenceStore(SQLBaseStore):
desc="set_presence_list_accepted", desc="set_presence_list_accepted",
) )
self.get_presence_list_accepted.invalidate((observer_localpart,)) self.get_presence_list_accepted.invalidate((observer_localpart,))
self.get_presence_list_observers_accepted.invalidate((observed_userid,))
defer.returnValue(result) defer.returnValue(result)
def get_presence_list(self, observer_localpart, accepted=None): def get_presence_list(self, observer_localpart, accepted=None):
@ -154,6 +211,19 @@ class PresenceStore(SQLBaseStore):
desc="get_presence_list_accepted", desc="get_presence_list_accepted",
) )
@cachedInlineCallbacks()
def get_presence_list_observers_accepted(self, observed_userid):
user_localparts = yield self._simple_select_onecol(
table="presence_list",
keyvalues={"observed_user_id": observed_userid, "accepted": True},
retcol="user_id",
desc="get_presence_list_accepted",
)
defer.returnValue([
"@%s:%s" % (u, self.hs.hostname,) for u in user_localparts
])
@defer.inlineCallbacks @defer.inlineCallbacks
def del_presence_list(self, observer_localpart, observed_userid): def del_presence_list(self, observer_localpart, observed_userid):
yield self._simple_delete_one( yield self._simple_delete_one(
@ -163,3 +233,4 @@ class PresenceStore(SQLBaseStore):
desc="del_presence_list", desc="del_presence_list",
) )
self.get_presence_list_accepted.invalidate((observer_localpart,)) self.get_presence_list_accepted.invalidate((observer_localpart,))
self.get_presence_list_observers_accepted.invalidate((observed_userid,))

View file

@ -80,9 +80,9 @@ class PusherStore(SQLBaseStore):
defer.returnValue(rows) defer.returnValue(rows)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_pusher(self, user_id, access_token, profile_tag, kind, app_id, def add_pusher(self, user_id, access_token, kind, app_id,
app_display_name, device_display_name, app_display_name, device_display_name,
pushkey, pushkey_ts, lang, data): pushkey, pushkey_ts, lang, data, profile_tag=""):
try: try:
next_id = yield self._pushers_id_gen.get_next() next_id = yield self._pushers_id_gen.get_next()
yield self._simple_upsert( yield self._simple_upsert(
@ -95,12 +95,12 @@ class PusherStore(SQLBaseStore):
dict( dict(
access_token=access_token, access_token=access_token,
kind=kind, kind=kind,
profile_tag=profile_tag,
app_display_name=app_display_name, app_display_name=app_display_name,
device_display_name=device_display_name, device_display_name=device_display_name,
ts=pushkey_ts, ts=pushkey_ts,
lang=lang, lang=lang,
data=encode_canonical_json(data), data=encode_canonical_json(data),
profile_tag=profile_tag,
), ),
insertion_values=dict( insertion_values=dict(
id=next_id, id=next_id,

View file

@ -31,7 +31,7 @@ class ReceiptsStore(SQLBaseStore):
super(ReceiptsStore, self).__init__(hs) super(ReceiptsStore, self).__init__(hs)
self._receipts_stream_cache = StreamChangeCache( self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token(None) "ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token()
) )
@cached(num_args=2) @cached(num_args=2)
@ -222,7 +222,7 @@ class ReceiptsStore(SQLBaseStore):
defer.returnValue(results) defer.returnValue(results)
def get_max_receipt_stream_id(self): def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_max_token(self) return self._receipts_id_gen.get_max_token()
def insert_linearized_receipt_txn(self, txn, room_id, receipt_type, def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
user_id, event_id, data, stream_id): user_id, event_id, data, stream_id):
@ -347,7 +347,7 @@ class ReceiptsStore(SQLBaseStore):
room_id, receipt_type, user_id, event_ids, data room_id, receipt_type, user_id, event_ids, data
) )
max_persisted_id = yield self._stream_id_gen.get_max_token(self) max_persisted_id = yield self._stream_id_gen.get_max_token()
defer.returnValue((stream_id, max_persisted_id)) defer.returnValue((stream_id, max_persisted_id))

View file

@ -0,0 +1,30 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE presence_stream(
stream_id BIGINT,
user_id TEXT,
state TEXT,
last_active_ts BIGINT,
last_federation_update_ts BIGINT,
last_user_sync_ts BIGINT,
status_msg TEXT,
currently_active BOOLEAN
);
CREATE INDEX presence_stream_id ON presence_stream(stream_id, user_id);
CREATE INDEX presence_stream_user_id ON presence_stream(user_id);
CREATE INDEX presence_stream_state ON presence_stream(state);

View file

@ -531,7 +531,7 @@ class StreamStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_events_max_id(self, direction='f'): def get_room_events_max_id(self, direction='f'):
token = yield self._stream_id_gen.get_max_token(self) token = yield self._stream_id_gen.get_max_token()
if direction != 'b': if direction != 'b':
defer.returnValue("s%d" % (token,)) defer.returnValue("s%d" % (token,))
else: else:

View file

@ -30,7 +30,7 @@ class TagsStore(SQLBaseStore):
Returns: Returns:
A deferred int. A deferred int.
""" """
return self._account_data_id_gen.get_max_token(self) return self._account_data_id_gen.get_max_token()
@cached() @cached()
def get_tags_for_user(self, user_id): def get_tags_for_user(self, user_id):
@ -147,7 +147,7 @@ class TagsStore(SQLBaseStore):
self.get_tags_for_user.invalidate((user_id,)) self.get_tags_for_user.invalidate((user_id,))
result = yield self._account_data_id_gen.get_max_token(self) result = yield self._account_data_id_gen.get_max_token()
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -169,7 +169,7 @@ class TagsStore(SQLBaseStore):
self.get_tags_for_user.invalidate((user_id,)) self.get_tags_for_user.invalidate((user_id,))
result = yield self._account_data_id_gen.get_max_token(self) result = yield self._account_data_id_gen.get_max_token()
defer.returnValue(result) defer.returnValue(result)
def _update_revision_txn(self, txn, user_id, room_id, next_id): def _update_revision_txn(self, txn, user_id, room_id, next_id):

View file

@ -130,7 +130,7 @@ class StreamIdGenerator(object):
return manager() return manager()
def get_max_token(self, store): def get_max_token(self):
"""Returns the maximum stream id such that all stream ids less than or """Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted. equal to it have been successfully persisted.
""" """

View file

@ -73,6 +73,14 @@ class DomainSpecificString(
"""Return a string encoding the fields of the structure object.""" """Return a string encoding the fields of the structure object."""
return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain) return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain)
@classmethod
def is_valid(cls, s):
try:
cls.from_string(s)
return True
except:
return False
__str__ = to_string __str__ = to_string
@classmethod @classmethod

View file

@ -42,7 +42,7 @@ class Clock(object):
def time_msec(self): def time_msec(self):
"""Returns the current system time in miliseconds since epoch.""" """Returns the current system time in miliseconds since epoch."""
return self.time() * 1000 return int(self.time() * 1000)
def looping_call(self, f, msec): def looping_call(self, f, msec):
l = task.LoopingCall(f) l = task.LoopingCall(f)

View file

@ -0,0 +1,97 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class _Entry(object):
__slots__ = ["end_key", "queue"]
def __init__(self, end_key):
self.end_key = end_key
self.queue = []
class WheelTimer(object):
"""Stores arbitrary objects that will be returned after their timers have
expired.
"""
def __init__(self, bucket_size=5000):
"""
Args:
bucket_size (int): Size of buckets in ms. Corresponds roughly to the
accuracy of the timer.
"""
self.bucket_size = bucket_size
self.entries = []
self.current_tick = 0
def insert(self, now, obj, then):
"""Inserts object into timer.
Args:
now (int): Current time in msec
obj (object): Object to be inserted
then (int): When to return the object strictly after.
"""
then_key = int(then / self.bucket_size) + 1
if self.entries:
min_key = self.entries[0].end_key
max_key = self.entries[-1].end_key
if then_key <= max_key:
# The max here is to protect against inserts for times in the past
self.entries[max(min_key, then_key) - min_key].queue.append(obj)
return
next_key = int(now / self.bucket_size) + 1
if self.entries:
last_key = self.entries[-1].end_key
else:
last_key = next_key
# Handle the case when `then` is in the past and `entries` is empty.
then_key = max(last_key, then_key)
# Add empty entries between the end of the current list and when we want
# to insert. This ensures there are no gaps.
self.entries.extend(
_Entry(key) for key in xrange(last_key, then_key + 1)
)
self.entries[-1].queue.append(obj)
def fetch(self, now):
"""Fetch any objects that have timed out
Args:
now (ms): Current time in msec
Returns:
list: List of objects that have timed out
"""
now_key = int(now / self.bucket_size)
ret = []
while self.entries and self.entries[0].end_key <= now_key:
ret.extend(self.entries.pop(0).queue)
return ret
def __len__(self):
l = 0
for entry in self.entries:
l += len(entry.queue)
return l

File diff suppressed because it is too large Load diff

View file

@ -1,311 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This file contains tests of the "presence-like" data that is shared between
presence and profiles; namely, the displayname and avatar_url."""
from tests import unittest
from twisted.internet import defer
from mock import Mock, call, ANY, NonCallableMock
from ..utils import MockClock, setup_test_homeserver
from synapse.api.constants import PresenceState
from synapse.handlers.presence import PresenceHandler
from synapse.handlers.profile import ProfileHandler
from synapse.types import UserID
OFFLINE = PresenceState.OFFLINE
UNAVAILABLE = PresenceState.UNAVAILABLE
ONLINE = PresenceState.ONLINE
class MockReplication(object):
def __init__(self):
self.edu_handlers = {}
def register_edu_handler(self, edu_type, handler):
self.edu_handlers[edu_type] = handler
def register_query_handler(self, query_type, handler):
pass
def received_edu(self, origin, edu_type, content):
self.edu_handlers[edu_type](origin, content)
class PresenceAndProfileHandlers(object):
def __init__(self, hs):
self.presence_handler = PresenceHandler(hs)
self.profile_handler = ProfileHandler(hs)
class PresenceProfilelikeDataTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
hs = yield setup_test_homeserver(
clock=MockClock(),
datastore=Mock(spec=[
"set_presence_state",
"is_presence_visible",
"set_profile_displayname",
"get_rooms_for_user",
]),
handlers=None,
resource_for_federation=Mock(),
http_client=None,
replication_layer=MockReplication(),
ratelimiter=NonCallableMock(spec_set=[
"send_message",
]),
)
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
hs.handlers = PresenceAndProfileHandlers(hs)
self.datastore = hs.get_datastore()
self.replication = hs.get_replication_layer()
self.replication.send_edu = Mock()
def send_edu(*args, **kwargs):
# print "send_edu: %s, %s" % (args, kwargs)
return defer.succeed((200, "OK"))
self.replication.send_edu.side_effect = send_edu
def get_profile_displayname(user_localpart):
return defer.succeed("Frank")
self.datastore.get_profile_displayname = get_profile_displayname
def is_presence_visible(*args, **kwargs):
return defer.succeed(False)
self.datastore.is_presence_visible = is_presence_visible
def get_profile_avatar_url(user_localpart):
return defer.succeed("http://foo")
self.datastore.get_profile_avatar_url = get_profile_avatar_url
self.presence_list = [
{"observed_user_id": "@banana:test", "accepted": True},
{"observed_user_id": "@clementine:test", "accepted": True},
]
def get_presence_list(user_localpart, accepted=None):
return defer.succeed(self.presence_list)
self.datastore.get_presence_list = get_presence_list
def user_rooms_intersect(userlist):
return defer.succeed(False)
self.datastore.user_rooms_intersect = user_rooms_intersect
self.handlers = hs.get_handlers()
self.mock_update_client = Mock()
def update(*args, **kwargs):
# print "mock_update_client: %s, %s" %(args, kwargs)
return defer.succeed(None)
self.mock_update_client.side_effect = update
self.handlers.presence_handler.push_update_to_clients = (
self.mock_update_client)
hs.handlers.room_member_handler = Mock(spec=[
"get_joined_rooms_for_user",
])
hs.handlers.room_member_handler.get_joined_rooms_for_user = (
lambda u: defer.succeed([]))
# Some local users to test with
self.u_apple = UserID.from_string("@apple:test")
self.u_banana = UserID.from_string("@banana:test")
self.u_clementine = UserID.from_string("@clementine:test")
# Remote user
self.u_potato = UserID.from_string("@potato:remote")
self.mock_get_joined = (
self.datastore.get_rooms_for_user
)
@defer.inlineCallbacks
def test_set_my_state(self):
self.presence_list = [
{"observed_user_id": "@banana:test", "accepted": True},
{"observed_user_id": "@clementine:test", "accepted": True},
]
mocked_set = self.datastore.set_presence_state
mocked_set.return_value = defer.succeed({"state": OFFLINE})
yield self.handlers.presence_handler.set_state(
target_user=self.u_apple, auth_user=self.u_apple,
state={"presence": UNAVAILABLE, "status_msg": "Away"})
mocked_set.assert_called_with("apple",
{"state": UNAVAILABLE, "status_msg": "Away"}
)
@defer.inlineCallbacks
def test_push_local(self):
def get_joined(*args):
return defer.succeed([])
self.mock_get_joined.side_effect = get_joined
self.presence_list = [
{"observed_user_id": "@banana:test", "accepted": True},
{"observed_user_id": "@clementine:test", "accepted": True},
]
self.datastore.set_presence_state.return_value = defer.succeed(
{"state": ONLINE}
)
# TODO(paul): Gut-wrenching
from synapse.handlers.presence import UserPresenceCache
self.handlers.presence_handler._user_cachemap[self.u_apple] = (
UserPresenceCache()
)
self.handlers.presence_handler._user_cachemap[self.u_apple].update(
{"presence": OFFLINE}, serial=0
)
apple_set = self.handlers.presence_handler._local_pushmap.setdefault(
"apple", set())
apple_set.add(self.u_banana)
apple_set.add(self.u_clementine)
yield self.handlers.presence_handler.set_state(self.u_apple,
self.u_apple, {"presence": ONLINE}
)
yield self.handlers.presence_handler.set_state(self.u_banana,
self.u_banana, {"presence": ONLINE}
)
presence = yield self.handlers.presence_handler.get_presence_list(
observer_user=self.u_apple, accepted=True)
self.assertEquals([
{"observed_user": self.u_banana,
"presence": ONLINE,
"last_active_ago": 0,
"displayname": "Frank",
"avatar_url": "http://foo",
"accepted": True},
{"observed_user": self.u_clementine,
"presence": OFFLINE,
"accepted": True}
], presence)
self.mock_update_client.assert_has_calls([
call(
users_to_push={self.u_apple, self.u_banana, self.u_clementine},
room_ids=[]
),
], any_order=True)
self.mock_update_client.reset_mock()
self.datastore.set_profile_displayname.return_value = defer.succeed(
None)
yield self.handlers.profile_handler.set_displayname(self.u_apple,
self.u_apple, "I am an Apple")
self.mock_update_client.assert_has_calls([
call(
users_to_push={self.u_apple, self.u_banana, self.u_clementine},
room_ids=[],
),
], any_order=True)
@defer.inlineCallbacks
def test_push_remote(self):
self.presence_list = [
{"observed_user_id": "@potato:remote", "accepted": True},
]
self.datastore.set_presence_state.return_value = defer.succeed(
{"state": ONLINE}
)
# TODO(paul): Gut-wrenching
from synapse.handlers.presence import UserPresenceCache
self.handlers.presence_handler._user_cachemap[self.u_apple] = (
UserPresenceCache()
)
self.handlers.presence_handler._user_cachemap[self.u_apple].update(
{"presence": OFFLINE}, serial=0
)
apple_set = self.handlers.presence_handler._remote_sendmap.setdefault(
"apple", set())
apple_set.add(self.u_potato.domain)
yield self.handlers.presence_handler.set_state(self.u_apple,
self.u_apple, {"presence": ONLINE}
)
self.replication.send_edu.assert_called_with(
destination="remote",
edu_type="m.presence",
content={
"push": [
{"user_id": "@apple:test",
"presence": "online",
"last_active_ago": 0,
"displayname": "Frank",
"avatar_url": "http://foo"},
],
},
)
@defer.inlineCallbacks
def test_recv_remote(self):
self.presence_list = [
{"observed_user_id": "@banana:test"},
{"observed_user_id": "@clementine:test"},
]
# TODO(paul): Gut-wrenching
potato_set = self.handlers.presence_handler._remote_recvmap.setdefault(
self.u_potato, set()
)
potato_set.add(self.u_apple)
yield self.replication.received_edu(
"remote", "m.presence", {
"push": [
{"user_id": "@potato:remote",
"presence": "online",
"displayname": "Frank",
"avatar_url": "http://foo"},
],
}
)
self.mock_update_client.assert_called_with(
users_to_push=set([self.u_apple]),
room_ids=[],
)
state = yield self.handlers.presence_handler.get_state(self.u_potato,
self.u_apple)
self.assertEquals(
{"presence": ONLINE,
"displayname": "Frank",
"avatar_url": "http://foo"},
state)

View file

@ -70,9 +70,6 @@ class ProfileTestCase(unittest.TestCase):
self.handler = hs.get_handlers().profile_handler self.handler = hs.get_handlers().profile_handler
# TODO(paul): Icky signal declarings.. booo
hs.get_distributor().declare("changed_presencelike_data")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_my_name(self): def test_get_my_name(self):
yield self.store.set_profile_displayname( yield self.store.set_profile_displayname(

View file

@ -1,412 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests REST events for /presence paths."""
from tests import unittest
from twisted.internet import defer
from mock import Mock
from ....utils import MockHttpResource, setup_test_homeserver
from synapse.api.constants import PresenceState
from synapse.handlers.presence import PresenceHandler
from synapse.rest.client.v1 import presence
from synapse.rest.client.v1 import events
from synapse.types import Requester, UserID
from synapse.util.async import run_on_reactor
from collections import namedtuple
OFFLINE = PresenceState.OFFLINE
UNAVAILABLE = PresenceState.UNAVAILABLE
ONLINE = PresenceState.ONLINE
myid = "@apple:test"
PATH_PREFIX = "/_matrix/client/api/v1"
class NullSource(object):
"""This event source never yields any events and its token remains at
zero. It may be useful for unit-testing."""
def __init__(self, hs):
pass
def get_new_events(
self,
user,
from_key,
room_ids=None,
limit=None,
is_guest=None
):
return defer.succeed(([], from_key))
def get_current_key(self, direction='f'):
return defer.succeed(0)
def get_pagination_rows(self, user, pagination_config, key):
return defer.succeed(([], pagination_config.from_key))
class JustPresenceHandlers(object):
def __init__(self, hs):
self.presence_handler = PresenceHandler(hs)
class PresenceStateTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
hs = yield setup_test_homeserver(
datastore=Mock(spec=[
"get_presence_state",
"set_presence_state",
"insert_client_ip",
]),
http_client=None,
resource_for_client=self.mock_resource,
resource_for_federation=self.mock_resource,
)
hs.handlers = JustPresenceHandlers(hs)
self.datastore = hs.get_datastore()
self.datastore.get_app_service_by_token = Mock(return_value=None)
def get_presence_list(*a, **kw):
return defer.succeed([])
self.datastore.get_presence_list = get_presence_list
def _get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(myid),
"token_id": 1,
"is_guest": False,
}
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
room_member_handler = hs.handlers.room_member_handler = Mock(
spec=[
"get_joined_rooms_for_user",
]
)
def get_rooms_for_user(user):
return defer.succeed([])
room_member_handler.get_joined_rooms_for_user = get_rooms_for_user
presence.register_servlets(hs, self.mock_resource)
self.u_apple = UserID.from_string(myid)
@defer.inlineCallbacks
def test_get_my_status(self):
mocked_get = self.datastore.get_presence_state
mocked_get.return_value = defer.succeed(
{"state": ONLINE, "status_msg": "Available"}
)
(code, response) = yield self.mock_resource.trigger("GET",
"/presence/%s/status" % (myid), None)
self.assertEquals(200, code)
self.assertEquals(
{"presence": ONLINE, "status_msg": "Available"},
response
)
mocked_get.assert_called_with("apple")
@defer.inlineCallbacks
def test_set_my_status(self):
mocked_set = self.datastore.set_presence_state
mocked_set.return_value = defer.succeed({"state": OFFLINE})
(code, response) = yield self.mock_resource.trigger("PUT",
"/presence/%s/status" % (myid),
'{"presence": "unavailable", "status_msg": "Away"}')
self.assertEquals(200, code)
mocked_set.assert_called_with("apple",
{"state": UNAVAILABLE, "status_msg": "Away"}
)
class PresenceListTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
hs = yield setup_test_homeserver(
datastore=Mock(spec=[
"has_presence_state",
"get_presence_state",
"allow_presence_visible",
"is_presence_visible",
"add_presence_list_pending",
"set_presence_list_accepted",
"del_presence_list",
"get_presence_list",
"insert_client_ip",
]),
http_client=None,
resource_for_client=self.mock_resource,
resource_for_federation=self.mock_resource,
)
hs.handlers = JustPresenceHandlers(hs)
self.datastore = hs.get_datastore()
self.datastore.get_app_service_by_token = Mock(return_value=None)
def has_presence_state(user_localpart):
return defer.succeed(
user_localpart in ("apple", "banana",)
)
self.datastore.has_presence_state = has_presence_state
def _get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(myid),
"token_id": 1,
"is_guest": False,
}
hs.handlers.room_member_handler = Mock(
spec=[
"get_joined_rooms_for_user",
]
)
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
presence.register_servlets(hs, self.mock_resource)
self.u_apple = UserID.from_string("@apple:test")
self.u_banana = UserID.from_string("@banana:test")
@defer.inlineCallbacks
def test_get_my_list(self):
self.datastore.get_presence_list.return_value = defer.succeed(
[{"observed_user_id": "@banana:test", "accepted": True}],
)
(code, response) = yield self.mock_resource.trigger("GET",
"/presence/list/%s" % (myid), None)
self.assertEquals(200, code)
self.assertEquals([
{"user_id": "@banana:test", "presence": OFFLINE, "accepted": True},
], response)
self.datastore.get_presence_list.assert_called_with(
"apple", accepted=True
)
@defer.inlineCallbacks
def test_invite(self):
self.datastore.add_presence_list_pending.return_value = (
defer.succeed(())
)
self.datastore.is_presence_visible.return_value = defer.succeed(
True
)
(code, response) = yield self.mock_resource.trigger("POST",
"/presence/list/%s" % (myid),
"""{"invite": ["@banana:test"]}"""
)
self.assertEquals(200, code)
self.datastore.add_presence_list_pending.assert_called_with(
"apple", "@banana:test"
)
self.datastore.set_presence_list_accepted.assert_called_with(
"apple", "@banana:test"
)
@defer.inlineCallbacks
def test_drop(self):
self.datastore.del_presence_list.return_value = (
defer.succeed(())
)
(code, response) = yield self.mock_resource.trigger("POST",
"/presence/list/%s" % (myid),
"""{"drop": ["@banana:test"]}"""
)
self.assertEquals(200, code)
self.datastore.del_presence_list.assert_called_with(
"apple", "@banana:test"
)
class PresenceEventStreamTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
# HIDEOUS HACKERY
# TODO(paul): This should be injected in via the HomeServer DI system
from synapse.streams.events import (
PresenceEventSource, EventSources
)
old_SOURCE_TYPES = EventSources.SOURCE_TYPES
def tearDown():
EventSources.SOURCE_TYPES = old_SOURCE_TYPES
self.tearDown = tearDown
EventSources.SOURCE_TYPES = {
k: NullSource for k in old_SOURCE_TYPES.keys()
}
EventSources.SOURCE_TYPES["presence"] = PresenceEventSource
clock = Mock(spec=[
"call_later",
"cancel_call_later",
"time_msec",
"looping_call",
])
clock.time_msec.return_value = 1000000
hs = yield setup_test_homeserver(
http_client=None,
resource_for_client=self.mock_resource,
resource_for_federation=self.mock_resource,
datastore=Mock(spec=[
"set_presence_state",
"get_presence_list",
"get_rooms_for_user",
]),
clock=clock,
)
def _get_user_by_req(req=None, allow_guest=False):
return Requester(UserID.from_string(myid), "", False)
hs.get_v1auth().get_user_by_req = _get_user_by_req
presence.register_servlets(hs, self.mock_resource)
events.register_servlets(hs, self.mock_resource)
hs.handlers.room_member_handler = Mock(spec=[])
self.room_members = []
def get_rooms_for_user(user):
if user in self.room_members:
return ["a-room"]
else:
return []
hs.handlers.room_member_handler.get_joined_rooms_for_user = get_rooms_for_user
hs.handlers.room_member_handler.get_room_members = (
lambda r: self.room_members if r == "a-room" else []
)
hs.handlers.room_member_handler._filter_events_for_client = (
lambda user_id, events, **kwargs: events
)
self.mock_datastore = hs.get_datastore()
self.mock_datastore.get_app_service_by_token = Mock(return_value=None)
self.mock_datastore.get_app_service_by_user_id = Mock(
return_value=defer.succeed(None)
)
self.mock_datastore.get_rooms_for_user = (
lambda u: [
namedtuple("Room", "room_id")(r)
for r in get_rooms_for_user(UserID.from_string(u))
]
)
def get_profile_displayname(user_id):
return defer.succeed("Frank")
self.mock_datastore.get_profile_displayname = get_profile_displayname
def get_profile_avatar_url(user_id):
return defer.succeed(None)
self.mock_datastore.get_profile_avatar_url = get_profile_avatar_url
def user_rooms_intersect(user_list):
room_member_ids = map(lambda u: u.to_string(), self.room_members)
shared = all(map(lambda i: i in room_member_ids, user_list))
return defer.succeed(shared)
self.mock_datastore.user_rooms_intersect = user_rooms_intersect
def get_joined_hosts_for_room(room_id):
return []
self.mock_datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
self.presence = hs.get_handlers().presence_handler
self.u_apple = UserID.from_string("@apple:test")
self.u_banana = UserID.from_string("@banana:test")
@defer.inlineCallbacks
def test_shortpoll(self):
self.room_members = [self.u_apple, self.u_banana]
self.mock_datastore.set_presence_state.return_value = defer.succeed(
{"state": ONLINE}
)
self.mock_datastore.get_presence_list.return_value = defer.succeed(
[]
)
(code, response) = yield self.mock_resource.trigger("GET",
"/events?timeout=0", None)
self.assertEquals(200, code)
# We've forced there to be only one data stream so the tokens will
# all be ours
# I'll already get my own presence state change
self.assertEquals({"start": "0_1_0_0_0", "end": "0_1_0_0_0", "chunk": []},
response
)
self.mock_datastore.set_presence_state.return_value = defer.succeed(
{"state": ONLINE}
)
self.mock_datastore.get_presence_list.return_value = defer.succeed([])
yield self.presence.set_state(self.u_banana, self.u_banana,
state={"presence": ONLINE}
)
yield run_on_reactor()
(code, response) = yield self.mock_resource.trigger("GET",
"/events?from=s0_1_0&timeout=0", None)
self.assertEquals(200, code)
self.assertEquals({"start": "s0_1_0_0_0", "end": "s0_2_0_0_0", "chunk": [
{"type": "m.presence",
"content": {
"user_id": "@banana:test",
"presence": ONLINE,
"displayname": "Frank",
"last_active_ago": 0,
}},
]}, response)

View file

@ -953,12 +953,6 @@ class RoomInitialSyncTestCase(RestTestCase):
synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)
# Since I'm getting my own presence I need to exist as far as presence
# is concerned.
hs.get_handlers().presence_handler.registered_user(
UserID.from_string(self.user_id)
)
# create the room # create the room
self.room_id = yield self.create_room_as(self.user_id) self.room_id = yield self.create_room_as(self.user_id)

View file

@ -34,32 +34,6 @@ class PresenceStoreTestCase(unittest.TestCase):
self.u_apple = UserID.from_string("@apple:test") self.u_apple = UserID.from_string("@apple:test")
self.u_banana = UserID.from_string("@banana:test") self.u_banana = UserID.from_string("@banana:test")
@defer.inlineCallbacks
def test_state(self):
yield self.store.create_presence(
self.u_apple.localpart
)
state = yield self.store.get_presence_state(
self.u_apple.localpart
)
self.assertEquals(
{"state": None, "status_msg": None, "mtime": None}, state
)
yield self.store.set_presence_state(
self.u_apple.localpart, {"state": "online", "status_msg": "Here"}
)
state = yield self.store.get_presence_state(
self.u_apple.localpart
)
self.assertEquals(
{"state": "online", "status_msg": "Here", "mtime": 1000000}, state
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_visibility(self): def test_visibility(self):
self.assertFalse((yield self.store.is_presence_visible( self.assertFalse((yield self.store.is_presence_visible(

View file

@ -0,0 +1,74 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .. import unittest
from synapse.util.wheel_timer import WheelTimer
class WheelTimerTestCase(unittest.TestCase):
def test_single_insert_fetch(self):
wheel = WheelTimer(bucket_size=5)
obj = object()
wheel.insert(100, obj, 150)
self.assertListEqual(wheel.fetch(101), [])
self.assertListEqual(wheel.fetch(110), [])
self.assertListEqual(wheel.fetch(120), [])
self.assertListEqual(wheel.fetch(130), [])
self.assertListEqual(wheel.fetch(149), [])
self.assertListEqual(wheel.fetch(156), [obj])
self.assertListEqual(wheel.fetch(170), [])
def test_mutli_insert(self):
wheel = WheelTimer(bucket_size=5)
obj1 = object()
obj2 = object()
obj3 = object()
wheel.insert(100, obj1, 150)
wheel.insert(105, obj2, 130)
wheel.insert(106, obj3, 160)
self.assertListEqual(wheel.fetch(110), [])
self.assertListEqual(wheel.fetch(135), [obj2])
self.assertListEqual(wheel.fetch(149), [])
self.assertListEqual(wheel.fetch(158), [obj1])
self.assertListEqual(wheel.fetch(160), [])
self.assertListEqual(wheel.fetch(200), [obj3])
self.assertListEqual(wheel.fetch(210), [])
def test_insert_past(self):
wheel = WheelTimer(bucket_size=5)
obj = object()
wheel.insert(100, obj, 50)
self.assertListEqual(wheel.fetch(120), [obj])
def test_insert_past_mutli(self):
wheel = WheelTimer(bucket_size=5)
obj1 = object()
obj2 = object()
obj3 = object()
wheel.insert(100, obj1, 150)
wheel.insert(100, obj2, 140)
wheel.insert(100, obj3, 50)
self.assertListEqual(wheel.fetch(110), [obj3])
self.assertListEqual(wheel.fetch(120), [])
self.assertListEqual(wheel.fetch(147), [obj2])
self.assertListEqual(wheel.fetch(200), [obj1])
self.assertListEqual(wheel.fetch(240), [])

View file

@ -224,12 +224,12 @@ class MockClock(object):
def time_msec(self): def time_msec(self):
return self.time() * 1000 return self.time() * 1000
def call_later(self, delay, callback): def call_later(self, delay, callback, *args, **kwargs):
current_context = LoggingContext.current_context() current_context = LoggingContext.current_context()
def wrapped_callback(): def wrapped_callback():
LoggingContext.thread_local.current_context = current_context LoggingContext.thread_local.current_context = current_context
callback() callback(*args, **kwargs)
t = [self.now + delay, wrapped_callback, False] t = [self.now + delay, wrapped_callback, False]
self.timers.append(t) self.timers.append(t)