0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-19 08:24:25 +01:00

Merge pull request #502 from matrix-org/erikj/push_notif_perf

Unread notification performance.
This commit is contained in:
Erik Johnston 2016-01-19 11:27:27 +00:00
commit 47e7963e50
13 changed files with 549 additions and 953 deletions

View file

@ -117,6 +117,15 @@ class EventBase(object):
def __set__(self, instance, value): def __set__(self, instance, value):
raise AttributeError("Unrecognized attribute %s" % (instance,)) raise AttributeError("Unrecognized attribute %s" % (instance,))
def __getitem__(self, field):
return self._event_dict[field]
def __contains__(self, field):
return field in self._event_dict
def items(self):
return self._event_dict.items()
class FrozenEvent(EventBase): class FrozenEvent(EventBase):
def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None): def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None):

View file

@ -53,16 +53,54 @@ class BaseHandler(object):
self.event_builder_factory = hs.get_event_builder_factory() self.event_builder_factory = hs.get_event_builder_factory()
@defer.inlineCallbacks @defer.inlineCallbacks
def _filter_events_for_client(self, user_id, events, is_guest=False): def _filter_events_for_clients(self, users, events):
# Assumes that user has at some point joined the room if not is_guest. """ Returns dict of user_id -> list of events that user is allowed to
see.
"""
event_id_to_state = yield self.store.get_state_for_events(
frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, None),
)
)
forgotten = yield defer.gatherResults([
self.store.who_forgot_in_room(
room_id,
)
for room_id in frozenset(e.room_id for e in events)
], consumeErrors=True)
# Set of membership event_ids that have been forgotten
event_id_forgotten = frozenset(
row["event_id"] for rows in forgotten for row in rows
)
def allowed(event, user_id, is_guest):
state = event_id_to_state[event.event_id]
visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
if visibility_event:
visibility = visibility_event.content.get("history_visibility", "shared")
else:
visibility = "shared"
def allowed(event, membership, visibility):
if visibility == "world_readable": if visibility == "world_readable":
return True return True
if is_guest: if is_guest:
return False return False
membership_event = state.get((EventTypes.Member, user_id), None)
if membership_event:
if membership_event.event_id in event_id_forgotten:
membership = None
else:
membership = membership_event.membership
else:
membership = None
if membership == Membership.JOIN: if membership == Membership.JOIN:
return True return True
@ -78,43 +116,20 @@ class BaseHandler(object):
return True return True
event_id_to_state = yield self.store.get_state_for_events( defer.returnValue({
frozenset(e.event_id for e in events), user_id: [
types=( event
(EventTypes.RoomHistoryVisibility, ""), for event in events
(EventTypes.Member, user_id), if allowed(event, user_id, is_guest)
) ]
) for user_id, is_guest in users
})
events_to_return = [] @defer.inlineCallbacks
for event in events: def _filter_events_for_client(self, user_id, events, is_guest=False):
state = event_id_to_state[event.event_id] # Assumes that user has at some point joined the room if not is_guest.
res = yield self._filter_events_for_clients([(user_id, is_guest)], events)
membership_event = state.get((EventTypes.Member, user_id), None) defer.returnValue(res.get(user_id, []))
if membership_event:
was_forgotten_at_event = yield self.store.was_forgotten_at(
membership_event.state_key,
membership_event.room_id,
membership_event.event_id
)
if was_forgotten_at_event:
membership = None
else:
membership = membership_event.membership
else:
membership = None
visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
if visibility_event:
visibility = visibility_event.content.get("history_visibility", "shared")
else:
visibility = "shared"
should_include = allowed(event, membership, visibility)
if should_include:
events_to_return.append(event)
defer.returnValue(events_to_return)
def ratelimit(self, user_id): def ratelimit(self, user_id):
time_now = self.clock.time() time_now = self.clock.time()

View file

@ -36,7 +36,7 @@ from synapse.events.utils import prune_event
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
# from synapse.push.action_generator import ActionGenerator from synapse.push.action_generator import ActionGenerator
from twisted.internet import defer from twisted.internet import defer
@ -244,12 +244,11 @@ class FederationHandler(BaseHandler):
user = UserID.from_string(event.state_key) user = UserID.from_string(event.state_key)
yield user_joined_room(self.distributor, user, event.room_id) yield user_joined_room(self.distributor, user, event.room_id)
# Temporarily disable notifications due to performance concerns. if not backfilled and not event.internal_metadata.is_outlier():
# if not backfilled and not event.internal_metadata.is_outlier(): action_generator = ActionGenerator(self.store)
# action_generator = ActionGenerator(self.store) yield action_generator.handle_push_actions_for_event(
# yield action_generator.handle_push_actions_for_event( event, self
# event, self )
# )
@defer.inlineCallbacks @defer.inlineCallbacks
def _filter_events_for_server(self, server_name, room_id, events): def _filter_events_for_server(self, server_name, room_id, events):

View file

@ -841,9 +841,6 @@ class SyncHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def unread_notifs_for_room_id(self, room_id, sync_config, ephemeral_by_room): def unread_notifs_for_room_id(self, room_id, sync_config, ephemeral_by_room):
# Temporarily disable notifications due to performance concerns.
defer.returnValue([])
last_unread_event_id = self.last_read_event_id_for_room_and_user( last_unread_event_id = self.last_read_event_id_for_room_and_user(
room_id, sync_config.user.to_string(), ephemeral_by_room room_id, sync_config.user.to_string(), ephemeral_by_room
) )

View file

@ -36,9 +36,6 @@ class ActionGenerator:
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_push_actions_for_event(self, event, handler): def handle_push_actions_for_event(self, event, handler):
# Temporarily disable notifications due to performance concerns.
return
if event.type == EventTypes.Redaction and event.redacts is not None: if event.type == EventTypes.Redaction and event.redacts is not None:
yield self.store.remove_push_actions_for_event_id( yield self.store.remove_push_actions_for_event_id(
event.room_id, event.redacts event.room_id, event.redacts

View file

@ -15,27 +15,25 @@
from synapse.push.rulekinds import PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP from synapse.push.rulekinds import PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
def list_with_base_rules(rawrules, user_id): def list_with_base_rules(rawrules):
ruleslist = [] ruleslist = []
# shove the server default rules for each kind onto the end of each # shove the server default rules for each kind onto the end of each
current_prio_class = PRIORITY_CLASS_INVERSE_MAP.keys()[-1] current_prio_class = PRIORITY_CLASS_INVERSE_MAP.keys()[-1]
ruleslist.extend(make_base_prepend_rules( ruleslist.extend(make_base_prepend_rules(
user_id, PRIORITY_CLASS_INVERSE_MAP[current_prio_class] PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
)) ))
for r in rawrules: for r in rawrules:
if r['priority_class'] < current_prio_class: if r['priority_class'] < current_prio_class:
while r['priority_class'] < current_prio_class: while r['priority_class'] < current_prio_class:
ruleslist.extend(make_base_append_rules( ruleslist.extend(make_base_append_rules(
user_id,
PRIORITY_CLASS_INVERSE_MAP[current_prio_class] PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
)) ))
current_prio_class -= 1 current_prio_class -= 1
if current_prio_class > 0: if current_prio_class > 0:
ruleslist.extend(make_base_prepend_rules( ruleslist.extend(make_base_prepend_rules(
user_id,
PRIORITY_CLASS_INVERSE_MAP[current_prio_class] PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
)) ))
@ -43,223 +41,233 @@ def list_with_base_rules(rawrules, user_id):
while current_prio_class > 0: while current_prio_class > 0:
ruleslist.extend(make_base_append_rules( ruleslist.extend(make_base_append_rules(
user_id,
PRIORITY_CLASS_INVERSE_MAP[current_prio_class] PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
)) ))
current_prio_class -= 1 current_prio_class -= 1
if current_prio_class > 0: if current_prio_class > 0:
ruleslist.extend(make_base_prepend_rules( ruleslist.extend(make_base_prepend_rules(
user_id,
PRIORITY_CLASS_INVERSE_MAP[current_prio_class] PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
)) ))
return ruleslist return ruleslist
def make_base_append_rules(user, kind): def make_base_append_rules(kind):
rules = [] rules = []
if kind == 'override': if kind == 'override':
rules = make_base_append_override_rules() rules = BASE_APPEND_OVRRIDE_RULES
elif kind == 'underride': elif kind == 'underride':
rules = make_base_append_underride_rules(user) rules = BASE_APPEND_UNDERRIDE_RULES
elif kind == 'content': elif kind == 'content':
rules = make_base_append_content_rules(user) rules = BASE_APPEND_CONTENT_RULES
for r in rules:
r['priority_class'] = PRIORITY_CLASS_MAP[kind]
r['default'] = True # Deprecated, left for backwards compat
return rules return rules
def make_base_prepend_rules(user, kind): def make_base_prepend_rules(kind):
rules = [] rules = []
if kind == 'override': if kind == 'override':
rules = make_base_prepend_override_rules() rules = BASE_PREPEND_OVERRIDE_RULES
for r in rules:
r['priority_class'] = PRIORITY_CLASS_MAP[kind]
r['default'] = True # Deprecated, left for backwards compat
return rules return rules
def make_base_append_content_rules(user): BASE_APPEND_CONTENT_RULES = [
return [ {
{ 'rule_id': 'global/content/.m.rule.contains_user_name',
'rule_id': 'global/content/.m.rule.contains_user_name', 'conditions': [
'conditions': [ {
{ 'kind': 'event_match',
'kind': 'event_match', 'key': 'content.body',
'key': 'content.body', 'pattern_type': 'user_localpart'
'pattern': user.localpart, # Matrix ID match }
} ],
], 'actions': [
'actions': [ 'notify',
'notify', {
{ 'set_tweak': 'sound',
'set_tweak': 'sound', 'value': 'default',
'value': 'default', }, {
}, { 'set_tweak': 'highlight'
'set_tweak': 'highlight' }
} ]
] },
}, ]
]
def make_base_prepend_override_rules(): BASE_PREPEND_OVERRIDE_RULES = [
return [ {
{ 'rule_id': 'global/override/.m.rule.master',
'rule_id': 'global/override/.m.rule.master', 'enabled': False,
'enabled': False, 'conditions': [],
'conditions': [], 'actions': [
'actions': [ "dont_notify"
"dont_notify" ]
] }
} ]
]
def make_base_append_override_rules(): BASE_APPEND_OVRRIDE_RULES = [
return [ {
{ 'rule_id': 'global/override/.m.rule.suppress_notices',
'rule_id': 'global/override/.m.rule.suppress_notices', 'conditions': [
'conditions': [ {
{ 'kind': 'event_match',
'kind': 'event_match', 'key': 'content.msgtype',
'key': 'content.msgtype', 'pattern': 'm.notice',
'pattern': 'm.notice', '_id': '_suppress_notices',
} }
], ],
'actions': [ 'actions': [
'dont_notify', 'dont_notify',
] ]
} }
] ]
def make_base_append_underride_rules(user): BASE_APPEND_UNDERRIDE_RULES = [
return [ {
{ 'rule_id': 'global/underride/.m.rule.call',
'rule_id': 'global/underride/.m.rule.call', 'conditions': [
'conditions': [ {
{ 'kind': 'event_match',
'kind': 'event_match', 'key': 'type',
'key': 'type', 'pattern': 'm.call.invite',
'pattern': 'm.call.invite', '_id': '_call',
} }
], ],
'actions': [ 'actions': [
'notify', 'notify',
{ {
'set_tweak': 'sound', 'set_tweak': 'sound',
'value': 'ring' 'value': 'ring'
}, { }, {
'set_tweak': 'highlight', 'set_tweak': 'highlight',
'value': False 'value': False
} }
] ]
}, },
{ {
'rule_id': 'global/underride/.m.rule.contains_display_name', 'rule_id': 'global/underride/.m.rule.contains_display_name',
'conditions': [ 'conditions': [
{ {
'kind': 'contains_display_name' 'kind': 'contains_display_name'
} }
], ],
'actions': [ 'actions': [
'notify', 'notify',
{ {
'set_tweak': 'sound', 'set_tweak': 'sound',
'value': 'default' 'value': 'default'
}, { }, {
'set_tweak': 'highlight' 'set_tweak': 'highlight'
} }
] ]
}, },
{ {
'rule_id': 'global/underride/.m.rule.room_one_to_one', 'rule_id': 'global/underride/.m.rule.room_one_to_one',
'conditions': [ 'conditions': [
{ {
'kind': 'room_member_count', 'kind': 'room_member_count',
'is': '2' 'is': '2',
} '_id': 'member_count',
], }
'actions': [ ],
'notify', 'actions': [
{ 'notify',
'set_tweak': 'sound', {
'value': 'default' 'set_tweak': 'sound',
}, { 'value': 'default'
'set_tweak': 'highlight', }, {
'value': False 'set_tweak': 'highlight',
} 'value': False
] }
}, ]
{ },
'rule_id': 'global/underride/.m.rule.invite_for_me', {
'conditions': [ 'rule_id': 'global/underride/.m.rule.invite_for_me',
{ 'conditions': [
'kind': 'event_match', {
'key': 'type', 'kind': 'event_match',
'pattern': 'm.room.member', 'key': 'type',
}, 'pattern': 'm.room.member',
{ '_id': '_member',
'kind': 'event_match', },
'key': 'content.membership', {
'pattern': 'invite', 'kind': 'event_match',
}, 'key': 'content.membership',
{ 'pattern': 'invite',
'kind': 'event_match', '_id': '_invite_member',
'key': 'state_key', },
'pattern': user.to_string(), {
}, 'kind': 'event_match',
], 'key': 'state_key',
'actions': [ 'pattern_type': 'user_id'
'notify', },
{ ],
'set_tweak': 'sound', 'actions': [
'value': 'default' 'notify',
}, { {
'set_tweak': 'highlight', 'set_tweak': 'sound',
'value': False 'value': 'default'
} }, {
] 'set_tweak': 'highlight',
}, 'value': False
{ }
'rule_id': 'global/underride/.m.rule.member_event', ]
'conditions': [ },
{ {
'kind': 'event_match', 'rule_id': 'global/underride/.m.rule.member_event',
'key': 'type', 'conditions': [
'pattern': 'm.room.member', {
} 'kind': 'event_match',
], 'key': 'type',
'actions': [ 'pattern': 'm.room.member',
'notify', { '_id': '_member',
'set_tweak': 'highlight', }
'value': False ],
} 'actions': [
] 'notify', {
}, 'set_tweak': 'highlight',
{ 'value': False
'rule_id': 'global/underride/.m.rule.message', }
'enabled': False, ]
'conditions': [ },
{ {
'kind': 'event_match', 'rule_id': 'global/underride/.m.rule.message',
'key': 'type', 'enabled': False,
'pattern': 'm.room.message', 'conditions': [
} {
], 'kind': 'event_match',
'actions': [ 'key': 'type',
'notify', { 'pattern': 'm.room.message',
'set_tweak': 'highlight', '_id': '_message',
'value': False }
} ],
] 'actions': [
} 'notify', {
] 'set_tweak': 'highlight',
'value': False
}
]
}
]
for r in BASE_APPEND_CONTENT_RULES:
r['priority_class'] = PRIORITY_CLASS_MAP['content']
r['default'] = True
for r in BASE_PREPEND_OVERRIDE_RULES:
r['priority_class'] = PRIORITY_CLASS_MAP['override']
r['default'] = True
for r in BASE_APPEND_OVRRIDE_RULES:
r['priority_class'] = PRIORITY_CLASS_MAP['override']
r['default'] = True
for r in BASE_APPEND_UNDERRIDE_RULES:
r['priority_class'] = PRIORITY_CLASS_MAP['underride']
r['default'] = True

View file

@ -14,16 +14,15 @@
# limitations under the License. # limitations under the License.
import logging import logging
import simplejson as json import ujson as json
from twisted.internet import defer from twisted.internet import defer
from synapse.types import UserID
import baserules import baserules
from push_rule_evaluator import PushRuleEvaluator from push_rule_evaluator import PushRuleEvaluatorForEvent
from synapse.api.constants import EventTypes
from synapse.events.utils import serialize_event
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -34,29 +33,26 @@ def decode_rule_json(rule):
return rule return rule
@defer.inlineCallbacks
def _get_rules(room_id, user_ids, store):
rules_by_user = yield store.bulk_get_push_rules(user_ids)
rules_by_user = {
uid: baserules.list_with_base_rules([
decode_rule_json(rule_list)
for rule_list in rules_by_user.get(uid, [])
])
for uid in user_ids
}
defer.returnValue(rules_by_user)
@defer.inlineCallbacks @defer.inlineCallbacks
def evaluator_for_room_id(room_id, store): def evaluator_for_room_id(room_id, store):
users = yield store.get_users_in_room(room_id) users = yield store.get_users_in_room(room_id)
rules_by_user = yield store.bulk_get_push_rules(users) rules_by_user = yield _get_rules(room_id, users, store)
rules_by_user = {
uid: baserules.list_with_base_rules(
[decode_rule_json(rule_list) for rule_list in rules_by_user[uid]]
if uid in rules_by_user else [],
UserID.from_string(uid),
)
for uid in users
}
member_events = yield store.get_current_state(
room_id=room_id,
event_type='m.room.member',
)
display_names = {}
for ev in member_events:
if ev.content.get("displayname"):
display_names[ev.state_key] = ev.content.get("displayname")
defer.returnValue(BulkPushRuleEvaluator( defer.returnValue(BulkPushRuleEvaluator(
room_id, rules_by_user, display_names, users, store room_id, rules_by_user, users, store
)) ))
@ -69,10 +65,9 @@ class BulkPushRuleEvaluator:
the same logic to run the actual rules, but could be optimised further the same logic to run the actual rules, but could be optimised further
(see https://matrix.org/jira/browse/SYN-562) (see https://matrix.org/jira/browse/SYN-562)
""" """
def __init__(self, room_id, rules_by_user, display_names, users_in_room, store): def __init__(self, room_id, rules_by_user, users_in_room, store):
self.room_id = room_id self.room_id = room_id
self.rules_by_user = rules_by_user self.rules_by_user = rules_by_user
self.display_names = display_names
self.users_in_room = users_in_room self.users_in_room = users_in_room
self.store = store self.store = store
@ -80,15 +75,30 @@ class BulkPushRuleEvaluator:
def action_for_event_by_user(self, event, handler): def action_for_event_by_user(self, event, handler):
actions_by_user = {} actions_by_user = {}
for uid, rules in self.rules_by_user.items(): users_dict = yield self.store.are_guests(self.rules_by_user.keys())
display_name = None
if uid in self.display_names:
display_name = self.display_names[uid]
is_guest = yield self.store.is_guest(UserID.from_string(uid)) filtered_by_user = yield handler._filter_events_for_clients(
filtered = yield handler._filter_events_for_client( users_dict.items(), [event]
uid, [event], is_guest=is_guest )
)
evaluator = PushRuleEvaluatorForEvent(event, len(self.users_in_room))
condition_cache = {}
member_state = yield self.store.get_state_for_event(
event.event_id,
)
display_names = {}
for ev in member_state.values():
nm = ev.content.get("displayname", None)
if nm and ev.type == EventTypes.Member:
display_names[ev.state_key] = nm
for uid, rules in self.rules_by_user.items():
display_name = display_names.get(uid, None)
filtered = filtered_by_user[uid]
if len(filtered) == 0: if len(filtered) == 0:
continue continue
@ -96,29 +106,32 @@ class BulkPushRuleEvaluator:
if 'enabled' in rule and not rule['enabled']: if 'enabled' in rule and not rule['enabled']:
continue continue
# XXX: profile tags matches = _condition_checker(
if BulkPushRuleEvaluator.event_matches_rule( evaluator, rule['conditions'], uid, display_name, condition_cache
event, rule, )
display_name, len(self.users_in_room), None if matches:
):
actions = [x for x in rule['actions'] if x != 'dont_notify'] actions = [x for x in rule['actions'] if x != 'dont_notify']
if len(actions) > 0: if actions:
actions_by_user[uid] = actions actions_by_user[uid] = actions
break break
defer.returnValue(actions_by_user) defer.returnValue(actions_by_user)
@staticmethod
def event_matches_rule(event, rule,
display_name, room_member_count, profile_tag):
matches = True
# passing the clock all the way into here is extremely awkward and push def _condition_checker(evaluator, conditions, uid, display_name, cache):
# rules do not care about any of the relative timestamps, so we just for cond in conditions:
# pass 0 for the current time. _id = cond.get("_id", None)
client_event = serialize_event(event, 0) if _id:
res = cache.get(_id, None)
if res is False:
break
elif res is True:
continue
for cond in rule['conditions']: res = evaluator.matches(cond, uid, display_name, None)
matches &= PushRuleEvaluator._event_fulfills_condition( if _id:
client_event, cond, display_name, room_member_count, profile_tag cache[_id] = res
)
return matches if res is False:
return False
return True

View file

@ -15,17 +15,22 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.types import UserID
import baserules import baserules
import logging import logging
import simplejson as json import simplejson as json
import re import re
from synapse.types import UserID
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
GLOB_REGEX = re.compile(r'\\\[(\\\!|)(.*)\\\]')
IS_GLOB = re.compile(r'[\?\*\[\]]')
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
@defer.inlineCallbacks @defer.inlineCallbacks
def evaluator_for_user_id_and_profile_tag(user_id, profile_tag, room_id, store): def evaluator_for_user_id_and_profile_tag(user_id, profile_tag, room_id, store):
rawrules = yield store.get_push_rules_for_user(user_id) rawrules = yield store.get_push_rules_for_user(user_id)
@ -42,9 +47,34 @@ def evaluator_for_user_id_and_profile_tag(user_id, profile_tag, room_id, store):
)) ))
def _room_member_count(ev, condition, room_member_count):
if 'is' not in condition:
return False
m = INEQUALITY_EXPR.match(condition['is'])
if not m:
return False
ineq = m.group(1)
rhs = m.group(2)
if not rhs.isdigit():
return False
rhs = int(rhs)
if ineq == '' or ineq == '==':
return room_member_count == rhs
elif ineq == '<':
return room_member_count < rhs
elif ineq == '>':
return room_member_count > rhs
elif ineq == '>=':
return room_member_count >= rhs
elif ineq == '<=':
return room_member_count <= rhs
else:
return False
class PushRuleEvaluator: class PushRuleEvaluator:
DEFAULT_ACTIONS = [] DEFAULT_ACTIONS = []
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
def __init__(self, user_id, profile_tag, raw_rules, enabled_map, room_id, def __init__(self, user_id, profile_tag, raw_rules, enabled_map, room_id,
our_member_event, store): our_member_event, store):
@ -61,8 +91,7 @@ class PushRuleEvaluator:
rule['actions'] = json.loads(raw_rule['actions']) rule['actions'] = json.loads(raw_rule['actions'])
rules.append(rule) rules.append(rule)
user = UserID.from_string(self.user_id) self.rules = baserules.list_with_base_rules(rules)
self.rules = baserules.list_with_base_rules(rules, user)
self.enabled_map = enabled_map self.enabled_map = enabled_map
@ -98,28 +127,19 @@ class PushRuleEvaluator:
room_members = yield self.store.get_users_in_room(room_id) room_members = yield self.store.get_users_in_room(room_id)
room_member_count = len(room_members) room_member_count = len(room_members)
evaluator = PushRuleEvaluatorForEvent(ev, room_member_count)
for r in self.rules: for r in self.rules:
if r['rule_id'] in self.enabled_map: enabled = self.enabled_map.get(r['rule_id'], None)
r['enabled'] = self.enabled_map[r['rule_id']] if enabled is not None and not enabled:
elif 'enabled' not in r: continue
r['enabled'] = True
if not r['enabled']: if not r.get("enabled", True):
continue continue
matches = True
conditions = r['conditions'] conditions = r['conditions']
actions = r['actions'] actions = r['actions']
for c in conditions:
matches &= self._event_fulfills_condition(
ev, c, display_name=my_display_name,
room_member_count=room_member_count,
profile_tag=self.profile_tag
)
logger.debug(
"Rule %s %s",
r['rule_id'], "matches" if matches else "doesn't match"
)
# ignore rules with no actions (we have an explict 'dont_notify') # ignore rules with no actions (we have an explict 'dont_notify')
if len(actions) == 0: if len(actions) == 0:
logger.warn( logger.warn(
@ -127,8 +147,22 @@ class PushRuleEvaluator:
r['rule_id'], self.user_id r['rule_id'], self.user_id
) )
continue continue
matches = True
for c in conditions:
matches = evaluator.matches(
c, self.user_id, my_display_name, self.profile_tag
)
if not matches:
break
logger.debug(
"Rule %s %s",
r['rule_id'], "matches" if matches else "doesn't match"
)
if matches: if matches:
logger.info( logger.debug(
"%s matches for user %s, event %s", "%s matches for user %s, event %s",
r['rule_id'], self.user_id, ev['event_id'] r['rule_id'], self.user_id, ev['event_id']
) )
@ -139,94 +173,132 @@ class PushRuleEvaluator:
defer.returnValue(actions) defer.returnValue(actions)
logger.info( logger.debug(
"No rules match for user %s, event %s", "No rules match for user %s, event %s",
self.user_id, ev['event_id'] self.user_id, ev['event_id']
) )
defer.returnValue(PushRuleEvaluator.DEFAULT_ACTIONS) defer.returnValue(PushRuleEvaluator.DEFAULT_ACTIONS)
@staticmethod
def _glob_to_regexp(glob):
r = re.escape(glob)
r = re.sub(r'\\\*', r'.*?', r)
r = re.sub(r'\\\?', r'.', r)
# handle [abc], [a-z] and [!a-z] style ranges. class PushRuleEvaluatorForEvent(object):
r = re.sub(r'\\\[(\\\!|)(.*)\\\]', def __init__(self, event, room_member_count):
lambda x: ('[%s%s]' % (x.group(1) and '^' or '', self._event = event
re.sub(r'\\\-', '-', x.group(2)))), r) self._room_member_count = room_member_count
return r
@staticmethod # Maps strings of e.g. 'content.body' -> event["content"]["body"]
def _event_fulfills_condition(ev, condition, self._value_cache = _flatten_dict(event)
display_name, room_member_count, profile_tag):
def matches(self, condition, user_id, display_name, profile_tag):
if condition['kind'] == 'event_match': if condition['kind'] == 'event_match':
if 'pattern' not in condition: return self._event_match(condition, user_id)
logger.warn("event_match condition with no pattern")
return False
# XXX: optimisation: cache our pattern regexps
if condition['key'] == 'content.body':
r = r'\b%s\b' % PushRuleEvaluator._glob_to_regexp(condition['pattern'])
else:
r = r'^%s$' % PushRuleEvaluator._glob_to_regexp(condition['pattern'])
val = _value_for_dotted_key(condition['key'], ev)
if val is None:
return False
return re.search(r, val, flags=re.IGNORECASE) is not None
elif condition['kind'] == 'device': elif condition['kind'] == 'device':
if 'profile_tag' not in condition: if 'profile_tag' not in condition:
return True return True
return condition['profile_tag'] == profile_tag return condition['profile_tag'] == profile_tag
elif condition['kind'] == 'contains_display_name': elif condition['kind'] == 'contains_display_name':
# This is special because display names can be different return self._contains_display_name(display_name)
# between rooms and so you can't really hard code it in a rule.
# Optimisation: we should cache these names and update them from
# the event stream.
if 'content' not in ev or 'body' not in ev['content']:
return False
if not display_name:
return False
return re.search(
r"\b%s\b" % re.escape(display_name), ev['content']['body'],
flags=re.IGNORECASE
) is not None
elif condition['kind'] == 'room_member_count': elif condition['kind'] == 'room_member_count':
if 'is' not in condition: return _room_member_count(
return False self._event, condition, self._room_member_count
m = PushRuleEvaluator.INEQUALITY_EXPR.match(condition['is']) )
if not m:
return False
ineq = m.group(1)
rhs = m.group(2)
if not rhs.isdigit():
return False
rhs = int(rhs)
if ineq == '' or ineq == '==':
return room_member_count == rhs
elif ineq == '<':
return room_member_count < rhs
elif ineq == '>':
return room_member_count > rhs
elif ineq == '>=':
return room_member_count >= rhs
elif ineq == '<=':
return room_member_count <= rhs
else:
return False
else: else:
return True return True
def _event_match(self, condition, user_id):
pattern = condition.get('pattern', None)
def _value_for_dotted_key(dotted_key, event): if not pattern:
parts = dotted_key.split(".") pattern_type = condition.get('pattern_type', None)
val = event if pattern_type == "user_id":
while len(parts) > 0: pattern = user_id
if parts[0] not in val: elif pattern_type == "user_localpart":
return None pattern = UserID.from_string(user_id).localpart
val = val[parts[0]]
parts = parts[1:] if not pattern:
return val logger.warn("event_match condition with no pattern")
return False
# XXX: optimisation: cache our pattern regexps
if condition['key'] == 'content.body':
body = self._event["content"].get("body", None)
if not body:
return False
return _glob_matches(pattern, body, word_boundary=True)
else:
haystack = self._get_value(condition['key'])
if haystack is None:
return False
return _glob_matches(pattern, haystack)
def _contains_display_name(self, display_name):
if not display_name:
return False
body = self._event["content"].get("body", None)
if not body:
return False
return _glob_matches(display_name, body, word_boundary=True)
def _get_value(self, dotted_key):
return self._value_cache.get(dotted_key, None)
def _glob_matches(glob, value, word_boundary=False):
"""Tests if value matches glob.
Args:
glob (string)
value (string): String to test against glob.
word_boundary (bool): Whether to match against word boundaries or entire
string. Defaults to False.
Returns:
bool
"""
if IS_GLOB.search(glob):
r = re.escape(glob)
r = r.replace(r'\*', '.*?')
r = r.replace(r'\?', '.')
# handle [abc], [a-z] and [!a-z] style ranges.
r = GLOB_REGEX.sub(
lambda x: (
'[%s%s]' % (
x.group(1) and '^' or '',
x.group(2).replace(r'\\\-', '-')
)
),
r,
)
if word_boundary:
r = r"\b%s\b" % (r,)
r = re.compile(r, flags=re.IGNORECASE)
return r.search(value)
else:
r = r + "$"
r = re.compile(r, flags=re.IGNORECASE)
return r.match(value)
elif word_boundary:
r = re.escape(glob)
r = r"\b%s\b" % (r,)
r = re.compile(r, flags=re.IGNORECASE)
return r.search(value)
else:
return value.lower() == glob.lower()
def _flatten_dict(d, prefix=[], result={}):
for key, value in d.items():
if isinstance(value, basestring):
result[".".join(prefix + [key])] = value.lower()
elif hasattr(value, "items"):
_flatten_dict(value, prefix=(prefix+[key]), result=result)
return result

View file

@ -27,6 +27,7 @@ from synapse.push.rulekinds import (
PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
) )
import copy
import simplejson as json import simplejson as json
@ -126,7 +127,8 @@ class PushRuleRestServlet(ClientV1RestServlet):
rule["actions"] = json.loads(rawrule["actions"]) rule["actions"] = json.loads(rawrule["actions"])
ruleslist.append(rule) ruleslist.append(rule)
ruleslist = baserules.list_with_base_rules(ruleslist, user) # We're going to be mutating this a lot, so do a deep copy
ruleslist = copy.deepcopy(baserules.list_with_base_rules(ruleslist))
rules = {'global': {}, 'device': {}} rules = {'global': {}, 'device': {}}
@ -140,6 +142,16 @@ class PushRuleRestServlet(ClientV1RestServlet):
template_name = _priority_class_to_template_name(r['priority_class']) template_name = _priority_class_to_template_name(r['priority_class'])
# Remove internal stuff.
for c in r["conditions"]:
c.pop("_id", None)
pattern_type = c.pop("pattern_type", None)
if pattern_type == "user_id":
c["pattern"] = user.to_string()
elif pattern_type == "user_localpart":
c["pattern"] = user.localpart
if r['priority_class'] > PRIORITY_CLASS_MAP['override']: if r['priority_class'] > PRIORITY_CLASS_MAP['override']:
# per-device rule # per-device rule
profile_tag = _profile_tag_from_conditions(r["conditions"]) profile_tag = _profile_tag_from_conditions(r["conditions"])

View file

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import StoreError, Codes from synapse.api.errors import StoreError, Codes
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
class RegistrationStore(SQLBaseStore): class RegistrationStore(SQLBaseStore):
@ -256,10 +256,10 @@ class RegistrationStore(SQLBaseStore):
defer.returnValue(res if res else False) defer.returnValue(res if res else False)
@cachedInlineCallbacks() @cachedInlineCallbacks()
def is_guest(self, user): def is_guest(self, user_id):
res = yield self._simple_select_one_onecol( res = yield self._simple_select_one_onecol(
table="users", table="users",
keyvalues={"name": user.to_string()}, keyvalues={"name": user_id},
retcol="is_guest", retcol="is_guest",
allow_none=True, allow_none=True,
desc="is_guest", desc="is_guest",
@ -267,6 +267,26 @@ class RegistrationStore(SQLBaseStore):
defer.returnValue(res if res else False) defer.returnValue(res if res else False)
@cachedList(cache=is_guest.cache, list_name="user_ids", num_args=1,
inlineCallbacks=True)
def are_guests(self, user_ids):
sql = "SELECT name, is_guest FROM users WHERE name IN (%s)" % (
",".join("?" for _ in user_ids),
)
rows = yield self._execute(
"are_guests", self.cursor_to_dict, sql, *user_ids
)
result = {user_id: False for user_id in user_ids}
result.update({
row["name"]: bool(row["is_guest"])
for row in rows
})
defer.returnValue(result)
def _query_for_auth(self, txn, token): def _query_for_auth(self, txn, token):
sql = ( sql = (
"SELECT users.name, users.is_guest, access_tokens.id as token_id" "SELECT users.name, users.is_guest, access_tokens.id as token_id"

View file

@ -287,6 +287,7 @@ class RoomMemberStore(SQLBaseStore):
txn.execute(sql, (user_id, room_id)) txn.execute(sql, (user_id, room_id))
yield self.runInteraction("forget_membership", f) yield self.runInteraction("forget_membership", f)
self.was_forgotten_at.invalidate_all() self.was_forgotten_at.invalidate_all()
self.who_forgot_in_room.invalidate_all()
self.did_forget.invalidate((user_id, room_id)) self.did_forget.invalidate((user_id, room_id))
@cachedInlineCallbacks(num_args=2) @cachedInlineCallbacks(num_args=2)
@ -336,3 +337,15 @@ class RoomMemberStore(SQLBaseStore):
return rows[0][0] return rows[0][0]
forgot = yield self.runInteraction("did_forget_membership_at", f) forgot = yield self.runInteraction("did_forget_membership_at", f)
defer.returnValue(forgot == 1) defer.returnValue(forgot == 1)
@cached()
def who_forgot_in_room(self, room_id):
return self._simple_select_list(
table="room_memberships",
retcols=("user_id", "event_id"),
keyvalues={
"room_id": room_id,
"forgotten": 1,
},
desc="who_forgot"
)

View file

@ -1,141 +0,0 @@
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from tests import unittest
from synapse.api.constants import EventTypes
from synapse.events import FrozenEvent
from synapse.handlers.federation import FederationHandler
from mock import NonCallableMock, ANY, Mock
from ..utils import setup_test_homeserver
class FederationTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.state_handler = NonCallableMock(spec_set=[
"compute_event_context",
])
self.auth = NonCallableMock(spec_set=[
"check",
"check_host_in_room",
])
self.hostname = "test"
hs = yield setup_test_homeserver(
self.hostname,
datastore=NonCallableMock(spec_set=[
"persist_event",
"store_room",
"get_room",
"get_destination_retry_timings",
"set_destination_retry_timings",
"have_events",
"get_users_in_room",
"bulk_get_push_rules",
"get_current_state",
"set_push_actions_for_event_and_users",
"is_guest",
"get_state_for_events",
]),
resource_for_federation=NonCallableMock(),
http_client=NonCallableMock(spec_set=[]),
notifier=NonCallableMock(spec_set=["on_new_room_event"]),
handlers=NonCallableMock(spec_set=[
"room_member_handler",
"federation_handler",
]),
auth=self.auth,
state_handler=self.state_handler,
keyring=Mock(),
)
self.datastore = hs.get_datastore()
self.handlers = hs.get_handlers()
self.notifier = hs.get_notifier()
self.hs = hs
self.handlers.federation_handler = FederationHandler(self.hs)
self.datastore.get_state_for_events.return_value = {"$a:b": {}}
@defer.inlineCallbacks
def test_msg(self):
pdu = FrozenEvent({
"type": EventTypes.Message,
"room_id": "foo",
"content": {"msgtype": u"fooo"},
"origin_server_ts": 0,
"event_id": "$a:b",
"user_id":"@a:b",
"origin": "b",
"auth_events": [],
"hashes": {"sha256":"AcLrgtUIqqwaGoHhrEvYG1YLDIsVPYJdSRGhkp3jJp8"},
})
self.datastore.persist_event.return_value = defer.succeed((1,1))
self.datastore.get_room.return_value = defer.succeed(True)
self.datastore.get_users_in_room.return_value = ["@a:b"]
self.datastore.bulk_get_push_rules.return_value = {}
self.datastore.get_current_state.return_value = {}
self.auth.check_host_in_room.return_value = defer.succeed(True)
retry_timings_res = {
"destination": "",
"retry_last_ts": 0,
"retry_interval": 0,
}
self.datastore.get_destination_retry_timings.return_value = (
defer.succeed(retry_timings_res)
)
def have_events(event_ids):
return defer.succeed({})
self.datastore.have_events.side_effect = have_events
def annotate(ev, old_state=None, outlier=False):
context = Mock()
context.current_state = {}
context.auth_events = {}
return defer.succeed(context)
self.state_handler.compute_event_context.side_effect = annotate
yield self.handlers.federation_handler.on_receive_pdu(
"fo", pdu, False
)
self.datastore.persist_event.assert_called_once_with(
ANY,
is_new_state=True,
backfilled=False,
current_state=None,
context=ANY,
)
self.state_handler.compute_event_context.assert_called_once_with(
ANY, old_state=None, outlier=False
)
self.auth.check.assert_called_once_with(ANY, auth_events={})
self.notifier.on_new_room_event.assert_called_once_with(
ANY, 1, 1, extra_users=[]
)

View file

@ -1,418 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from .. import unittest
from synapse.api.constants import EventTypes, Membership
from synapse.handlers.room import RoomMemberHandler, RoomCreationHandler
from synapse.handlers.profile import ProfileHandler
from synapse.types import UserID
from ..utils import setup_test_homeserver
from mock import Mock, NonCallableMock
class RoomMemberHandlerTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.hostname = "red"
hs = yield setup_test_homeserver(
self.hostname,
ratelimiter=NonCallableMock(spec_set=[
"send_message",
]),
datastore=NonCallableMock(spec_set=[
"persist_event",
"get_room_member",
"get_room",
"store_room",
"get_latest_events_in_room",
"add_event_hashes",
"get_users_in_room",
"bulk_get_push_rules",
"get_current_state",
"set_push_actions_for_event_and_users",
"get_state_for_events",
"is_guest",
]),
resource_for_federation=NonCallableMock(),
http_client=NonCallableMock(spec_set=[]),
notifier=NonCallableMock(spec_set=["on_new_room_event"]),
handlers=NonCallableMock(spec_set=[
"room_member_handler",
"profile_handler",
"federation_handler",
]),
auth=NonCallableMock(spec_set=[
"check",
"add_auth_events",
"check_host_in_room",
]),
state_handler=NonCallableMock(spec_set=[
"compute_event_context",
"get_current_state",
]),
)
self.federation = NonCallableMock(spec_set=[
"handle_new_event",
"send_invite",
"get_state_for_room",
])
self.datastore = hs.get_datastore()
self.handlers = hs.get_handlers()
self.notifier = hs.get_notifier()
self.state_handler = hs.get_state_handler()
self.distributor = hs.get_distributor()
self.auth = hs.get_auth()
self.hs = hs
self.handlers.federation_handler = self.federation
self.distributor.declare("collect_presencelike_data")
self.handlers.room_member_handler = RoomMemberHandler(self.hs)
self.handlers.profile_handler = ProfileHandler(self.hs)
self.room_member_handler = self.handlers.room_member_handler
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
self.datastore.persist_event.return_value = (1,1)
self.datastore.add_event_hashes.return_value = []
self.datastore.get_users_in_room.return_value = ["@bob:red"]
self.datastore.bulk_get_push_rules.return_value = {}
@defer.inlineCallbacks
def test_invite(self):
room_id = "!foo:red"
user_id = "@bob:red"
target_user_id = "@red:blue"
content = {"membership": Membership.INVITE}
builder = self.hs.get_event_builder_factory().new({
"type": EventTypes.Member,
"sender": user_id,
"state_key": target_user_id,
"room_id": room_id,
"content": content,
})
self.datastore.get_latest_events_in_room.return_value = (
defer.succeed([])
)
self.datastore.get_current_state.return_value = {}
self.datastore.get_state_for_events = lambda event_ids,types: {x: {} for x in event_ids}
def annotate(_):
ctx = Mock()
ctx.current_state = {
(EventTypes.Member, "@alice:green"): self._create_member(
user_id="@alice:green",
room_id=room_id,
),
(EventTypes.Member, "@bob:red"): self._create_member(
user_id="@bob:red",
room_id=room_id,
),
}
ctx.prev_state_events = []
return defer.succeed(ctx)
self.state_handler.compute_event_context.side_effect = annotate
def add_auth(_, ctx):
ctx.auth_events = ctx.current_state[
(EventTypes.Member, "@bob:red")
]
return defer.succeed(True)
self.auth.add_auth_events.side_effect = add_auth
def send_invite(domain, event):
return defer.succeed(event)
self.federation.send_invite.side_effect = send_invite
room_handler = self.room_member_handler
event, context = yield room_handler._create_new_client_event(
builder
)
yield room_handler.send_membership_event(event, context)
self.state_handler.compute_event_context.assert_called_once_with(
builder
)
self.auth.add_auth_events.assert_called_once_with(
builder, context
)
self.federation.send_invite.assert_called_once_with(
"blue", event,
)
self.datastore.persist_event.assert_called_once_with(
event, context=context,
)
self.notifier.on_new_room_event.assert_called_once_with(
event, 1, 1, extra_users=[UserID.from_string(target_user_id)]
)
self.assertFalse(self.datastore.get_room.called)
self.assertFalse(self.datastore.store_room.called)
self.assertFalse(self.federation.get_state_for_room.called)
@defer.inlineCallbacks
def test_simple_join(self):
room_id = "!foo:red"
user_id = "@bob:red"
user = UserID.from_string(user_id)
join_signal_observer = Mock()
self.distributor.observe("user_joined_room", join_signal_observer)
builder = self.hs.get_event_builder_factory().new({
"type": EventTypes.Member,
"sender": user_id,
"state_key": user_id,
"room_id": room_id,
"content": {"membership": Membership.JOIN},
})
self.datastore.get_latest_events_in_room.return_value = (
defer.succeed([])
)
self.datastore.get_current_state.return_value = {}
self.datastore.get_state_for_events = lambda event_ids,types: {x: {} for x in event_ids}
def annotate(_):
ctx = Mock()
ctx.current_state = {
(EventTypes.Member, "@bob:red"): self._create_member(
user_id="@bob:red",
room_id=room_id,
membership=Membership.INVITE
),
}
ctx.prev_state_events = []
return defer.succeed(ctx)
self.state_handler.compute_event_context.side_effect = annotate
def add_auth(_, ctx):
ctx.auth_events = ctx.current_state[
(EventTypes.Member, "@bob:red")
]
return defer.succeed(True)
self.auth.add_auth_events.side_effect = add_auth
room_handler = self.room_member_handler
event, context = yield room_handler._create_new_client_event(
builder
)
# Actual invocation
yield room_handler.send_membership_event(event, context)
self.federation.handle_new_event.assert_called_once_with(
event, destinations=set()
)
self.datastore.persist_event.assert_called_once_with(
event, context=context
)
self.notifier.on_new_room_event.assert_called_once_with(
event, 1, 1, extra_users=[user]
)
join_signal_observer.assert_called_with(
user=user, room_id=room_id
)
def _create_member(self, user_id, room_id, membership=Membership.JOIN):
builder = self.hs.get_event_builder_factory().new({
"type": EventTypes.Member,
"sender": user_id,
"state_key": user_id,
"room_id": room_id,
"content": {"membership": membership},
})
return builder.build()
@defer.inlineCallbacks
def test_simple_leave(self):
room_id = "!foo:red"
user_id = "@bob:red"
user = UserID.from_string(user_id)
builder = self.hs.get_event_builder_factory().new({
"type": EventTypes.Member,
"sender": user_id,
"state_key": user_id,
"room_id": room_id,
"content": {"membership": Membership.LEAVE},
})
self.datastore.get_latest_events_in_room.return_value = (
defer.succeed([])
)
self.datastore.get_current_state.return_value = {}
self.datastore.get_state_for_events = lambda event_ids,types: {x: {} for x in event_ids}
def annotate(_):
ctx = Mock()
ctx.current_state = {
(EventTypes.Member, "@bob:red"): self._create_member(
user_id="@bob:red",
room_id=room_id,
membership=Membership.JOIN
),
}
ctx.prev_state_events = []
return defer.succeed(ctx)
self.state_handler.compute_event_context.side_effect = annotate
def add_auth(_, ctx):
ctx.auth_events = ctx.current_state[
(EventTypes.Member, "@bob:red")
]
return defer.succeed(True)
self.auth.add_auth_events.side_effect = add_auth
room_handler = self.room_member_handler
event, context = yield room_handler._create_new_client_event(
builder
)
leave_signal_observer = Mock()
self.distributor.observe("user_left_room", leave_signal_observer)
# Actual invocation
yield room_handler.send_membership_event(event, context)
self.federation.handle_new_event.assert_called_once_with(
event, destinations=set(['red'])
)
self.datastore.persist_event.assert_called_once_with(
event, context=context
)
self.notifier.on_new_room_event.assert_called_once_with(
event, 1, 1, extra_users=[user]
)
leave_signal_observer.assert_called_with(
user=user, room_id=room_id
)
class RoomCreationTest(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.hostname = "red"
hs = yield setup_test_homeserver(
self.hostname,
datastore=NonCallableMock(spec_set=[
"store_room",
"snapshot_room",
"persist_event",
"get_joined_hosts_for_room",
]),
http_client=NonCallableMock(spec_set=[]),
notifier=NonCallableMock(spec_set=["on_new_room_event"]),
handlers=NonCallableMock(spec_set=[
"room_creation_handler",
"message_handler",
]),
auth=NonCallableMock(spec_set=["check", "add_auth_events"]),
ratelimiter=NonCallableMock(spec_set=[
"send_message",
]),
)
self.federation = NonCallableMock(spec_set=[
"handle_new_event",
])
self.handlers = hs.get_handlers()
self.handlers.room_creation_handler = RoomCreationHandler(hs)
self.room_creation_handler = self.handlers.room_creation_handler
self.message_handler = self.handlers.message_handler
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
@defer.inlineCallbacks
def test_room_creation(self):
user_id = "@foo:red"
room_id = "!bobs_room:red"
config = {"visibility": "private"}
yield self.room_creation_handler.create_room(
user_id=user_id,
room_id=room_id,
config=config,
)
self.assertTrue(self.message_handler.create_and_send_event.called)
event_dicts = [
e[0][0]
for e in self.message_handler.create_and_send_event.call_args_list
]
self.assertTrue(len(event_dicts) > 3)
self.assertDictContainsSubset(
{
"type": EventTypes.Create,
"sender": user_id,
"room_id": room_id,
},
event_dicts[0]
)
self.assertEqual(user_id, event_dicts[0]["content"]["creator"])
self.assertDictContainsSubset(
{
"type": EventTypes.Member,
"sender": user_id,
"room_id": room_id,
"state_key": user_id,
},
event_dicts[1]
)
self.assertEqual(
Membership.JOIN,
event_dicts[1]["content"]["membership"]
)