Merge branch 'develop' into push_badge_counts

This commit is contained in:
David Baker 2016-01-19 18:17:23 +00:00
commit afb7b377f2
35 changed files with 1043 additions and 1246 deletions

View file

@ -258,6 +258,14 @@ During setup of Synapse you need to call python2.7 directly again::
...substituting your host and domain name as appropriate. ...substituting your host and domain name as appropriate.
FreeBSD
-------
Synapse can be installed via FreeBSD Ports or Packages:
- Ports: ``cd /usr/ports/net/py-matrix-synapse && make install clean``
- Packages: ``pkg install py27-matrix-synapse``
Windows Install Windows Install
--------------- ---------------
Synapse can be installed on Cygwin. It requires the following Cygwin packages: Synapse can be installed on Cygwin. It requires the following Cygwin packages:

View file

@ -510,37 +510,14 @@ class Auth(object):
""" """
# Can optionally look elsewhere in the request (e.g. headers) # Can optionally look elsewhere in the request (e.g. headers)
try: try:
access_token = request.args["access_token"][0] user_id = yield self._get_appservice_user_id(request.args)
if user_id:
# Check for application service tokens with a user_id override
try:
app_service = yield self.store.get_app_service_by_token(
access_token
)
if not app_service:
raise KeyError
user_id = app_service.sender
if "user_id" in request.args:
user_id = request.args["user_id"][0]
if not app_service.is_interested_in_user(user_id):
raise AuthError(
403,
"Application service cannot masquerade as this user."
)
if not user_id:
raise KeyError
request.authenticated_entity = user_id request.authenticated_entity = user_id
defer.returnValue( defer.returnValue(
Requester(UserID.from_string(user_id), "", False) Requester(UserID.from_string(user_id), "", False)
) )
return
except KeyError:
pass # normal users won't have the user_id query parameter set.
access_token = request.args["access_token"][0]
user_info = yield self._get_user_by_access_token(access_token) user_info = yield self._get_user_by_access_token(access_token)
user = user_info["user"] user = user_info["user"]
token_id = user_info["token_id"] token_id = user_info["token_id"]
@ -573,6 +550,33 @@ class Auth(object):
errcode=Codes.MISSING_TOKEN errcode=Codes.MISSING_TOKEN
) )
@defer.inlineCallbacks
def _get_appservice_user_id(self, request_args):
app_service = yield self.store.get_app_service_by_token(
request_args["access_token"][0]
)
if app_service is None:
defer.returnValue(None)
if "user_id" not in request_args:
defer.returnValue(app_service.sender)
user_id = request_args["user_id"][0]
if app_service.sender == user_id:
defer.returnValue(app_service.sender)
if not app_service.is_interested_in_user(user_id):
raise AuthError(
403,
"Application service cannot masquerade as this user."
)
if not (yield self.store.get_user_by_id(user_id)):
raise AuthError(
403,
"Application service has not registered this user"
)
defer.returnValue(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_user_by_access_token(self, token): def _get_user_by_access_token(self, token):
""" Get a registered user's ID. """ Get a registered user's ID.

View file

@ -29,6 +29,7 @@ class Codes(object):
USER_IN_USE = "M_USER_IN_USE" USER_IN_USE = "M_USER_IN_USE"
ROOM_IN_USE = "M_ROOM_IN_USE" ROOM_IN_USE = "M_ROOM_IN_USE"
BAD_PAGINATION = "M_BAD_PAGINATION" BAD_PAGINATION = "M_BAD_PAGINATION"
BAD_STATE = "M_BAD_STATE"
UNKNOWN = "M_UNKNOWN" UNKNOWN = "M_UNKNOWN"
NOT_FOUND = "M_NOT_FOUND" NOT_FOUND = "M_NOT_FOUND"
MISSING_TOKEN = "M_MISSING_TOKEN" MISSING_TOKEN = "M_MISSING_TOKEN"
@ -42,6 +43,7 @@ class Codes(object):
EXCLUSIVE = "M_EXCLUSIVE" EXCLUSIVE = "M_EXCLUSIVE"
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED" THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
THREEPID_IN_USE = "THREEPID_IN_USE" THREEPID_IN_USE = "THREEPID_IN_USE"
INVALID_USERNAME = "M_INVALID_USERNAME"
class CodeMessageException(RuntimeError): class CodeMessageException(RuntimeError):

View file

@ -88,6 +88,9 @@ import time
logger = logging.getLogger("synapse.app.homeserver") logger = logging.getLogger("synapse.app.homeserver")
ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$')
def gz_wrap(r): def gz_wrap(r):
return EncodingResourceWrapper(r, [GzipEncoderFactory()]) return EncodingResourceWrapper(r, [GzipEncoderFactory()])
@ -495,9 +498,8 @@ class SynapseRequest(Request):
) )
def get_redacted_uri(self): def get_redacted_uri(self):
return re.sub( return ACCESS_TOKEN_RE.sub(
r'(\?.*access_token=)[^&]*(.*)$', r'\1<redacted>\3',
r'\1<redacted>\2',
self.uri self.uri
) )

View file

@ -117,6 +117,15 @@ class EventBase(object):
def __set__(self, instance, value): def __set__(self, instance, value):
raise AttributeError("Unrecognized attribute %s" % (instance,)) raise AttributeError("Unrecognized attribute %s" % (instance,))
def __getitem__(self, field):
return self._event_dict[field]
def __contains__(self, field):
return field in self._event_dict
def items(self):
return self._event_dict.items()
class FrozenEvent(EventBase): class FrozenEvent(EventBase):
def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None): def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None):

View file

