forked from MirrorHub/synapse
Merge branch 'develop' into push_badge_counts
This commit is contained in:
commit
afb7b377f2
35 changed files with 1043 additions and 1246 deletions
|
@ -258,6 +258,14 @@ During setup of Synapse you need to call python2.7 directly again::
|
|||
|
||||
...substituting your host and domain name as appropriate.
|
||||
|
||||
FreeBSD
|
||||
-------
|
||||
|
||||
Synapse can be installed via FreeBSD Ports or Packages:
|
||||
|
||||
- Ports: ``cd /usr/ports/net/py-matrix-synapse && make install clean``
|
||||
- Packages: ``pkg install py27-matrix-synapse``
|
||||
|
||||
Windows Install
|
||||
---------------
|
||||
Synapse can be installed on Cygwin. It requires the following Cygwin packages:
|
||||
|
|
|
@ -510,37 +510,14 @@ class Auth(object):
|
|||
"""
|
||||
# Can optionally look elsewhere in the request (e.g. headers)
|
||||
try:
|
||||
access_token = request.args["access_token"][0]
|
||||
|
||||
# Check for application service tokens with a user_id override
|
||||
try:
|
||||
app_service = yield self.store.get_app_service_by_token(
|
||||
access_token
|
||||
)
|
||||
if not app_service:
|
||||
raise KeyError
|
||||
|
||||
user_id = app_service.sender
|
||||
if "user_id" in request.args:
|
||||
user_id = request.args["user_id"][0]
|
||||
if not app_service.is_interested_in_user(user_id):
|
||||
raise AuthError(
|
||||
403,
|
||||
"Application service cannot masquerade as this user."
|
||||
)
|
||||
|
||||
if not user_id:
|
||||
raise KeyError
|
||||
|
||||
user_id = yield self._get_appservice_user_id(request.args)
|
||||
if user_id:
|
||||
request.authenticated_entity = user_id
|
||||
|
||||
defer.returnValue(
|
||||
Requester(UserID.from_string(user_id), "", False)
|
||||
)
|
||||
return
|
||||
except KeyError:
|
||||
pass # normal users won't have the user_id query parameter set.
|
||||
|
||||
access_token = request.args["access_token"][0]
|
||||
user_info = yield self._get_user_by_access_token(access_token)
|
||||
user = user_info["user"]
|
||||
token_id = user_info["token_id"]
|
||||
|
@ -573,6 +550,33 @@ class Auth(object):
|
|||
errcode=Codes.MISSING_TOKEN
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_appservice_user_id(self, request_args):
|
||||
app_service = yield self.store.get_app_service_by_token(
|
||||
request_args["access_token"][0]
|
||||
)
|
||||
if app_service is None:
|
||||
defer.returnValue(None)
|
||||
|
||||
if "user_id" not in request_args:
|
||||
defer.returnValue(app_service.sender)
|
||||
|
||||
user_id = request_args["user_id"][0]
|
||||
if app_service.sender == user_id:
|
||||
defer.returnValue(app_service.sender)
|
||||
|
||||
if not app_service.is_interested_in_user(user_id):
|
||||
raise AuthError(
|
||||
403,
|
||||
"Application service cannot masquerade as this user."
|
||||
)
|
||||
if not (yield self.store.get_user_by_id(user_id)):
|
||||
raise AuthError(
|
||||
403,
|
||||
"Application service has not registered this user"
|
||||
)
|
||||
defer.returnValue(user_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_user_by_access_token(self, token):
|
||||
""" Get a registered user's ID.
|
||||
|
|
|
@ -29,6 +29,7 @@ class Codes(object):
|
|||
USER_IN_USE = "M_USER_IN_USE"
|
||||
ROOM_IN_USE = "M_ROOM_IN_USE"
|
||||
BAD_PAGINATION = "M_BAD_PAGINATION"
|
||||
BAD_STATE = "M_BAD_STATE"
|
||||
UNKNOWN = "M_UNKNOWN"
|
||||
NOT_FOUND = "M_NOT_FOUND"
|
||||
MISSING_TOKEN = "M_MISSING_TOKEN"
|
||||
|
@ -42,6 +43,7 @@ class Codes(object):
|
|||
EXCLUSIVE = "M_EXCLUSIVE"
|
||||
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
|
||||
THREEPID_IN_USE = "THREEPID_IN_USE"
|
||||
INVALID_USERNAME = "M_INVALID_USERNAME"
|
||||
|
||||
|
||||
class CodeMessageException(RuntimeError):
|
||||
|
|
|
@ -88,6 +88,9 @@ import time
|
|||
logger = logging.getLogger("synapse.app.homeserver")
|
||||
|
||||
|
||||
ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$')
|
||||
|
||||
|
||||
def gz_wrap(r):
|
||||
return EncodingResourceWrapper(r, [GzipEncoderFactory()])
|
||||
|
||||
|
@ -495,9 +498,8 @@ class SynapseRequest(Request):
|
|||
)
|
||||
|
||||
def get_redacted_uri(self):
|
||||
return re.sub(
|
||||
r'(\?.*access_token=)[^&]*(.*)$',
|
||||
r'\1<redacted>\2',
|
||||
return ACCESS_TOKEN_RE.sub(
|
||||
r'\1<redacted>\3',
|
||||
self.uri
|
||||
)
|
||||
|
||||
|
|
|
@ -117,6 +117,15 @@ class EventBase(object):
|
|||
def __set__(self, instance, value):
|
||||
raise AttributeError("Unrecognized attribute %s" % (instance,))
|
||||
|
||||
def __getitem__(self, field):
|
||||
return self._event_dict[field]
|
||||
|
||||
def __contains__(self, field):
|
||||
return field in self._event_dict
|
||||
|
||||
def items(self):
|
||||
return self._event_dict.items()
|
||||
|
||||
|
||||
class FrozenEvent(EventBase):
|
||||
def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None):
|
||||
|
|
|
@ -53,16 +53,54 @@ class BaseHandler(object):
|
|||
self.event_builder_factory = hs.get_event_builder_factory()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _filter_events_for_client(self, user_id, events, is_guest=False):
|
||||
# Assumes that user has at some point joined the room if not is_guest.
|
||||
def _filter_events_for_clients(self, users, events):
|
||||
""" Returns dict of user_id -> list of events that user is allowed to
|
||||
see.
|
||||
"""
|
||||
event_id_to_state = yield self.store.get_state_for_events(
|
||||
frozenset(e.event_id for e in events),
|
||||
types=(
|
||||
(EventTypes.RoomHistoryVisibility, ""),
|
||||
(EventTypes.Member, None),
|
||||
)
|
||||
)
|
||||
|
||||
forgotten = yield defer.gatherResults([
|
||||
self.store.who_forgot_in_room(
|
||||
room_id,
|
||||
)
|
||||
for room_id in frozenset(e.room_id for e in events)
|
||||
], consumeErrors=True)
|
||||
|
||||
# Set of membership event_ids that have been forgotten
|
||||
event_id_forgotten = frozenset(
|
||||
row["event_id"] for rows in forgotten for row in rows
|
||||
)
|
||||
|
||||
def allowed(event, user_id, is_guest):
|
||||
state = event_id_to_state[event.event_id]
|
||||
|
||||
visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
|
||||
if visibility_event:
|
||||
visibility = visibility_event.content.get("history_visibility", "shared")
|
||||
else:
|
||||
visibility = "shared"
|
||||
|
||||
def allowed(event, membership, visibility):
|
||||
if visibility == "world_readable":
|
||||
return True
|
||||
|
||||
if is_guest:
|
||||
return False
|
||||
|
||||
membership_event = state.get((EventTypes.Member, user_id), None)
|
||||
if membership_event:
|
||||
if membership_event.event_id in event_id_forgotten:
|
||||
membership = None
|
||||
else:
|
||||
membership = membership_event.membership
|
||||
else:
|
||||
membership = None
|
||||
|
||||
if membership == Membership.JOIN:
|
||||
return True
|
||||
|
||||
|
@ -78,43 +116,20 @@ class BaseHandler(object):
|
|||
|
||||
return True
|
||||
|
||||
event_id_to_state = yield self.store.get_state_for_events(
|
||||
frozenset(e.event_id for e in events),
|
||||
types=(
|
||||
(EventTypes.RoomHistoryVisibility, ""),
|
||||
(EventTypes.Member, user_id),
|
||||
)
|
||||
)
|
||||
defer.returnValue({
|
||||
user_id: [
|
||||
event
|
||||
for event in events
|
||||
if allowed(event, user_id, is_guest)
|
||||
]
|
||||
for user_id, is_guest in users
|
||||
})
|
||||
|
||||
events_to_return = []
|
||||
for event in events:
|
||||
state = event_id_to_state[event.event_id]
|
||||
|
||||
membership_event = state.get((EventTypes.Member, user_id), None)
|
||||
if membership_event:
|
||||
was_forgotten_at_event = yield self.store.was_forgotten_at(
|
||||
membership_event.state_key,
|
||||
membership_event.room_id,
|
||||
membership_event.event_id
|
||||
)
|
||||
if was_forgotten_at_event:
|
||||
membership = None
|
||||
else:
|
||||
membership = membership_event.membership
|
||||
else:
|
||||
membership = None
|
||||
|
||||
visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
|
||||
if visibility_event:
|
||||
visibility = visibility_event.content.get("history_visibility", "shared")
|
||||
else:
|
||||
visibility = "shared"
|
||||
|
||||
should_include = allowed(event, membership, visibility)
|
||||
if should_include:
|
||||
events_to_return.append(event)
|
||||
|
||||
defer.returnValue(events_to_return)
|
||||
@defer.inlineCallbacks
|
||||
def _filter_events_for_client(self, user_id, events, is_guest=False):
|
||||
# Assumes that user has at some point joined the room if not is_guest.
|
||||
res = yield self._filter_events_for_clients([(user_id, is_guest)], events)
|
||||
defer.returnValue(res.get(user_id, []))
|
||||
|
||||
def ratelimit(self, user_id):
|
||||
time_now = self.clock.time()
|
||||
|
@ -171,11 +186,9 @@ class BaseHandler(object):
|
|||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_new_client_event(self, event, context, extra_destinations=[],
|
||||
extra_users=[], suppress_auth=False):
|
||||
def handle_new_client_event(self, event, context, extra_users=[]):
|
||||
# We now need to go and hit out to wherever we need to hit out to.
|
||||
|
||||
if not suppress_auth:
|
||||
self.auth.check(event, auth_events=context.current_state)
|
||||
|
||||
yield self.maybe_kick_guest_users(event, context.current_state.values())
|
||||
|
@ -253,12 +266,12 @@ class BaseHandler(object):
|
|||
event, context=context
|
||||
)
|
||||
|
||||
action_generator = ActionGenerator(self.store)
|
||||
action_generator = ActionGenerator(self.hs)
|
||||
yield action_generator.handle_push_actions_for_event(
|
||||
event, self
|
||||
)
|
||||
|
||||
destinations = set(extra_destinations)
|
||||
destinations = set()
|
||||
for k, s in context.current_state.items():
|
||||
try:
|
||||
if k[0] == EventTypes.Member:
|
||||
|
|
|
@ -245,7 +245,7 @@ class FederationHandler(BaseHandler):
|
|||
yield user_joined_room(self.distributor, user, event.room_id)
|
||||
|
||||
if not backfilled and not event.internal_metadata.is_outlier():
|
||||
action_generator = ActionGenerator(self.store)
|
||||
action_generator = ActionGenerator(self.hs)
|
||||
yield action_generator.handle_push_actions_for_event(
|
||||
event, self
|
||||
)
|
||||
|
@ -1692,7 +1692,7 @@ class FederationHandler(BaseHandler):
|
|||
self.auth.check(event, context.current_state)
|
||||
yield self._validate_keyserver(event, auth_events=context.current_state)
|
||||
member_handler = self.hs.get_handlers().room_member_handler
|
||||
yield member_handler.change_membership(event, context)
|
||||
yield member_handler.send_membership_event(event, context)
|
||||
else:
|
||||
destinations = set([x.split(":", 1)[-1] for x in (sender, room_id)])
|
||||
yield self.replication_layer.forward_third_party_invite(
|
||||
|
@ -1721,7 +1721,7 @@ class FederationHandler(BaseHandler):
|
|||
# TODO: Make sure the signatures actually are correct.
|
||||
event.signatures.update(returned_invite.signatures)
|
||||
member_handler = self.hs.get_handlers().room_member_handler
|
||||
yield member_handler.change_membership(event, context)
|
||||
yield member_handler.send_membership_event(event, context)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_display_name_to_third_party_invite(self, event_dict, event, context):
|
||||
|
|
|
@ -174,30 +174,25 @@ class MessageHandler(BaseHandler):
|
|||
defer.returnValue(chunk)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_and_send_event(self, event_dict, ratelimit=True,
|
||||
token_id=None, txn_id=None, is_guest=False):
|
||||
""" Given a dict from a client, create and handle a new event.
|
||||
def create_event(self, event_dict, token_id=None, txn_id=None):
|
||||
"""
|
||||
Given a dict from a client, create a new event.
|
||||
|
||||
Creates an FrozenEvent object, filling out auth_events, prev_events,
|
||||
etc.
|
||||
|
||||
Adds display names to Join membership events.
|
||||
|
||||
Persists and notifies local clients and federation.
|
||||
|
||||
Args:
|
||||
event_dict (dict): An entire event
|
||||
|
||||
Returns:
|
||||
Tuple of created event (FrozenEvent), Context
|
||||
"""
|
||||
builder = self.event_builder_factory.new(event_dict)
|
||||
|
||||
self.validator.validate_new(builder)
|
||||
|
||||
if ratelimit:
|
||||
self.ratelimit(builder.user_id)
|
||||
# TODO(paul): Why does 'event' not have a 'user' object?
|
||||
user = UserID.from_string(builder.user_id)
|
||||
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
|
||||
|
||||
if builder.type == EventTypes.Member:
|
||||
membership = builder.content.get("membership", None)
|
||||
if membership == Membership.JOIN:
|
||||
|
@ -216,6 +211,25 @@ class MessageHandler(BaseHandler):
|
|||
event, context = yield self._create_new_client_event(
|
||||
builder=builder,
|
||||
)
|
||||
defer.returnValue((event, context))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_event(self, event, context, ratelimit=True, is_guest=False):
|
||||
"""
|
||||
Persists and notifies local clients and federation of an event.
|
||||
|
||||
Args:
|
||||
event (FrozenEvent) the event to send.
|
||||
context (Context) the context of the event.
|
||||
ratelimit (bool): Whether to rate limit this send.
|
||||
is_guest (bool): Whether the sender is a guest.
|
||||
"""
|
||||
user = UserID.from_string(event.sender)
|
||||
|
||||
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
|
||||
|
||||
if ratelimit:
|
||||
self.ratelimit(event.sender)
|
||||
|
||||
if event.is_state():
|
||||
prev_state = context.current_state.get((event.type, event.state_key))
|
||||
|
@ -229,7 +243,7 @@ class MessageHandler(BaseHandler):
|
|||
|
||||
if event.type == EventTypes.Member:
|
||||
member_handler = self.hs.get_handlers().room_member_handler
|
||||
yield member_handler.change_membership(event, context, is_guest=is_guest)
|
||||
yield member_handler.send_membership_event(event, context, is_guest=is_guest)
|
||||
else:
|
||||
yield self.handle_new_client_event(
|
||||
event=event,
|
||||
|
@ -241,6 +255,25 @@ class MessageHandler(BaseHandler):
|
|||
with PreserveLoggingContext():
|
||||
presence.bump_presence_active_time(user)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_and_send_event(self, event_dict, ratelimit=True,
|
||||
token_id=None, txn_id=None, is_guest=False):
|
||||
"""
|
||||
Creates an event, then sends it.
|
||||
|
||||
See self.create_event and self.send_event.
|
||||
"""
|
||||
event, context = yield self.create_event(
|
||||
event_dict,
|
||||
token_id=token_id,
|
||||
txn_id=txn_id
|
||||
)
|
||||
yield self.send_event(
|
||||
event,
|
||||
context,
|
||||
ratelimit=ratelimit,
|
||||
is_guest=is_guest
|
||||
)
|
||||
defer.returnValue(event)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
|
@ -53,7 +53,8 @@ class RegistrationHandler(BaseHandler):
|
|||
raise SynapseError(
|
||||
400,
|
||||
"User ID must only contain characters which do not"
|
||||
" require URL encoding."
|
||||
" require URL encoding.",
|
||||
Codes.INVALID_USERNAME
|
||||
)
|
||||
|
||||
user = UserID(localpart, self.hs.hostname)
|
||||
|
|
|
@ -22,7 +22,7 @@ from synapse.types import UserID, RoomAlias, RoomID
|
|||
from synapse.api.constants import (
|
||||
EventTypes, Membership, JoinRules, RoomCreationPreset,
|
||||
)
|
||||
from synapse.api.errors import AuthError, StoreError, SynapseError
|
||||
from synapse.api.errors import AuthError, StoreError, SynapseError, Codes
|
||||
from synapse.util import stringutils, unwrapFirstError
|
||||
from synapse.util.async import run_on_reactor
|
||||
|
||||
|
@ -397,7 +397,58 @@ class RoomMemberHandler(BaseHandler):
|
|||
remotedomains.add(member.domain)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def change_membership(self, event, context, do_auth=True, is_guest=False):
|
||||
def update_membership(self, requester, target, room_id, action, txn_id=None):
|
||||
effective_membership_state = action
|
||||
if action in ["kick", "unban"]:
|
||||
effective_membership_state = "leave"
|
||||
elif action == "forget":
|
||||
effective_membership_state = "leave"
|
||||
|
||||
msg_handler = self.hs.get_handlers().message_handler
|
||||
|
||||
content = {"membership": unicode(effective_membership_state)}
|
||||
if requester.is_guest:
|
||||
content["kind"] = "guest"
|
||||
|
||||
event, context = yield msg_handler.create_event(
|
||||
{
|
||||
"type": EventTypes.Member,
|
||||
"content": content,
|
||||
"room_id": room_id,
|
||||
"sender": requester.user.to_string(),
|
||||
"state_key": target.to_string(),
|
||||
},
|
||||
token_id=requester.access_token_id,
|
||||
txn_id=txn_id,
|
||||
)
|
||||
|
||||
old_state = context.current_state.get((EventTypes.Member, event.state_key))
|
||||
old_membership = old_state.content.get("membership") if old_state else None
|
||||
if action == "unban" and old_membership != "ban":
|
||||
raise SynapseError(
|
||||
403,
|
||||
"Cannot unban user who was not banned (membership=%s)" % old_membership,
|
||||
errcode=Codes.BAD_STATE
|
||||
)
|
||||
if old_membership == "ban" and action != "unban":
|
||||
raise SynapseError(
|
||||
403,
|
||||
"Cannot %s user who was is banned" % (action,),
|
||||
errcode=Codes.BAD_STATE
|
||||
)
|
||||
|
||||
yield msg_handler.send_event(
|
||||
event,
|
||||
context,
|
||||
ratelimit=True,
|
||||
is_guest=requester.is_guest
|
||||
)
|
||||
|
||||
if action == "forget":
|
||||
yield self.forget(requester.user, room_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_membership_event(self, event, context, is_guest=False):
|
||||
""" Change the membership status of a user in a room.
|
||||
|
||||
Args:
|
||||
|
@ -432,7 +483,7 @@ class RoomMemberHandler(BaseHandler):
|
|||
if not is_guest_access_allowed:
|
||||
raise AuthError(403, "Guest access not allowed")
|
||||
|
||||
yield self._do_join(event, context, do_auth=do_auth)
|
||||
yield self._do_join(event, context)
|
||||
else:
|
||||
if event.membership == Membership.LEAVE:
|
||||
is_host_in_room = yield self.is_host_in_room(room_id, context)
|
||||
|
@ -459,9 +510,7 @@ class RoomMemberHandler(BaseHandler):
|
|||
|
||||
yield self._do_local_membership_update(
|
||||
event,
|
||||
membership=event.content["membership"],
|
||||
context=context,
|
||||
do_auth=do_auth,
|
||||
)
|
||||
|
||||
if prev_state and prev_state.membership == Membership.JOIN:
|
||||
|
@ -497,12 +546,12 @@ class RoomMemberHandler(BaseHandler):
|
|||
})
|
||||
event, context = yield self._create_new_client_event(builder)
|
||||
|
||||
yield self._do_join(event, context, room_hosts=hosts, do_auth=True)
|
||||
yield self._do_join(event, context, room_hosts=hosts)
|
||||
|
||||
defer.returnValue({"room_id": room_id})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_join(self, event, context, room_hosts=None, do_auth=True):
|
||||
def _do_join(self, event, context, room_hosts=None):
|
||||
room_id = event.room_id
|
||||
|
||||
# XXX: We don't do an auth check if we are doing an invite
|
||||
|
@ -536,9 +585,7 @@ class RoomMemberHandler(BaseHandler):
|
|||
|
||||
yield self._do_local_membership_update(
|
||||
event,
|
||||
membership=event.content["membership"],
|
||||
context=context,
|
||||
do_auth=do_auth,
|
||||
)
|
||||
|
||||
prev_state = context.current_state.get((event.type, event.state_key))
|
||||
|
@ -603,8 +650,7 @@ class RoomMemberHandler(BaseHandler):
|
|||
defer.returnValue(room_ids)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_local_membership_update(self, event, membership, context,
|
||||
do_auth):
|
||||
def _do_local_membership_update(self, event, context):
|
||||
yield run_on_reactor()
|
||||
|
||||
target_user = UserID.from_string(event.state_key)
|
||||
|
@ -613,7 +659,6 @@ class RoomMemberHandler(BaseHandler):
|
|||
event,
|
||||
context,
|
||||
extra_users=[target_user],
|
||||
suppress_auth=(not do_auth),
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -880,28 +925,39 @@ class RoomContextHandler(BaseHandler):
|
|||
(excluding state).
|
||||
|
||||
Returns:
|
||||
dict
|
||||
dict, or None if the event isn't found
|
||||
"""
|
||||
before_limit = math.floor(limit/2.)
|
||||
after_limit = limit - before_limit
|
||||
|
||||
now_token = yield self.hs.get_event_sources().get_current_token()
|
||||
|
||||
def filter_evts(events):
|
||||
return self._filter_events_for_client(
|
||||
user.to_string(),
|
||||
events,
|
||||
is_guest=is_guest)
|
||||
|
||||
event = yield self.store.get_event(event_id, get_prev_content=True,
|
||||
allow_none=True)
|
||||
if not event:
|
||||
defer.returnValue(None)
|
||||
return
|
||||
|
||||
filtered = yield(filter_evts([event]))
|
||||
if not filtered:
|
||||
raise AuthError(
|
||||
403,
|
||||
"You don't have permission to access that event."
|
||||
)
|
||||
|
||||
results = yield self.store.get_events_around(
|
||||
room_id, event_id, before_limit, after_limit
|
||||
)
|
||||
|
||||
results["events_before"] = yield self._filter_events_for_client(
|
||||
user.to_string(),
|
||||
results["events_before"],
|
||||
is_guest=is_guest,
|
||||
)
|
||||
|
||||
results["events_after"] = yield self._filter_events_for_client(
|
||||
user.to_string(),
|
||||
results["events_after"],
|
||||
is_guest=is_guest,
|
||||
)
|
||||
results["events_before"] = yield filter_evts(results["events_before"])
|
||||
results["events_after"] = yield filter_evts(results["events_after"])
|
||||
results["event"] = event
|
||||
|
||||
if results["events_after"]:
|
||||
last_event_id = results["events_after"][-1].event_id
|
||||
|
|
|
@ -55,6 +55,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
|
|||
"ephemeral",
|
||||
"account_data",
|
||||
"unread_notification_count",
|
||||
"unread_highlight_count",
|
||||
])):
|
||||
__slots__ = []
|
||||
|
||||
|
@ -292,9 +293,14 @@ class SyncHandler(BaseHandler):
|
|||
notifs = yield self.unread_notifs_for_room_id(
|
||||
room_id, sync_config, ephemeral_by_room
|
||||
)
|
||||
|
||||
notif_count = None
|
||||
highlight_count = None
|
||||
if notifs is not None:
|
||||
notif_count = len(notifs)
|
||||
highlight_count = len([
|
||||
1 for notif in notifs if _action_has_highlight(notif["actions"])
|
||||
])
|
||||
|
||||
current_state = yield self.get_state_at(room_id, now_token)
|
||||
|
||||
|
@ -307,6 +313,7 @@ class SyncHandler(BaseHandler):
|
|||
room_id, tags_by_room, account_data_by_room
|
||||
),
|
||||
unread_notification_count=notif_count,
|
||||
unread_highlight_count=highlight_count,
|
||||
))
|
||||
|
||||
def account_data_for_user(self, account_data):
|
||||
|
@ -529,9 +536,14 @@ class SyncHandler(BaseHandler):
|
|||
notifs = yield self.unread_notifs_for_room_id(
|
||||
room_id, sync_config, all_ephemeral_by_room
|
||||
)
|
||||
|
||||
notif_count = None
|
||||
highlight_count = None
|
||||
if notifs is not None:
|
||||
notif_count = len(notifs)
|
||||
highlight_count = len([
|
||||
1 for notif in notifs if _action_has_highlight(notif["actions"])
|
||||
])
|
||||
|
||||
just_joined = yield self.check_joined_room(sync_config, state)
|
||||
if just_joined:
|
||||
|
@ -553,7 +565,8 @@ class SyncHandler(BaseHandler):
|
|||
account_data=self.account_data_for_room(
|
||||
room_id, tags_by_room, account_data_by_room
|
||||
),
|
||||
unread_notification_count=notif_count
|
||||
unread_notification_count=notif_count,
|
||||
unread_highlight_count=highlight_count,
|
||||
)
|
||||
logger.debug("Result for room %s: %r", room_id, room_sync)
|
||||
|
||||
|
@ -575,7 +588,8 @@ class SyncHandler(BaseHandler):
|
|||
for room_id in joined_room_ids:
|
||||
room_sync = yield self.incremental_sync_with_gap_for_room(
|
||||
room_id, sync_config, since_token, now_token,
|
||||
ephemeral_by_room, tags_by_room, account_data_by_room
|
||||
ephemeral_by_room, tags_by_room, account_data_by_room,
|
||||
all_ephemeral_by_room=all_ephemeral_by_room,
|
||||
)
|
||||
if room_sync:
|
||||
joined.append(room_sync)
|
||||
|
@ -655,7 +669,8 @@ class SyncHandler(BaseHandler):
|
|||
def incremental_sync_with_gap_for_room(self, room_id, sync_config,
|
||||
since_token, now_token,
|
||||
ephemeral_by_room, tags_by_room,
|
||||
account_data_by_room):
|
||||
account_data_by_room,
|
||||
all_ephemeral_by_room):
|
||||
""" Get the incremental delta needed to bring the client up to date for
|
||||
the room. Gives the client the most recent events and the changes to
|
||||
state.
|
||||
|
@ -671,7 +686,7 @@ class SyncHandler(BaseHandler):
|
|||
room_id, sync_config, now_token, since_token,
|
||||
)
|
||||
|
||||
logging.debug("Recents %r", batch)
|
||||
logger.debug("Recents %r", batch)
|
||||
|
||||
current_state = yield self.get_state_at(room_id, now_token)
|
||||
|
||||
|
@ -690,11 +705,16 @@ class SyncHandler(BaseHandler):
|
|||
state = yield self.get_state_at(room_id, now_token)
|
||||
|
||||
notifs = yield self.unread_notifs_for_room_id(
|
||||
room_id, sync_config, ephemeral_by_room
|
||||
room_id, sync_config, all_ephemeral_by_room
|
||||
)
|
||||
|
||||
notif_count = None
|
||||
highlight_count = None
|
||||
if notifs is not None:
|
||||
notif_count = len(notifs)
|
||||
highlight_count = len([
|
||||
1 for notif in notifs if _action_has_highlight(notif["actions"])
|
||||
])
|
||||
|
||||
room_sync = JoinedSyncResult(
|
||||
room_id=room_id,
|
||||
|
@ -705,6 +725,7 @@ class SyncHandler(BaseHandler):
|
|||
room_id, tags_by_room, account_data_by_room
|
||||
),
|
||||
unread_notification_count=notif_count,
|
||||
unread_highlight_count=highlight_count,
|
||||
)
|
||||
|
||||
logger.debug("Room sync: %r", room_sync)
|
||||
|
@ -734,7 +755,7 @@ class SyncHandler(BaseHandler):
|
|||
leave_event.room_id, sync_config, leave_token, since_token,
|
||||
)
|
||||
|
||||
logging.debug("Recents %r", batch)
|
||||
logger.debug("Recents %r", batch)
|
||||
|
||||
state_events_at_leave = yield self.store.get_state_for_event(
|
||||
leave_event.event_id
|
||||
|
@ -850,8 +871,19 @@ class SyncHandler(BaseHandler):
|
|||
notifs = yield self.store.get_unread_event_push_actions_by_room_for_user(
|
||||
room_id, sync_config.user.to_string(), last_unread_event_id
|
||||
)
|
||||
else:
|
||||
defer.returnValue(notifs)
|
||||
|
||||
# There is no new information in this period, so your notification
|
||||
# count is whatever it was last time.
|
||||
defer.returnValue(None)
|
||||
defer.returnValue(notifs)
|
||||
|
||||
|
||||
def _action_has_highlight(actions):
|
||||
for action in actions:
|
||||
try:
|
||||
if action.get("set_tweak", None) == "highlight":
|
||||
return action.get("value", True)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
|
|
@ -37,7 +37,7 @@ class Pusher(object):
|
|||
MAX_BACKOFF = 60 * 60 * 1000
|
||||
GIVE_UP_AFTER = 24 * 60 * 60 * 1000
|
||||
|
||||
def __init__(self, _hs, profile_tag, user_name, app_id,
|
||||
def __init__(self, _hs, profile_tag, user_id, app_id,
|
||||
app_display_name, device_display_name, pushkey, pushkey_ts,
|
||||
data, last_token, last_success, failing_since):
|
||||
self.hs = _hs
|
||||
|
@ -45,7 +45,7 @@ class Pusher(object):
|
|||
self.store = self.hs.get_datastore()
|
||||
self.clock = self.hs.get_clock()
|
||||
self.profile_tag = profile_tag
|
||||
self.user_name = user_name
|
||||
self.user_id = user_id
|
||||
self.app_id = app_id
|
||||
self.app_display_name = app_display_name
|
||||
self.device_display_name = device_display_name
|
||||
|
@ -95,14 +95,14 @@ class Pusher(object):
|
|||
# we fail to dispatch the push)
|
||||
config = PaginationConfig(from_token=None, limit='1')
|
||||
chunk = yield self.evStreamHandler.get_stream(
|
||||
self.user_name, config, timeout=0, affect_presence=False
|
||||
self.user_id, config, timeout=0, affect_presence=False
|
||||
)
|
||||
self.last_token = chunk['end']
|
||||
self.store.update_pusher_last_token(
|
||||
self.app_id, self.pushkey, self.user_name, self.last_token
|
||||
self.app_id, self.pushkey, self.user_id, self.last_token
|
||||
)
|
||||
logger.info("Pusher %s for user %s starting from token %s",
|
||||
self.pushkey, self.user_name, self.last_token)
|
||||
self.pushkey, self.user_id, self.last_token)
|
||||
|
||||
wait = 0
|
||||
while self.alive:
|
||||
|
@ -127,7 +127,7 @@ class Pusher(object):
|
|||
config = PaginationConfig(from_token=from_tok, limit='1')
|
||||
timeout = (300 + random.randint(-60, 60)) * 1000
|
||||
chunk = yield self.evStreamHandler.get_stream(
|
||||
self.user_name, config, timeout=timeout, affect_presence=False
|
||||
self.user_id, config, timeout=timeout, affect_presence=False
|
||||
)
|
||||
|
||||
# limiting to 1 may get 1 event plus 1 presence event, so
|
||||
|
@ -144,7 +144,7 @@ class Pusher(object):
|
|||
if read_receipt:
|
||||
for receipt_part in read_receipt['content'].values():
|
||||
if 'm.read' in receipt_part:
|
||||
if self.user_name in receipt_part['m.read'].keys():
|
||||
if self.user_id in receipt_part['m.read'].keys():
|
||||
have_updated_badge = True
|
||||
|
||||
if not single_event:
|
||||
|
@ -154,7 +154,7 @@ class Pusher(object):
|
|||
yield self.store.update_pusher_last_token(
|
||||
self.app_id,
|
||||
self.pushkey,
|
||||
self.user_name,
|
||||
self.user_id,
|
||||
self.last_token
|
||||
)
|
||||
return
|
||||
|
@ -165,8 +165,8 @@ class Pusher(object):
|
|||
processed = False
|
||||
|
||||
rule_evaluator = yield \
|
||||
push_rule_evaluator.evaluator_for_user_name_and_profile_tag(
|
||||
self.user_name, self.profile_tag, single_event['room_id'], self.store
|
||||
push_rule_evaluator.evaluator_for_user_id_and_profile_tag(
|
||||
self.user_id, self.profile_tag, single_event['room_id'], self.store
|
||||
)
|
||||
|
||||
actions = yield rule_evaluator.actions_for_event(single_event)
|
||||
|
@ -192,7 +192,7 @@ class Pusher(object):
|
|||
pk
|
||||
)
|
||||
yield self.hs.get_pusherpool().remove_pusher(
|
||||
self.app_id, pk, self.user_name
|
||||
self.app_id, pk, self.user_id
|
||||
)
|
||||
else:
|
||||
if have_updated_badge:
|
||||
|
@ -208,7 +208,7 @@ class Pusher(object):
|
|||
yield self.store.update_pusher_last_token_and_success(
|
||||
self.app_id,
|
||||
self.pushkey,
|
||||
self.user_name,
|
||||
self.user_id,
|
||||
self.last_token,
|
||||
self.clock.time_msec()
|
||||
)
|
||||
|
@ -217,7 +217,7 @@ class Pusher(object):
|
|||
yield self.store.update_pusher_failing_since(
|
||||
self.app_id,
|
||||
self.pushkey,
|
||||
self.user_name,
|
||||
self.user_id,
|
||||
self.failing_since)
|
||||
else:
|
||||
if not self.failing_since:
|
||||
|
@ -225,7 +225,7 @@ class Pusher(object):
|
|||
yield self.store.update_pusher_failing_since(
|
||||
self.app_id,
|
||||
self.pushkey,
|
||||
self.user_name,
|
||||
self.user_id,
|
||||
self.failing_since
|
||||
)
|
||||
|
||||
|
@ -237,13 +237,13 @@ class Pusher(object):
|
|||
# of old notifications.
|
||||
logger.warn("Giving up on a notification to user %s, "
|
||||
"pushkey %s",
|
||||
self.user_name, self.pushkey)
|
||||
self.user_id, self.pushkey)
|
||||
self.backoff_delay = Pusher.INITIAL_BACKOFF
|
||||
self.last_token = chunk['end']
|
||||
yield self.store.update_pusher_last_token(
|
||||
self.app_id,
|
||||
self.pushkey,
|
||||
self.user_name,
|
||||
self.user_id,
|
||||
self.last_token
|
||||
)
|
||||
|
||||
|
@ -251,14 +251,14 @@ class Pusher(object):
|
|||
yield self.store.update_pusher_failing_since(
|
||||
self.app_id,
|
||||
self.pushkey,
|
||||
self.user_name,
|
||||
self.user_id,
|
||||
self.failing_since
|
||||
)
|
||||
else:
|
||||
logger.warn("Failed to dispatch push for user %s "
|
||||
"(failing for %dms)."
|
||||
"Trying again in %dms",
|
||||
self.user_name,
|
||||
self.user_id,
|
||||
self.clock.time_msec() - self.failing_since,
|
||||
self.backoff_delay)
|
||||
yield synapse.util.async.sleep(self.backoff_delay / 1000.0)
|
||||
|
@ -299,11 +299,11 @@ class Pusher(object):
|
|||
membership_list = (Membership.INVITE, Membership.JOIN)
|
||||
|
||||
room_list = yield self.store.get_rooms_for_user_where_membership_is(
|
||||
user_id=self.user_name,
|
||||
user_id=self.user_id,
|
||||
membership_list=membership_list
|
||||
)
|
||||
|
||||
user_is_guest = yield self.store.is_guest(UserID.from_string(self.user_name))
|
||||
user_is_guest = yield self.store.is_guest(self.user_id)
|
||||
|
||||
# XXX: importing inside method to break circular dependency.
|
||||
# should sort out the mess by moving all this logic out of
|
||||
|
@ -311,7 +311,7 @@ class Pusher(object):
|
|||
# handler to somewhere more amenable to re-use.
|
||||
from synapse.handlers.sync import SyncConfig
|
||||
sync_config = SyncConfig(
|
||||
user=UserID.from_string(self.user_name),
|
||||
user=UserID.from_string(self.user_id),
|
||||
filter=FilterCollection({}),
|
||||
is_guest=user_is_guest,
|
||||
)
|
||||
|
@ -328,13 +328,13 @@ class Pusher(object):
|
|||
badge += 1
|
||||
else:
|
||||
last_unread_event_id = sync_handler.last_read_event_id_for_room_and_user(
|
||||
r.room_id, self.user_name, ephemeral_by_room
|
||||
r.room_id, self.user_id, ephemeral_by_room
|
||||
)
|
||||
|
||||
if last_unread_event_id:
|
||||
notifs = yield (
|
||||
self.store.get_unread_event_push_actions_by_room_for_user(
|
||||
r.room_id, self.user_name, last_unread_event_id
|
||||
r.room_id, self.user_id, last_unread_event_id
|
||||
)
|
||||
)
|
||||
badge += len(notifs)
|
||||
|
|
|
@ -25,8 +25,9 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class ActionGenerator:
|
||||
def __init__(self, store):
|
||||
self.store = store
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
# really we want to get all user ids and all profile tags too,
|
||||
# since we want the actions for each profile tag for every user and
|
||||
# also actions for a client with no profile tag for each user.
|
||||
|
@ -42,7 +43,7 @@ class ActionGenerator:
|
|||
)
|
||||
|
||||
bulk_evaluator = yield bulk_push_rule_evaluator.evaluator_for_room_id(
|
||||
event.room_id, self.store
|
||||
event.room_id, self.hs, self.store
|
||||
)
|
||||
|
||||
actions_by_user = yield bulk_evaluator.action_for_event_by_user(event, handler)
|
||||
|
|
|
@ -15,27 +15,25 @@
|
|||
from synapse.push.rulekinds import PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
|
||||
|
||||
|
||||
def list_with_base_rules(rawrules, user_name):
|
||||
def list_with_base_rules(rawrules):
|
||||
ruleslist = []
|
||||
|
||||
# shove the server default rules for each kind onto the end of each
|
||||
current_prio_class = PRIORITY_CLASS_INVERSE_MAP.keys()[-1]
|
||||
|
||||
ruleslist.extend(make_base_prepend_rules(
|
||||
user_name, PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
|
||||
PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
|
||||
))
|
||||
|
||||
for r in rawrules:
|
||||
if r['priority_class'] < current_prio_class:
|
||||
while r['priority_class'] < current_prio_class:
|
||||
ruleslist.extend(make_base_append_rules(
|
||||
user_name,
|
||||
PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
|
||||
))
|
||||
current_prio_class -= 1
|
||||
if current_prio_class > 0:
|
||||
ruleslist.extend(make_base_prepend_rules(
|
||||
user_name,
|
||||
PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
|
||||
))
|
||||
|
||||
|
@ -43,58 +41,47 @@ def list_with_base_rules(rawrules, user_name):
|
|||
|
||||
while current_prio_class > 0:
|
||||
ruleslist.extend(make_base_append_rules(
|
||||
user_name,
|
||||
PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
|
||||
))
|
||||
current_prio_class -= 1
|
||||
if current_prio_class > 0:
|
||||
ruleslist.extend(make_base_prepend_rules(
|
||||
user_name,
|
||||
PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
|
||||
))
|
||||
|
||||
return ruleslist
|
||||
|
||||
|
||||
def make_base_append_rules(user, kind):
|
||||
def make_base_append_rules(kind):
|
||||
rules = []
|
||||
|
||||
if kind == 'override':
|
||||
rules = make_base_append_override_rules()
|
||||
rules = BASE_APPEND_OVRRIDE_RULES
|
||||
elif kind == 'underride':
|
||||
rules = make_base_append_underride_rules(user)
|
||||
rules = BASE_APPEND_UNDERRIDE_RULES
|
||||
elif kind == 'content':
|
||||
rules = make_base_append_content_rules(user)
|
||||
|
||||
for r in rules:
|
||||
r['priority_class'] = PRIORITY_CLASS_MAP[kind]
|
||||
r['default'] = True # Deprecated, left for backwards compat
|
||||
rules = BASE_APPEND_CONTENT_RULES
|
||||
|
||||
return rules
|
||||
|
||||
|
||||
def make_base_prepend_rules(user, kind):
|
||||
def make_base_prepend_rules(kind):
|
||||
rules = []
|
||||
|
||||
if kind == 'override':
|
||||
rules = make_base_prepend_override_rules()
|
||||
|
||||
for r in rules:
|
||||
r['priority_class'] = PRIORITY_CLASS_MAP[kind]
|
||||
r['default'] = True # Deprecated, left for backwards compat
|
||||
rules = BASE_PREPEND_OVERRIDE_RULES
|
||||
|
||||
return rules
|
||||
|
||||
|
||||
def make_base_append_content_rules(user):
|
||||
return [
|
||||
BASE_APPEND_CONTENT_RULES = [
|
||||
{
|
||||
'rule_id': 'global/content/.m.rule.contains_user_name',
|
||||
'conditions': [
|
||||
{
|
||||
'kind': 'event_match',
|
||||
'key': 'content.body',
|
||||
'pattern': user.localpart, # Matrix ID match
|
||||
'pattern_type': 'user_localpart'
|
||||
}
|
||||
],
|
||||
'actions': [
|
||||
|
@ -110,8 +97,7 @@ def make_base_append_content_rules(user):
|
|||
]
|
||||
|
||||
|
||||
def make_base_prepend_override_rules():
|
||||
return [
|
||||
BASE_PREPEND_OVERRIDE_RULES = [
|
||||
{
|
||||
'rule_id': 'global/override/.m.rule.master',
|
||||
'enabled': False,
|
||||
|
@ -123,8 +109,7 @@ def make_base_prepend_override_rules():
|
|||
]
|
||||
|
||||
|
||||
def make_base_append_override_rules():
|
||||
return [
|
||||
BASE_APPEND_OVRRIDE_RULES = [
|
||||
{
|
||||
'rule_id': 'global/override/.m.rule.suppress_notices',
|
||||
'conditions': [
|
||||
|
@ -132,6 +117,7 @@ def make_base_append_override_rules():
|
|||
'kind': 'event_match',
|
||||
'key': 'content.msgtype',
|
||||
'pattern': 'm.notice',
|
||||
'_id': '_suppress_notices',
|
||||
}
|
||||
],
|
||||
'actions': [
|
||||
|
@ -141,8 +127,7 @@ def make_base_append_override_rules():
|
|||
]
|
||||
|
||||
|
||||
def make_base_append_underride_rules(user):
|
||||
return [
|
||||
BASE_APPEND_UNDERRIDE_RULES = [
|
||||
{
|
||||
'rule_id': 'global/underride/.m.rule.call',
|
||||
'conditions': [
|
||||
|
@ -150,6 +135,7 @@ def make_base_append_underride_rules(user):
|
|||
'kind': 'event_match',
|
||||
'key': 'type',
|
||||
'pattern': 'm.call.invite',
|
||||
'_id': '_call',
|
||||
}
|
||||
],
|
||||
'actions': [
|
||||
|
@ -185,7 +171,8 @@ def make_base_append_underride_rules(user):
|
|||
'conditions': [
|
||||
{
|
||||
'kind': 'room_member_count',
|
||||
'is': '2'
|
||||
'is': '2',
|
||||
'_id': 'member_count',
|
||||
}
|
||||
],
|
||||
'actions': [
|
||||
|
@ -206,16 +193,18 @@ def make_base_append_underride_rules(user):
|
|||
'kind': 'event_match',
|
||||
'key': 'type',
|
||||
'pattern': 'm.room.member',
|
||||
'_id': '_member',
|
||||
},
|
||||
{
|
||||
'kind': 'event_match',
|
||||
'key': 'content.membership',
|
||||
'pattern': 'invite',
|
||||
'_id': '_invite_member',
|
||||
},
|
||||
{
|
||||
'kind': 'event_match',
|
||||
'key': 'state_key',
|
||||
'pattern': user.to_string(),
|
||||
'pattern_type': 'user_id'
|
||||
},
|
||||
],
|
||||
'actions': [
|
||||
|
@ -236,6 +225,7 @@ def make_base_append_underride_rules(user):
|
|||
'kind': 'event_match',
|
||||
'key': 'type',
|
||||
'pattern': 'm.room.member',
|
||||
'_id': '_member',
|
||||
}
|
||||
],
|
||||
'actions': [
|
||||
|
@ -247,12 +237,12 @@ def make_base_append_underride_rules(user):
|
|||
},
|
||||
{
|
||||
'rule_id': 'global/underride/.m.rule.message',
|
||||
'enabled': False,
|
||||
'conditions': [
|
||||
{
|
||||
'kind': 'event_match',
|
||||
'key': 'type',
|
||||
'pattern': 'm.room.message',
|
||||
'_id': '_message',
|
||||
}
|
||||
],
|
||||
'actions': [
|
||||
|
@ -263,3 +253,20 @@ def make_base_append_underride_rules(user):
|
|||
]
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
for r in BASE_APPEND_CONTENT_RULES:
|
||||
r['priority_class'] = PRIORITY_CLASS_MAP['content']
|
||||
r['default'] = True
|
||||
|
||||
for r in BASE_PREPEND_OVERRIDE_RULES:
|
||||
r['priority_class'] = PRIORITY_CLASS_MAP['override']
|
||||
r['default'] = True
|
||||
|
||||
for r in BASE_APPEND_OVRRIDE_RULES:
|
||||
r['priority_class'] = PRIORITY_CLASS_MAP['override']
|
||||
r['default'] = True
|
||||
|
||||
for r in BASE_APPEND_UNDERRIDE_RULES:
|
||||
r['priority_class'] = PRIORITY_CLASS_MAP['underride']
|
||||
r['default'] = True
|
||||
|
|
|
@ -14,16 +14,15 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import simplejson as json
|
||||
import ujson as json
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.types import UserID
|
||||
|
||||
import baserules
|
||||
from push_rule_evaluator import PushRuleEvaluator
|
||||
from push_rule_evaluator import PushRuleEvaluatorForEvent
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
|
||||
from synapse.events.utils import serialize_event
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -35,28 +34,30 @@ def decode_rule_json(rule):
|
|||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def evaluator_for_room_id(room_id, store):
|
||||
users = yield store.get_users_in_room(room_id)
|
||||
rules_by_user = yield store.bulk_get_push_rules(users)
|
||||
def _get_rules(room_id, user_ids, store):
|
||||
rules_by_user = yield store.bulk_get_push_rules(user_ids)
|
||||
|
||||
rules_by_user = {
|
||||
uid: baserules.list_with_base_rules(
|
||||
[decode_rule_json(rule_list) for rule_list in rules_by_user[uid]]
|
||||
if uid in rules_by_user else [],
|
||||
UserID.from_string(uid),
|
||||
)
|
||||
for uid in users
|
||||
uid: baserules.list_with_base_rules([
|
||||
decode_rule_json(rule_list)
|
||||
for rule_list in rules_by_user.get(uid, [])
|
||||
])
|
||||
for uid in user_ids
|
||||
}
|
||||
member_events = yield store.get_current_state(
|
||||
room_id=room_id,
|
||||
event_type='m.room.member',
|
||||
)
|
||||
display_names = {}
|
||||
for ev in member_events:
|
||||
if ev.content.get("displayname"):
|
||||
display_names[ev.state_key] = ev.content.get("displayname")
|
||||
defer.returnValue(rules_by_user)
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def evaluator_for_room_id(room_id, hs, store):
|
||||
results = yield store.get_receipts_for_room(room_id, "m.read")
|
||||
user_ids = [
|
||||
row["user_id"] for row in results
|
||||
if hs.is_mine_id(row["user_id"])
|
||||
]
|
||||
rules_by_user = yield _get_rules(room_id, user_ids, store)
|
||||
|
||||
defer.returnValue(BulkPushRuleEvaluator(
|
||||
room_id, rules_by_user, display_names, users, store
|
||||
room_id, rules_by_user, user_ids, store
|
||||
))
|
||||
|
||||
|
||||
|
@ -69,10 +70,9 @@ class BulkPushRuleEvaluator:
|
|||
the same logic to run the actual rules, but could be optimised further
|
||||
(see https://matrix.org/jira/browse/SYN-562)
|
||||
"""
|
||||
def __init__(self, room_id, rules_by_user, display_names, users_in_room, store):
|
||||
def __init__(self, room_id, rules_by_user, users_in_room, store):
|
||||
self.room_id = room_id
|
||||
self.rules_by_user = rules_by_user
|
||||
self.display_names = display_names
|
||||
self.users_in_room = users_in_room
|
||||
self.store = store
|
||||
|
||||
|
@ -80,15 +80,30 @@ class BulkPushRuleEvaluator:
|
|||
def action_for_event_by_user(self, event, handler):
|
||||
actions_by_user = {}
|
||||
|
||||
for uid, rules in self.rules_by_user.items():
|
||||
display_name = None
|
||||
if uid in self.display_names:
|
||||
display_name = self.display_names[uid]
|
||||
users_dict = yield self.store.are_guests(self.rules_by_user.keys())
|
||||
|
||||
is_guest = yield self.store.is_guest(UserID.from_string(uid))
|
||||
filtered = yield handler._filter_events_for_client(
|
||||
uid, [event], is_guest=is_guest
|
||||
filtered_by_user = yield handler._filter_events_for_clients(
|
||||
users_dict.items(), [event]
|
||||
)
|
||||
|
||||
evaluator = PushRuleEvaluatorForEvent(event, len(self.users_in_room))
|
||||
|
||||
condition_cache = {}
|
||||
|
||||
member_state = yield self.store.get_state_for_event(
|
||||
event.event_id,
|
||||
)
|
||||
|
||||
display_names = {}
|
||||
for ev in member_state.values():
|
||||
nm = ev.content.get("displayname", None)
|
||||
if nm and ev.type == EventTypes.Member:
|
||||
display_names[ev.state_key] = nm
|
||||
|
||||
for uid, rules in self.rules_by_user.items():
|
||||
display_name = display_names.get(uid, None)
|
||||
|
||||
filtered = filtered_by_user[uid]
|
||||
if len(filtered) == 0:
|
||||
continue
|
||||
|
||||
|
@ -96,29 +111,32 @@ class BulkPushRuleEvaluator:
|
|||
if 'enabled' in rule and not rule['enabled']:
|
||||
continue
|
||||
|
||||
# XXX: profile tags
|
||||
if BulkPushRuleEvaluator.event_matches_rule(
|
||||
event, rule,
|
||||
display_name, len(self.users_in_room), None
|
||||
):
|
||||
matches = _condition_checker(
|
||||
evaluator, rule['conditions'], uid, display_name, condition_cache
|
||||
)
|
||||
if matches:
|
||||
actions = [x for x in rule['actions'] if x != 'dont_notify']
|
||||
if len(actions) > 0:
|
||||
if actions:
|
||||
actions_by_user[uid] = actions
|
||||
break
|
||||
defer.returnValue(actions_by_user)
|
||||
|
||||
@staticmethod
|
||||
def event_matches_rule(event, rule,
|
||||
display_name, room_member_count, profile_tag):
|
||||
matches = True
|
||||
|
||||
# passing the clock all the way into here is extremely awkward and push
|
||||
# rules do not care about any of the relative timestamps, so we just
|
||||
# pass 0 for the current time.
|
||||
client_event = serialize_event(event, 0)
|
||||
def _condition_checker(evaluator, conditions, uid, display_name, cache):
|
||||
for cond in conditions:
|
||||
_id = cond.get("_id", None)
|
||||
if _id:
|
||||
res = cache.get(_id, None)
|
||||
if res is False:
|
||||
return False
|
||||
elif res is True:
|
||||
continue
|
||||
|
||||
for cond in rule['conditions']:
|
||||
matches &= PushRuleEvaluator._event_fulfills_condition(
|
||||
client_event, cond, display_name, room_member_count, profile_tag
|
||||
)
|
||||
return matches
|
||||
res = evaluator.matches(cond, uid, display_name, None)
|
||||
if _id:
|
||||
cache[_id] = bool(res)
|
||||
|
||||
if not res:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
|
|
@ -23,13 +23,13 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class HttpPusher(Pusher):
|
||||
def __init__(self, _hs, profile_tag, user_name, app_id,
|
||||
def __init__(self, _hs, profile_tag, user_id, app_id,
|
||||
app_display_name, device_display_name, pushkey, pushkey_ts,
|
||||
data, last_token, last_success, failing_since):
|
||||
super(HttpPusher, self).__init__(
|
||||
_hs,
|
||||
profile_tag,
|
||||
user_name,
|
||||
user_id,
|
||||
app_id,
|
||||
app_display_name,
|
||||
device_display_name,
|
||||
|
@ -87,7 +87,7 @@ class HttpPusher(Pusher):
|
|||
}
|
||||
if event['type'] == 'm.room.member':
|
||||
d['notification']['membership'] = event['content']['membership']
|
||||
d['notification']['user_is_target'] = event['state_key'] == self.user_name
|
||||
d['notification']['user_is_target'] = event['state_key'] == self.user_id
|
||||
if 'content' in event:
|
||||
d['notification']['content'] = event['content']
|
||||
|
||||
|
@ -117,7 +117,7 @@ class HttpPusher(Pusher):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def send_badge(self, badge):
|
||||
logger.info("Sending updated badge count %d to %r", badge, self.user_name)
|
||||
logger.info("Sending updated badge count %d to %r", badge, self.user_id)
|
||||
d = {
|
||||
'notification': {
|
||||
'id': '',
|
||||
|
|
|
@ -15,188 +15,43 @@
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.types import UserID
|
||||
|
||||
import baserules
|
||||
|
||||
import logging
|
||||
import simplejson as json
|
||||
import re
|
||||
|
||||
from synapse.types import UserID
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
GLOB_REGEX = re.compile(r'\\\[(\\\!|)(.*)\\\]')
|
||||
IS_GLOB = re.compile(r'[\?\*\[\]]')
|
||||
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def evaluator_for_user_name_and_profile_tag(user_name, profile_tag, room_id, store):
|
||||
rawrules = yield store.get_push_rules_for_user(user_name)
|
||||
enabled_map = yield store.get_push_rules_enabled_for_user(user_name)
|
||||
def evaluator_for_user_id_and_profile_tag(user_id, profile_tag, room_id, store):
|
||||
rawrules = yield store.get_push_rules_for_user(user_id)
|
||||
enabled_map = yield store.get_push_rules_enabled_for_user(user_id)
|
||||
our_member_event = yield store.get_current_state(
|
||||
room_id=room_id,
|
||||
event_type='m.room.member',
|
||||
state_key=user_name,
|
||||
state_key=user_id,
|
||||
)
|
||||
|
||||
defer.returnValue(PushRuleEvaluator(
|
||||
user_name, profile_tag, rawrules, enabled_map,
|
||||
user_id, profile_tag, rawrules, enabled_map,
|
||||
room_id, our_member_event, store
|
||||
))
|
||||
|
||||
|
||||
class PushRuleEvaluator:
|
||||
DEFAULT_ACTIONS = []
|
||||
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
|
||||
|
||||
def __init__(self, user_name, profile_tag, raw_rules, enabled_map, room_id,
|
||||
our_member_event, store):
|
||||
self.user_name = user_name
|
||||
self.profile_tag = profile_tag
|
||||
self.room_id = room_id
|
||||
self.our_member_event = our_member_event
|
||||
self.store = store
|
||||
|
||||
rules = []
|
||||
for raw_rule in raw_rules:
|
||||
rule = dict(raw_rule)
|
||||
rule['conditions'] = json.loads(raw_rule['conditions'])
|
||||
rule['actions'] = json.loads(raw_rule['actions'])
|
||||
rules.append(rule)
|
||||
|
||||
user = UserID.from_string(self.user_name)
|
||||
self.rules = baserules.list_with_base_rules(rules, user)
|
||||
|
||||
self.enabled_map = enabled_map
|
||||
|
||||
@staticmethod
|
||||
def tweaks_for_actions(actions):
|
||||
tweaks = {}
|
||||
for a in actions:
|
||||
if not isinstance(a, dict):
|
||||
continue
|
||||
if 'set_tweak' in a and 'value' in a:
|
||||
tweaks[a['set_tweak']] = a['value']
|
||||
return tweaks
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def actions_for_event(self, ev):
|
||||
"""
|
||||
This should take into account notification settings that the user
|
||||
has configured both globally and per-room when we have the ability
|
||||
to do such things.
|
||||
"""
|
||||
if ev['user_id'] == self.user_name:
|
||||
# let's assume you probably know about messages you sent yourself
|
||||
defer.returnValue([])
|
||||
|
||||
room_id = ev['room_id']
|
||||
|
||||
# get *our* member event for display name matching
|
||||
my_display_name = None
|
||||
|
||||
if self.our_member_event:
|
||||
my_display_name = self.our_member_event[0].content.get("displayname")
|
||||
|
||||
room_members = yield self.store.get_users_in_room(room_id)
|
||||
room_member_count = len(room_members)
|
||||
|
||||
for r in self.rules:
|
||||
if r['rule_id'] in self.enabled_map:
|
||||
r['enabled'] = self.enabled_map[r['rule_id']]
|
||||
elif 'enabled' not in r:
|
||||
r['enabled'] = True
|
||||
if not r['enabled']:
|
||||
continue
|
||||
matches = True
|
||||
|
||||
conditions = r['conditions']
|
||||
actions = r['actions']
|
||||
|
||||
for c in conditions:
|
||||
matches &= self._event_fulfills_condition(
|
||||
ev, c, display_name=my_display_name,
|
||||
room_member_count=room_member_count,
|
||||
profile_tag=self.profile_tag
|
||||
)
|
||||
logger.debug(
|
||||
"Rule %s %s",
|
||||
r['rule_id'], "matches" if matches else "doesn't match"
|
||||
)
|
||||
# ignore rules with no actions (we have an explict 'dont_notify')
|
||||
if len(actions) == 0:
|
||||
logger.warn(
|
||||
"Ignoring rule id %s with no actions for user %s",
|
||||
r['rule_id'], self.user_name
|
||||
)
|
||||
continue
|
||||
if matches:
|
||||
logger.info(
|
||||
"%s matches for user %s, event %s",
|
||||
r['rule_id'], self.user_name, ev['event_id']
|
||||
)
|
||||
|
||||
# filter out dont_notify as we treat an empty actions list
|
||||
# as dont_notify, and this doesn't take up a row in our database
|
||||
actions = [x for x in actions if x != 'dont_notify']
|
||||
|
||||
defer.returnValue(actions)
|
||||
|
||||
logger.info(
|
||||
"No rules match for user %s, event %s",
|
||||
self.user_name, ev['event_id']
|
||||
)
|
||||
defer.returnValue(PushRuleEvaluator.DEFAULT_ACTIONS)
|
||||
|
||||
@staticmethod
|
||||
def _glob_to_regexp(glob):
|
||||
r = re.escape(glob)
|
||||
r = re.sub(r'\\\*', r'.*?', r)
|
||||
r = re.sub(r'\\\?', r'.', r)
|
||||
|
||||
# handle [abc], [a-z] and [!a-z] style ranges.
|
||||
r = re.sub(r'\\\[(\\\!|)(.*)\\\]',
|
||||
lambda x: ('[%s%s]' % (x.group(1) and '^' or '',
|
||||
re.sub(r'\\\-', '-', x.group(2)))), r)
|
||||
return r
|
||||
|
||||
@staticmethod
|
||||
def _event_fulfills_condition(ev, condition,
|
||||
display_name, room_member_count, profile_tag):
|
||||
if condition['kind'] == 'event_match':
|
||||
if 'pattern' not in condition:
|
||||
logger.warn("event_match condition with no pattern")
|
||||
return False
|
||||
# XXX: optimisation: cache our pattern regexps
|
||||
if condition['key'] == 'content.body':
|
||||
r = r'\b%s\b' % PushRuleEvaluator._glob_to_regexp(condition['pattern'])
|
||||
else:
|
||||
r = r'^%s$' % PushRuleEvaluator._glob_to_regexp(condition['pattern'])
|
||||
val = _value_for_dotted_key(condition['key'], ev)
|
||||
if val is None:
|
||||
return False
|
||||
return re.search(r, val, flags=re.IGNORECASE) is not None
|
||||
|
||||
elif condition['kind'] == 'device':
|
||||
if 'profile_tag' not in condition:
|
||||
return True
|
||||
return condition['profile_tag'] == profile_tag
|
||||
|
||||
elif condition['kind'] == 'contains_display_name':
|
||||
# This is special because display names can be different
|
||||
# between rooms and so you can't really hard code it in a rule.
|
||||
# Optimisation: we should cache these names and update them from
|
||||
# the event stream.
|
||||
if 'content' not in ev or 'body' not in ev['content']:
|
||||
return False
|
||||
if not display_name:
|
||||
return False
|
||||
return re.search(
|
||||
r"\b%s\b" % re.escape(display_name), ev['content']['body'],
|
||||
flags=re.IGNORECASE
|
||||
) is not None
|
||||
|
||||
elif condition['kind'] == 'room_member_count':
|
||||
def _room_member_count(ev, condition, room_member_count):
|
||||
if 'is' not in condition:
|
||||
return False
|
||||
m = PushRuleEvaluator.INEQUALITY_EXPR.match(condition['is'])
|
||||
m = INEQUALITY_EXPR.match(condition['is'])
|
||||
if not m:
|
||||
return False
|
||||
ineq = m.group(1)
|
||||
|
@ -217,16 +72,251 @@ class PushRuleEvaluator:
|
|||
return room_member_count <= rhs
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class PushRuleEvaluator:
|
||||
DEFAULT_ACTIONS = []
|
||||
|
||||
def __init__(self, user_id, profile_tag, raw_rules, enabled_map, room_id,
|
||||
our_member_event, store):
|
||||
self.user_id = user_id
|
||||
self.profile_tag = profile_tag
|
||||
self.room_id = room_id
|
||||
self.our_member_event = our_member_event
|
||||
self.store = store
|
||||
|
||||
rules = []
|
||||
for raw_rule in raw_rules:
|
||||
rule = dict(raw_rule)
|
||||
rule['conditions'] = json.loads(raw_rule['conditions'])
|
||||
rule['actions'] = json.loads(raw_rule['actions'])
|
||||
rules.append(rule)
|
||||
|
||||
self.rules = baserules.list_with_base_rules(rules)
|
||||
|
||||
self.enabled_map = enabled_map
|
||||
|
||||
@staticmethod
|
||||
def tweaks_for_actions(actions):
|
||||
tweaks = {}
|
||||
for a in actions:
|
||||
if not isinstance(a, dict):
|
||||
continue
|
||||
if 'set_tweak' in a and 'value' in a:
|
||||
tweaks[a['set_tweak']] = a['value']
|
||||
return tweaks
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def actions_for_event(self, ev):
|
||||
"""
|
||||
This should take into account notification settings that the user
|
||||
has configured both globally and per-room when we have the ability
|
||||
to do such things.
|
||||
"""
|
||||
if ev['user_id'] == self.user_id:
|
||||
# let's assume you probably know about messages you sent yourself
|
||||
defer.returnValue([])
|
||||
|
||||
room_id = ev['room_id']
|
||||
|
||||
# get *our* member event for display name matching
|
||||
my_display_name = None
|
||||
|
||||
if self.our_member_event:
|
||||
my_display_name = self.our_member_event[0].content.get("displayname")
|
||||
|
||||
room_members = yield self.store.get_users_in_room(room_id)
|
||||
room_member_count = len(room_members)
|
||||
|
||||
evaluator = PushRuleEvaluatorForEvent(ev, room_member_count)
|
||||
|
||||
for r in self.rules:
|
||||
enabled = self.enabled_map.get(r['rule_id'], None)
|
||||
if enabled is not None and not enabled:
|
||||
continue
|
||||
|
||||
if not r.get("enabled", True):
|
||||
continue
|
||||
|
||||
conditions = r['conditions']
|
||||
actions = r['actions']
|
||||
|
||||
# ignore rules with no actions (we have an explict 'dont_notify')
|
||||
if len(actions) == 0:
|
||||
logger.warn(
|
||||
"Ignoring rule id %s with no actions for user %s",
|
||||
r['rule_id'], self.user_id
|
||||
)
|
||||
continue
|
||||
|
||||
matches = True
|
||||
for c in conditions:
|
||||
matches = evaluator.matches(
|
||||
c, self.user_id, my_display_name, self.profile_tag
|
||||
)
|
||||
if not matches:
|
||||
break
|
||||
|
||||
logger.debug(
|
||||
"Rule %s %s",
|
||||
r['rule_id'], "matches" if matches else "doesn't match"
|
||||
)
|
||||
|
||||
if matches:
|
||||
logger.debug(
|
||||
"%s matches for user %s, event %s",
|
||||
r['rule_id'], self.user_id, ev['event_id']
|
||||
)
|
||||
|
||||
# filter out dont_notify as we treat an empty actions list
|
||||
# as dont_notify, and this doesn't take up a row in our database
|
||||
actions = [x for x in actions if x != 'dont_notify']
|
||||
|
||||
defer.returnValue(actions)
|
||||
|
||||
logger.debug(
|
||||
"No rules match for user %s, event %s",
|
||||
self.user_id, ev['event_id']
|
||||
)
|
||||
defer.returnValue(PushRuleEvaluator.DEFAULT_ACTIONS)
|
||||
|
||||
|
||||
class PushRuleEvaluatorForEvent(object):
|
||||
def __init__(self, event, room_member_count):
|
||||
self._event = event
|
||||
self._room_member_count = room_member_count
|
||||
|
||||
# Maps strings of e.g. 'content.body' -> event["content"]["body"]
|
||||
self._value_cache = _flatten_dict(event)
|
||||
|
||||
def matches(self, condition, user_id, display_name, profile_tag):
|
||||
if condition['kind'] == 'event_match':
|
||||
return self._event_match(condition, user_id)
|
||||
elif condition['kind'] == 'device':
|
||||
if 'profile_tag' not in condition:
|
||||
return True
|
||||
return condition['profile_tag'] == profile_tag
|
||||
elif condition['kind'] == 'contains_display_name':
|
||||
return self._contains_display_name(display_name)
|
||||
elif condition['kind'] == 'room_member_count':
|
||||
return _room_member_count(
|
||||
self._event, condition, self._room_member_count
|
||||
)
|
||||
else:
|
||||
return True
|
||||
|
||||
def _event_match(self, condition, user_id):
|
||||
pattern = condition.get('pattern', None)
|
||||
|
||||
def _value_for_dotted_key(dotted_key, event):
|
||||
parts = dotted_key.split(".")
|
||||
val = event
|
||||
while len(parts) > 0:
|
||||
if parts[0] not in val:
|
||||
return None
|
||||
val = val[parts[0]]
|
||||
parts = parts[1:]
|
||||
return val
|
||||
if not pattern:
|
||||
pattern_type = condition.get('pattern_type', None)
|
||||
if pattern_type == "user_id":
|
||||
pattern = user_id
|
||||
elif pattern_type == "user_localpart":
|
||||
pattern = UserID.from_string(user_id).localpart
|
||||
|
||||
if not pattern:
|
||||
logger.warn("event_match condition with no pattern")
|
||||
return False
|
||||
|
||||
# XXX: optimisation: cache our pattern regexps
|
||||
if condition['key'] == 'content.body':
|
||||
body = self._event["content"].get("body", None)
|
||||
if not body:
|
||||
return False
|
||||
|
||||
return _glob_matches(pattern, body, word_boundary=True)
|
||||
else:
|
||||
haystack = self._get_value(condition['key'])
|
||||
if haystack is None:
|
||||
return False
|
||||
|
||||
return _glob_matches(pattern, haystack)
|
||||
|
||||
def _contains_display_name(self, display_name):
|
||||
if not display_name:
|
||||
return False
|
||||
|
||||
body = self._event["content"].get("body", None)
|
||||
if not body:
|
||||
return False
|
||||
|
||||
return _glob_matches(display_name, body, word_boundary=True)
|
||||
|
||||
def _get_value(self, dotted_key):
|
||||
return self._value_cache.get(dotted_key, None)
|
||||
|
||||
|
||||
def _glob_matches(glob, value, word_boundary=False):
|
||||
"""Tests if value matches glob.
|
||||
|
||||
Args:
|
||||
glob (string)
|
||||
value (string): String to test against glob.
|
||||
word_boundary (bool): Whether to match against word boundaries or entire
|
||||
string. Defaults to False.
|
||||
|
||||
Returns:
|
||||
bool
|
||||
"""
|
||||
try:
|
||||
if IS_GLOB.search(glob):
|
||||
r = re.escape(glob)
|
||||
|
||||
r = r.replace(r'\*', '.*?')
|
||||
r = r.replace(r'\?', '.')
|
||||
|
||||
# handle [abc], [a-z] and [!a-z] style ranges.
|
||||
r = GLOB_REGEX.sub(
|
||||
lambda x: (
|
||||
'[%s%s]' % (
|
||||
x.group(1) and '^' or '',
|
||||
x.group(2).replace(r'\\\-', '-')
|
||||
)
|
||||
),
|
||||
r,
|
||||
)
|
||||
if word_boundary:
|
||||
r = r"\b%s\b" % (r,)
|
||||
r = _compile_regex(r)
|
||||
|
||||
return r.search(value)
|
||||
else:
|
||||
r = r + "$"
|
||||
r = _compile_regex(r)
|
||||
|
||||
return r.match(value)
|
||||
elif word_boundary:
|
||||
r = re.escape(glob)
|
||||
r = r"\b%s\b" % (r,)
|
||||
r = _compile_regex(r)
|
||||
|
||||
return r.search(value)
|
||||
else:
|
||||
return value.lower() == glob.lower()
|
||||
except re.error:
|
||||
logger.warn("Failed to parse glob to regex: %r", glob)
|
||||
return False
|
||||
|
||||
|
||||
def _flatten_dict(d, prefix=[], result={}):
|
||||
for key, value in d.items():
|
||||
if isinstance(value, basestring):
|
||||
result[".".join(prefix + [key])] = value.lower()
|
||||
elif hasattr(value, "items"):
|
||||
_flatten_dict(value, prefix=(prefix+[key]), result=result)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
regex_cache = LruCache(5000)
|
||||
|
||||
|
||||
def _compile_regex(regex_str):
|
||||
r = regex_cache.get(regex_str, None)
|
||||
if r:
|
||||
return r
|
||||
|
||||
r = re.compile(regex_str, flags=re.IGNORECASE)
|
||||
regex_cache[regex_str] = r
|
||||
return r
|
||||
|
|
|
@ -37,14 +37,14 @@ class PusherPool:
|
|||
self._start_pushers(pushers)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_pusher(self, user_name, access_token, profile_tag, kind, app_id,
|
||||
def add_pusher(self, user_id, access_token, profile_tag, kind, app_id,
|
||||
app_display_name, device_display_name, pushkey, lang, data):
|
||||
# we try to create the pusher just to validate the config: it
|
||||
# will then get pulled out of the database,
|
||||
# recreated, added and started: this means we have only one
|
||||
# code path adding pushers.
|
||||
self._create_pusher({
|
||||
"user_name": user_name,
|
||||
"user_name": user_id,
|
||||
"kind": kind,
|
||||
"profile_tag": profile_tag,
|
||||
"app_id": app_id,
|
||||
|
@ -59,7 +59,7 @@ class PusherPool:
|
|||
"failing_since": None
|
||||
})
|
||||
yield self._add_pusher_to_store(
|
||||
user_name, access_token, profile_tag, kind, app_id,
|
||||
user_id, access_token, profile_tag, kind, app_id,
|
||||
app_display_name, device_display_name,
|
||||
pushkey, lang, data
|
||||
)
|
||||
|
@ -94,11 +94,11 @@ class PusherPool:
|
|||
self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _add_pusher_to_store(self, user_name, access_token, profile_tag, kind,
|
||||
def _add_pusher_to_store(self, user_id, access_token, profile_tag, kind,
|
||||
app_id, app_display_name, device_display_name,
|
||||
pushkey, lang, data):
|
||||
yield self.store.add_pusher(
|
||||
user_name=user_name,
|
||||
user_id=user_id,
|
||||
access_token=access_token,
|
||||
profile_tag=profile_tag,
|
||||
kind=kind,
|
||||
|
@ -110,14 +110,14 @@ class PusherPool:
|
|||
lang=lang,
|
||||
data=data,
|
||||
)
|
||||
self._refresh_pusher(app_id, pushkey, user_name)
|
||||
self._refresh_pusher(app_id, pushkey, user_id)
|
||||
|
||||
def _create_pusher(self, pusherdict):
|
||||
if pusherdict['kind'] == 'http':
|
||||
return HttpPusher(
|
||||
self.hs,
|
||||
profile_tag=pusherdict['profile_tag'],
|
||||
user_name=pusherdict['user_name'],
|
||||
user_id=pusherdict['user_name'],
|
||||
app_id=pusherdict['app_id'],
|
||||
app_display_name=pusherdict['app_display_name'],
|
||||
device_display_name=pusherdict['device_display_name'],
|
||||
|
@ -135,14 +135,14 @@ class PusherPool:
|
|||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _refresh_pusher(self, app_id, pushkey, user_name):
|
||||
def _refresh_pusher(self, app_id, pushkey, user_id):
|
||||
resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(
|
||||
app_id, pushkey
|
||||
)
|
||||
|
||||
p = None
|
||||
for r in resultlist:
|
||||
if r['user_name'] == user_name:
|
||||
if r['user_name'] == user_id:
|
||||
p = r
|
||||
|
||||
if p:
|
||||
|
@ -171,12 +171,12 @@ class PusherPool:
|
|||
logger.info("Started pushers")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def remove_pusher(self, app_id, pushkey, user_name):
|
||||
fullid = "%s:%s:%s" % (app_id, pushkey, user_name)
|
||||
def remove_pusher(self, app_id, pushkey, user_id):
|
||||
fullid = "%s:%s:%s" % (app_id, pushkey, user_id)
|
||||
if fullid in self.pushers:
|
||||
logger.info("Stopping pusher %s", fullid)
|
||||
self.pushers[fullid].stop()
|
||||
del self.pushers[fullid]
|
||||
yield self.store.delete_pusher_by_app_id_pushkey_user_name(
|
||||
app_id, pushkey, user_name
|
||||
yield self.store.delete_pusher_by_app_id_pushkey_user_id(
|
||||
app_id, pushkey, user_id
|
||||
)
|
||||
|
|
|
@ -27,6 +27,7 @@ from synapse.push.rulekinds import (
|
|||
PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
|
||||
)
|
||||
|
||||
import copy
|
||||
import simplejson as json
|
||||
|
||||
|
||||
|
@ -51,7 +52,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
|||
content = _parse_json(request)
|
||||
|
||||
if 'attr' in spec:
|
||||
self.set_rule_attr(requester.user, spec, content)
|
||||
self.set_rule_attr(requester.user.to_string(), spec, content)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
try:
|
||||
|
@ -73,7 +74,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
|||
|
||||
try:
|
||||
yield self.hs.get_datastore().add_push_rule(
|
||||
user_name=requester.user.to_string(),
|
||||
user_id=requester.user.to_string(),
|
||||
rule_id=_namespaced_rule_id_from_spec(spec),
|
||||
priority_class=priority_class,
|
||||
conditions=conditions,
|
||||
|
@ -126,7 +127,8 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
|||
rule["actions"] = json.loads(rawrule["actions"])
|
||||
ruleslist.append(rule)
|
||||
|
||||
ruleslist = baserules.list_with_base_rules(ruleslist, user)
|
||||
# We're going to be mutating this a lot, so do a deep copy
|
||||
ruleslist = copy.deepcopy(baserules.list_with_base_rules(ruleslist))
|
||||
|
||||
rules = {'global': {}, 'device': {}}
|
||||
|
||||
|
@ -140,6 +142,16 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
|||
|
||||
template_name = _priority_class_to_template_name(r['priority_class'])
|
||||
|
||||
# Remove internal stuff.
|
||||
for c in r["conditions"]:
|
||||
c.pop("_id", None)
|
||||
|
||||
pattern_type = c.pop("pattern_type", None)
|
||||
if pattern_type == "user_id":
|
||||
c["pattern"] = user.to_string()
|
||||
elif pattern_type == "user_localpart":
|
||||
c["pattern"] = user.localpart
|
||||
|
||||
if r['priority_class'] > PRIORITY_CLASS_MAP['override']:
|
||||
# per-device rule
|
||||
profile_tag = _profile_tag_from_conditions(r["conditions"])
|
||||
|
@ -206,7 +218,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
|||
def on_OPTIONS(self, _):
|
||||
return 200, {}
|
||||
|
||||
def set_rule_attr(self, user_name, spec, val):
|
||||
def set_rule_attr(self, user_id, spec, val):
|
||||
if spec['attr'] == 'enabled':
|
||||
if isinstance(val, dict) and "enabled" in val:
|
||||
val = val["enabled"]
|
||||
|
@ -217,15 +229,15 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
|||
raise SynapseError(400, "Value for 'enabled' must be boolean")
|
||||
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
|
||||
self.hs.get_datastore().set_push_rule_enabled(
|
||||
user_name, namespaced_rule_id, val
|
||||
user_id, namespaced_rule_id, val
|
||||
)
|
||||
else:
|
||||
raise UnrecognizedRequestError()
|
||||
|
||||
def get_rule_attr(self, user_name, namespaced_rule_id, attr):
|
||||
def get_rule_attr(self, user_id, namespaced_rule_id, attr):
|
||||
if attr == 'enabled':
|
||||
return self.hs.get_datastore().get_push_rule_enabled_by_user_rule_id(
|
||||
user_name, namespaced_rule_id
|
||||
user_id, namespaced_rule_id
|
||||
)
|
||||
else:
|
||||
raise UnrecognizedRequestError()
|
||||
|
|
|
@ -41,7 +41,7 @@ class PusherRestServlet(ClientV1RestServlet):
|
|||
and 'kind' in content and
|
||||
content['kind'] is None):
|
||||
yield pusher_pool.remove_pusher(
|
||||
content['app_id'], content['pushkey'], user_name=user.to_string()
|
||||
content['app_id'], content['pushkey'], user_id=user.to_string()
|
||||
)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
@ -71,7 +71,7 @@ class PusherRestServlet(ClientV1RestServlet):
|
|||
|
||||
try:
|
||||
yield pusher_pool.add_pusher(
|
||||
user_name=user.to_string(),
|
||||
user_id=user.to_string(),
|
||||
access_token=requester.access_token_id,
|
||||
profile_tag=content['profile_tag'],
|
||||
kind=content['kind'],
|
||||
|
|
|
@ -414,10 +414,16 @@ class RoomEventContext(ClientV1RestServlet):
|
|||
requester.is_guest,
|
||||
)
|
||||
|
||||
if not results:
|
||||
raise SynapseError(
|
||||
404, "Event not found.", errcode=Codes.NOT_FOUND
|
||||
)
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
results["events_before"] = [
|
||||
serialize_event(event, time_now) for event in results["events_before"]
|
||||
]
|
||||
results["event"] = serialize_event(results["event"], time_now)
|
||||
results["events_after"] = [
|
||||
serialize_event(event, time_now) for event in results["events_after"]
|
||||
]
|
||||
|
@ -436,7 +442,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
|
|||
def register(self, http_server):
|
||||
# /rooms/$roomid/[invite|join|leave]
|
||||
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
|
||||
"(?P<membership_action>join|invite|leave|ban|kick|forget)")
|
||||
"(?P<membership_action>join|invite|leave|ban|unban|kick|forget)")
|
||||
register_txn_path(self, PATTERNS, http_server)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -445,9 +451,6 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
|
|||
request,
|
||||
allow_guest=True,
|
||||
)
|
||||
user = requester.user
|
||||
|
||||
effective_membership_action = membership_action
|
||||
|
||||
if requester.is_guest and membership_action not in {
|
||||
Membership.JOIN,
|
||||
|
@ -457,13 +460,10 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
|
|||
|
||||
content = _parse_json(request)
|
||||
|
||||
# target user is you unless it is an invite
|
||||
state_key = user.to_string()
|
||||
|
||||
if membership_action == "invite" and self._has_3pid_invite_keys(content):
|
||||
yield self.handlers.room_member_handler.do_3pid_invite(
|
||||
room_id,
|
||||
user,
|
||||
requester.user,
|
||||
content["medium"],
|
||||
content["address"],
|
||||
content["id_server"],
|
||||
|
@ -472,42 +472,21 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
|
|||
)
|
||||
defer.returnValue((200, {}))
|
||||
return
|
||||
elif membership_action in ["invite", "ban", "kick"]:
|
||||
if "user_id" in content:
|
||||
state_key = content["user_id"]
|
||||
else:
|
||||
|
||||
target = requester.user
|
||||
if membership_action in ["invite", "ban", "unban", "kick"]:
|
||||
if "user_id" not in content:
|
||||
raise SynapseError(400, "Missing user_id key.")
|
||||
target = UserID.from_string(content["user_id"])
|
||||
|
||||
# make sure it looks like a user ID; it'll throw if it's invalid.
|
||||
UserID.from_string(state_key)
|
||||
|
||||
if membership_action == "kick":
|
||||
effective_membership_action = "leave"
|
||||
elif membership_action == "forget":
|
||||
effective_membership_action = "leave"
|
||||
|
||||
msg_handler = self.handlers.message_handler
|
||||
|
||||
content = {"membership": unicode(effective_membership_action)}
|
||||
if requester.is_guest:
|
||||
content["kind"] = "guest"
|
||||
|
||||
yield msg_handler.create_and_send_event(
|
||||
{
|
||||
"type": EventTypes.Member,
|
||||
"content": content,
|
||||
"room_id": room_id,
|
||||
"sender": user.to_string(),
|
||||
"state_key": state_key,
|
||||
},
|
||||
token_id=requester.access_token_id,
|
||||
yield self.handlers.room_member_handler.update_membership(
|
||||
requester=requester,
|
||||
target=target,
|
||||
room_id=room_id,
|
||||
action=membership_action,
|
||||
txn_id=txn_id,
|
||||
is_guest=requester.is_guest,
|
||||
)
|
||||
|
||||
if membership_action == "forget":
|
||||
yield self.handlers.room_member_handler.forget(user, room_id)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
def _has_3pid_invite_keys(self, content):
|
||||
|
|
|
@ -313,6 +313,7 @@ class SyncRestServlet(RestServlet):
|
|||
ephemeral_events = filter.filter_room_ephemeral(room.ephemeral)
|
||||
result["ephemeral"] = {"events": ephemeral_events}
|
||||
result["unread_notification_count"] = room.unread_notification_count
|
||||
result["unread_highlight_count"] = room.unread_highlight_count
|
||||
|
||||
return result
|
||||
|
||||
|
|
|
@ -139,6 +139,9 @@ class BaseHomeServer(object):
|
|||
def is_mine(self, domain_specific_string):
|
||||
return domain_specific_string.domain == self.hostname
|
||||
|
||||
def is_mine_id(self, string):
|
||||
return string.split(":", 1)[1] == self.hostname
|
||||
|
||||
# Build magic accessors for every dependency
|
||||
for depname in BaseHomeServer.DEPENDENCIES:
|
||||
BaseHomeServer._make_dependency_method(depname)
|
||||
|
|
|
@ -15,12 +15,12 @@
|
|||
import logging
|
||||
import urllib
|
||||
import yaml
|
||||
from simplejson import JSONDecodeError
|
||||
import simplejson as json
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.appservice import ApplicationService, AppServiceTransaction
|
||||
from synapse.config._base import ConfigError
|
||||
from synapse.storage.roommember import RoomsForUser
|
||||
from synapse.types import UserID
|
||||
from ._base import SQLBaseStore
|
||||
|
@ -144,66 +144,9 @@ class ApplicationServiceStore(SQLBaseStore):
|
|||
|
||||
return rooms_for_user_matching_user_id
|
||||
|
||||
def _parse_services_dict(self, results):
|
||||
# SQL results in the form:
|
||||
# [
|
||||
# {
|
||||
# 'regex': "something",
|
||||
# 'url': "something",
|
||||
# 'namespace': enum,
|
||||
# 'as_id': 0,
|
||||
# 'token': "something",
|
||||
# 'hs_token': "otherthing",
|
||||
# 'id': 0
|
||||
# }
|
||||
# ]
|
||||
services = {}
|
||||
for res in results:
|
||||
as_token = res["token"]
|
||||
if as_token is None:
|
||||
continue
|
||||
if as_token not in services:
|
||||
# add the service
|
||||
services[as_token] = {
|
||||
"id": res["id"],
|
||||
"url": res["url"],
|
||||
"token": as_token,
|
||||
"hs_token": res["hs_token"],
|
||||
"sender": res["sender"],
|
||||
"namespaces": {
|
||||
ApplicationService.NS_USERS: [],
|
||||
ApplicationService.NS_ALIASES: [],
|
||||
ApplicationService.NS_ROOMS: []
|
||||
}
|
||||
}
|
||||
# add the namespace regex if one exists
|
||||
ns_int = res["namespace"]
|
||||
if ns_int is None:
|
||||
continue
|
||||
try:
|
||||
services[as_token]["namespaces"][
|
||||
ApplicationService.NS_LIST[ns_int]].append(
|
||||
json.loads(res["regex"])
|
||||
)
|
||||
except IndexError:
|
||||
logger.error("Bad namespace enum '%s'. %s", ns_int, res)
|
||||
except JSONDecodeError:
|
||||
logger.error("Bad regex object '%s'", res["regex"])
|
||||
|
||||
service_list = []
|
||||
for service in services.values():
|
||||
service_list.append(ApplicationService(
|
||||
token=service["token"],
|
||||
url=service["url"],
|
||||
namespaces=service["namespaces"],
|
||||
hs_token=service["hs_token"],
|
||||
sender=service["sender"],
|
||||
id=service["id"]
|
||||
))
|
||||
return service_list
|
||||
|
||||
def _load_appservice(self, as_info):
|
||||
required_string_fields = [
|
||||
# TODO: Add id here when it's stable to release
|
||||
"url", "as_token", "hs_token", "sender_localpart"
|
||||
]
|
||||
for field in required_string_fields:
|
||||
|
@ -245,7 +188,7 @@ class ApplicationServiceStore(SQLBaseStore):
|
|||
namespaces=as_info["namespaces"],
|
||||
hs_token=as_info["hs_token"],
|
||||
sender=user_id,
|
||||
id=as_info["as_token"] # the token is the only unique thing here
|
||||
id=as_info["id"] if "id" in as_info else as_info["as_token"],
|
||||
)
|
||||
|
||||
def _populate_appservice_cache(self, config_files):
|
||||
|
@ -256,15 +199,38 @@ class ApplicationServiceStore(SQLBaseStore):
|
|||
)
|
||||
return
|
||||
|
||||
# Dicts of value -> filename
|
||||
seen_as_tokens = {}
|
||||
seen_ids = {}
|
||||
|
||||
for config_file in config_files:
|
||||
try:
|
||||
with open(config_file, 'r') as f:
|
||||
appservice = self._load_appservice(yaml.load(f))
|
||||
if appservice.id in seen_ids:
|
||||
raise ConfigError(
|
||||
"Cannot reuse ID across application services: "
|
||||
"%s (files: %s, %s)" % (
|
||||
appservice.id, config_file, seen_ids[appservice.id],
|
||||
)
|
||||
)
|
||||
seen_ids[appservice.id] = config_file
|
||||
if appservice.token in seen_as_tokens:
|
||||
raise ConfigError(
|
||||
"Cannot reuse as_token across application services: "
|
||||
"%s (files: %s, %s)" % (
|
||||
appservice.token,
|
||||
config_file,
|
||||
seen_as_tokens[appservice.token],
|
||||
)
|
||||
)
|
||||
seen_as_tokens[appservice.token] = config_file
|
||||
logger.info("Loaded application service: %s", appservice)
|
||||
self.services_cache.append(appservice)
|
||||
except Exception as e:
|
||||
logger.error("Failed to load appservice from '%s'", config_file)
|
||||
logger.exception(e)
|
||||
raise
|
||||
|
||||
|
||||
class ApplicationServiceTransactionStore(SQLBaseStore):
|
||||
|
|
|
@ -17,7 +17,7 @@ from ._base import SQLBaseStore
|
|||
from twisted.internet import defer
|
||||
|
||||
import logging
|
||||
import simplejson as json
|
||||
import ujson as json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -84,7 +84,8 @@ class EventPushActionsStore(SQLBaseStore):
|
|||
)
|
||||
)
|
||||
return [
|
||||
{"event_id": row[0], "actions": row[1]} for row in txn.fetchall()
|
||||
{"event_id": row[0], "actions": json.loads(row[1])}
|
||||
for row in txn.fetchall()
|
||||
]
|
||||
|
||||
ret = yield self.runInteraction(
|
||||
|
|
|
@ -25,11 +25,11 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class PushRuleStore(SQLBaseStore):
|
||||
@cachedInlineCallbacks()
|
||||
def get_push_rules_for_user(self, user_name):
|
||||
def get_push_rules_for_user(self, user_id):
|
||||
rows = yield self._simple_select_list(
|
||||
table="push_rules",
|
||||
keyvalues={
|
||||
"user_name": user_name,
|
||||
"user_name": user_id,
|
||||
},
|
||||
retcols=(
|
||||
"user_name", "rule_id", "priority_class", "priority",
|
||||
|
@ -45,11 +45,11 @@ class PushRuleStore(SQLBaseStore):
|
|||
defer.returnValue(rows)
|
||||
|
||||
@cachedInlineCallbacks()
|
||||
def get_push_rules_enabled_for_user(self, user_name):
|
||||
def get_push_rules_enabled_for_user(self, user_id):
|
||||
results = yield self._simple_select_list(
|
||||
table="push_rules_enable",
|
||||
keyvalues={
|
||||
'user_name': user_name
|
||||
'user_name': user_id
|
||||
},
|
||||
retcols=(
|
||||
"user_name", "rule_id", "enabled",
|
||||
|
@ -122,7 +122,7 @@ class PushRuleStore(SQLBaseStore):
|
|||
)
|
||||
defer.returnValue(ret)
|
||||
|
||||
def _add_push_rule_relative_txn(self, txn, user_name, **kwargs):
|
||||
def _add_push_rule_relative_txn(self, txn, user_id, **kwargs):
|
||||
after = kwargs.pop("after", None)
|
||||
relative_to_rule = kwargs.pop("before", after)
|
||||
|
||||
|
@ -130,7 +130,7 @@ class PushRuleStore(SQLBaseStore):
|
|||
txn,
|
||||
table="push_rules",
|
||||
keyvalues={
|
||||
"user_name": user_name,
|
||||
"user_name": user_id,
|
||||
"rule_id": relative_to_rule,
|
||||
},
|
||||
retcols=["priority_class", "priority"],
|
||||
|
@ -154,7 +154,7 @@ class PushRuleStore(SQLBaseStore):
|
|||
new_rule.pop("before", None)
|
||||
new_rule.pop("after", None)
|
||||
new_rule['priority_class'] = priority_class
|
||||
new_rule['user_name'] = user_name
|
||||
new_rule['user_name'] = user_id
|
||||
new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn)
|
||||
|
||||
# check if the priority before/after is free
|
||||
|
@ -170,7 +170,7 @@ class PushRuleStore(SQLBaseStore):
|
|||
"SELECT COUNT(*) FROM push_rules"
|
||||
" WHERE user_name = ? AND priority_class = ? AND priority = ?"
|
||||
)
|
||||
txn.execute(sql, (user_name, priority_class, new_rule_priority))
|
||||
txn.execute(sql, (user_id, priority_class, new_rule_priority))
|
||||
res = txn.fetchall()
|
||||
num_conflicting = res[0][0]
|
||||
|
||||
|
@ -187,14 +187,14 @@ class PushRuleStore(SQLBaseStore):
|
|||
else:
|
||||
sql += ">= ?"
|
||||
|
||||
txn.execute(sql, (user_name, priority_class, new_rule_priority))
|
||||
txn.execute(sql, (user_id, priority_class, new_rule_priority))
|
||||
|
||||
txn.call_after(
|
||||
self.get_push_rules_for_user.invalidate, (user_name,)
|
||||
self.get_push_rules_for_user.invalidate, (user_id,)
|
||||
)
|
||||
|
||||
txn.call_after(
|
||||
self.get_push_rules_enabled_for_user.invalidate, (user_name,)
|
||||
self.get_push_rules_enabled_for_user.invalidate, (user_id,)
|
||||
)
|
||||
|
||||
self._simple_insert_txn(
|
||||
|
@ -203,14 +203,14 @@ class PushRuleStore(SQLBaseStore):
|
|||
values=new_rule,
|
||||
)
|
||||
|
||||
def _add_push_rule_highest_priority_txn(self, txn, user_name,
|
||||
def _add_push_rule_highest_priority_txn(self, txn, user_id,
|
||||
priority_class, **kwargs):
|
||||
# find the highest priority rule in that class
|
||||
sql = (
|
||||
"SELECT COUNT(*), MAX(priority) FROM push_rules"
|
||||
" WHERE user_name = ? and priority_class = ?"
|
||||
)
|
||||
txn.execute(sql, (user_name, priority_class))
|
||||
txn.execute(sql, (user_id, priority_class))
|
||||
res = txn.fetchall()
|
||||
(how_many, highest_prio) = res[0]
|
||||
|
||||
|
@ -221,15 +221,15 @@ class PushRuleStore(SQLBaseStore):
|
|||
# and insert the new rule
|
||||
new_rule = kwargs
|
||||
new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn)
|
||||
new_rule['user_name'] = user_name
|
||||
new_rule['user_name'] = user_id
|
||||
new_rule['priority_class'] = priority_class
|
||||
new_rule['priority'] = new_prio
|
||||
|
||||
txn.call_after(
|
||||
self.get_push_rules_for_user.invalidate, (user_name,)
|
||||
self.get_push_rules_for_user.invalidate, (user_id,)
|
||||
)
|
||||
txn.call_after(
|
||||
self.get_push_rules_enabled_for_user.invalidate, (user_name,)
|
||||
self.get_push_rules_enabled_for_user.invalidate, (user_id,)
|
||||
)
|
||||
|
||||
self._simple_insert_txn(
|
||||
|
@ -239,48 +239,48 @@ class PushRuleStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_push_rule(self, user_name, rule_id):
|
||||
def delete_push_rule(self, user_id, rule_id):
|
||||
"""
|
||||
Delete a push rule. Args specify the row to be deleted and can be
|
||||
any of the columns in the push_rule table, but below are the
|
||||
standard ones
|
||||
|
||||
Args:
|
||||
user_name (str): The matrix ID of the push rule owner
|
||||
user_id (str): The matrix ID of the push rule owner
|
||||
rule_id (str): The rule_id of the rule to be deleted
|
||||
"""
|
||||
yield self._simple_delete_one(
|
||||
"push_rules",
|
||||
{'user_name': user_name, 'rule_id': rule_id},
|
||||
{'user_name': user_id, 'rule_id': rule_id},
|
||||
desc="delete_push_rule",
|
||||
)
|
||||
|
||||
self.get_push_rules_for_user.invalidate((user_name,))
|
||||
self.get_push_rules_enabled_for_user.invalidate((user_name,))
|
||||
self.get_push_rules_for_user.invalidate((user_id,))
|
||||
self.get_push_rules_enabled_for_user.invalidate((user_id,))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_push_rule_enabled(self, user_name, rule_id, enabled):
|
||||
def set_push_rule_enabled(self, user_id, rule_id, enabled):
|
||||
ret = yield self.runInteraction(
|
||||
"_set_push_rule_enabled_txn",
|
||||
self._set_push_rule_enabled_txn,
|
||||
user_name, rule_id, enabled
|
||||
user_id, rule_id, enabled
|
||||
)
|
||||
defer.returnValue(ret)
|
||||
|
||||
def _set_push_rule_enabled_txn(self, txn, user_name, rule_id, enabled):
|
||||
def _set_push_rule_enabled_txn(self, txn, user_id, rule_id, enabled):
|
||||
new_id = self._push_rules_enable_id_gen.get_next_txn(txn)
|
||||
self._simple_upsert_txn(
|
||||
txn,
|
||||
"push_rules_enable",
|
||||
{'user_name': user_name, 'rule_id': rule_id},
|
||||
{'user_name': user_id, 'rule_id': rule_id},
|
||||
{'enabled': 1 if enabled else 0},
|
||||
{'id': new_id},
|
||||
)
|
||||
txn.call_after(
|
||||
self.get_push_rules_for_user.invalidate, (user_name,)
|
||||
self.get_push_rules_for_user.invalidate, (user_id,)
|
||||
)
|
||||
txn.call_after(
|
||||
self.get_push_rules_enabled_for_user.invalidate, (user_name,)
|
||||
self.get_push_rules_enabled_for_user.invalidate, (user_id,)
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -80,7 +80,7 @@ class PusherStore(SQLBaseStore):
|
|||
defer.returnValue(rows)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_pusher(self, user_name, access_token, profile_tag, kind, app_id,
|
||||
def add_pusher(self, user_id, access_token, profile_tag, kind, app_id,
|
||||
app_display_name, device_display_name,
|
||||
pushkey, pushkey_ts, lang, data):
|
||||
try:
|
||||
|
@ -90,7 +90,7 @@ class PusherStore(SQLBaseStore):
|
|||
dict(
|
||||
app_id=app_id,
|
||||
pushkey=pushkey,
|
||||
user_name=user_name,
|
||||
user_name=user_id,
|
||||
),
|
||||
dict(
|
||||
access_token=access_token,
|
||||
|
@ -112,38 +112,38 @@ class PusherStore(SQLBaseStore):
|
|||
raise StoreError(500, "Problem creating pusher.")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_pusher_by_app_id_pushkey_user_name(self, app_id, pushkey, user_name):
|
||||
def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
|
||||
yield self._simple_delete_one(
|
||||
"pushers",
|
||||
{"app_id": app_id, "pushkey": pushkey, 'user_name': user_name},
|
||||
desc="delete_pusher_by_app_id_pushkey_user_name",
|
||||
{"app_id": app_id, "pushkey": pushkey, 'user_name': user_id},
|
||||
desc="delete_pusher_by_app_id_pushkey_user_id",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_pusher_last_token(self, app_id, pushkey, user_name, last_token):
|
||||
def update_pusher_last_token(self, app_id, pushkey, user_id, last_token):
|
||||
yield self._simple_update_one(
|
||||
"pushers",
|
||||
{'app_id': app_id, 'pushkey': pushkey, 'user_name': user_name},
|
||||
{'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
|
||||
{'last_token': last_token},
|
||||
desc="update_pusher_last_token",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_pusher_last_token_and_success(self, app_id, pushkey, user_name,
|
||||
def update_pusher_last_token_and_success(self, app_id, pushkey, user_id,
|
||||
last_token, last_success):
|
||||
yield self._simple_update_one(
|
||||
"pushers",
|
||||
{'app_id': app_id, 'pushkey': pushkey, 'user_name': user_name},
|
||||
{'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
|
||||
{'last_token': last_token, 'last_success': last_success},
|
||||
desc="update_pusher_last_token_and_success",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_pusher_failing_since(self, app_id, pushkey, user_name,
|
||||
def update_pusher_failing_since(self, app_id, pushkey, user_id,
|
||||
failing_since):
|
||||
yield self._simple_update_one(
|
||||
"pushers",
|
||||
{'app_id': app_id, 'pushkey': pushkey, 'user_name': user_name},
|
||||
{'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
|
||||
{'failing_since': failing_since},
|
||||
desc="update_pusher_failing_since",
|
||||
)
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
|
||||
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached
|
||||
from synapse.util.caches import cache_counter, caches_by_name
|
||||
|
||||
from twisted.internet import defer
|
||||
|
@ -33,6 +33,18 @@ class ReceiptsStore(SQLBaseStore):
|
|||
|
||||
self._receipts_stream_cache = _RoomStreamChangeCache()
|
||||
|
||||
@cached(num_args=2)
|
||||
def get_receipts_for_room(self, room_id, receipt_type):
|
||||
return self._simple_select_list(
|
||||
table="receipts_linearized",
|
||||
keyvalues={
|
||||
"room_id": room_id,
|
||||
"receipt_type": receipt_type,
|
||||
},
|
||||
retcols=("user_id", "event_id"),
|
||||
desc="get_receipts_for_room",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
|
||||
"""Get receipts for multiple rooms for sending to clients.
|
||||
|
|
|
@ -18,7 +18,7 @@ from twisted.internet import defer
|
|||
from synapse.api.errors import StoreError, Codes
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
|
||||
|
||||
|
||||
class RegistrationStore(SQLBaseStore):
|
||||
|
@ -256,10 +256,10 @@ class RegistrationStore(SQLBaseStore):
|
|||
defer.returnValue(res if res else False)
|
||||
|
||||
@cachedInlineCallbacks()
|
||||
def is_guest(self, user):
|
||||
def is_guest(self, user_id):
|
||||
res = yield self._simple_select_one_onecol(
|
||||
table="users",
|
||||
keyvalues={"name": user.to_string()},
|
||||
keyvalues={"name": user_id},
|
||||
retcol="is_guest",
|
||||
allow_none=True,
|
||||
desc="is_guest",
|
||||
|
@ -267,6 +267,26 @@ class RegistrationStore(SQLBaseStore):
|
|||
|
||||
defer.returnValue(res if res else False)
|
||||
|
||||
@cachedList(cache=is_guest.cache, list_name="user_ids", num_args=1,
|
||||
inlineCallbacks=True)
|
||||
def are_guests(self, user_ids):
|
||||
sql = "SELECT name, is_guest FROM users WHERE name IN (%s)" % (
|
||||
",".join("?" for _ in user_ids),
|
||||
)
|
||||
|
||||
rows = yield self._execute(
|
||||
"are_guests", self.cursor_to_dict, sql, *user_ids
|
||||
)
|
||||
|
||||
result = {user_id: False for user_id in user_ids}
|
||||
|
||||
result.update({
|
||||
row["name"]: bool(row["is_guest"])
|
||||
for row in rows
|
||||
})
|
||||
|
||||
defer.returnValue(result)
|
||||
|
||||
def _query_for_auth(self, txn, token):
|
||||
sql = (
|
||||
"SELECT users.name, users.is_guest, access_tokens.id as token_id"
|
||||
|
|
|
@ -287,6 +287,7 @@ class RoomMemberStore(SQLBaseStore):
|
|||
txn.execute(sql, (user_id, room_id))
|
||||
yield self.runInteraction("forget_membership", f)
|
||||
self.was_forgotten_at.invalidate_all()
|
||||
self.who_forgot_in_room.invalidate_all()
|
||||
self.did_forget.invalidate((user_id, room_id))
|
||||
|
||||
@cachedInlineCallbacks(num_args=2)
|
||||
|
@ -336,3 +337,15 @@ class RoomMemberStore(SQLBaseStore):
|
|||
return rows[0][0]
|
||||
forgot = yield self.runInteraction("did_forget_membership_at", f)
|
||||
defer.returnValue(forgot == 1)
|
||||
|
||||
@cached()
|
||||
def who_forgot_in_room(self, room_id):
|
||||
return self._simple_select_list(
|
||||
table="room_memberships",
|
||||
retcols=("user_id", "event_id"),
|
||||
keyvalues={
|
||||
"room_id": room_id,
|
||||
"forgotten": 1,
|
||||
},
|
||||
desc="who_forgot"
|
||||
)
|
||||
|
|
|
@ -22,6 +22,9 @@ import logging
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MAX_LIMIT = 1000
|
||||
|
||||
|
||||
class SourcePaginationConfig(object):
|
||||
|
||||
"""A configuration object which stores pagination parameters for a
|
||||
|
@ -32,7 +35,7 @@ class SourcePaginationConfig(object):
|
|||
self.from_key = from_key
|
||||
self.to_key = to_key
|
||||
self.direction = 'f' if direction == 'f' else 'b'
|
||||
self.limit = int(limit) if limit is not None else None
|
||||
self.limit = min(int(limit), MAX_LIMIT) if limit is not None else None
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
|
@ -49,7 +52,7 @@ class PaginationConfig(object):
|
|||
self.from_token = from_token
|
||||
self.to_token = to_token
|
||||
self.direction = 'f' if direction == 'f' else 'b'
|
||||
self.limit = int(limit) if limit is not None else None
|
||||
self.limit = min(int(limit), MAX_LIMIT) if limit is not None else None
|
||||
|
||||
@classmethod
|
||||
def from_request(cls, request, raise_invalid_params=True,
|
||||
|
|
|
@ -29,6 +29,7 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
|||
|
||||
def setUp(self):
|
||||
self.service = ApplicationService(
|
||||
id="unique_identifier",
|
||||
url="some_url",
|
||||
token="some_token",
|
||||
namespaces={
|
||||
|
|
|
@ -1,141 +0,0 @@
|
|||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from twisted.internet import defer
|
||||
from tests import unittest
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.events import FrozenEvent
|
||||
from synapse.handlers.federation import FederationHandler
|
||||
|
||||
from mock import NonCallableMock, ANY, Mock
|
||||
|
||||
from ..utils import setup_test_homeserver
|
||||
|
||||
|
||||
class FederationTestCase(unittest.TestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
|
||||
self.state_handler = NonCallableMock(spec_set=[
|
||||
"compute_event_context",
|
||||
])
|
||||
|
||||
self.auth = NonCallableMock(spec_set=[
|
||||
"check",
|
||||
"check_host_in_room",
|
||||
])
|
||||
|
||||
self.hostname = "test"
|
||||
hs = yield setup_test_homeserver(
|
||||
self.hostname,
|
||||
datastore=NonCallableMock(spec_set=[
|
||||
"persist_event",
|
||||
"store_room",
|
||||
"get_room",
|
||||
"get_destination_retry_timings",
|
||||
"set_destination_retry_timings",
|
||||
"have_events",
|
||||
"get_users_in_room",
|
||||
"bulk_get_push_rules",
|
||||
"get_current_state",
|
||||
"set_push_actions_for_event_and_users",
|
||||
"is_guest",
|
||||
"get_state_for_events",
|
||||
]),
|
||||
resource_for_federation=NonCallableMock(),
|
||||
http_client=NonCallableMock(spec_set=[]),
|
||||
notifier=NonCallableMock(spec_set=["on_new_room_event"]),
|
||||
handlers=NonCallableMock(spec_set=[
|
||||
"room_member_handler",
|
||||
"federation_handler",
|
||||
]),
|
||||
auth=self.auth,
|
||||
state_handler=self.state_handler,
|
||||
keyring=Mock(),
|
||||
)
|
||||
|
||||
self.datastore = hs.get_datastore()
|
||||
self.handlers = hs.get_handlers()
|
||||
self.notifier = hs.get_notifier()
|
||||
self.hs = hs
|
||||
|
||||
self.handlers.federation_handler = FederationHandler(self.hs)
|
||||
|
||||
self.datastore.get_state_for_events.return_value = {"$a:b": {}}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_msg(self):
|
||||
pdu = FrozenEvent({
|
||||
"type": EventTypes.Message,
|
||||
"room_id": "foo",
|
||||
"content": {"msgtype": u"fooo"},
|
||||
"origin_server_ts": 0,
|
||||
"event_id": "$a:b",
|
||||
"user_id":"@a:b",
|
||||
"origin": "b",
|
||||
"auth_events": [],
|
||||
"hashes": {"sha256":"AcLrgtUIqqwaGoHhrEvYG1YLDIsVPYJdSRGhkp3jJp8"},
|
||||
})
|
||||
|
||||
self.datastore.persist_event.return_value = defer.succeed((1,1))
|
||||
self.datastore.get_room.return_value = defer.succeed(True)
|
||||
self.datastore.get_users_in_room.return_value = ["@a:b"]
|
||||
self.datastore.bulk_get_push_rules.return_value = {}
|
||||
self.datastore.get_current_state.return_value = {}
|
||||
self.auth.check_host_in_room.return_value = defer.succeed(True)
|
||||
|
||||
retry_timings_res = {
|
||||
"destination": "",
|
||||
"retry_last_ts": 0,
|
||||
"retry_interval": 0,
|
||||
}
|
||||
self.datastore.get_destination_retry_timings.return_value = (
|
||||
defer.succeed(retry_timings_res)
|
||||
)
|
||||
|
||||
def have_events(event_ids):
|
||||
return defer.succeed({})
|
||||
self.datastore.have_events.side_effect = have_events
|
||||
|
||||
def annotate(ev, old_state=None, outlier=False):
|
||||
context = Mock()
|
||||
context.current_state = {}
|
||||
context.auth_events = {}
|
||||
return defer.succeed(context)
|
||||
self.state_handler.compute_event_context.side_effect = annotate
|
||||
|
||||
yield self.handlers.federation_handler.on_receive_pdu(
|
||||
"fo", pdu, False
|
||||
)
|
||||
|
||||
self.datastore.persist_event.assert_called_once_with(
|
||||
ANY,
|
||||
is_new_state=True,
|
||||
backfilled=False,
|
||||
current_state=None,
|
||||
context=ANY,
|
||||
)
|
||||
|
||||
self.state_handler.compute_event_context.assert_called_once_with(
|
||||
ANY, old_state=None, outlier=False
|
||||
)
|
||||
|
||||
self.auth.check.assert_called_once_with(ANY, auth_events={})
|
||||
|
||||
self.notifier.on_new_room_event.assert_called_once_with(
|
||||
ANY, 1, 1, extra_users=[]
|
||||
)
|
|
@ -1,418 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from twisted.internet import defer
|
||||
from .. import unittest
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.handlers.room import RoomMemberHandler, RoomCreationHandler
|
||||
from synapse.handlers.profile import ProfileHandler
|
||||
from synapse.types import UserID
|
||||
from ..utils import setup_test_homeserver
|
||||
|
||||
from mock import Mock, NonCallableMock
|
||||
|
||||
|
||||
class RoomMemberHandlerTestCase(unittest.TestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
self.hostname = "red"
|
||||
hs = yield setup_test_homeserver(
|
||||
self.hostname,
|
||||
ratelimiter=NonCallableMock(spec_set=[
|
||||
"send_message",
|
||||
]),
|
||||
datastore=NonCallableMock(spec_set=[
|
||||
"persist_event",
|
||||
"get_room_member",
|
||||
"get_room",
|
||||
"store_room",
|
||||
"get_latest_events_in_room",
|
||||
"add_event_hashes",
|
||||
"get_users_in_room",
|
||||
"bulk_get_push_rules",
|
||||
"get_current_state",
|
||||
"set_push_actions_for_event_and_users",
|
||||
"get_state_for_events",
|
||||
"is_guest",
|
||||
]),
|
||||
resource_for_federation=NonCallableMock(),
|
||||
http_client=NonCallableMock(spec_set=[]),
|
||||
notifier=NonCallableMock(spec_set=["on_new_room_event"]),
|
||||
handlers=NonCallableMock(spec_set=[
|
||||
"room_member_handler",
|
||||
"profile_handler",
|
||||
"federation_handler",
|
||||
]),
|
||||
auth=NonCallableMock(spec_set=[
|
||||
"check",
|
||||
"add_auth_events",
|
||||
"check_host_in_room",
|
||||
]),
|
||||
state_handler=NonCallableMock(spec_set=[
|
||||
"compute_event_context",
|
||||
"get_current_state",
|
||||
]),
|
||||
)
|
||||
|
||||
self.federation = NonCallableMock(spec_set=[
|
||||
"handle_new_event",
|
||||
"send_invite",
|
||||
"get_state_for_room",
|
||||
])
|
||||
|
||||
self.datastore = hs.get_datastore()
|
||||
self.handlers = hs.get_handlers()
|
||||
self.notifier = hs.get_notifier()
|
||||
self.state_handler = hs.get_state_handler()
|
||||
self.distributor = hs.get_distributor()
|
||||
self.auth = hs.get_auth()
|
||||
self.hs = hs
|
||||
|
||||
self.handlers.federation_handler = self.federation
|
||||
|
||||
self.distributor.declare("collect_presencelike_data")
|
||||
|
||||
self.handlers.room_member_handler = RoomMemberHandler(self.hs)
|
||||
self.handlers.profile_handler = ProfileHandler(self.hs)
|
||||
self.room_member_handler = self.handlers.room_member_handler
|
||||
|
||||
self.ratelimiter = hs.get_ratelimiter()
|
||||
self.ratelimiter.send_message.return_value = (True, 0)
|
||||
|
||||
self.datastore.persist_event.return_value = (1,1)
|
||||
self.datastore.add_event_hashes.return_value = []
|
||||
self.datastore.get_users_in_room.return_value = ["@bob:red"]
|
||||
self.datastore.bulk_get_push_rules.return_value = {}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_invite(self):
|
||||
room_id = "!foo:red"
|
||||
user_id = "@bob:red"
|
||||
target_user_id = "@red:blue"
|
||||
content = {"membership": Membership.INVITE}
|
||||
|
||||
builder = self.hs.get_event_builder_factory().new({
|
||||
"type": EventTypes.Member,
|
||||
"sender": user_id,
|
||||
"state_key": target_user_id,
|
||||
"room_id": room_id,
|
||||
"content": content,
|
||||
})
|
||||
|
||||
self.datastore.get_latest_events_in_room.return_value = (
|
||||
defer.succeed([])
|
||||
)
|
||||
self.datastore.get_current_state.return_value = {}
|
||||
self.datastore.get_state_for_events = lambda event_ids,types: {x: {} for x in event_ids}
|
||||
|
||||
def annotate(_):
|
||||
ctx = Mock()
|
||||
ctx.current_state = {
|
||||
(EventTypes.Member, "@alice:green"): self._create_member(
|
||||
user_id="@alice:green",
|
||||
room_id=room_id,
|
||||
),
|
||||
(EventTypes.Member, "@bob:red"): self._create_member(
|
||||
user_id="@bob:red",
|
||||
room_id=room_id,
|
||||
),
|
||||
}
|
||||
ctx.prev_state_events = []
|
||||
|
||||
return defer.succeed(ctx)
|
||||
|
||||
self.state_handler.compute_event_context.side_effect = annotate
|
||||
|
||||
def add_auth(_, ctx):
|
||||
ctx.auth_events = ctx.current_state[
|
||||
(EventTypes.Member, "@bob:red")
|
||||
]
|
||||
|
||||
return defer.succeed(True)
|
||||
self.auth.add_auth_events.side_effect = add_auth
|
||||
|
||||
def send_invite(domain, event):
|
||||
return defer.succeed(event)
|
||||
|
||||
self.federation.send_invite.side_effect = send_invite
|
||||
|
||||
room_handler = self.room_member_handler
|
||||
event, context = yield room_handler._create_new_client_event(
|
||||
builder
|
||||
)
|
||||
|
||||
yield room_handler.change_membership(event, context)
|
||||
|
||||
self.state_handler.compute_event_context.assert_called_once_with(
|
||||
builder
|
||||
)
|
||||
|
||||
self.auth.add_auth_events.assert_called_once_with(
|
||||
builder, context
|
||||
)
|
||||
|
||||
self.federation.send_invite.assert_called_once_with(
|
||||
"blue", event,
|
||||
)
|
||||
|
||||
self.datastore.persist_event.assert_called_once_with(
|
||||
event, context=context,
|
||||
)
|
||||
self.notifier.on_new_room_event.assert_called_once_with(
|
||||
event, 1, 1, extra_users=[UserID.from_string(target_user_id)]
|
||||
)
|
||||
self.assertFalse(self.datastore.get_room.called)
|
||||
self.assertFalse(self.datastore.store_room.called)
|
||||
self.assertFalse(self.federation.get_state_for_room.called)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_simple_join(self):
|
||||
room_id = "!foo:red"
|
||||
user_id = "@bob:red"
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
join_signal_observer = Mock()
|
||||
self.distributor.observe("user_joined_room", join_signal_observer)
|
||||
|
||||
builder = self.hs.get_event_builder_factory().new({
|
||||
"type": EventTypes.Member,
|
||||
"sender": user_id,
|
||||
"state_key": user_id,
|
||||
"room_id": room_id,
|
||||
"content": {"membership": Membership.JOIN},
|
||||
})
|
||||
|
||||
self.datastore.get_latest_events_in_room.return_value = (
|
||||
defer.succeed([])
|
||||
)
|
||||
self.datastore.get_current_state.return_value = {}
|
||||
self.datastore.get_state_for_events = lambda event_ids,types: {x: {} for x in event_ids}
|
||||
|
||||
def annotate(_):
|
||||
ctx = Mock()
|
||||
ctx.current_state = {
|
||||
(EventTypes.Member, "@bob:red"): self._create_member(
|
||||
user_id="@bob:red",
|
||||
room_id=room_id,
|
||||
membership=Membership.INVITE
|
||||
),
|
||||
}
|
||||
ctx.prev_state_events = []
|
||||
|
||||
return defer.succeed(ctx)
|
||||
|
||||
self.state_handler.compute_event_context.side_effect = annotate
|
||||
|
||||
def add_auth(_, ctx):
|
||||
ctx.auth_events = ctx.current_state[
|
||||
(EventTypes.Member, "@bob:red")
|
||||
]
|
||||
|
||||
return defer.succeed(True)
|
||||
self.auth.add_auth_events.side_effect = add_auth
|
||||
|
||||
room_handler = self.room_member_handler
|
||||
event, context = yield room_handler._create_new_client_event(
|
||||
builder
|
||||
)
|
||||
|
||||
# Actual invocation
|
||||
yield room_handler.change_membership(event, context)
|
||||
|
||||
self.federation.handle_new_event.assert_called_once_with(
|
||||
event, destinations=set()
|
||||
)
|
||||
|
||||
self.datastore.persist_event.assert_called_once_with(
|
||||
event, context=context
|
||||
)
|
||||
self.notifier.on_new_room_event.assert_called_once_with(
|
||||
event, 1, 1, extra_users=[user]
|
||||
)
|
||||
|
||||
join_signal_observer.assert_called_with(
|
||||
user=user, room_id=room_id
|
||||
)
|
||||
|
||||
def _create_member(self, user_id, room_id, membership=Membership.JOIN):
|
||||
builder = self.hs.get_event_builder_factory().new({
|
||||
"type": EventTypes.Member,
|
||||
"sender": user_id,
|
||||
"state_key": user_id,
|
||||
"room_id": room_id,
|
||||
"content": {"membership": membership},
|
||||
})
|
||||
|
||||
return builder.build()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_simple_leave(self):
|
||||
room_id = "!foo:red"
|
||||
user_id = "@bob:red"
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
builder = self.hs.get_event_builder_factory().new({
|
||||
"type": EventTypes.Member,
|
||||
"sender": user_id,
|
||||
"state_key": user_id,
|
||||
"room_id": room_id,
|
||||
"content": {"membership": Membership.LEAVE},
|
||||
})
|
||||
|
||||
self.datastore.get_latest_events_in_room.return_value = (
|
||||
defer.succeed([])
|
||||
)
|
||||
self.datastore.get_current_state.return_value = {}
|
||||
self.datastore.get_state_for_events = lambda event_ids,types: {x: {} for x in event_ids}
|
||||
|
||||
def annotate(_):
|
||||
ctx = Mock()
|
||||
ctx.current_state = {
|
||||
(EventTypes.Member, "@bob:red"): self._create_member(
|
||||
user_id="@bob:red",
|
||||
room_id=room_id,
|
||||
membership=Membership.JOIN
|
||||
),
|
||||
}
|
||||
ctx.prev_state_events = []
|
||||
|
||||
return defer.succeed(ctx)
|
||||
|
||||
self.state_handler.compute_event_context.side_effect = annotate
|
||||
|
||||
def add_auth(_, ctx):
|
||||
ctx.auth_events = ctx.current_state[
|
||||
(EventTypes.Member, "@bob:red")
|
||||
]
|
||||
|
||||
return defer.succeed(True)
|
||||
self.auth.add_auth_events.side_effect = add_auth
|
||||
|
||||
room_handler = self.room_member_handler
|
||||
event, context = yield room_handler._create_new_client_event(
|
||||
builder
|
||||
)
|
||||
|
||||
leave_signal_observer = Mock()
|
||||
self.distributor.observe("user_left_room", leave_signal_observer)
|
||||
|
||||
# Actual invocation
|
||||
yield room_handler.change_membership(event, context)
|
||||
|
||||
self.federation.handle_new_event.assert_called_once_with(
|
||||
event, destinations=set(['red'])
|
||||
)
|
||||
|
||||
self.datastore.persist_event.assert_called_once_with(
|
||||
event, context=context
|
||||
)
|
||||
self.notifier.on_new_room_event.assert_called_once_with(
|
||||
event, 1, 1, extra_users=[user]
|
||||
)
|
||||
|
||||
leave_signal_observer.assert_called_with(
|
||||
user=user, room_id=room_id
|
||||
)
|
||||
|
||||
|
||||
class RoomCreationTest(unittest.TestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
self.hostname = "red"
|
||||
|
||||
hs = yield setup_test_homeserver(
|
||||
self.hostname,
|
||||
datastore=NonCallableMock(spec_set=[
|
||||
"store_room",
|
||||
"snapshot_room",
|
||||
"persist_event",
|
||||
"get_joined_hosts_for_room",
|
||||
]),
|
||||
http_client=NonCallableMock(spec_set=[]),
|
||||
notifier=NonCallableMock(spec_set=["on_new_room_event"]),
|
||||
handlers=NonCallableMock(spec_set=[
|
||||
"room_creation_handler",
|
||||
"message_handler",
|
||||
]),
|
||||
auth=NonCallableMock(spec_set=["check", "add_auth_events"]),
|
||||
ratelimiter=NonCallableMock(spec_set=[
|
||||
"send_message",
|
||||
]),
|
||||
)
|
||||
|
||||
self.federation = NonCallableMock(spec_set=[
|
||||
"handle_new_event",
|
||||
])
|
||||
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
self.handlers.room_creation_handler = RoomCreationHandler(hs)
|
||||
self.room_creation_handler = self.handlers.room_creation_handler
|
||||
|
||||
self.message_handler = self.handlers.message_handler
|
||||
|
||||
self.ratelimiter = hs.get_ratelimiter()
|
||||
self.ratelimiter.send_message.return_value = (True, 0)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_room_creation(self):
|
||||
user_id = "@foo:red"
|
||||
room_id = "!bobs_room:red"
|
||||
config = {"visibility": "private"}
|
||||
|
||||
yield self.room_creation_handler.create_room(
|
||||
user_id=user_id,
|
||||
room_id=room_id,
|
||||
config=config,
|
||||
)
|
||||
|
||||
self.assertTrue(self.message_handler.create_and_send_event.called)
|
||||
|
||||
event_dicts = [
|
||||
e[0][0]
|
||||
for e in self.message_handler.create_and_send_event.call_args_list
|
||||
]
|
||||
|
||||
self.assertTrue(len(event_dicts) > 3)
|
||||
|
||||
self.assertDictContainsSubset(
|
||||
{
|
||||
"type": EventTypes.Create,
|
||||
"sender": user_id,
|
||||
"room_id": room_id,
|
||||
},
|
||||
event_dicts[0]
|
||||
)
|
||||
|
||||
self.assertEqual(user_id, event_dicts[0]["content"]["creator"])
|
||||
|
||||
self.assertDictContainsSubset(
|
||||
{
|
||||
"type": EventTypes.Member,
|
||||
"sender": user_id,
|
||||
"room_id": room_id,
|
||||
"state_key": user_id,
|
||||
},
|
||||
event_dicts[1]
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
Membership.JOIN,
|
||||
event_dicts[1]["content"]["membership"]
|
||||
)
|
|
@ -12,12 +12,13 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import tempfile
|
||||
from synapse.config._base import ConfigError
|
||||
from tests import unittest
|
||||
from twisted.internet import defer
|
||||
|
||||
from tests.utils import setup_test_homeserver
|
||||
from synapse.appservice import ApplicationService, ApplicationServiceState
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.appservice import (
|
||||
ApplicationServiceStore, ApplicationServiceTransactionStore
|
||||
)
|
||||
|
@ -26,7 +27,6 @@ import json
|
|||
import os
|
||||
import yaml
|
||||
from mock import Mock
|
||||
from tests.utils import SQLiteMemoryDbPool, MockClock
|
||||
|
||||
|
||||
class ApplicationServiceStoreTestCase(unittest.TestCase):
|
||||
|
@ -41,9 +41,16 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
|
|||
|
||||
self.as_token = "token1"
|
||||
self.as_url = "some_url"
|
||||
self._add_appservice(self.as_token, self.as_url, "some_hs_token", "bob")
|
||||
self._add_appservice("token2", "some_url", "some_hs_token", "bob")
|
||||
self._add_appservice("token3", "some_url", "some_hs_token", "bob")
|
||||
self.as_id = "as1"
|
||||
self._add_appservice(
|
||||
self.as_token,
|
||||
self.as_id,
|
||||
self.as_url,
|
||||
"some_hs_token",
|
||||
"bob"
|
||||
)
|
||||
self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
|
||||
self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
|
||||
# must be done after inserts
|
||||
self.store = ApplicationServiceStore(hs)
|
||||
|
||||
|
@ -55,9 +62,9 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
|
|||
except:
|
||||
pass
|
||||
|
||||
def _add_appservice(self, as_token, url, hs_token, sender):
|
||||
def _add_appservice(self, as_token, id, url, hs_token, sender):
|
||||
as_yaml = dict(url=url, as_token=as_token, hs_token=hs_token,
|
||||
sender_localpart=sender, namespaces={})
|
||||
id=id, sender_localpart=sender, namespaces={})
|
||||
# use the token as the filename
|
||||
with open(as_token, 'w') as outfile:
|
||||
outfile.write(yaml.dump(as_yaml))
|
||||
|
@ -74,6 +81,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
|
|||
self.as_token
|
||||
)
|
||||
self.assertEquals(stored_service.token, self.as_token)
|
||||
self.assertEquals(stored_service.id, self.as_id)
|
||||
self.assertEquals(stored_service.url, self.as_url)
|
||||
self.assertEquals(
|
||||
stored_service.namespaces[ApplicationService.NS_ALIASES],
|
||||
|
@ -110,34 +118,34 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
|
|||
{
|
||||
"token": "token1",
|
||||
"url": "https://matrix-as.org",
|
||||
"id": "token1"
|
||||
"id": "id_1"
|
||||
},
|
||||
{
|
||||
"token": "alpha_tok",
|
||||
"url": "https://alpha.com",
|
||||
"id": "alpha_tok"
|
||||
"id": "id_alpha"
|
||||
},
|
||||
{
|
||||
"token": "beta_tok",
|
||||
"url": "https://beta.com",
|
||||
"id": "beta_tok"
|
||||
"id": "id_beta"
|
||||
},
|
||||
{
|
||||
"token": "delta_tok",
|
||||
"url": "https://delta.com",
|
||||
"id": "delta_tok"
|
||||
"token": "gamma_tok",
|
||||
"url": "https://gamma.com",
|
||||
"id": "id_gamma"
|
||||
},
|
||||
]
|
||||
for s in self.as_list:
|
||||
yield self._add_service(s["url"], s["token"])
|
||||
yield self._add_service(s["url"], s["token"], s["id"])
|
||||
|
||||
self.as_yaml_files = []
|
||||
|
||||
self.store = TestTransactionStore(hs)
|
||||
|
||||
def _add_service(self, url, as_token):
|
||||
def _add_service(self, url, as_token, id):
|
||||
as_yaml = dict(url=url, as_token=as_token, hs_token="something",
|
||||
sender_localpart="a_sender", namespaces={})
|
||||
id=id, sender_localpart="a_sender", namespaces={})
|
||||
# use the token as the filename
|
||||
with open(as_token, 'w') as outfile:
|
||||
outfile.write(yaml.dump(as_yaml))
|
||||
|
@ -405,3 +413,64 @@ class TestTransactionStore(ApplicationServiceTransactionStore,
|
|||
|
||||
def __init__(self, hs):
|
||||
super(TestTransactionStore, self).__init__(hs)
|
||||
|
||||
|
||||
class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
|
||||
|
||||
def _write_config(self, suffix, **kwargs):
|
||||
vals = {
|
||||
"id": "id" + suffix,
|
||||
"url": "url" + suffix,
|
||||
"as_token": "as_token" + suffix,
|
||||
"hs_token": "hs_token" + suffix,
|
||||
"sender_localpart": "sender_localpart" + suffix,
|
||||
"namespaces": {},
|
||||
}
|
||||
vals.update(kwargs)
|
||||
|
||||
_, path = tempfile.mkstemp(prefix="as_config")
|
||||
with open(path, "w") as f:
|
||||
f.write(yaml.dump(vals))
|
||||
return path
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_unique_works(self):
|
||||
f1 = self._write_config(suffix="1")
|
||||
f2 = self._write_config(suffix="2")
|
||||
|
||||
config = Mock(app_service_config_files=[f1, f2])
|
||||
hs = yield setup_test_homeserver(config=config)
|
||||
|
||||
ApplicationServiceStore(hs)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_duplicate_ids(self):
|
||||
f1 = self._write_config(id="id", suffix="1")
|
||||
f2 = self._write_config(id="id", suffix="2")
|
||||
|
||||
config = Mock(app_service_config_files=[f1, f2])
|
||||
hs = yield setup_test_homeserver(config=config)
|
||||
|
||||
with self.assertRaises(ConfigError) as cm:
|
||||
ApplicationServiceStore(hs)
|
||||
|
||||
e = cm.exception
|
||||
self.assertIn(f1, e.message)
|
||||
self.assertIn(f2, e.message)
|
||||
self.assertIn("id", e.message)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_duplicate_as_tokens(self):
|
||||
f1 = self._write_config(as_token="as_token", suffix="1")
|
||||
f2 = self._write_config(as_token="as_token", suffix="2")
|
||||
|
||||
config = Mock(app_service_config_files=[f1, f2])
|
||||
hs = yield setup_test_homeserver(config=config)
|
||||
|
||||
with self.assertRaises(ConfigError) as cm:
|
||||
ApplicationServiceStore(hs)
|
||||
|
||||
e = cm.exception
|
||||
self.assertIn(f1, e.message)
|
||||
self.assertIn(f2, e.message)
|
||||
self.assertIn("as_token", e.message)
|
||||
|
|
Loading…
Reference in a new issue