@ -53,16 +53,54 @@ class BaseHandler(object):
self.event_builder_factory = hs.get_event_builder_factory() self.event_builder_factory = hs.get_event_builder_factory()
@defer.inlineCallbacks @defer.inlineCallbacks
def _filter_events_for_client(self, user_id, events, is_guest=False): def _filter_events_for_clients(self, users, events):
# Assumes that user has at some point joined the room if not is_guest. """ Returns dict of user_id -> list of events that user is allowed to
see.
"""
event_id_to_state = yield self.store.get_state_for_events(
frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, None),
)
)
forgotten = yield defer.gatherResults([
self.store.who_forgot_in_room(
room_id,
)
for room_id in frozenset(e.room_id for e in events)
], consumeErrors=True)
# Set of membership event_ids that have been forgotten
event_id_forgotten = frozenset(
row["event_id"] for rows in forgotten for row in rows
)
def allowed(event, user_id, is_guest):
state = event_id_to_state[event.event_id]
visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
if visibility_event:
visibility = visibility_event.content.get("history_visibility", "shared")
else:
visibility = "shared"
def allowed(event, membership, visibility):
if visibility == "world_readable": if visibility == "world_readable":
return True return True
if is_guest: if is_guest:
return False return False
membership_event = state.get((EventTypes.Member, user_id), None)
if membership_event:
if membership_event.event_id in event_id_forgotten:
membership = None
else:
membership = membership_event.membership
else:
membership = None
if membership == Membership.JOIN: if membership == Membership.JOIN:
return True return True
@ -78,43 +116,20 @@ class BaseHandler(object):
return True return True
event_id_to_state = yield self.store.get_state_for_events( defer.returnValue({
frozenset(e.event_id for e in events), user_id: [
types=( event
(EventTypes.RoomHistoryVisibility, ""), for event in events
(EventTypes.Member, user_id), if allowed(event, user_id, is_guest)
) ]
) for user_id, is_guest in users
})
events_to_return = [] @defer.inlineCallbacks
for event in events: def _filter_events_for_client(self, user_id, events, is_guest=False):
state = event_id_to_state[event.event_id] # Assumes that user has at some point joined the room if not is_guest.
res = yield self._filter_events_for_clients([(user_id, is_guest)], events)
membership_event = state.get((EventTypes.Member, user_id), None) defer.returnValue(res.get(user_id, []))
if membership_event:
was_forgotten_at_event = yield self.store.was_forgotten_at(
membership_event.state_key,
membership_event.room_id,
membership_event.event_id
)
if was_forgotten_at_event:
membership = None
else:
membership = membership_event.membership
else:
membership = None
visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
if visibility_event:
visibility = visibility_event.content.get("history_visibility", "shared")
else:
visibility = "shared"
should_include = allowed(event, membership, visibility)
if should_include:
events_to_return.append(event)
defer.returnValue(events_to_return)
def ratelimit(self, user_id): def ratelimit(self, user_id):
time_now = self.clock.time() time_now = self.clock.time()
@ -171,11 +186,9 @@ class BaseHandler(object):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_new_client_event(self, event, context, extra_destinations=[], def handle_new_client_event(self, event, context, extra_users=[]):
extra_users=[], suppress_auth=False):
# We now need to go and hit out to wherever we need to hit out to. # We now need to go and hit out to wherever we need to hit out to.
if not suppress_auth:
self.auth.check(event, auth_events=context.current_state) self.auth.check(event, auth_events=context.current_state)
yield self.maybe_kick_guest_users(event, context.current_state.values()) yield self.maybe_kick_guest_users(event, context.current_state.values())
@ -253,12 +266,12 @@ class BaseHandler(object):
event, context=context event, context=context
) )
action_generator = ActionGenerator(self.store) action_generator = ActionGenerator(self.hs)
yield action_generator.handle_push_actions_for_event( yield action_generator.handle_push_actions_for_event(
event, self event, self
) )
destinations = set(extra_destinations) destinations = set()
for k, s in context.current_state.items(): for k, s in context.current_state.items():
try: try:
if k[0] == EventTypes.Member: if k[0] == EventTypes.Member:

View file

@ -245,7 +245,7 @@ class FederationHandler(BaseHandler):
yield user_joined_room(self.distributor, user, event.room_id) yield user_joined_room(self.distributor, user, event.room_id)
if not backfilled and not event.internal_metadata.is_outlier(): if not backfilled and not event.internal_metadata.is_outlier():
action_generator = ActionGenerator(self.store) action_generator = ActionGenerator(self.hs)
yield action_generator.handle_push_actions_for_event( yield action_generator.handle_push_actions_for_event(
event, self event, self
) )
@ -1692,7 +1692,7 @@ class FederationHandler(BaseHandler):
self.auth.check(event, context.current_state) self.auth.check(event, context.current_state)
yield self._validate_keyserver(event, auth_events=context.current_state) yield self._validate_keyserver(event, auth_events=context.current_state)
member_handler = self.hs.get_handlers().room_member_handler member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.change_membership(event, context) yield member_handler.send_membership_event(event, context)
else: else:
destinations = set([x.split(":", 1)[-1] for x in (sender, room_id)]) destinations = set([x.split(":", 1)[-1] for x in (sender, room_id)])
yield self.replication_layer.forward_third_party_invite( yield self.replication_layer.forward_third_party_invite(
@ -1721,7 +1721,7 @@ class FederationHandler(BaseHandler):
# TODO: Make sure the signatures actually are correct. # TODO: Make sure the signatures actually are correct.
event.signatures.update(returned_invite.signatures) event.signatures.update(returned_invite.signatures)
member_handler = self.hs.get_handlers().room_member_handler member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.change_membership(event, context) yield member_handler.send_membership_event(event, context)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_display_name_to_third_party_invite(self, event_dict, event, context): def add_display_name_to_third_party_invite(self, event_dict, event, context):

View file

@ -174,30 +174,25 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk) defer.returnValue(chunk)
@defer.inlineCallbacks @defer.inlineCallbacks
def create_and_send_event(self, event_dict, ratelimit=True, def create_event(self, event_dict, token_id=None, txn_id=None):
token_id=None, txn_id=None, is_guest=False): """
""" Given a dict from a client, create and handle a new event. Given a dict from a client, create a new event.
Creates an FrozenEvent object, filling out auth_events, prev_events, Creates an FrozenEvent object, filling out auth_events, prev_events,
etc. etc.
Adds display names to Join membership events. Adds display names to Join membership events.
Persists and notifies local clients and federation.
Args: Args:
event_dict (dict): An entire event event_dict (dict): An entire event
Returns:
Tuple of created event (FrozenEvent), Context
""" """
builder = self.event_builder_factory.new(event_dict) builder = self.event_builder_factory.new(event_dict)
self.validator.validate_new(builder) self.validator.validate_new(builder)
if ratelimit:
self.ratelimit(builder.user_id)
# TODO(paul): Why does 'event' not have a 'user' object?
user = UserID.from_string(builder.user_id)
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if builder.type == EventTypes.Member: if builder.type == EventTypes.Member:
membership = builder.content.get("membership", None) membership = builder.content.get("membership", None)
if membership == Membership.JOIN: if membership == Membership.JOIN:
@ -216,6 +211,25 @@ class MessageHandler(BaseHandler):
event, context = yield self._create_new_client_event( event, context = yield self._create_new_client_event(
builder=builder, builder=builder,
) )
defer.returnValue((event, context))
@defer.inlineCallbacks
def send_event(self, event, context, ratelimit=True, is_guest=False):
"""
Persists and notifies local clients and federation of an event.
Args:
event (FrozenEvent) the event to send.
context (Context) the context of the event.
ratelimit (bool): Whether to rate limit this send.
is_guest (bool): Whether the sender is a guest.
"""
user = UserID.from_string(event.sender)
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if ratelimit:
self.ratelimit(event.sender)
if event.is_state(): if event.is_state():
prev_state = context.current_state.get((event.type, event.state_key)) prev_state = context.current_state.get((event.type, event.state_key))
@ -229,7 +243,7 @@ class MessageHandler(BaseHandler):
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
member_handler = self.hs.get_handlers().room_member_handler member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.change_membership(event, context, is_guest=is_guest) yield member_handler.send_membership_event(event, context, is_guest=is_guest)
else: else:
yield self.handle_new_client_event( yield self.handle_new_client_event(
event=event, event=event,
@ -241,6 +255,25 @@ class MessageHandler(BaseHandler):
with PreserveLoggingContext(): with PreserveLoggingContext():
presence.bump_presence_active_time(user) presence.bump_presence_active_time(user)
@defer.inlineCallbacks
def create_and_send_event(self, event_dict, ratelimit=True,
token_id=None, txn_id=None, is_guest=False):
"""
Creates an event, then sends it.
See self.create_event and self.send_event.
"""
event, context = yield self.create_event(
event_dict,
token_id=token_id,
txn_id=txn_id
)
yield self.send_event(
event,
context,
ratelimit=ratelimit,
is_guest=is_guest
)
defer.returnValue(event) defer.returnValue(event)
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -53,7 +53,8 @@ class RegistrationHandler(BaseHandler):
raise SynapseError( raise SynapseError(
400, 400,
"User ID must only contain characters which do not" "User ID must only contain characters which do not"
" require URL encoding." " require URL encoding.",
Codes.INVALID_USERNAME
) )
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)

View file

@ -22,7 +22,7 @@ from synapse.types import UserID, RoomAlias, RoomID
from synapse.api.constants import ( from synapse.api.constants import (
EventTypes, Membership, JoinRules, RoomCreationPreset, EventTypes, Membership, JoinRules, RoomCreationPreset,
) )
from synapse.api.errors import AuthError, StoreError, SynapseError from synapse.api.errors import AuthError, StoreError, SynapseError, Codes
from synapse.util import stringutils, unwrapFirstError from synapse.util import stringutils, unwrapFirstError
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
@ -397,7 +397,58 @@ class RoomMemberHandler(BaseHandler):
remotedomains.add(member.domain) remotedomains.add(member.domain)
@defer.inlineCallbacks @defer.inlineCallbacks
def change_membership(self, event, context, do_auth=True, is_guest=False): def update_membership(self, requester, target, room_id, action, txn_id=None):
effective_membership_state = action
if action in ["kick", "unban"]:
effective_membership_state = "leave"
elif action == "forget":
effective_membership_state = "leave"
msg_handler = self.hs.get_handlers().message_handler
content = {"membership": unicode(effective_membership_state)}
if requester.is_guest:
content["kind"] = "guest"
event, context = yield msg_handler.create_event(
{
"type": EventTypes.Member,
"content": content,
"room_id": room_id,
"sender": requester.user.to_string(),
"state_key": target.to_string(),
},
token_id=requester.access_token_id,
txn_id=txn_id,
)
old_state = context.current_state.get((EventTypes.Member, event.state_key))
old_membership = old_state.content.get("membership") if old_state else None
if action == "unban" and old_membership != "ban":
raise SynapseError(
403,
"Cannot unban user who was not banned (membership=%s)" % old_membership,
errcode=Codes.BAD_STATE
)
if old_membership == "ban" and action != "unban":
raise SynapseError(
403,
"Cannot %s user who was is banned" % (action,),
errcode=Codes.BAD_STATE
)
yield msg_handler.send_event(
event,
context,
ratelimit=True,
is_guest=requester.is_guest
)
if action == "forget":
yield self.forget(requester.user, room_id)
@defer.inlineCallbacks
def send_membership_event(self, event, context, is_guest=False):
""" Change the membership status of a user in a room. """ Change the membership status of a user in a room.
Args: Args:
@ -432,7 +483,7 @@ class RoomMemberHandler(BaseHandler):
if not is_guest_access_allowed: if not is_guest_access_allowed:
raise AuthError(403, "Guest access not allowed") raise AuthError(403, "Guest access not allowed")
yield self._do_join(event, context, do_auth=do_auth) yield self._do_join(event, context)
else: else:
if event.membership == Membership.LEAVE: if event.membership == Membership.LEAVE:
is_host_in_room = yield self.is_host_in_room(room_id, context) is_host_in_room = yield self.is_host_in_room(room_id, context)
@ -459,9 +510,7 @@ class RoomMemberHandler(BaseHandler):
yield self._do_local_membership_update( yield self._do_local_membership_update(
event, event,
membership=event.content["membership"],
context=context, context=context,
do_auth=do_auth,
) )
if prev_state and prev_state.membership == Membership.JOIN: if prev_state and prev_state.membership == Membership.JOIN:
@ -497,12 +546,12 @@ class RoomMemberHandler(BaseHandler):
}) })
event, context = yield self._create_new_client_event(builder) event, context = yield self._create_new_client_event(builder)
yield self._do_join(event, context, room_hosts=hosts, do_auth=True) yield self._do_join(event, context, room_hosts=hosts)
defer.returnValue({"room_id": room_id}) defer.returnValue({"room_id": room_id})
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_join(self, event, context, room_hosts=None, do_auth=True): def _do_join(self, event, context, room_hosts=None):
room_id = event.room_id room_id = event.room_id
# XXX: We don't do an auth check if we are doing an invite # XXX: We don't do an auth check if we are doing an invite
@ -536,9 +585,7 @@ class RoomMemberHandler(BaseHandler):
yield self._do_local_membership_update( yield self._do_local_membership_update(
event, event,
membership=event.content["membership"],
context=context, context=context,
do_auth=do_auth,
) )
prev_state = context.current_state.get((event.type, event.state_key)) prev_state = context.current_state.get((event.type, event.state_key))
@ -603,8 +650,7 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue(room_ids) defer.returnValue(room_ids)
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_local_membership_update(self, event, membership, context, def _do_local_membership_update(self, event, context):
do_auth):
yield run_on_reactor() yield run_on_reactor()
target_user = UserID.from_string(event.state_key) target_user = UserID.from_string(event.state_key)
@ -613,7 +659,6 @@ class RoomMemberHandler(BaseHandler):
event, event,
context, context,
extra_users=[target_user], extra_users=[target_user],
suppress_auth=(not do_auth),
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -880,28 +925,39 @@ class RoomContextHandler(BaseHandler):
(excluding state). (excluding state).
Returns: Returns:
dict dict, or None if the event isn't found
""" """
before_limit = math.floor(limit/2.) before_limit = math.floor(limit/2.)
after_limit = limit - before_limit after_limit = limit - before_limit
now_token = yield self.hs.get_event_sources().get_current_token() now_token = yield self.hs.get_event_sources().get_current_token()
def filter_evts(events):
return self._filter_events_for_client(
user.to_string(),
events,
is_guest=is_guest)
event = yield self.store.get_event(event_id, get_prev_content=True,
allow_none=True)
if not event:
defer.returnValue(None)
return
filtered = yield(filter_evts([event]))
if not filtered:
raise AuthError(
403,
"You don't have permission to access that event."
)
results = yield self.store.get_events_around( results = yield self.store.get_events_around(
room_id, event_id, before_limit, after_limit room_id, event_id, before_limit, after_limit
) )
results["events_before"] = yield self._filter_events_for_client( results["events_before"] = yield filter_evts(results["events_before"])
user.to_string(), results["events_after"] = yield filter_evts(results["events_after"])
results["events_before"], results["event"] = event
is_guest=is_guest,
)
results["events_after"] = yield self._filter_events_for_client(
user.to_string(),
results["events_after"],
is_guest=is_guest,
)
if results["events_after"]: if results["events_after"]:
last_event_id = results["events_after"][-1].event_id last_event_id = results["events_after"][-1].event_id

View file

@ -55,6 +55,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
"ephemeral", "ephemeral",
"account_data", "account_data",
"unread_notification_count", "unread_notification_count",
"unread_highlight_count",
])): ])):
__slots__ = [] __slots__ = []
@ -292,9 +293,14 @@ class SyncHandler(BaseHandler):
notifs = yield self.unread_notifs_for_room_id( notifs = yield self.unread_notifs_for_room_id(
room_id, sync_config, ephemeral_by_room room_id, sync_config, ephemeral_by_room
) )
notif_count = None notif_count = None
highlight_count = None
if notifs is not None: if notifs is not None:
notif_count = len(notifs) notif_count = len(notifs)
highlight_count = len([
1 for notif in notifs if _action_has_highlight(notif["actions"])
])
current_state = yield self.get_state_at(room_id, now_token) current_state = yield self.get_state_at(room_id, now_token)
@ -307,6 +313,7 @@ class SyncHandler(BaseHandler):
room_id, tags_by_room, account_data_by_room room_id, tags_by_room, account_data_by_room
), ),
unread_notification_count=notif_count, unread_notification_count=notif_count,
unread_highlight_count=highlight_count,
)) ))
def account_data_for_user(self, account_data): def account_data_for_user(self, account_data):
@ -529,9 +536,14 @@ class SyncHandler(BaseHandler):
notifs = yield self.unread_notifs_for_room_id( notifs = yield self.unread_notifs_for_room_id(
room_id, sync_config, all_ephemeral_by_room room_id, sync_config, all_ephemeral_by_room
) )
notif_count = None notif_count = None
highlight_count = None
if notifs is not None: if notifs is not None:
notif_count = len(notifs) notif_count = len(notifs)
highlight_count = len([
1 for notif in notifs if _action_has_highlight(notif["actions"])
])
just_joined = yield self.check_joined_room(sync_config, state) just_joined = yield self.check_joined_room(sync_config, state)
if just_joined: if just_joined:
@ -553,7 +565,8 @@ class SyncHandler(BaseHandler):
account_data=self.account_data_for_room( account_data=self.account_data_for_room(
room_id, tags_by_room, account_data_by_room room_id, tags_by_room, account_data_by_room
), ),
unread_notification_count=notif_count unread_notification_count=notif_count,
unread_highlight_count=highlight_count,
) )
logger.debug("Result for room %s: %r", room_id, room_sync) logger.debug("Result for room %s: %r", room_id, room_sync)
@ -575,7 +588,8 @@ class SyncHandler(BaseHandler):
for room_id in joined_room_ids: for room_id in joined_room_ids:
room_sync = yield self.incremental_sync_with_gap_for_room( room_sync = yield self.incremental_sync_with_gap_for_room(
room_id, sync_config, since_token, now_token, room_id, sync_config, since_token, now_token,
ephemeral_by_room, tags_by_room, account_data_by_room ephemeral_by_room, tags_by_room, account_data_by_room,
all_ephemeral_by_room=all_ephemeral_by_room,
) )
if room_sync: if room_sync:
joined.append(room_sync) joined.append(room_sync)
@ -655,7 +669,8 @@ class SyncHandler(BaseHandler):
def incremental_sync_with_gap_for_room(self, room_id, sync_config, def incremental_sync_with_gap_for_room(self, room_id, sync_config,
since_token, now_token, since_token, now_token,
ephemeral_by_room, tags_by_room, ephemeral_by_room, tags_by_room,
account_data_by_room): account_data_by_room,
all_ephemeral_by_room):
""" Get the incremental delta needed to bring the client up to date for """ Get the incremental delta needed to bring the client up to date for
the room. Gives the client the most recent events and the changes to the room. Gives the client the most recent events and the changes to
state. state.
@ -671,7 +686,7 @@ class SyncHandler(BaseHandler):
room_id, sync_config, now_token, since_token, room_id, sync_config, now_token, since_token,
) )
logging.debug("Recents %r", batch) logger.debug("Recents %r", batch)
current_state = yield self.get_state_at(room_id, now_token) current_state = yield self.get_state_at(room_id, now_token)
@ -690,11 +705,16 @@ class SyncHandler(BaseHandler):
state = yield self.get_state_at(room_id, now_token) state = yield self.get_state_at(room_id, now_token)
notifs = yield self.unread_notifs_for_room_id( notifs = yield self.unread_notifs_for_room_id(
room_id, sync_config, ephemeral_by_room room_id, sync_config, all_ephemeral_by_room
) )
notif_count = None notif_count = None
highlight_count = None
if notifs is not None: if notifs is not None:
notif_count = len(notifs) notif_count = len(notifs)
highlight_count = len([
1 for notif in notifs if _action_has_highlight(notif["actions"])
])
room_sync = JoinedSyncResult( room_sync = JoinedSyncResult(
room_id=room_id, room_id=room_id,
@ -705,6 +725,7 @@ class SyncHandler(BaseHandler):
room_id, tags_by_room, account_data_by_room room_id, tags_by_room, account_data_by_room
), ),
unread_notification_count=notif_count, unread_notification_count=notif_count,
unread_highlight_count=highlight_count,
) )
logger.debug("Room sync: %r", room_sync) logger.debug("Room sync: %r", room_sync)
@ -734,7 +755,7 @@ class SyncHandler(BaseHandler):
leave_event.room_id, sync_config, leave_token, since_token, leave_event.room_id, sync_config, leave_token, since_token,
) )
logging.debug("Recents %r", batch) logger.debug("Recents %r", batch)
state_events_at_leave = yield self.store.get_state_for_event( state_events_at_leave = yield self.store.get_state_for_event(
leave_event.event_id leave_event.event_id
@ -850,8 +871,19 @@ class SyncHandler(BaseHandler):
notifs = yield self.store.get_unread_event_push_actions_by_room_for_user( notifs = yield self.store.get_unread_event_push_actions_by_room_for_user(
room_id, sync_config.user.to_string(), last_unread_event_id room_id, sync_config.user.to_string(), last_unread_event_id
) )
else: defer.returnValue(notifs)
# There is no new information in this period, so your notification # There is no new information in this period, so your notification
# count is whatever it was last time. # count is whatever it was last time.
defer.returnValue(None) defer.returnValue(None)
defer.returnValue(notifs)
def _action_has_highlight(actions):
for action in actions:
try:
if action.get("set_tweak", None) == "highlight":
return action.get("value", True)
except AttributeError:
pass
return False

View file

@ -37,7 +37,7 @@ class Pusher(object):
MAX_BACKOFF = 60 * 60 * 1000 MAX_BACKOFF = 60 * 60 * 1000
GIVE_UP_AFTER = 24 * 60 * 60 * 1000 GIVE_UP_AFTER = 24 * 60 * 60 * 1000
def __init__(self, _hs, profile_tag, user_name, app_id, def __init__(self, _hs, profile_tag, user_id, app_id,
app_display_name, device_display_name, pushkey, pushkey_ts, app_display_name, device_display_name, pushkey, pushkey_ts,
data, last_token, last_success, failing_since): data, last_token, last_success, failing_since):
self.hs = _hs self.hs = _hs
@ -45,7 +45,7 @@ class Pusher(object):
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()
self.profile_tag = profile_tag self.profile_tag = profile_tag
self.user_name = user_name self.user_id = user_id
self.app_id = app_id self.app_id = app_id
self.app_display_name = app_display_name self.app_display_name = app_display_name
self.device_display_name = device_display_name self.device_display_name = device_display_name
@ -95,14 +95,14 @@ class Pusher(object):
# we fail to dispatch the push) # we fail to dispatch the push)
config = PaginationConfig(from_token=None, limit='1') config = PaginationConfig(from_token=None, limit='1')
chunk = yield self.evStreamHandler.get_stream( chunk = yield self.evStreamHandler.get_stream(
self.user_name, config, timeout=0, affect_presence=False self.user_id, config, timeout=0, affect_presence=False
) )
self.last_token = chunk['end'] self.last_token = chunk['end']
self.store.update_pusher_last_token( self.store.update_pusher_last_token(
self.app_id, self.pushkey, self.user_name, self.last_token self.app_id, self.pushkey, self.user_id, self.last_token
) )
logger.info("Pusher %s for user %s starting from token %s", logger.info("Pusher %s for user %s starting from token %s",
self.pushkey, self.user_name, self.last_token) self.pushkey, self.user_id, self.last_token)
wait = 0 wait = 0
while self.alive: while self.alive:
@ -127,7 +127,7 @@ class Pusher(object):
config = PaginationConfig(from_token=from_tok, limit='1') config = PaginationConfig(from_token=from_tok, limit='1')
timeout = (300 + random.randint(-60, 60)) * 1000 timeout = (300 + random.randint(-60, 60)) * 1000
chunk = yield self.evStreamHandler.get_stream( chunk = yield self.evStreamHandler.get_stream(
self.user_name, config, timeout=timeout, affect_presence=False self.user_id, config, timeout=timeout, affect_presence=False
) )
# limiting to 1 may get 1 event plus 1 presence event, so # limiting to 1 may get 1 event plus 1 presence event, so
@ -144,7 +144,7 @@ class Pusher(object):
if read_receipt: if read_receipt:
for receipt_part in read_receipt['content'].values(): for receipt_part in read_receipt['content'].values():
if 'm.read' in receipt_part: if 'm.read' in receipt_part:
if self.user_name in receipt_part['m.read'].keys(): if self.user_id in receipt_part['m.read'].keys():
have_updated_badge = True have_updated_badge = True
if not single_event: if not single_event:
@ -154,7 +154,7 @@ class Pusher(object):
yield self.store.update_pusher_last_token( yield self.store.update_pusher_last_token(
self.app_id, self.app_id,
self.pushkey, self.pushkey,
self.user_name, self.user_id,
self.last_token self.last_token
) )
return return
@ -165,8 +165,8 @@ class Pusher(object):
processed = False processed = False
rule_evaluator = yield \ rule_evaluator = yield \
push_rule_evaluator.evaluator_for_user_name_and_profile_tag( push_rule_evaluator.evaluator_for_user_id_and_profile_tag(
self.user_name, self.profile_tag, single_event['room_id'], self.store self.user_id, self.profile_tag, single_event['room_id'], self.store
) )
actions = yield rule_evaluator.actions_for_event(single_event) actions = yield rule_evaluator.actions_for_event(single_event)
@ -192,7 +192,7 @@ class Pusher(object):
pk pk
) )
yield self.hs.get_pusherpool().remove_pusher( yield self.hs.get_pusherpool().remove_pusher(
self.app_id, pk, self.user_name self.app_id, pk, self.user_id
) )
else: else:
if have_updated_badge: if have_updated_badge:
@ -208,7 +208,7 @@ class Pusher(object):
yield self.store.update_pusher_last_token_and_success( yield self.store.update_pusher_last_token_and_success(
self.app_id, self.app_id,
self.pushkey, self.pushkey,
self.user_name, self.user_id,
self.last_token, self.last_token,
self.clock.time_msec() self.clock.time_msec()
) )
@ -217,7 +217,7 @@ class Pusher(object):
yield self.store.update_pusher_failing_since( yield self.store.update_pusher_failing_since(
self.app_id, self.app_id,
self.pushkey, self.pushkey,
self.user_name, self.user_id,
self.failing_since) self.failing_since)
else: else:
if not self.failing_since: if not self.failing_since:
@ -225,7 +225,7 @@ class Pusher(object):
yield self.store.update_pusher_failing_since( yield self.store.update_pusher_failing_since(
self.app_id, self.app_id,
self.pushkey, self.pushkey,
self.user_name, self.user_id,
self.failing_since self.failing_since
) )
@ -237,13 +237,13 @@ class Pusher(object):
# of old notifications. # of old notifications.
logger.warn("Giving up on a notification to user %s, " logger.warn("Giving up on a notification to user %s, "
"pushkey %s", "pushkey %s",
self.user_name, self.pushkey) self.user_id, self.pushkey)
self.backoff_delay = Pusher.INITIAL_BACKOFF self.backoff_delay = Pusher.INITIAL_BACKOFF
self.last_token = chunk['end'] self.last_token = chunk['end']
yield self.store.update_pusher_last_token( yield self.store.update_pusher_last_token(
self.app_id, self.app_id,
self.pushkey, self.pushkey,
self.user_name, self.user_id,
self.last_token self.last_token
) )
@ -251,14 +251,14 @@ class Pusher(object):
yield self.store.update_pusher_failing_since( yield self.store.update_pusher_failing_since(
self.app_id, self.app_id,
self.pushkey, self.pushkey,
self.user_name, self.user_id,
self.failing_since self.failing_since
) )
else: else:
logger.warn("Failed to dispatch push for user %s " logger.warn("Failed to dispatch push for user %s "
"(failing for %dms)." "(failing for %dms)."
"Trying again in %dms", "Trying again in %dms",
self.user_name, self.user_id,
self.clock.time_msec() - self.failing_since, self.clock.time_msec() - self.failing_since,
self.backoff_delay) self.backoff_delay)
yield synapse.util.async.sleep(self.backoff_delay / 1000.0) yield synapse.util.async.sleep(self.backoff_delay / 1000.0)
@ -299,11 +299,11 @@ class Pusher(object):
membership_list = (Membership.INVITE, Membership.JOIN) membership_list = (Membership.INVITE, Membership.JOIN)
room_list = yield self.store.get_rooms_for_user_where_membership_is( room_list = yield self.store.get_rooms_for_user_where_membership_is(
user_id=self.user_name, user_id=self.user_id,
membership_list=membership_list membership_list=membership_list
) )
user_is_guest = yield self.store.is_guest(UserID.from_string(self.user_name)) user_is_guest = yield self.store.is_guest(self.user_id)
# XXX: importing inside method to break circular dependency. # XXX: importing inside method to break circular dependency.
# should sort out the mess by moving all this logic out of # should sort out the mess by moving all this logic out of
@ -311,7 +311,7 @@ class Pusher(object):
# handler to somewhere more amenable to re-use. # handler to somewhere more amenable to re-use.
from synapse.handlers.sync import SyncConfig from synapse.handlers.sync import SyncConfig
sync_config = SyncConfig( sync_config = SyncConfig(
user=UserID.from_string(self.user_name), user=UserID.from_string(self.user_id),
filter=FilterCollection({}), filter=FilterCollection({}),
is_guest=user_is_guest, is_guest=user_is_guest,
) )
@ -328,13 +328,13 @@ class Pusher(object):
badge += 1 badge += 1
else: else:
last_unread_event_id = sync_handler.last_read_event_id_for_room_and_user( last_unread_event_id = sync_handler.last_read_event_id_for_room_and_user(
r.room_id, self.user_name, ephemeral_by_room r.room_id, self.user_id, ephemeral_by_room
) )
if last_unread_event_id: if last_unread_event_id:
notifs = yield ( notifs = yield (
self.store.get_unread_event_push_actions_by_room_for_user( self.store.get_unread_event_push_actions_by_room_for_user(
r.room_id, self.user_name, last_unread_event_id r.room_id, self.user_id, last_unread_event_id
) )
) )
badge += len(notifs) badge += len(notifs)

View file

@ -25,8 +25,9 @@ logger = logging.getLogger(__name__)
class ActionGenerator: class ActionGenerator:
def __init__(self, store): def __init__(self, hs):
self.store = store self.hs = hs
self.store = hs.get_datastore()
# really we want to get all user ids and all profile tags too, # really we want to get all user ids and all profile tags too,
# since we want the actions for each profile tag for every user and # since we want the actions for each profile tag for every user and
# also actions for a client with no profile tag for each user. # also actions for a client with no profile tag for each user.
@ -42,7 +43,7 @@ class ActionGenerator:
) )
bulk_evaluator = yield bulk_push_rule_evaluator.evaluator_for_room_id( bulk_evaluator = yield bulk_push_rule_evaluator.evaluator_for_room_id(
event.room_id, self.store event.room_id, self.hs, self.store
) )
actions_by_user = yield bulk_evaluator.action_for_event_by_user(event, handler) actions_by_user = yield bulk_evaluator.action_for_event_by_user(event, handler)

View file

@ -15,27 +15,25 @@
from synapse.push.rulekinds import PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP from synapse.push.rulekinds import PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
def list_with_base_rules(rawrules, user_name): def list_with_base_rules(rawrules):
ruleslist = [] ruleslist = []
# shove the server default rules for each kind onto the end of each # shove the server default rules for each kind onto the end of each
current_prio_class = PRIORITY_CLASS_INVERSE_MAP.keys()[-1] current_prio_class = PRIORITY_CLASS_INVERSE_MAP.keys()[-1]
ruleslist.extend(make_base_prepend_rules( ruleslist.extend(make_base_prepend_rules(
user_name, PRIORITY_CLASS_INVERSE_MAP[current_prio_class] PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
)) ))
for r in rawrules: for r in rawrules:
if r['priority_class'] < current_prio_class: if r['priority_class'] < current_prio_class:
while r['priority_class'] < current_prio_class: while r['priority_class'] < current_prio_class:
ruleslist.extend(make_base_append_rules( ruleslist.extend(make_base_append_rules(
user_name,
PRIORITY_CLASS_INVERSE_MAP[current_prio_class] PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
)) ))
current_prio_class -= 1 current_prio_class -= 1
if current_prio_class > 0: if current_prio_class > 0:
ruleslist.extend(make_base_prepend_rules( ruleslist.extend(make_base_prepend_rules(
user_name,
PRIORITY_CLASS_INVERSE_MAP[current_prio_class] PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
)) ))
@ -43,58 +41,47 @@ def list_with_base_rules(rawrules, user_name):
while current_prio_class > 0: while current_prio_class > 0:
ruleslist.extend(make_base_append_rules( ruleslist.extend(make_base_append_rules(
user_name,
PRIORITY_CLASS_INVERSE_MAP[current_prio_class] PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
)) ))
current_prio_class -= 1 current_prio_class -= 1
if current_prio_class > 0: if current_prio_class > 0:
ruleslist.extend(make_base_prepend_rules( ruleslist.extend(make_base_prepend_rules(
user_name,
PRIORITY_CLASS_INVERSE_MAP[current_prio_class] PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
)) ))
return ruleslist return ruleslist
def make_base_append_rules(user, kind): def make_base_append_rules(kind):
rules = [] rules = []
if kind == 'override': if kind == 'override':
rules = make_base_append_override_rules() rules = BASE_APPEND_OVRRIDE_RULES
elif kind == 'underride': elif kind == 'underride':
rules = make_base_append_underride_rules(user) rules = BASE_APPEND_UNDERRIDE_RULES
elif kind == 'content': elif kind == 'content':
rules = make_base_append_content_rules(user) rules = BASE_APPEND_CONTENT_RULES
for r in rules:
r['priority_class'] = PRIORITY_CLASS_MAP[kind]
r['default'] = True # Deprecated, left for backwards compat
return rules return rules
def make_base_prepend_rules(user, kind): def make_base_prepend_rules(kind):
rules = [] rules = []
if kind == 'override': if kind == 'override':
rules = make_base_prepend_override_rules() rules = BASE_PREPEND_OVERRIDE_RULES
for r in rules:
r['priority_class'] = PRIORITY_CLASS_MAP[kind]
r['default'] = True # Deprecated, left for backwards compat
return rules return rules
def make_base_append_content_rules(user): BASE_APPEND_CONTENT_RULES = [
return [
{ {
'rule_id': 'global/content/.m.rule.contains_user_name', 'rule_id': 'global/content/.m.rule.contains_user_name',
'conditions': [ 'conditions': [
{ {
'kind': 'event_match', 'kind': 'event_match',
'key': 'content.body', 'key': 'content.body',
'pattern': user.localpart, # Matrix ID match 'pattern_type': 'user_localpart'
} }
], ],
'actions': [ 'actions': [
@ -110,8 +97,7 @@ def make_base_append_content_rules(user):
] ]
def make_base_prepend_override_rules(): BASE_PREPEND_OVERRIDE_RULES = [
return [
{ {
'rule_id': 'global/override/.m.rule.master', 'rule_id': 'global/override/.m.rule.master',
'enabled': False, 'enabled': False,
@ -123,8 +109,7 @@ def make_base_prepend_override_rules():
] ]
def make_base_append_override_rules(): BASE_APPEND_OVRRIDE_RULES = [
return [
{ {
'rule_id': 'global/override/.m.rule.suppress_notices', 'rule_id': 'global/override/.m.rule.suppress_notices',
'conditions': [ 'conditions': [
@ -132,6 +117,7 @@ def make_base_append_override_rules():
'kind': 'event_match', 'kind': 'event_match',
'key': 'content.msgtype', 'key': 'content.msgtype',
'pattern': 'm.notice', 'pattern': 'm.notice',
'_id': '_suppress_notices',
} }
], ],
'actions': [ 'actions': [
@ -141,8 +127,7 @@ def make_base_append_override_rules():
] ]
def make_base_append_underride_rules(user): BASE_APPEND_UNDERRIDE_RULES = [
return [
{ {
'rule_id': 'global/underride/.m.rule.call', 'rule_id': 'global/underride/.m.rule.call',
'conditions': [ 'conditions': [
@ -150,6 +135,7 @@ def make_base_append_underride_rules(user):
'kind': 'event_match', 'kind': 'event_match',
'key': 'type', 'key': 'type',
'pattern': 'm.call.invite', 'pattern': 'm.call.invite',
'_id': '_call',
} }
], ],
'actions': [ 'actions': [
@ -185,7 +171,8 @@ def make_base_append_underride_rules(user):
'conditions': [ 'conditions': [
{ {
'kind': 'room_member_count', 'kind': 'room_member_count',
'is': '2' 'is': '2',
'_id': 'member_count',
} }
], ],
'actions': [ 'actions': [
@ -206,16 +193,18 @@ def make_base_append_underride_rules(user):
'kind': 'event_match', 'kind': 'event_match',
'key': 'type', 'key': 'type',
'pattern': 'm.room.member', 'pattern': 'm.room.member',
'_id': '_member',
}, },
{ {
'kind': 'event_match', 'kind': 'event_match',
'key': 'content.membership', 'key': 'content.membership',
'pattern': 'invite', 'pattern': 'invite',
'_id': '_invite_member',
}, },
{ {
'kind': 'event_match', 'kind': 'event_match',
'key': 'state_key', 'key': 'state_key',
'pattern': user.to_string(), 'pattern_type': 'user_id'
}, },
], ],
'actions': [ 'actions': [
@ -236,6 +225,7 @@ def make_base_append_underride_rules(user):
'kind': 'event_match', 'kind': 'event_match',
'key': 'type', 'key': 'type',
'pattern': 'm.room.member', 'pattern': 'm.room.member',
'_id': '_member',
} }
], ],
'actions': [ 'actions': [
@ -247,12 +237,12 @@ def make_base_append_underride_rules(user):
}, },
{ {
'rule_id': 'global/underride/.m.rule.message', 'rule_id': 'global/underride/.m.rule.message',
'enabled': False,
'conditions': [ 'conditions': [
{ {
'kind': 'event_match', 'kind': 'event_match',
'key': 'type', 'key': 'type',
'pattern': 'm.room.message', 'pattern': 'm.room.message',
'_id': '_message',
} }
], ],
'actions': [ 'actions': [
@ -263,3 +253,20 @@ def make_base_append_underride_rules(user):
] ]
} }
] ]
for r in BASE_APPEND_CONTENT_RULES:
r['priority_class'] = PRIORITY_CLASS_MAP['content']
r['default'] = True
for r in BASE_PREPEND_OVERRIDE_RULES:
r['priority_class'] = PRIORITY_CLASS_MAP['override']
r['default'] = True
for r in BASE_APPEND_OVRRIDE_RULES:
r['priority_class'] = PRIORITY_CLASS_MAP['override']
r['default'] = True
for r in BASE_APPEND_UNDERRIDE_RULES:
r['priority_class'] = PRIORITY_CLASS_MAP['underride']
r['default'] = True

View file

@ -14,16 +14,15 @@
# limitations under the License. # limitations under the License.
import logging import logging
import simplejson as json import ujson as json
from twisted.internet import defer from twisted.internet import defer
from synapse.types import UserID
import baserules import baserules
from push_rule_evaluator import PushRuleEvaluator from push_rule_evaluator import PushRuleEvaluatorForEvent
from synapse.api.constants import EventTypes
from synapse.events.utils import serialize_event
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -35,28 +34,30 @@ def decode_rule_json(rule):
@defer.inlineCallbacks @defer.inlineCallbacks
def evaluator_for_room_id(room_id, store): def _get_rules(room_id, user_ids, store):
users = yield store.get_users_in_room(room_id) rules_by_user = yield store.bulk_get_push_rules(user_ids)
rules_by_user = yield store.bulk_get_push_rules(users)
rules_by_user = { rules_by_user = {
uid: baserules.list_with_base_rules( uid: baserules.list_with_base_rules([
[decode_rule_json(rule_list) for rule_list in rules_by_user[uid]] decode_rule_json(rule_list)
if uid in rules_by_user else [], for rule_list in rules_by_user.get(uid, [])
UserID.from_string(uid), ])
) for uid in user_ids
for uid in users
} }
member_events = yield store.get_current_state( defer.returnValue(rules_by_user)
room_id=room_id,
event_type='m.room.member',
) @defer.inlineCallbacks
display_names = {} def evaluator_for_room_id(room_id, hs, store):
for ev in member_events: results = yield store.get_receipts_for_room(room_id, "m.read")
if ev.content.get("displayname"): user_ids = [
display_names[ev.state_key] = ev.content.get("displayname") row["user_id"] for row in results
if hs.is_mine_id(row["user_id"])
]
rules_by_user = yield _get_rules(room_id, user_ids, store)
defer.returnValue(BulkPushRuleEvaluator( defer.returnValue(BulkPushRuleEvaluator(
room_id, rules_by_user, display_names, users, store room_id, rules_by_user, user_ids, store
)) ))
@ -69,10 +70,9 @@ class BulkPushRuleEvaluator:
the same logic to run the actual rules, but could be optimised further the same logic to run the actual rules, but could be optimised further
(see https://matrix.org/jira/browse/SYN-562) (see https://matrix.org/jira/browse/SYN-562)
""" """
def __init__(self, room_id, rules_by_user, display_names, users_in_room, store): def __init__(self, room_id, rules_by_user, users_in_room, store):
self.room_id = room_id self.room_id = room_id
self.rules_by_user = rules_by_user self.rules_by_user = rules_by_user
self.display_names = display_names
self.users_in_room = users_in_room self.users_in_room = users_in_room
self.store = store self.store = store
@ -80,15 +80,30 @@ class BulkPushRuleEvaluator:
def action_for_event_by_user(self, event, handler): def action_for_event_by_user(self, event, handler):
actions_by_user = {} actions_by_user = {}
for uid, rules in self.rules_by_user.items(): users_dict = yield self.store.are_guests(self.rules_by_user.keys())
display_name = None
if uid in self.display_names:
display_name = self.display_names[uid]
is_guest = yield self.store.is_guest(UserID.from_string(uid)) filtered_by_user = yield handler._filter_events_for_clients(
filtered = yield handler._filter_events_for_client( users_dict.items(), [event]
uid, [event], is_guest=is_guest
) )
evaluator = PushRuleEvaluatorForEvent(event, len(self.users_in_room))
condition_cache = {}
member_state = yield self.store.get_state_for_event(
event.event_id,
)
display_names = {}
for ev in member_state.values():
nm = ev.content.get("displayname", None)
if nm and ev.type == EventTypes.Member:
display_names[ev.state_key] = nm
for uid, rules in self.rules_by_user.items():
display_name = display_names.get(uid, None)
filtered = filtered_by_user[uid]
if len(filtered) == 0: if len(filtered) == 0:
continue continue
@ -96,29 +111,32 @@ class BulkPushRuleEvaluator:
if 'enabled' in rule and not rule['enabled']: if 'enabled' in rule and not rule['enabled']:
continue continue
# XXX: profile tags matches = _condition_checker(
if BulkPushRuleEvaluator.event_matches_rule( evaluator, rule['conditions'], uid, display_name, condition_cache
event, rule, )
display_name, len(self.users_in_room), None if matches:
):
actions = [x for x in rule['actions'] if x != 'dont_notify'] actions = [x for x in rule['actions'] if x != 'dont_notify']
if len(actions) > 0: if actions:
actions_by_user[uid] = actions actions_by_user[uid] = actions
break break
defer.returnValue(actions_by_user) defer.returnValue(actions_by_user)
@staticmethod
def event_matches_rule(event, rule,
display_name, room_member_count, profile_tag):
matches = True
# passing the clock all the way into here is extremely awkward and push def _condition_checker(evaluator, conditions, uid, display_name, cache):
# rules do not care about any of the relative timestamps, so we just for cond in conditions:
# pass 0 for the current time. _id = cond.get("_id", None)
client_event = serialize_event(event, 0) if _id:
res = cache.get(_id, None)
if res is False:
return False
elif res is True:
continue
for cond in rule['conditions']: res = evaluator.matches(cond, uid, display_name, None)
matches &= PushRuleEvaluator._event_fulfills_condition( if _id:
client_event, cond, display_name, room_member_count, profile_tag cache[_id] = bool(res)
)
return matches if not res:
return False
return True

View file

@ -23,13 +23,13 @@ logger = logging.getLogger(__name__)
class HttpPusher(Pusher): class HttpPusher(Pusher):
def __init__(self, _hs, profile_tag, user_name, app_id, def __init__(self, _hs, profile_tag, user_id, app_id,
app_display_name, device_display_name, pushkey, pushkey_ts, app_display_name, device_display_name, pushkey, pushkey_ts,
data, last_token, last_success, failing_since): data, last_token, last_success, failing_since):
super(HttpPusher, self).__init__( super(HttpPusher, self).__init__(
_hs, _hs,
profile_tag, profile_tag,
user_name, user_id,
app_id, app_id,
app_display_name, app_display_name,
device_display_name, device_display_name,
@ -87,7 +87,7 @@ class HttpPusher(Pusher):
} }
if event['type'] == 'm.room.member': if event['type'] == 'm.room.member':
d['notification']['membership'] = event['content']['membership'] d['notification']['membership'] = event['content']['membership']
d['notification']['user_is_target'] = event['state_key'] == self.user_name d['notification']['user_is_target'] = event['state_key'] == self.user_id
if 'content' in event: if 'content' in event:
d['notification']['content'] = event['content'] d['notification']['content'] = event['content']
@ -117,7 +117,7 @@ class HttpPusher(Pusher):
@defer.inlineCallbacks @defer.inlineCallbacks
def send_badge(self, badge): def send_badge(self, badge):
logger.info("Sending updated badge count %d to %r", badge, self.user_name) logger.info("Sending updated badge count %d to %r", badge, self.user_id)
d = { d = {
'notification': { 'notification': {
'id': '', 'id': '',

View file

@ -15,188 +15,43 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.types import UserID
import baserules import baserules
import logging import logging
import simplejson as json import simplejson as json
import re import re
from synapse.types import UserID
from synapse.util.caches.lrucache import LruCache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
GLOB_REGEX = re.compile(r'\\\[(\\\!|)(.*)\\\]')
IS_GLOB = re.compile(r'[\?\*\[\]]')
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
@defer.inlineCallbacks @defer.inlineCallbacks
def evaluator_for_user_name_and_profile_tag(user_name, profile_tag, room_id, store): def evaluator_for_user_id_and_profile_tag(user_id, profile_tag, room_id, store):
rawrules = yield store.get_push_rules_for_user(user_name) rawrules = yield store.get_push_rules_for_user(user_id)
enabled_map = yield store.get_push_rules_enabled_for_user(user_name) enabled_map = yield store.get_push_rules_enabled_for_user(user_id)
our_member_event = yield store.get_current_state( our_member_event = yield store.get_current_state(
room_id=room_id, room_id=room_id,
event_type='m.room.member', event_type='m.room.member',
state_key=user_name, state_key=user_id,
) )
defer.returnValue(PushRuleEvaluator( defer.returnValue(PushRuleEvaluator(
user_name, profile_tag, rawrules, enabled_map, user_id, profile_tag, rawrules, enabled_map,
room_id, our_member_event, store room_id, our_member_event, store
)) ))
class PushRuleEvaluator: def _room_member_count(ev, condition, room_member_count):
DEFAULT_ACTIONS = []
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
def __init__(self, user_name, profile_tag, raw_rules, enabled_map, room_id,
our_member_event, store):
self.user_name = user_name
self.profile_tag = profile_tag
self.room_id = room_id
self.our_member_event = our_member_event
self.store = store
rules = []
for raw_rule in raw_rules:
rule = dict(raw_rule)
rule['conditions'] = json.loads(raw_rule['conditions'])
rule['actions'] = json.loads(raw_rule['actions'])
rules.append(rule)
user = UserID.from_string(self.user_name)
self.rules = baserules.list_with_base_rules(rules, user)
self.enabled_map = enabled_map
@staticmethod
def tweaks_for_actions(actions):
tweaks = {}
for a in actions:
if not isinstance(a, dict):
continue
if 'set_tweak' in a and 'value' in a:
tweaks[a['set_tweak']] = a['value']
return tweaks
@defer.inlineCallbacks
def actions_for_event(self, ev):
"""
This should take into account notification settings that the user
has configured both globally and per-room when we have the ability
to do such things.
"""
if ev['user_id'] == self.user_name:
# let's assume you probably know about messages you sent yourself
defer.returnValue([])
room_id = ev['room_id']
# get *our* member event for display name matching
my_display_name = None
if self.our_member_event:
my_display_name = self.our_member_event[0].content.get("displayname")
room_members = yield self.store.get_users_in_room(room_id)
room_member_count = len(room_members)
for r in self.rules:
if r['rule_id'] in self.enabled_map:
r['enabled'] = self.enabled_map[r['rule_id']]
elif 'enabled' not in r:
r['enabled'] = True
if not r['enabled']:
continue
matches = True
conditions = r['conditions']
actions = r['actions']
for c in conditions:
matches &= self._event_fulfills_condition(
ev, c, display_name=my_display_name,
room_member_count=room_member_count,
profile_tag=self.profile_tag
)
logger.debug(
"Rule %s %s",
r['rule_id'], "matches" if matches else "doesn't match"
)
# ignore rules with no actions (we have an explict 'dont_notify')
if len(actions) == 0:
logger.warn(
"Ignoring rule id %s with no actions for user %s",
r['rule_id'], self.user_name
)
continue
if matches:
logger.info(
"%s matches for user %s, event %s",
r['rule_id'], self.user_name, ev['event_id']
)
# filter out dont_notify as we treat an empty actions list
# as dont_notify, and this doesn't take up a row in our database
actions = [x for x in actions if x != 'dont_notify']
defer.returnValue(actions)
logger.info(
"No rules match for user %s, event %s",
self.user_name, ev['event_id']
)
defer.returnValue(PushRuleEvaluator.DEFAULT_ACTIONS)
@staticmethod
def _glob_to_regexp(glob):
r = re.escape(glob)
r = re.sub(r'\\\*', r'.*?', r)
r = re.sub(r'\\\?', r'.', r)
# handle [abc], [a-z] and [!a-z] style ranges.
r = re.sub(r'\\\[(\\\!|)(.*)\\\]',
lambda x: ('[%s%s]' % (x.group(1) and '^' or '',
re.sub(r'\\\-', '-', x.group(2)))), r)
return r
@staticmethod
def _event_fulfills_condition(ev, condition,
display_name, room_member_count, profile_tag):
if condition['kind'] == 'event_match':
if 'pattern' not in condition:
logger.warn("event_match condition with no pattern")
return False
# XXX: optimisation: cache our pattern regexps
if condition['key'] == 'content.body':
r = r'\b%s\b' % PushRuleEvaluator._glob_to_regexp(condition['pattern'])
else:
r = r'^%s$' % PushRuleEvaluator._glob_to_regexp(condition['pattern'])
val = _value_for_dotted_key(condition['key'], ev)
if val is None:
return False
return re.search(r, val, flags=re.IGNORECASE) is not None
elif condition['kind'] == 'device':
if 'profile_tag' not in condition:
return True
return condition['profile_tag'] == profile_tag
elif condition['kind'] == 'contains_display_name':
# This is special because display names can be different
# between rooms and so you can't really hard code it in a rule.
# Optimisation: we should cache these names and update them from
# the event stream.
if 'content' not in ev or 'body' not in ev['content']:
return False
if not display_name:
return False
return re.search(
r"\b%s\b" % re.escape(display_name), ev['content']['body'],
flags=re.IGNORECASE
) is not None
elif condition['kind'] == 'room_member_count':
if 'is' not in condition: if 'is' not in condition:
return False return False
m = PushRuleEvaluator.INEQUALITY_EXPR.match(condition['is']) m = INEQUALITY_EXPR.match(condition['is'])
if not m: if not m:
return False return False
ineq = m.group(1) ineq = m.group(1)
@ -217,16 +72,251 @@ class PushRuleEvaluator:
return room_member_count <= rhs return room_member_count <= rhs
else: else:
return False return False
class PushRuleEvaluator:
DEFAULT_ACTIONS = []
def __init__(self, user_id, profile_tag, raw_rules, enabled_map, room_id,
our_member_event, store):
self.user_id = user_id
self.profile_tag = profile_tag
self.room_id = room_id
self.our_member_event = our_member_event
self.store = store
rules = []
for raw_rule in raw_rules:
rule = dict(raw_rule)
rule['conditions'] = json.loads(raw_rule['conditions'])
rule['actions'] = json.loads(raw_rule['actions'])
rules.append(rule)
self.rules = baserules.list_with_base_rules(rules)
self.enabled_map = enabled_map
@staticmethod
def tweaks_for_actions(actions):
tweaks = {}
for a in actions:
if not isinstance(a, dict):
continue
if 'set_tweak' in a and 'value' in a:
tweaks[a['set_tweak']] = a['value']
return tweaks
@defer.inlineCallbacks
def actions_for_event(self, ev):
"""
This should take into account notification settings that the user
has configured both globally and per-room when we have the ability
to do such things.
"""
if ev['user_id'] == self.user_id:
# let's assume you probably know about messages you sent yourself
defer.returnValue([])
room_id = ev['room_id']
# get *our* member event for display name matching
my_display_name = None
if self.our_member_event:
my_display_name = self.our_member_event[0].content.get("displayname")
room_members = yield self.store.get_users_in_room(room_id)
room_member_count = len(room_members)
evaluator = PushRuleEvaluatorForEvent(ev, room_member_count)
for r in self.rules:
enabled = self.enabled_map.get(r['rule_id'], None)
if enabled is not None and not enabled:
continue
if not r.get("enabled", True):
continue
conditions = r['conditions']
actions = r['actions']
# ignore rules with no actions (we have an explict 'dont_notify')
if len(actions) == 0:
logger.warn(
"Ignoring rule id %s with no actions for user %s",
r['rule_id'], self.user_id
)
continue
matches = True
for c in conditions:
matches = evaluator.matches(
c, self.user_id, my_display_name, self.profile_tag
)
if not matches:
break
logger.debug(
"Rule %s %s",
r['rule_id'], "matches" if matches else "doesn't match"
)
if matches:
logger.debug(
"%s matches for user %s, event %s",
r['rule_id'], self.user_id, ev['event_id']
)
# filter out dont_notify as we treat an empty actions list
# as dont_notify, and this doesn't take up a row in our database
actions = [x for x in actions if x != 'dont_notify']
defer.returnValue(actions)
logger.debug(
"No rules match for user %s, event %s",
self.user_id, ev['event_id']
)
defer.returnValue(PushRuleEvaluator.DEFAULT_ACTIONS)
class PushRuleEvaluatorForEvent(object):
def __init__(self, event, room_member_count):
self._event = event
self._room_member_count = room_member_count
# Maps strings of e.g. 'content.body' -> event["content"]["body"]
self._value_cache = _flatten_dict(event)
def matches(self, condition, user_id, display_name, profile_tag):
if condition['kind'] == 'event_match':
return self._event_match(condition, user_id)
elif condition['kind'] == 'device':
if 'profile_tag' not in condition:
return True
return condition['profile_tag'] == profile_tag
elif condition['kind'] == 'contains_display_name':
return self._contains_display_name(display_name)
elif condition['kind'] == 'room_member_count':
return _room_member_count(
self._event, condition, self._room_member_count
)
else: else:
return True return True
def _event_match(self, condition, user_id):
pattern = condition.get('pattern', None)
def _value_for_dotted_key(dotted_key, event): if not pattern:
parts = dotted_key.split(".") pattern_type = condition.get('pattern_type', None)
val = event if pattern_type == "user_id":
while len(parts) > 0: pattern = user_id
if parts[0] not in val: elif pattern_type == "user_localpart":
return None pattern = UserID.from_string(user_id).localpart
val = val[parts[0]]
parts = parts[1:] if not pattern:
return val logger.warn("event_match condition with no pattern")
return False
# XXX: optimisation: cache our pattern regexps
if condition['key'] == 'content.body':
body = self._event["content"].get("body", None)
if not body:
return False
return _glob_matches(pattern, body, word_boundary=True)
else:
haystack = self._get_value(condition['key'])
if haystack is None:
return False
return _glob_matches(pattern, haystack)
def _contains_display_name(self, display_name):
if not display_name:
return False
body = self._event["content"].get("body", None)
if not body:
return False
return _glob_matches(display_name, body, word_boundary=True)
def _get_value(self, dotted_key):
return self._value_cache.get(dotted_key, None)
def _glob_matches(glob, value, word_boundary=False):
"""Tests if value matches glob.
Args:
glob (string)
value (string): String to test against glob.
word_boundary (bool): Whether to match against word boundaries or entire
string. Defaults to False.
Returns:
bool
"""
try:
if IS_GLOB.search(glob):
r = re.escape(glob)
r = r.replace(r'\*', '.*?')
r = r.replace(r'\?', '.')
# handle [abc], [a-z] and [!a-z] style ranges.
r = GLOB_REGEX.sub(
lambda x: (
'[%s%s]' % (
x.group(1) and '^' or '',
x.group(2).replace(r'\\\-', '-')
)
),
r,
)
if word_boundary:
r = r"\b%s\b" % (r,)
r = _compile_regex(r)
return r.search(value)
else:
r = r + "$"
r = _compile_regex(r)
return r.match(value)
elif word_boundary:
r = re.escape(glob)
r = r"\b%s\b" % (r,)
r = _compile_regex(r)
return r.search(value)
else:
return value.lower() == glob.lower()
except re.error:
logger.warn("Failed to parse glob to regex: %r", glob)
return False
def _flatten_dict(d, prefix=[], result={}):
for key, value in d.items():
if isinstance(value, basestring):
result[".".join(prefix + [key])] = value.lower()
elif hasattr(value, "items"):
_flatten_dict(value, prefix=(prefix+[key]), result=result)
return result
regex_cache = LruCache(5000)
def _compile_regex(regex_str):
r = regex_cache.get(regex_str, None)
if r:
return r
r = re.compile(regex_str, flags=re.IGNORECASE)
regex_cache[regex_str] = r
return r

View file

@ -37,14 +37,14 @@ class PusherPool:
self._start_pushers(pushers) self._start_pushers(pushers)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_pusher(self, user_name, access_token, profile_tag, kind, app_id, def add_pusher(self, user_id, access_token, profile_tag, kind, app_id,
app_display_name, device_display_name, pushkey, lang, data): app_display_name, device_display_name, pushkey, lang, data):
# we try to create the pusher just to validate the config: it # we try to create the pusher just to validate the config: it
# will then get pulled out of the database, # will then get pulled out of the database,
# recreated, added and started: this means we have only one # recreated, added and started: this means we have only one
# code path adding pushers. # code path adding pushers.
self._create_pusher({ self._create_pusher({
"user_name": user_name, "user_name": user_id,
"kind": kind, "kind": kind,
"profile_tag": profile_tag, "profile_tag": profile_tag,
"app_id": app_id, "app_id": app_id,
@ -59,7 +59,7 @@ class PusherPool:
"failing_since": None "failing_since": None
}) })
yield self._add_pusher_to_store( yield self._add_pusher_to_store(
user_name, access_token, profile_tag, kind, app_id, user_id, access_token, profile_tag, kind, app_id,
app_display_name, device_display_name, app_display_name, device_display_name,
pushkey, lang, data pushkey, lang, data
) )
@ -94,11 +94,11 @@ class PusherPool:
self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks @defer.inlineCallbacks
def _add_pusher_to_store(self, user_name, access_token, profile_tag, kind, def _add_pusher_to_store(self, user_id, access_token, profile_tag, kind,
app_id, app_display_name, device_display_name, app_id, app_display_name, device_display_name,
pushkey, lang, data): pushkey, lang, data):
yield self.store.add_pusher( yield self.store.add_pusher(
user_name=user_name, user_id=user_id,
access_token=access_token, access_token=access_token,
profile_tag=profile_tag, profile_tag=profile_tag,
kind=kind, kind=kind,
@ -110,14 +110,14 @@ class PusherPool:
lang=lang, lang=lang,
data=data, data=data,
) )
self._refresh_pusher(app_id, pushkey, user_name) self._refresh_pusher(app_id, pushkey, user_id)
def _create_pusher(self, pusherdict): def _create_pusher(self, pusherdict):
if pusherdict['kind'] == 'http': if pusherdict['kind'] == 'http':
return HttpPusher( return HttpPusher(
self.hs, self.hs,
profile_tag=pusherdict['profile_tag'], profile_tag=pusherdict['profile_tag'],
user_name=pusherdict['user_name'], user_id=pusherdict['user_name'],
app_id=pusherdict['app_id'], app_id=pusherdict['app_id'],
app_display_name=pusherdict['app_display_name'], app_display_name=pusherdict['app_display_name'],
device_display_name=pusherdict['device_display_name'], device_display_name=pusherdict['device_display_name'],
@ -135,14 +135,14 @@ class PusherPool:
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _refresh_pusher(self, app_id, pushkey, user_name): def _refresh_pusher(self, app_id, pushkey, user_id):
resultlist = yield self.store.get_pushers_by_app_id_and_pushkey( resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(
app_id, pushkey app_id, pushkey
) )
p = None p = None
for r in resultlist: for r in resultlist:
if r['user_name'] == user_name: if r['user_name'] == user_id:
p = r p = r
if p: if p:
@ -171,12 +171,12 @@ class PusherPool:
logger.info("Started pushers") logger.info("Started pushers")
@defer.inlineCallbacks @defer.inlineCallbacks
def remove_pusher(self, app_id, pushkey, user_name): def remove_pusher(self, app_id, pushkey, user_id):
fullid = "%s:%s:%s" % (app_id, pushkey, user_name) fullid = "%s:%s:%s" % (app_id, pushkey, user_id)
if fullid in self.pushers: if fullid in self.pushers:
logger.info("Stopping pusher %s", fullid) logger.info("Stopping pusher %s", fullid)
self.pushers[fullid].stop() self.pushers[fullid].stop()
del self.pushers[fullid] del self.pushers[fullid]
yield self.store.delete_pusher_by_app_id_pushkey_user_name( yield self.store.delete_pusher_by_app_id_pushkey_user_id(
app_id, pushkey, user_name app_id, pushkey, user_id
) )

View file

@ -27,6 +27,7 @@ from synapse.push.rulekinds import (
PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
) )
import copy
import simplejson as json import simplejson as json
@ -51,7 +52,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
content = _parse_json(request) content = _parse_json(request)
if 'attr' in spec: if 'attr' in spec:
self.set_rule_attr(requester.user, spec, content) self.set_rule_attr(requester.user.to_string(), spec, content)
defer.returnValue((200, {})) defer.returnValue((200, {}))
try: try:
@ -73,7 +74,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
try: try:
yield self.hs.get_datastore().add_push_rule( yield self.hs.get_datastore().add_push_rule(
user_name=requester.user.to_string(), user_id=requester.user.to_string(),
rule_id=_namespaced_rule_id_from_spec(spec), rule_id=_namespaced_rule_id_from_spec(spec),
priority_class=priority_class, priority_class=priority_class,
conditions=conditions, conditions=conditions,
@ -126,7 +127,8 @@ class PushRuleRestServlet(ClientV1RestServlet):
rule["actions"] = json.loads(rawrule["actions"]) rule["actions"] = json.loads(rawrule["actions"])
ruleslist.append(rule) ruleslist.append(rule)
ruleslist = baserules.list_with_base_rules(ruleslist, user) # We're going to be mutating this a lot, so do a deep copy
ruleslist = copy.deepcopy(baserules.list_with_base_rules(ruleslist))
rules = {'global': {}, 'device': {}} rules = {'global': {}, 'device': {}}
@ -140,6 +142,16 @@ class PushRuleRestServlet(ClientV1RestServlet):
template_name = _priority_class_to_template_name(r['priority_class']) template_name = _priority_class_to_template_name(r['priority_class'])
# Remove internal stuff.
for c in r["conditions"]:
c.pop("_id", None)
pattern_type = c.pop("pattern_type", None)
if pattern_type == "user_id":
c["pattern"] = user.to_string()
elif pattern_type == "user_localpart":
c["pattern"] = user.localpart
if r['priority_class'] > PRIORITY_CLASS_MAP['override']: if r['priority_class'] > PRIORITY_CLASS_MAP['override']:
# per-device rule # per-device rule
profile_tag = _profile_tag_from_conditions(r["conditions"]) profile_tag = _profile_tag_from_conditions(r["conditions"])
@ -206,7 +218,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
def on_OPTIONS(self, _): def on_OPTIONS(self, _):
return 200, {} return 200, {}
def set_rule_attr(self, user_name, spec, val): def set_rule_attr(self, user_id, spec, val):
if spec['attr'] == 'enabled': if spec['attr'] == 'enabled':
if isinstance(val, dict) and "enabled" in val: if isinstance(val, dict) and "enabled" in val:
val = val["enabled"] val = val["enabled"]
@ -217,15 +229,15 @@ class PushRuleRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Value for 'enabled' must be boolean") raise SynapseError(400, "Value for 'enabled' must be boolean")
namespaced_rule_id = _namespaced_rule_id_from_spec(spec) namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
self.hs.get_datastore().set_push_rule_enabled( self.hs.get_datastore().set_push_rule_enabled(
user_name, namespaced_rule_id, val user_id, namespaced_rule_id, val
) )
else: else:
raise UnrecognizedRequestError() raise UnrecognizedRequestError()
def get_rule_attr(self, user_name, namespaced_rule_id, attr): def get_rule_attr(self, user_id, namespaced_rule_id, attr):
if attr == 'enabled': if attr == 'enabled':
return self.hs.get_datastore().get_push_rule_enabled_by_user_rule_id( return self.hs.get_datastore().get_push_rule_enabled_by_user_rule_id(
user_name, namespaced_rule_id user_id, namespaced_rule_id
) )
else: else:
raise UnrecognizedRequestError() raise UnrecognizedRequestError()

View file

@ -41,7 +41,7 @@ class PusherRestServlet(ClientV1RestServlet):
and 'kind' in content and and 'kind' in content and
content['kind'] is None): content['kind'] is None):
yield pusher_pool.remove_pusher( yield pusher_pool.remove_pusher(
content['app_id'], content['pushkey'], user_name=user.to_string() content['app_id'], content['pushkey'], user_id=user.to_string()
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -71,7 +71,7 @@ class PusherRestServlet(ClientV1RestServlet):
try: try:
yield pusher_pool.add_pusher( yield pusher_pool.add_pusher(
user_name=user.to_string(), user_id=user.to_string(),
access_token=requester.access_token_id, access_token=requester.access_token_id,
profile_tag=content['profile_tag'], profile_tag=content['profile_tag'],
kind=content['kind'], kind=content['kind'],

View file

@ -414,10 +414,16 @@ class RoomEventContext(ClientV1RestServlet):
requester.is_guest, requester.is_guest,
) )
if not results:
raise SynapseError(
404, "Event not found.", errcode=Codes.NOT_FOUND
)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
results["events_before"] = [ results["events_before"] = [
serialize_event(event, time_now) for event in results["events_before"] serialize_event(event, time_now) for event in results["events_before"]
] ]
results["event"] = serialize_event(results["event"], time_now)
results["events_after"] = [ results["events_after"] = [
serialize_event(event, time_now) for event in results["events_after"] serialize_event(event, time_now) for event in results["events_after"]
] ]
@ -436,7 +442,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
def register(self, http_server): def register(self, http_server):
# /rooms/$roomid/[invite|join|leave] # /rooms/$roomid/[invite|join|leave]
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/" PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
"(?P<membership_action>join|invite|leave|ban|kick|forget)") "(?P<membership_action>join|invite|leave|ban|unban|kick|forget)")
register_txn_path(self, PATTERNS, http_server) register_txn_path(self, PATTERNS, http_server)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -445,9 +451,6 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
request, request,
allow_guest=True, allow_guest=True,
) )
user = requester.user
effective_membership_action = membership_action
if requester.is_guest and membership_action not in { if requester.is_guest and membership_action not in {
Membership.JOIN, Membership.JOIN,
@ -457,13 +460,10 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
content = _parse_json(request) content = _parse_json(request)
# target user is you unless it is an invite
state_key = user.to_string()
if membership_action == "invite" and self._has_3pid_invite_keys(content): if membership_action == "invite" and self._has_3pid_invite_keys(content):
yield self.handlers.room_member_handler.do_3pid_invite( yield self.handlers.room_member_handler.do_3pid_invite(
room_id, room_id,
user, requester.user,
content["medium"], content["medium"],
content["address"], content["address"],
content["id_server"], content["id_server"],
@ -472,42 +472,21 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
return return
elif membership_action in ["invite", "ban", "kick"]:
if "user_id" in content: target = requester.user
state_key = content["user_id"] if membership_action in ["invite", "ban", "unban", "kick"]:
else: if "user_id" not in content:
raise SynapseError(400, "Missing user_id key.") raise SynapseError(400, "Missing user_id key.")
target = UserID.from_string(content["user_id"])
# make sure it looks like a user ID; it'll throw if it's invalid. yield self.handlers.room_member_handler.update_membership(
UserID.from_string(state_key) requester=requester,
target=target,
if membership_action == "kick": room_id=room_id,
effective_membership_action = "leave" action=membership_action,
elif membership_action == "forget":
effective_membership_action = "leave"
msg_handler = self.handlers.message_handler
content = {"membership": unicode(effective_membership_action)}
if requester.is_guest:
content["kind"] = "guest"
yield msg_handler.create_and_send_event(
{
"type": EventTypes.Member,
"content": content,
"room_id": room_id,
"sender": user.to_string(),
"state_key": state_key,
},
token_id=requester.access_token_id,
txn_id=txn_id, txn_id=txn_id,
is_guest=requester.is_guest,
) )
if membership_action == "forget":
yield self.handlers.room_member_handler.forget(user, room_id)
defer.returnValue((200, {})) defer.returnValue((200, {}))
def _has_3pid_invite_keys(self, content): def _has_3pid_invite_keys(self, content):

View file

@ -313,6 +313,7 @@ class SyncRestServlet(RestServlet):
ephemeral_events = filter.filter_room_ephemeral(room.ephemeral) ephemeral_events = filter.filter_room_ephemeral(room.ephemeral)
result["ephemeral"] = {"events": ephemeral_events} result["ephemeral"] = {"events": ephemeral_events}
result["unread_notification_count"] = room.unread_notification_count result["unread_notification_count"] = room.unread_notification_count
result["unread_highlight_count"] = room.unread_highlight_count
return result return result

View file

@ -139,6 +139,9 @@ class BaseHomeServer(object):
def is_mine(self, domain_specific_string): def is_mine(self, domain_specific_string):
return domain_specific_string.domain == self.hostname return domain_specific_string.domain == self.hostname
def is_mine_id(self, string):
return string.split(":", 1)[1] == self.hostname
# Build magic accessors for every dependency # Build magic accessors for every dependency
for depname in BaseHomeServer.DEPENDENCIES: for depname in BaseHomeServer.DEPENDENCIES:
BaseHomeServer._make_dependency_method(depname) BaseHomeServer._make_dependency_method(depname)

View file

@ -15,12 +15,12 @@
import logging import logging
import urllib import urllib
import yaml import yaml
from simplejson import JSONDecodeError
import simplejson as json import simplejson as json
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.appservice import ApplicationService, AppServiceTransaction from synapse.appservice import ApplicationService, AppServiceTransaction
from synapse.config._base import ConfigError
from synapse.storage.roommember import RoomsForUser from synapse.storage.roommember import RoomsForUser
from synapse.types import UserID from synapse.types import UserID
from ._base import SQLBaseStore from ._base import SQLBaseStore
@ -144,66 +144,9 @@ class ApplicationServiceStore(SQLBaseStore):
return rooms_for_user_matching_user_id return rooms_for_user_matching_user_id
def _parse_services_dict(self, results):
# SQL results in the form:
# [
# {
# 'regex': "something",
# 'url': "something",
# 'namespace': enum,
# 'as_id': 0,
# 'token': "something",
# 'hs_token': "otherthing",
# 'id': 0
# }
# ]
services = {}
for res in results:
as_token = res["token"]
if as_token is None:
continue
if as_token not in services:
# add the service
services[as_token] = {
"id": res["id"],
"url": res["url"],
"token": as_token,
"hs_token": res["hs_token"],
"sender": res["sender"],
"namespaces": {
ApplicationService.NS_USERS: [],
ApplicationService.NS_ALIASES: [],
ApplicationService.NS_ROOMS: []
}
}
# add the namespace regex if one exists
ns_int = res["namespace"]
if ns_int is None:
continue
try:
services[as_token]["namespaces"][
ApplicationService.NS_LIST[ns_int]].append(
json.loads(res["regex"])
)
except IndexError:
logger.error("Bad namespace enum '%s'. %s", ns_int, res)
except JSONDecodeError:
logger.error("Bad regex object '%s'", res["regex"])
service_list = []
for service in services.values():
service_list.append(ApplicationService(
token=service["token"],
url=service["url"],
namespaces=service["namespaces"],
hs_token=service["hs_token"],
sender=service["sender"],
id=service["id"]
))
return service_list
def _load_appservice(self, as_info): def _load_appservice(self, as_info):
required_string_fields = [ required_string_fields = [
# TODO: Add id here when it's stable to release
"url", "as_token", "hs_token", "sender_localpart" "url", "as_token", "hs_token", "sender_localpart"
] ]
for field in required_string_fields: for field in required_string_fields:
@ -245,7 +188,7 @@ class ApplicationServiceStore(SQLBaseStore):
namespaces=as_info["namespaces"], namespaces=as_info["namespaces"],
hs_token=as_info["hs_token"], hs_token=as_info["hs_token"],
sender=user_id, sender=user_id,
id=as_info["as_token"] # the token is the only unique thing here id=as_info["id"] if "id" in as_info else as_info["as_token"],
) )
def _populate_appservice_cache(self, config_files): def _populate_appservice_cache(self, config_files):
@ -256,15 +199,38 @@ class ApplicationServiceStore(SQLBaseStore):
) )
return return
# Dicts of value -> filename
seen_as_tokens = {}
seen_ids = {}
for config_file in config_files: for config_file in config_files:
try: try:
with open(config_file, 'r') as f: with open(config_file, 'r') as f:
appservice = self._load_appservice(yaml.load(f)) appservice = self._load_appservice(yaml.load(f))
if appservice.id in seen_ids:
raise ConfigError(
"Cannot reuse ID across application services: "
"%s (files: %s, %s)" % (
appservice.id, config_file, seen_ids[appservice.id],
)
)
seen_ids[appservice.id] = config_file
if appservice.token in seen_as_tokens:
raise ConfigError(
"Cannot reuse as_token across application services: "
"%s (files: %s, %s)" % (
appservice.token,
config_file,
seen_as_tokens[appservice.token],
)
)
seen_as_tokens[appservice.token] = config_file
logger.info("Loaded application service: %s", appservice) logger.info("Loaded application service: %s", appservice)
self.services_cache.append(appservice) self.services_cache.append(appservice)
except Exception as e: except Exception as e:
logger.error("Failed to load appservice from '%s'", config_file) logger.error("Failed to load appservice from '%s'", config_file)
logger.exception(e) logger.exception(e)
raise
class ApplicationServiceTransactionStore(SQLBaseStore): class ApplicationServiceTransactionStore(SQLBaseStore):

View file

@ -17,7 +17,7 @@ from ._base import SQLBaseStore
from twisted.internet import defer from twisted.internet import defer
import logging import logging
import simplejson as json import ujson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -84,7 +84,8 @@ class EventPushActionsStore(SQLBaseStore):
) )
) )
return [ return [
{"event_id": row[0], "actions": row[1]} for row in txn.fetchall() {"event_id": row[0], "actions": json.loads(row[1])}
for row in txn.fetchall()
] ]
ret = yield self.runInteraction( ret = yield self.runInteraction(

View file

@ -25,11 +25,11 @@ logger = logging.getLogger(__name__)
class PushRuleStore(SQLBaseStore): class PushRuleStore(SQLBaseStore):
@cachedInlineCallbacks() @cachedInlineCallbacks()
def get_push_rules_for_user(self, user_name): def get_push_rules_for_user(self, user_id):
rows = yield self._simple_select_list( rows = yield self._simple_select_list(
table="push_rules", table="push_rules",
keyvalues={ keyvalues={
"user_name": user_name, "user_name": user_id,
}, },
retcols=( retcols=(
"user_name", "rule_id", "priority_class", "priority", "user_name", "rule_id", "priority_class", "priority",
@ -45,11 +45,11 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(rows) defer.returnValue(rows)
@cachedInlineCallbacks() @cachedInlineCallbacks()
def get_push_rules_enabled_for_user(self, user_name): def get_push_rules_enabled_for_user(self, user_id):
results = yield self._simple_select_list( results = yield self._simple_select_list(
table="push_rules_enable", table="push_rules_enable",
keyvalues={ keyvalues={
'user_name': user_name 'user_name': user_id
}, },
retcols=( retcols=(
"user_name", "rule_id", "enabled", "user_name", "rule_id", "enabled",
@ -122,7 +122,7 @@ class PushRuleStore(SQLBaseStore):
) )
defer.returnValue(ret) defer.returnValue(ret)
def _add_push_rule_relative_txn(self, txn, user_name, **kwargs): def _add_push_rule_relative_txn(self, txn, user_id, **kwargs):
after = kwargs.pop("after", None) after = kwargs.pop("after", None)
relative_to_rule = kwargs.pop("before", after) relative_to_rule = kwargs.pop("before", after)
@ -130,7 +130,7 @@ class PushRuleStore(SQLBaseStore):
txn, txn,
table="push_rules", table="push_rules",
keyvalues={ keyvalues={
"user_name": user_name, "user_name": user_id,
"rule_id": relative_to_rule, "rule_id": relative_to_rule,
}, },
retcols=["priority_class", "priority"], retcols=["priority_class", "priority"],
@ -154,7 +154,7 @@ class PushRuleStore(SQLBaseStore):
new_rule.pop("before", None) new_rule.pop("before", None)
new_rule.pop("after", None) new_rule.pop("after", None)
new_rule['priority_class'] = priority_class new_rule['priority_class'] = priority_class
new_rule['user_name'] = user_name new_rule['user_name'] = user_id
new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn) new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn)
# check if the priority before/after is free # check if the priority before/after is free
@ -170,7 +170,7 @@ class PushRuleStore(SQLBaseStore):
"SELECT COUNT(*) FROM push_rules" "SELECT COUNT(*) FROM push_rules"
" WHERE user_name = ? AND priority_class = ? AND priority = ?" " WHERE user_name = ? AND priority_class = ? AND priority = ?"
) )
txn.execute(sql, (user_name, priority_class, new_rule_priority)) txn.execute(sql, (user_id, priority_class, new_rule_priority))
res = txn.fetchall() res = txn.fetchall()
num_conflicting = res[0][0] num_conflicting = res[0][0]
@ -187,14 +187,14 @@ class PushRuleStore(SQLBaseStore):
else: else:
sql += ">= ?" sql += ">= ?"
txn.execute(sql, (user_name, priority_class, new_rule_priority)) txn.execute(sql, (user_id, priority_class, new_rule_priority))
txn.call_after( txn.call_after(
self.get_push_rules_for_user.invalidate, (user_name,) self.get_push_rules_for_user.invalidate, (user_id,)
) )
txn.call_after( txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, (user_name,) self.get_push_rules_enabled_for_user.invalidate, (user_id,)
) )
self._simple_insert_txn( self._simple_insert_txn(
@ -203,14 +203,14 @@ class PushRuleStore(SQLBaseStore):
values=new_rule, values=new_rule,
) )
def _add_push_rule_highest_priority_txn(self, txn, user_name, def _add_push_rule_highest_priority_txn(self, txn, user_id,
priority_class, **kwargs): priority_class, **kwargs):
# find the highest priority rule in that class # find the highest priority rule in that class
sql = ( sql = (
"SELECT COUNT(*), MAX(priority) FROM push_rules" "SELECT COUNT(*), MAX(priority) FROM push_rules"
" WHERE user_name = ? and priority_class = ?" " WHERE user_name = ? and priority_class = ?"
) )
txn.execute(sql, (user_name, priority_class)) txn.execute(sql, (user_id, priority_class))
res = txn.fetchall() res = txn.fetchall()
(how_many, highest_prio) = res[0] (how_many, highest_prio) = res[0]
@ -221,15 +221,15 @@ class PushRuleStore(SQLBaseStore):
# and insert the new rule # and insert the new rule
new_rule = kwargs new_rule = kwargs
new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn) new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn)
new_rule['user_name'] = user_name new_rule['user_name'] = user_id
new_rule['priority_class'] = priority_class new_rule['priority_class'] = priority_class
new_rule['priority'] = new_prio new_rule['priority'] = new_prio
txn.call_after( txn.call_after(
self.get_push_rules_for_user.invalidate, (user_name,) self.get_push_rules_for_user.invalidate, (user_id,)
) )
txn.call_after( txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, (user_name,) self.get_push_rules_enabled_for_user.invalidate, (user_id,)
) )
self._simple_insert_txn( self._simple_insert_txn(
@ -239,48 +239,48 @@ class PushRuleStore(SQLBaseStore):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_push_rule(self, user_name, rule_id): def delete_push_rule(self, user_id, rule_id):
""" """
Delete a push rule. Args specify the row to be deleted and can be Delete a push rule. Args specify the row to be deleted and can be
any of the columns in the push_rule table, but below are the any of the columns in the push_rule table, but below are the
standard ones standard ones
Args: Args:
user_name (str): The matrix ID of the push rule owner user_id (str): The matrix ID of the push rule owner
rule_id (str): The rule_id of the rule to be deleted rule_id (str): The rule_id of the rule to be deleted
""" """
yield self._simple_delete_one( yield self._simple_delete_one(
"push_rules", "push_rules",
{'user_name': user_name, 'rule_id': rule_id}, {'user_name': user_id, 'rule_id': rule_id},
desc="delete_push_rule", desc="delete_push_rule",
) )
self.get_push_rules_for_user.invalidate((user_name,)) self.get_push_rules_for_user.invalidate((user_id,))
self.get_push_rules_enabled_for_user.invalidate((user_name,)) self.get_push_rules_enabled_for_user.invalidate((user_id,))
@defer.inlineCallbacks @defer.inlineCallbacks
def set_push_rule_enabled(self, user_name, rule_id, enabled): def set_push_rule_enabled(self, user_id, rule_id, enabled):
ret = yield self.runInteraction( ret = yield self.runInteraction(
"_set_push_rule_enabled_txn", "_set_push_rule_enabled_txn",
self._set_push_rule_enabled_txn, self._set_push_rule_enabled_txn,
user_name, rule_id, enabled user_id, rule_id, enabled
) )
defer.returnValue(ret) defer.returnValue(ret)
def _set_push_rule_enabled_txn(self, txn, user_name, rule_id, enabled): def _set_push_rule_enabled_txn(self, txn, user_id, rule_id, enabled):
new_id = self._push_rules_enable_id_gen.get_next_txn(txn) new_id = self._push_rules_enable_id_gen.get_next_txn(txn)
self._simple_upsert_txn( self._simple_upsert_txn(
txn, txn,
"push_rules_enable", "push_rules_enable",
{'user_name': user_name, 'rule_id': rule_id}, {'user_name': user_id, 'rule_id': rule_id},
{'enabled': 1 if enabled else 0}, {'enabled': 1 if enabled else 0},
{'id': new_id}, {'id': new_id},
) )
txn.call_after( txn.call_after(
self.get_push_rules_for_user.invalidate, (user_name,) self.get_push_rules_for_user.invalidate, (user_id,)
) )
txn.call_after( txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, (user_name,) self.get_push_rules_enabled_for_user.invalidate, (user_id,)
) )

View file

@ -80,7 +80,7 @@ class PusherStore(SQLBaseStore):
defer.returnValue(rows) defer.returnValue(rows)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_pusher(self, user_name, access_token, profile_tag, kind, app_id, def add_pusher(self, user_id, access_token, profile_tag, kind, app_id,
app_display_name, device_display_name, app_display_name, device_display_name,
pushkey, pushkey_ts, lang, data): pushkey, pushkey_ts, lang, data):
try: try:
@ -90,7 +90,7 @@ class PusherStore(SQLBaseStore):
dict( dict(
app_id=app_id, app_id=app_id,
pushkey=pushkey, pushkey=pushkey,
user_name=user_name, user_name=user_id,
), ),
dict( dict(
access_token=access_token, access_token=access_token,
@ -112,38 +112,38 @@ class PusherStore(SQLBaseStore):
raise StoreError(500, "Problem creating pusher.") raise StoreError(500, "Problem creating pusher.")
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_pusher_by_app_id_pushkey_user_name(self, app_id, pushkey, user_name): def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
yield self._simple_delete_one( yield self._simple_delete_one(
"pushers", "pushers",
{"app_id": app_id, "pushkey": pushkey, 'user_name': user_name}, {"app_id": app_id, "pushkey": pushkey, 'user_name': user_id},
desc="delete_pusher_by_app_id_pushkey_user_name", desc="delete_pusher_by_app_id_pushkey_user_id",
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def update_pusher_last_token(self, app_id, pushkey, user_name, last_token): def update_pusher_last_token(self, app_id, pushkey, user_id, last_token):
yield self._simple_update_one( yield self._simple_update_one(
"pushers", "pushers",
{'app_id': app_id, 'pushkey': pushkey, 'user_name': user_name}, {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
{'last_token': last_token}, {'last_token': last_token},
desc="update_pusher_last_token", desc="update_pusher_last_token",
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def update_pusher_last_token_and_success(self, app_id, pushkey, user_name, def update_pusher_last_token_and_success(self, app_id, pushkey, user_id,
last_token, last_success): last_token, last_success):
yield self._simple_update_one( yield self._simple_update_one(
"pushers", "pushers",
{'app_id': app_id, 'pushkey': pushkey, 'user_name': user_name}, {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
{'last_token': last_token, 'last_success': last_success}, {'last_token': last_token, 'last_success': last_success},
desc="update_pusher_last_token_and_success", desc="update_pusher_last_token_and_success",
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def update_pusher_failing_since(self, app_id, pushkey, user_name, def update_pusher_failing_since(self, app_id, pushkey, user_id,
failing_since): failing_since):
yield self._simple_update_one( yield self._simple_update_one(
"pushers", "pushers",
{'app_id': app_id, 'pushkey': pushkey, 'user_name': user_name}, {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
{'failing_since': failing_since}, {'failing_since': failing_since},
desc="update_pusher_failing_since", desc="update_pusher_failing_since",
) )

View file

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached
from synapse.util.caches import cache_counter, caches_by_name from synapse.util.caches import cache_counter, caches_by_name
from twisted.internet import defer from twisted.internet import defer
@ -33,6 +33,18 @@ class ReceiptsStore(SQLBaseStore):
self._receipts_stream_cache = _RoomStreamChangeCache() self._receipts_stream_cache = _RoomStreamChangeCache()
@cached(num_args=2)
def get_receipts_for_room(self, room_id, receipt_type):
return self._simple_select_list(
table="receipts_linearized",
keyvalues={
"room_id": room_id,
"receipt_type": receipt_type,
},
retcols=("user_id", "event_id"),
desc="get_receipts_for_room",
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
"""Get receipts for multiple rooms for sending to clients. """Get receipts for multiple rooms for sending to clients.

View file

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import StoreError, Codes from synapse.api.errors import StoreError, Codes
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
class RegistrationStore(SQLBaseStore): class RegistrationStore(SQLBaseStore):
@ -256,10 +256,10 @@ class RegistrationStore(SQLBaseStore):
defer.returnValue(res if res else False) defer.returnValue(res if res else False)
@cachedInlineCallbacks() @cachedInlineCallbacks()
def is_guest(self, user): def is_guest(self, user_id):
res = yield self._simple_select_one_onecol( res = yield self._simple_select_one_onecol(
table="users", table="users",
keyvalues={"name": user.to_string()}, keyvalues={"name": user_id},
retcol="is_guest", retcol="is_guest",
allow_none=True, allow_none=True,
desc="is_guest", desc="is_guest",
@ -267,6 +267,26 @@ class RegistrationStore(SQLBaseStore):
defer.returnValue(res if res else False) defer.returnValue(res if res else False)
@cachedList(cache=is_guest.cache, list_name="user_ids", num_args=1,
inlineCallbacks=True)
def are_guests(self, user_ids):
sql = "SELECT name, is_guest FROM users WHERE name IN (%s)" % (
",".join("?" for _ in user_ids),
)
rows = yield self._execute(
"are_guests", self.cursor_to_dict, sql, *user_ids
)
result = {user_id: False for user_id in user_ids}
result.update({
row["name"]: bool(row["is_guest"])
for row in rows
})
defer.returnValue(result)
def _query_for_auth(self, txn, token): def _query_for_auth(self, txn, token):
sql = ( sql = (
"SELECT users.name, users.is_guest, access_tokens.id as token_id" "SELECT users.name, users.is_guest, access_tokens.id as token_id"

View file

@ -287,6 +287,7 @@ class RoomMemberStore(SQLBaseStore):
txn.execute(sql, (user_id, room_id)) txn.execute(sql, (user_id, room_id))
yield self.runInteraction("forget_membership", f) yield self.runInteraction("forget_membership", f)
self.was_forgotten_at.invalidate_all() self.was_forgotten_at.invalidate_all()
self.who_forgot_in_room.invalidate_all()
self.did_forget.invalidate((user_id, room_id)) self.did_forget.invalidate((user_id, room_id))
@cachedInlineCallbacks(num_args=2) @cachedInlineCallbacks(num_args=2)
@ -336,3 +337,15 @@ class RoomMemberStore(SQLBaseStore):
return rows[0][0] return rows[0][0]
forgot = yield self.runInteraction("did_forget_membership_at", f) forgot = yield self.runInteraction("did_forget_membership_at", f)
defer.returnValue(forgot == 1) defer.returnValue(forgot == 1)
@cached()
def who_forgot_in_room(self, room_id):
return self._simple_select_list(
table="room_memberships",
retcols=("user_id", "event_id"),
keyvalues={
"room_id": room_id,
"forgotten": 1,
},
desc="who_forgot"
)

View file

@ -22,6 +22,9 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MAX_LIMIT = 1000
class SourcePaginationConfig(object): class SourcePaginationConfig(object):
"""A configuration object which stores pagination parameters for a """A configuration object which stores pagination parameters for a
@ -32,7 +35,7 @@ class SourcePaginationConfig(object):
self.from_key = from_key self.from_key = from_key
self.to_key = to_key self.to_key = to_key
self.direction = 'f' if direction == 'f' else 'b' self.direction = 'f' if direction == 'f' else 'b'
self.limit = int(limit) if limit is not None else None self.limit = min(int(limit), MAX_LIMIT) if limit is not None else None
def __repr__(self): def __repr__(self):
return ( return (
@ -49,7 +52,7 @@ class PaginationConfig(object):
self.from_token = from_token self.from_token = from_token
self.to_token = to_token self.to_token = to_token
self.direction = 'f' if direction == 'f' else 'b' self.direction = 'f' if direction == 'f' else 'b'
self.limit = int(limit) if limit is not None else None self.limit = min(int(limit), MAX_LIMIT) if limit is not None else None
@classmethod @classmethod
def from_request(cls, request, raise_invalid_params=True, def from_request(cls, request, raise_invalid_params=True,

View file

@ -29,6 +29,7 @@ class ApplicationServiceTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.service = ApplicationService( self.service = ApplicationService(
id="unique_identifier",
url="some_url", url="some_url",
token="some_token", token="some_token",
namespaces={ namespaces={

View file

@ -1,141 +0,0 @@
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from tests import unittest
from synapse.api.constants import EventTypes
from synapse.events import FrozenEvent
from synapse.handlers.federation import FederationHandler
from mock import NonCallableMock, ANY, Mock
from ..utils import setup_test_homeserver
class FederationTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.state_handler = NonCallableMock(spec_set=[
"compute_event_context",
])
self.auth = NonCallableMock(spec_set=[
"check",
"check_host_in_room",
])
self.hostname = "test"
hs = yield setup_test_homeserver(
self.hostname,
datastore=NonCallableMock(spec_set=[
"persist_event",
"store_room",
"get_room",
"get_destination_retry_timings",
"set_destination_retry_timings",
"have_events",
"get_users_in_room",
"bulk_get_push_rules",
"get_current_state",
"set_push_actions_for_event_and_users",
"is_guest",
"get_state_for_events",
]),
resource_for_federation=NonCallableMock(),
http_client=NonCallableMock(spec_set=[]),
notifier=NonCallableMock(spec_set=["on_new_room_event"]),
handlers=NonCallableMock(spec_set=[
"room_member_handler",
"federation_handler",
]),
auth=self.auth,
state_handler=self.state_handler,
keyring=Mock(),
)
self.datastore = hs.get_datastore()
self.handlers = hs.get_handlers()
self.notifier = hs.get_notifier()
self.hs = hs
self.handlers.federation_handler = FederationHandler(self.hs)
self.datastore.get_state_for_events.return_value = {"$a:b": {}}
@defer.inlineCallbacks
def test_msg(self):
pdu = FrozenEvent({
"type": EventTypes.Message,
"room_id": "foo",
"content": {"msgtype": u"fooo"},
"origin_server_ts": 0,
"event_id": "$a:b",
"user_id":"@a:b",
"origin": "b",
"auth_events": [],
"hashes": {"sha256":"AcLrgtUIqqwaGoHhrEvYG1YLDIsVPYJdSRGhkp3jJp8"},
})
self.datastore.persist_event.return_value = defer.succeed((1,1))
self.datastore.get_room.return_value = defer.succeed(True)
self.datastore.get_users_in_room.return_value = ["@a:b"]
self.datastore.bulk_get_push_rules.return_value = {}
self.datastore.get_current_state.return_value = {}
self.auth.check_host_in_room.return_value = defer.succeed(True)
retry_timings_res = {
"destination": "",
"retry_last_ts": 0,
"retry_interval": 0,
}
self.datastore.get_destination_retry_timings.return_value = (
defer.succeed(retry_timings_res)
)
def have_events(event_ids):
return defer.succeed({})
self.datastore.have_events.side_effect = have_events
def annotate(ev, old_state=None, outlier=False):
context = Mock()
context.current_state = {}
context.auth_events = {}
return defer.succeed(context)
self.state_handler.compute_event_context.side_effect = annotate
yield self.handlers.federation_handler.on_receive_pdu(
"fo", pdu, False
)
self.datastore.persist_event.assert_called_once_with(
ANY,
is_new_state=True,
backfilled=False,
current_state=None,
context=ANY,
)
self.state_handler.compute_event_context.assert_called_once_with(
ANY, old_state=None, outlier=False
)
self.auth.check.assert_called_once_with(ANY, auth_events={})
self.notifier.on_new_room_event.assert_called_once_with(
ANY, 1, 1, extra_users=[]
)

View file

@ -1,418 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from .. import unittest
from synapse.api.constants import EventTypes, Membership
from synapse.handlers.room import RoomMemberHandler, RoomCreationHandler
from synapse.handlers.profile import ProfileHandler
from synapse.types import UserID
from ..utils import setup_test_homeserver
from mock import Mock, NonCallableMock
class RoomMemberHandlerTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.hostname = "red"
hs = yield setup_test_homeserver(
self.hostname,
ratelimiter=NonCallableMock(spec_set=[
"send_message",
]),
datastore=NonCallableMock(spec_set=[
"persist_event",
"get_room_member",
"get_room",
"store_room",
"get_latest_events_in_room",
"add_event_hashes",
"get_users_in_room",
"bulk_get_push_rules",
"get_current_state",
"set_push_actions_for_event_and_users",
"get_state_for_events",
"is_guest",
]),
resource_for_federation=NonCallableMock(),
http_client=NonCallableMock(spec_set=[]),
notifier=NonCallableMock(spec_set=["on_new_room_event"]),
handlers=NonCallableMock(spec_set=[
"room_member_handler",
"profile_handler",
"federation_handler",
]),
auth=NonCallableMock(spec_set=[
"check",
"add_auth_events",
"check_host_in_room",
]),
state_handler=NonCallableMock(spec_set=[
"compute_event_context",
"get_current_state",
]),
)
self.federation = NonCallableMock(spec_set=[
"handle_new_event",
"send_invite",
"get_state_for_room",
])
self.datastore = hs.get_datastore()
self.handlers = hs.get_handlers()
self.notifier = hs.get_notifier()
self.state_handler = hs.get_state_handler()
self.distributor = hs.get_distributor()
self.auth = hs.get_auth()
self.hs = hs
self.handlers.federation_handler = self.federation
self.distributor.declare("collect_presencelike_data")
self.handlers.room_member_handler = RoomMemberHandler(self.hs)
self.handlers.profile_handler = ProfileHandler(self.hs)
self.room_member_handler = self.handlers.room_member_handler
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
self.datastore.persist_event.return_value = (1,1)
self.datastore.add_event_hashes.return_value = []
self.datastore.get_users_in_room.return_value = ["@bob:red"]
self.datastore.bulk_get_push_rules.return_value = {}
@defer.inlineCallbacks
def test_invite(self):
room_id = "!foo:red"
user_id = "@bob:red"
target_user_id = "@red:blue"
content = {"membership": Membership.INVITE}
builder = self.hs.get_event_builder_factory().new({
"type": EventTypes.Member,
"sender": user_id,
"state_key": target_user_id,
"room_id": room_id,
"content": content,
})
self.datastore.get_latest_events_in_room.return_value = (
defer.succeed([])
)
self.datastore.get_current_state.return_value = {}
self.datastore.get_state_for_events = lambda event_ids,types: {x: {} for x in event_ids}
def annotate(_):
ctx = Mock()
ctx.current_state = {
(EventTypes.Member, "@alice:green"): self._create_member(
user_id="@alice:green",
room_id=room_id,
),
(EventTypes.Member, "@bob:red"): self._create_member(
user_id="@bob:red",
room_id=room_id,
),
}
ctx.prev_state_events = []
return defer.succeed(ctx)
self.state_handler.compute_event_context.side_effect = annotate
def add_auth(_, ctx):
ctx.auth_events = ctx.current_state[
(EventTypes.Member, "@bob:red")
]
return defer.succeed(True)
self.auth.add_auth_events.side_effect = add_auth
def send_invite(domain, event):
return defer.succeed(event)
self.federation.send_invite.side_effect = send_invite
room_handler = self.room_member_handler
event, context = yield room_handler._create_new_client_event(
builder
)
yield room_handler.change_membership(event, context)
self.state_handler.compute_event_context.assert_called_once_with(
builder
)
self.auth.add_auth_events.assert_called_once_with(
builder, context
)
self.federation.send_invite.assert_called_once_with(
"blue", event,
)
self.datastore.persist_event.assert_called_once_with(
event, context=context,
)
self.notifier.on_new_room_event.assert_called_once_with(
event, 1, 1, extra_users=[UserID.from_string(target_user_id)]
)
self.assertFalse(self.datastore.get_room.called)
self.assertFalse(self.datastore.store_room.called)
self.assertFalse(self.federation.get_state_for_room.called)
@defer.inlineCallbacks
def test_simple_join(self):
room_id = "!foo:red"
user_id = "@bob:red"
user = UserID.from_string(user_id)
join_signal_observer = Mock()
self.distributor.observe("user_joined_room", join_signal_observer)
builder = self.hs.get_event_builder_factory().new({
"type": EventTypes.Member,
"sender": user_id,
"state_key": user_id,
"room_id": room_id,
"content": {"membership": Membership.JOIN},
})
self.datastore.get_latest_events_in_room.return_value = (
defer.succeed([])
)
self.datastore.get_current_state.return_value = {}
self.datastore.get_state_for_events = lambda event_ids,types: {x: {} for x in event_ids}
def annotate(_):
ctx = Mock()
ctx.current_state = {
(EventTypes.Member, "@bob:red"): self._create_member(
user_id="@bob:red",
room_id=room_id,
membership=Membership.INVITE
),
}
ctx.prev_state_events = []
return defer.succeed(ctx)
self.state_handler.compute_event_context.side_effect = annotate
def add_auth(_, ctx):
ctx.auth_events = ctx.current_state[
(EventTypes.Member, "@bob:red")
]
return defer.succeed(True)
self.auth.add_auth_events.side_effect = add_auth
room_handler = self.room_member_handler
event, context = yield room_handler._create_new_client_event(
builder
)
# Actual invocation
yield room_handler.change_membership(event, context)
self.federation.handle_new_event.assert_called_once_with(
event, destinations=set()
)
self.datastore.persist_event.assert_called_once_with(
event, context=context
)
self.notifier.on_new_room_event.assert_called_once_with(
event, 1, 1, extra_users=[user]
)
join_signal_observer.assert_called_with(
user=user, room_id=room_id
)
def _create_member(self, user_id, room_id, membership=Membership.JOIN):
builder = self.hs.get_event_builder_factory().new({
"type": EventTypes.Member,
"sender": user_id,
"state_key": user_id,
"room_id": room_id,
"content": {"membership": membership},
})
return builder.build()
@defer.inlineCallbacks
def test_simple_leave(self):
room_id = "!foo:red"
user_id = "@bob:red"
user = UserID.from_string(user_id)
builder = self.hs.get_event_builder_factory().new({
"type": EventTypes.Member,
"sender": user_id,
"state_key": user_id,
"room_id": room_id,
"content": {"membership": Membership.LEAVE},
})
self.datastore.get_latest_events_in_room.return_value = (
defer.succeed([])
)
self.datastore.get_current_state.return_value = {}
self.datastore.get_state_for_events = lambda event_ids,types: {x: {} for x in event_ids}
def annotate(_):
ctx = Mock()
ctx.current_state = {
(EventTypes.Member, "@bob:red"): self._create_member(
user_id="@bob:red",
room_id=room_id,
membership=Membership.JOIN
),
}
ctx.prev_state_events = []
return defer.succeed(ctx)
self.state_handler.compute_event_context.side_effect = annotate
def add_auth(_, ctx):
ctx.auth_events = ctx.current_state[
(EventTypes.Member, "@bob:red")
]
return defer.succeed(True)
self.auth.add_auth_events.side_effect = add_auth
room_handler = self.room_member_handler
event, context = yield room_handler._create_new_client_event(
builder
)
leave_signal_observer = Mock()
self.distributor.observe("user_left_room", leave_signal_observer)
# Actual invocation
yield room_handler.change_membership(event, context)
self.federation.handle_new_event.assert_called_once_with(
event, destinations=set(['red'])
)
self.datastore.persist_event.assert_called_once_with(
event, context=context
)
self.notifier.on_new_room_event.assert_called_once_with(
event, 1, 1, extra_users=[user]
)
leave_signal_observer.assert_called_with(
user=user, room_id=room_id
)
class RoomCreationTest(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.hostname = "red"
hs = yield setup_test_homeserver(
self.hostname,
datastore=NonCallableMock(spec_set=[
"store_room",
"snapshot_room",
"persist_event",
"get_joined_hosts_for_room",
]),
http_client=NonCallableMock(spec_set=[]),
notifier=NonCallableMock(spec_set=["on_new_room_event"]),
handlers=NonCallableMock(spec_set=[
"room_creation_handler",
"message_handler",
]),
auth=NonCallableMock(spec_set=["check", "add_auth_events"]),
ratelimiter=NonCallableMock(spec_set=[
"send_message",
]),
)
self.federation = NonCallableMock(spec_set=[
"handle_new_event",
])
self.handlers = hs.get_handlers()
self.handlers.room_creation_handler = RoomCreationHandler(hs)
self.room_creation_handler = self.handlers.room_creation_handler
self.message_handler = self.handlers.message_handler
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
@defer.inlineCallbacks
def test_room_creation(self):
user_id = "@foo:red"
room_id = "!bobs_room:red"
config = {"visibility": "private"}
yield self.room_creation_handler.create_room(
user_id=user_id,
room_id=room_id,
config=config,
)
self.assertTrue(self.message_handler.create_and_send_event.called)
event_dicts = [
e[0][0]
for e in self.message_handler.create_and_send_event.call_args_list
]
self.assertTrue(len(event_dicts) > 3)
self.assertDictContainsSubset(
{
"type": EventTypes.Create,
"sender": user_id,
"room_id": room_id,
},
event_dicts[0]
)
self.assertEqual(user_id, event_dicts[0]["content"]["creator"])
self.assertDictContainsSubset(
{
"type": EventTypes.Member,
"sender": user_id,
"room_id": room_id,
"state_key": user_id,
},
event_dicts[1]
)
self.assertEqual(
Membership.JOIN,
event_dicts[1]["content"]["membership"]
)

View file

@ -12,12 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import tempfile
from synapse.config._base import ConfigError
from tests import unittest from tests import unittest
from twisted.internet import defer from twisted.internet import defer
from tests.utils import setup_test_homeserver from tests.utils import setup_test_homeserver
from synapse.appservice import ApplicationService, ApplicationServiceState from synapse.appservice import ApplicationService, ApplicationServiceState
from synapse.server import HomeServer
from synapse.storage.appservice import ( from synapse.storage.appservice import (
ApplicationServiceStore, ApplicationServiceTransactionStore ApplicationServiceStore, ApplicationServiceTransactionStore
) )
@ -26,7 +27,6 @@ import json
import os import os
import yaml import yaml
from mock import Mock from mock import Mock
from tests.utils import SQLiteMemoryDbPool, MockClock
class ApplicationServiceStoreTestCase(unittest.TestCase): class ApplicationServiceStoreTestCase(unittest.TestCase):
@ -41,9 +41,16 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
self.as_token = "token1" self.as_token = "token1"
self.as_url = "some_url" self.as_url = "some_url"
self._add_appservice(self.as_token, self.as_url, "some_hs_token", "bob") self.as_id = "as1"
self._add_appservice("token2", "some_url", "some_hs_token", "bob") self._add_appservice(
self._add_appservice("token3", "some_url", "some_hs_token", "bob") self.as_token,
self.as_id,
self.as_url,
"some_hs_token",
"bob"
)
self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
# must be done after inserts # must be done after inserts
self.store = ApplicationServiceStore(hs) self.store = ApplicationServiceStore(hs)
@ -55,9 +62,9 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
except: except:
pass pass
def _add_appservice(self, as_token, url, hs_token, sender): def _add_appservice(self, as_token, id, url, hs_token, sender):
as_yaml = dict(url=url, as_token=as_token, hs_token=hs_token, as_yaml = dict(url=url, as_token=as_token, hs_token=hs_token,
sender_localpart=sender, namespaces={}) id=id, sender_localpart=sender, namespaces={})
# use the token as the filename # use the token as the filename
with open(as_token, 'w') as outfile: with open(as_token, 'w') as outfile:
outfile.write(yaml.dump(as_yaml)) outfile.write(yaml.dump(as_yaml))
@ -74,6 +81,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
self.as_token self.as_token
) )
self.assertEquals(stored_service.token, self.as_token) self.assertEquals(stored_service.token, self.as_token)
self.assertEquals(stored_service.id, self.as_id)
self.assertEquals(stored_service.url, self.as_url) self.assertEquals(stored_service.url, self.as_url)
self.assertEquals( self.assertEquals(
stored_service.namespaces[ApplicationService.NS_ALIASES], stored_service.namespaces[ApplicationService.NS_ALIASES],
@ -110,34 +118,34 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
{ {
"token": "token1", "token": "token1",
"url": "https://matrix-as.org", "url": "https://matrix-as.org",
"id": "token1" "id": "id_1"
}, },
{ {
"token": "alpha_tok", "token": "alpha_tok",
"url": "https://alpha.com", "url": "https://alpha.com",
"id": "alpha_tok" "id": "id_alpha"
}, },
{ {
"token": "beta_tok", "token": "beta_tok",
"url": "https://beta.com", "url": "https://beta.com",
"id": "beta_tok" "id": "id_beta"
}, },
{ {
"token": "delta_tok", "token": "gamma_tok",
"url": "https://delta.com", "url": "https://gamma.com",
"id": "delta_tok" "id": "id_gamma"
}, },
] ]
for s in self.as_list: for s in self.as_list:
yield self._add_service(s["url"], s["token"]) yield self._add_service(s["url"], s["token"], s["id"])
self.as_yaml_files = [] self.as_yaml_files = []
self.store = TestTransactionStore(hs) self.store = TestTransactionStore(hs)
def _add_service(self, url, as_token): def _add_service(self, url, as_token, id):
as_yaml = dict(url=url, as_token=as_token, hs_token="something", as_yaml = dict(url=url, as_token=as_token, hs_token="something",
sender_localpart="a_sender", namespaces={}) id=id, sender_localpart="a_sender", namespaces={})
# use the token as the filename # use the token as the filename
with open(as_token, 'w') as outfile: with open(as_token, 'w') as outfile:
outfile.write(yaml.dump(as_yaml)) outfile.write(yaml.dump(as_yaml))
@ -405,3 +413,64 @@ class TestTransactionStore(ApplicationServiceTransactionStore,
def __init__(self, hs): def __init__(self, hs):
super(TestTransactionStore, self).__init__(hs) super(TestTransactionStore, self).__init__(hs)
class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
def _write_config(self, suffix, **kwargs):
vals = {
"id": "id" + suffix,
"url": "url" + suffix,
"as_token": "as_token" + suffix,
"hs_token": "hs_token" + suffix,
"sender_localpart": "sender_localpart" + suffix,
"namespaces": {},
}
vals.update(kwargs)
_, path = tempfile.mkstemp(prefix="as_config")
with open(path, "w") as f:
f.write(yaml.dump(vals))
return path
@defer.inlineCallbacks
def test_unique_works(self):
f1 = self._write_config(suffix="1")
f2 = self._write_config(suffix="2")
config = Mock(app_service_config_files=[f1, f2])
hs = yield setup_test_homeserver(config=config)
ApplicationServiceStore(hs)
@defer.inlineCallbacks
def test_duplicate_ids(self):
f1 = self._write_config(id="id", suffix="1")
f2 = self._write_config(id="id", suffix="2")
config = Mock(app_service_config_files=[f1, f2])
hs = yield setup_test_homeserver(config=config)
with self.assertRaises(ConfigError) as cm:
ApplicationServiceStore(hs)
e = cm.exception
self.assertIn(f1, e.message)
self.assertIn(f2, e.message)
self.assertIn("id", e.message)
@defer.inlineCallbacks
def test_duplicate_as_tokens(self):
f1 = self._write_config(as_token="as_token", suffix="1")
f2 = self._write_config(as_token="as_token", suffix="2")
config = Mock(app_service_config_files=[f1, f2])
hs = yield setup_test_homeserver(config=config)
with self.assertRaises(ConfigError) as cm:
ApplicationServiceStore(hs)
e = cm.exception
self.assertIn(f1, e.message)
self.assertIn(f2, e.message)
self.assertIn("as_token", e.message)