Merge branch 'develop' into application-services-txn-reliability

This commit is contained in:
Kegan Dougal 2015-03-06 17:28:49 +00:00
commit 34ce2ca62f
20 changed files with 496 additions and 121 deletions

View file

@ -1,11 +1,40 @@
Changes in synapse vx.x.x (x-x-x) Changes in synapse v0.8.0 (2015-03-06)
================================= ======================================
General:
* Add support for registration fallback. This is a page hosted on the server * Add support for registration fallback. This is a page hosted on the server
which allows a user to register for an account, regardless of what client which allows a user to register for an account, regardless of what client
they are using (e.g. mobile devices). they are using (e.g. mobile devices).
* Added new default push rules and made them configurable by clients:
* Suppress all notice messages.
* Notify when invited to a new room.
* Notify for messages that don't match any rule.
* Notify on incoming call.
* Notify if there were no matching rules.
Federation:
* Added per host server side rate-limiting of incoming federation requests.
* Added a ``/get_missing_events/`` API to federation to reduce number of
``/events/`` requests.
Configuration:
* Added configuration option to disable registration:
``disable_registration``.
* Added configuration option to change soft limit of number of open file
descriptors: ``soft_file_limit``.
* Make ``tls_private_key_path`` optional when running with ``no_tls``.
Application services:
* Application services can now poll on the CS API ``/events`` for their events, * Application services can now poll on the CS API ``/events`` for their events,
by providing their application service ``access_token``. by providing their application service ``access_token``.
* Added exclusive namespace support to application services API.
Changes in synapse v0.7.1 (2015-02-19) Changes in synapse v0.7.1 (2015-02-19)
====================================== ======================================

View file

@ -118,6 +118,7 @@ environment under ``~/.synapse``.
To set up your homeserver, run (in your virtualenv, as before):: To set up your homeserver, run (in your virtualenv, as before)::
$ cd ~/.synapse
$ python -m synapse.app.homeserver \ $ python -m synapse.app.homeserver \
--server-name machine.my.domain.name \ --server-name machine.my.domain.name \
--config-path homeserver.yaml \ --config-path homeserver.yaml \
@ -179,6 +180,7 @@ installing under virtualenv)::
During setup of homeserver you need to call python2.7 directly again:: During setup of homeserver you need to call python2.7 directly again::
$ cd ~/.synapse
$ python2.7 -m synapse.app.homeserver \ $ python2.7 -m synapse.app.homeserver \
--server-name machine.my.domain.name \ --server-name machine.my.domain.name \
--config-path homeserver.yaml \ --config-path homeserver.yaml \

View file

@ -1,5 +1,5 @@
Upgrading to vx.xx Upgrading to v0.8.0
================== ===================
Servers which use captchas will need to add their public key to:: Servers which use captchas will need to add their public key to::
@ -12,9 +12,6 @@ Servers which use captchas will need to add their public key to::
This is required in order to support registration fallback (typically used on This is required in order to support registration fallback (typically used on
mobile devices). mobile devices).
The format of stored application services has changed in Synapse. You will need
to run ``PYTHONPATH=. python scripts/upgrade_appservice_db.py <database file path>``
to convert to the new format.
Upgrading to v0.7.0 Upgrading to v0.7.0
=================== ===================

View file

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.7.1-r1" __version__ = "0.8.0"

View file

@ -219,61 +219,64 @@ class SynapseHomeServer(HomeServer):
def get_version_string(): def get_version_string():
null = open(os.devnull, 'w')
cwd = os.path.dirname(os.path.abspath(__file__))
try: try:
git_branch = subprocess.check_output( null = open(os.devnull, 'w')
['git', 'rev-parse', '--abbrev-ref', 'HEAD'], cwd = os.path.dirname(os.path.abspath(__file__))
stderr=null, try:
cwd=cwd, git_branch = subprocess.check_output(
).strip() ['git', 'rev-parse', '--abbrev-ref', 'HEAD'],
git_branch = "b=" + git_branch stderr=null,
except subprocess.CalledProcessError: cwd=cwd,
git_branch = "" ).strip()
git_branch = "b=" + git_branch
except subprocess.CalledProcessError:
git_branch = ""
try: try:
git_tag = subprocess.check_output( git_tag = subprocess.check_output(
['git', 'describe', '--exact-match'], ['git', 'describe', '--exact-match'],
stderr=null, stderr=null,
cwd=cwd, cwd=cwd,
).strip() ).strip()
git_tag = "t=" + git_tag git_tag = "t=" + git_tag
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
git_tag = "" git_tag = ""
try: try:
git_commit = subprocess.check_output( git_commit = subprocess.check_output(
['git', 'rev-parse', '--short', 'HEAD'], ['git', 'rev-parse', '--short', 'HEAD'],
stderr=null, stderr=null,
cwd=cwd, cwd=cwd,
).strip() ).strip()
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
git_commit = "" git_commit = ""
try: try:
dirty_string = "-this_is_a_dirty_checkout" dirty_string = "-this_is_a_dirty_checkout"
is_dirty = subprocess.check_output( is_dirty = subprocess.check_output(
['git', 'describe', '--dirty=' + dirty_string], ['git', 'describe', '--dirty=' + dirty_string],
stderr=null, stderr=null,
cwd=cwd, cwd=cwd,
).strip().endswith(dirty_string) ).strip().endswith(dirty_string)
git_dirty = "dirty" if is_dirty else "" git_dirty = "dirty" if is_dirty else ""
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
git_dirty = "" git_dirty = ""
if git_branch or git_tag or git_commit or git_dirty: if git_branch or git_tag or git_commit or git_dirty:
git_version = ",".join( git_version = ",".join(
s for s in s for s in
(git_branch, git_tag, git_commit, git_dirty,) (git_branch, git_tag, git_commit, git_dirty,)
if s if s
)
return (
"Synapse/%s (%s)" % (
synapse.__version__, git_version,
) )
).encode("ascii")
return (
"Synapse/%s (%s)" % (
synapse.__version__, git_version,
)
).encode("ascii")
except Exception as e:
logger.warn("Failed to check for git repository: %s", e)
return ("Synapse/%s" % (synapse.__version__,)).encode("ascii") return ("Synapse/%s" % (synapse.__version__,)).encode("ascii")

View file

@ -30,7 +30,6 @@ class ServerConfig(Config):
self.pid_file = self.abspath(args.pid_file) self.pid_file = self.abspath(args.pid_file)
self.webclient = True self.webclient = True
self.manhole = args.manhole self.manhole = args.manhole
self.no_tls = args.no_tls
self.soft_file_limit = args.soft_file_limit self.soft_file_limit = args.soft_file_limit
if not args.content_addr: if not args.content_addr:
@ -76,8 +75,6 @@ class ServerConfig(Config):
server_group.add_argument("--content-addr", default=None, server_group.add_argument("--content-addr", default=None,
help="The host and scheme to use for the " help="The host and scheme to use for the "
"content repository") "content repository")
server_group.add_argument("--no-tls", action='store_true',
help="Don't bind to the https port.")
server_group.add_argument("--soft-file-limit", type=int, default=0, server_group.add_argument("--soft-file-limit", type=int, default=0,
help="Set the soft limit on the number of " help="Set the soft limit on the number of "
"file descriptors synapse can use. " "file descriptors synapse can use. "

View file

@ -28,9 +28,16 @@ class TlsConfig(Config):
self.tls_certificate = self.read_tls_certificate( self.tls_certificate = self.read_tls_certificate(
args.tls_certificate_path args.tls_certificate_path
) )
self.tls_private_key = self.read_tls_private_key(
args.tls_private_key_path self.no_tls = args.no_tls
)
if self.no_tls:
self.tls_private_key = None
else:
self.tls_private_key = self.read_tls_private_key(
args.tls_private_key_path
)
self.tls_dh_params_path = self.check_file( self.tls_dh_params_path = self.check_file(
args.tls_dh_params_path, "tls_dh_params" args.tls_dh_params_path, "tls_dh_params"
) )
@ -45,6 +52,8 @@ class TlsConfig(Config):
help="PEM encoded private key for TLS") help="PEM encoded private key for TLS")
tls_group.add_argument("--tls-dh-params-path", tls_group.add_argument("--tls-dh-params-path",
help="PEM dh parameters for ephemeral keys") help="PEM dh parameters for ephemeral keys")
tls_group.add_argument("--no-tls", action='store_true',
help="Don't bind to the https port.")
def read_tls_certificate(self, cert_path): def read_tls_certificate(self, cert_path):
cert_pem = self.read_file(cert_path, "tls_certificate") cert_pem = self.read_file(cert_path, "tls_certificate")

View file

@ -38,7 +38,10 @@ class ServerContextFactory(ssl.ContextFactory):
logger.exception("Failed to enable eliptic curve for TLS") logger.exception("Failed to enable eliptic curve for TLS")
context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3) context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3)
context.use_certificate(config.tls_certificate) context.use_certificate(config.tls_certificate)
context.use_privatekey(config.tls_private_key)
if not config.no_tls:
context.use_privatekey(config.tls_private_key)
context.load_tmp_dh(config.tls_dh_params_path) context.load_tmp_dh(config.tls_dh_params_path)
context.set_cipher_list("!ADH:HIGH+kEDH:!AECDH:HIGH+kEECDH") context.set_cipher_list("!ADH:HIGH+kEDH:!AECDH:HIGH+kEECDH")

View file

@ -50,18 +50,27 @@ class Keyring(object):
) )
try: try:
verify_key = yield self.get_server_verify_key(server_name, key_ids) verify_key = yield self.get_server_verify_key(server_name, key_ids)
except IOError: except IOError as e:
logger.warn(
"Got IOError when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e.message),
)
raise SynapseError( raise SynapseError(
502, 502,
"Error downloading keys for %s" % (server_name,), "Error downloading keys for %s" % (server_name,),
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
) )
except: except Exception as e:
logger.warn(
"Got Exception when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e.message),
)
raise SynapseError( raise SynapseError(
401, 401,
"No key for %s with id %s" % (server_name, key_ids), "No key for %s with id %s" % (server_name, key_ids),
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
) )
try: try:
verify_signed_json(json_object, server_name, verify_key) verify_signed_json(json_object, server_name, verify_key)
except: except:

View file

@ -19,14 +19,18 @@ from twisted.internet import defer
from .federation_base import FederationBase from .federation_base import FederationBase
from .units import Edu from .units import Edu
from synapse.api.errors import CodeMessageException, SynapseError from synapse.api.errors import (
CodeMessageException, HttpResponseException, SynapseError,
)
from synapse.util.expiringcache import ExpiringCache from synapse.util.expiringcache import ExpiringCache
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
import itertools
import logging import logging
import random
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -440,21 +444,112 @@ class FederationClient(FederationBase):
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_missing_events(self, destination, room_id, earliest_events, def get_missing_events(self, destination, room_id, earliest_events_ids,
latest_events, limit, min_depth): latest_events, limit, min_depth):
content = yield self.transport_layer.get_missing_events( """Tries to fetch events we are missing. This is called when we receive
destination, room_id, earliest_events, latest_events, limit, an event without having received all of its ancestors.
min_depth,
)
events = [ Args:
self.event_from_pdu_json(e) destination (str)
for e in content.get("events", []) room_id (str)
] earliest_events_ids (list): List of event ids. Effectively the
events we expected to receive, but haven't. `get_missing_events`
should only return events that didn't happen before these.
latest_events (list): List of events we have received that we don't
have all previous events for.
limit (int): Maximum number of events to return.
min_depth (int): Minimum depth of events tor return.
"""
try:
content = yield self.transport_layer.get_missing_events(
destination=destination,
room_id=room_id,
earliest_events=earliest_events_ids,
latest_events=[e.event_id for e in latest_events],
limit=limit,
min_depth=min_depth,
)
signed_events = yield self._check_sigs_and_hash_and_fetch( events = [
destination, events, outlier=True self.event_from_pdu_json(e)
) for e in content.get("events", [])
]
signed_events = yield self._check_sigs_and_hash_and_fetch(
destination, events, outlier=True
)
have_gotten_all_from_destination = True
except HttpResponseException as e:
if not e.code == 400:
raise
# We are probably hitting an old server that doesn't support
# get_missing_events
signed_events = []
have_gotten_all_from_destination = False
if len(signed_events) >= limit:
defer.returnValue(signed_events)
servers = yield self.store.get_joined_hosts_for_room(room_id)
servers = set(servers)
servers.discard(self.server_name)
failed_to_fetch = set()
while len(signed_events) < limit:
# Are we missing any?
seen_events = set(earliest_events_ids)
seen_events.update(e.event_id for e in signed_events)
missing_events = {}
for e in itertools.chain(latest_events, signed_events):
if e.depth > min_depth:
missing_events.update({
e_id: e.depth for e_id, _ in e.prev_events
if e_id not in seen_events
and e_id not in failed_to_fetch
})
if not missing_events:
break
have_seen = yield self.store.have_events(missing_events)
for k in have_seen:
missing_events.pop(k, None)
if not missing_events:
break
# Okay, we haven't gotten everything yet. Lets get them.
ordered_missing = sorted(missing_events.items(), key=lambda x: x[0])
if have_gotten_all_from_destination:
servers.discard(destination)
def random_server_list():
srvs = list(servers)
random.shuffle(srvs)
return srvs
deferreds = [
self.get_pdu(
destinations=random_server_list(),
event_id=e_id,
)
for e_id, depth in ordered_missing[:limit - len(signed_events)]
]
res = yield defer.DeferredList(deferreds, consumeErrors=True)
for (result, val), (e_id, _) in zip(res, ordered_missing):
if result:
signed_events.append(val)
else:
failed_to_fetch.add(e_id)
defer.returnValue(signed_events) defer.returnValue(signed_events)

View file

@ -413,12 +413,16 @@ class FederationServer(FederationBase):
missing_events = yield self.get_missing_events( missing_events = yield self.get_missing_events(
origin, origin,
pdu.room_id, pdu.room_id,
earliest_events=list(latest), earliest_events_ids=list(latest),
latest_events=[pdu.event_id], latest_events=[pdu],
limit=10, limit=10,
min_depth=min_depth, min_depth=min_depth,
) )
# We want to sort these by depth so we process them and
# tell clients about them in order.
missing_events.sort(key=lambda x: x.depth)
for e in missing_events: for e in missing_events:
yield self._handle_new_pdu( yield self._handle_new_pdu(
origin, origin,

View file

@ -23,6 +23,7 @@ from synapse.events.utils import serialize_event
from ._base import BaseHandler from ._base import BaseHandler
import logging import logging
import random
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -72,6 +73,14 @@ class EventStreamHandler(BaseHandler):
rm_handler = self.hs.get_handlers().room_member_handler rm_handler = self.hs.get_handlers().room_member_handler
room_ids = yield rm_handler.get_rooms_for_user(auth_user) room_ids = yield rm_handler.get_rooms_for_user(auth_user)
if timeout:
# If they've set a timeout set a minimum limit.
timeout = max(timeout, 500)
# Add some randomness to this value to try and mitigate against
# thundering herds on restart.
timeout = random.randint(int(timeout*0.9), int(timeout*1.1))
with PreserveLoggingContext(): with PreserveLoggingContext():
events, tokens = yield self.notifier.get_events_for( events, tokens = yield self.notifier.get_events_for(
auth_user, room_ids, pagin_config, timeout auth_user, room_ids, pagin_config, timeout

View file

@ -212,10 +212,16 @@ class ProfileHandler(BaseHandler):
) )
msg_handler = self.hs.get_handlers().message_handler msg_handler = self.hs.get_handlers().message_handler
yield msg_handler.create_and_send_event({ try:
"type": EventTypes.Member, yield msg_handler.create_and_send_event({
"room_id": j.room_id, "type": EventTypes.Member,
"state_key": user.to_string(), "room_id": j.room_id,
"content": content, "state_key": user.to_string(),
"sender": user.to_string() "content": content,
}, ratelimit=False) "sender": user.to_string()
}, ratelimit=False)
except Exception as e:
logger.warn(
"Failed to update join event for room %s - %s",
j.room_id, str(e.message)
)

View file

@ -124,27 +124,29 @@ class JsonResource(HttpServer, resource.Resource):
# and path regex match # and path regex match
for path_entry in self.path_regexs.get(request.method, []): for path_entry in self.path_regexs.get(request.method, []):
m = path_entry.pattern.match(request.path) m = path_entry.pattern.match(request.path)
if m: if not m:
# We found a match! Trigger callback and then return the continue
# returned response. We pass both the request and any
# matched groups from the regex to the callback.
args = [ # We found a match! Trigger callback and then return the
urllib.unquote(u).decode("UTF-8") for u in m.groups() # returned response. We pass both the request and any
] # matched groups from the regex to the callback.
logger.info( args = [
"Received request: %s %s", urllib.unquote(u).decode("UTF-8") for u in m.groups()
request.method, request.path ]
)
code, response = yield path_entry.callback( logger.info(
request, "Received request: %s %s",
*args request.method, request.path
) )
self._send_response(request, code, response) code, response = yield path_entry.callback(
return request,
*args
)
self._send_response(request, code, response)
return
# Huh. No one wanted to handle that? Fiiiiiine. Send 400. # Huh. No one wanted to handle that? Fiiiiiine. Send 400.
raise UnrecognizedRequestError() raise UnrecognizedRequestError()

View file

@ -32,7 +32,7 @@ class Pusher(object):
INITIAL_BACKOFF = 1000 INITIAL_BACKOFF = 1000
MAX_BACKOFF = 60 * 60 * 1000 MAX_BACKOFF = 60 * 60 * 1000
GIVE_UP_AFTER = 24 * 60 * 60 * 1000 GIVE_UP_AFTER = 24 * 60 * 60 * 1000
DEFAULT_ACTIONS = ['notify'] DEFAULT_ACTIONS = ['dont-notify']
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$") INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
@ -72,16 +72,14 @@ class Pusher(object):
# let's assume you probably know about messages you sent yourself # let's assume you probably know about messages you sent yourself
defer.returnValue(['dont_notify']) defer.returnValue(['dont_notify'])
if ev['type'] == 'm.room.member': rawrules = yield self.store.get_push_rules_for_user(self.user_name)
if ev['state_key'] != self.user_name:
defer.returnValue(['dont_notify'])
rawrules = yield self.store.get_push_rules_for_user_name(self.user_name)
for r in rawrules: for r in rawrules:
r['conditions'] = json.loads(r['conditions']) r['conditions'] = json.loads(r['conditions'])
r['actions'] = json.loads(r['actions']) r['actions'] = json.loads(r['actions'])
enabled_map = yield self.store.get_push_rules_enabled_for_user(self.user_name)
user = UserID.from_string(self.user_name) user = UserID.from_string(self.user_name)
rules = baserules.list_with_base_rules(rawrules, user) rules = baserules.list_with_base_rules(rawrules, user)
@ -107,6 +105,8 @@ class Pusher(object):
room_member_count += 1 room_member_count += 1
for r in rules: for r in rules:
if r['rule_id'] in enabled_map and not enabled_map[r['rule_id']]:
continue
matches = True matches = True
conditions = r['conditions'] conditions = r['conditions']
@ -117,7 +117,11 @@ class Pusher(object):
ev, c, display_name=my_display_name, ev, c, display_name=my_display_name,
room_member_count=room_member_count room_member_count=room_member_count
) )
# ignore rules with no actions (we have an explict 'dont_notify' logger.debug(
"Rule %s %s",
r['rule_id'], "matches" if matches else "doesn't match"
)
# ignore rules with no actions (we have an explict 'dont_notify')
if len(actions) == 0: if len(actions) == 0:
logger.warn( logger.warn(
"Ignoring rule id %s with no actions for user %s" % "Ignoring rule id %s with no actions for user %s" %

View file

@ -32,12 +32,14 @@ def make_base_rules(user, kind):
if kind == 'override': if kind == 'override':
rules = make_base_override_rules() rules = make_base_override_rules()
elif kind == 'underride':
rules = make_base_underride_rules(user)
elif kind == 'content': elif kind == 'content':
rules = make_base_content_rules(user) rules = make_base_content_rules(user)
for r in rules: for r in rules:
r['priority_class'] = PRIORITY_CLASS_MAP[kind] r['priority_class'] = PRIORITY_CLASS_MAP[kind]
r['default'] = True r['default'] = True # Deprecated, left for backwards compat
return rules return rules
@ -45,6 +47,7 @@ def make_base_rules(user, kind):
def make_base_content_rules(user): def make_base_content_rules(user):
return [ return [
{ {
'rule_id': 'global/content/.m.rule.contains_user_name',
'conditions': [ 'conditions': [
{ {
'kind': 'event_match', 'kind': 'event_match',
@ -57,6 +60,8 @@ def make_base_content_rules(user):
{ {
'set_tweak': 'sound', 'set_tweak': 'sound',
'value': 'default', 'value': 'default',
}, {
'set_tweak': 'highlight'
} }
] ]
}, },
@ -66,6 +71,20 @@ def make_base_content_rules(user):
def make_base_override_rules(): def make_base_override_rules():
return [ return [
{ {
'rule_id': 'global/underride/.m.rule.suppress_notices',
'conditions': [
{
'kind': 'event_match',
'key': 'content.msgtype',
'pattern': 'm.notice',
}
],
'actions': [
'dont-notify',
]
},
{
'rule_id': 'global/override/.m.rule.contains_display_name',
'conditions': [ 'conditions': [
{ {
'kind': 'contains_display_name' 'kind': 'contains_display_name'
@ -76,10 +95,13 @@ def make_base_override_rules():
{ {
'set_tweak': 'sound', 'set_tweak': 'sound',
'value': 'default' 'value': 'default'
}, {
'set_tweak': 'highlight'
} }
] ]
}, },
{ {
'rule_id': 'global/override/.m.rule.room_one_to_one',
'conditions': [ 'conditions': [
{ {
'kind': 'room_member_count', 'kind': 'room_member_count',
@ -95,3 +117,86 @@ def make_base_override_rules():
] ]
} }
] ]
def make_base_underride_rules(user):
return [
{
'rule_id': 'global/underride/.m.rule.invite_for_me',
'conditions': [
{
'kind': 'event_match',
'key': 'type',
'pattern': 'm.room.member',
},
{
'kind': 'event_match',
'key': 'content.membership',
'pattern': 'invite',
},
{
'kind': 'event_match',
'key': 'state_key',
'pattern': user.to_string(),
},
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default'
}
]
},
{
'rule_id': 'global/underride/.m.rule.member_event',
'conditions': [
{
'kind': 'event_match',
'key': 'type',
'pattern': 'm.room.member',
}
],
'actions': [
'notify',
]
},
{
'rule_id': 'global/underride/.m.rule.message',
'conditions': [
{
'kind': 'event_match',
'key': 'type',
'pattern': 'm.room.message',
}
],
'actions': [
'notify',
]
},
{
'rule_id': 'global/underride/.m.rule.call',
'conditions': [
{
'kind': 'event_match',
'key': 'type',
'pattern': 'm.call.invite',
}
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'ring'
}
]
},
{
'rule_id': 'global/underride/.m.rule.fallback',
'conditions': [
],
'actions': [
'notify',
]
},
]

View file

@ -88,6 +88,7 @@ class HttpPusher(Pusher):
} }
if event['type'] == 'm.room.member': if event['type'] == 'm.room.member':
d['notification']['membership'] = event['content']['membership'] d['notification']['membership'] = event['content']['membership']
d['notification']['user_is_target'] = event['state_key'] == self.user_name
if 'content' in event: if 'content' in event:
d['notification']['content'] = event['content'] d['notification']['content'] = event['content']

View file

@ -50,6 +50,10 @@ class PushRuleRestServlet(ClientV1RestServlet):
content = _parse_json(request) content = _parse_json(request)
if 'attr' in spec:
self.set_rule_attr(user.to_string(), spec, content)
defer.returnValue((200, {}))
try: try:
(conditions, actions) = _rule_tuple_from_request_object( (conditions, actions) = _rule_tuple_from_request_object(
spec['template'], spec['template'],
@ -110,7 +114,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
# we build up the full structure and then decide which bits of it # we build up the full structure and then decide which bits of it
# to send which means doing unnecessary work sometimes but is # to send which means doing unnecessary work sometimes but is
# is probably not going to make a whole lot of difference # is probably not going to make a whole lot of difference
rawrules = yield self.hs.get_datastore().get_push_rules_for_user_name( rawrules = yield self.hs.get_datastore().get_push_rules_for_user(
user.to_string() user.to_string()
) )
@ -124,6 +128,9 @@ class PushRuleRestServlet(ClientV1RestServlet):
rules['global'] = _add_empty_priority_class_arrays(rules['global']) rules['global'] = _add_empty_priority_class_arrays(rules['global'])
enabled_map = yield self.hs.get_datastore().\
get_push_rules_enabled_for_user(user.to_string())
for r in ruleslist: for r in ruleslist:
rulearray = None rulearray = None
@ -149,6 +156,9 @@ class PushRuleRestServlet(ClientV1RestServlet):
template_rule = _rule_to_template(r) template_rule = _rule_to_template(r)
if template_rule: if template_rule:
template_rule['enabled'] = True
if r['rule_id'] in enabled_map:
template_rule['enabled'] = enabled_map[r['rule_id']]
rulearray.append(template_rule) rulearray.append(template_rule)
path = request.postpath[1:] path = request.postpath[1:]
@ -189,6 +199,25 @@ class PushRuleRestServlet(ClientV1RestServlet):
def on_OPTIONS(self, _): def on_OPTIONS(self, _):
return 200, {} return 200, {}
def set_rule_attr(self, user_name, spec, val):
if spec['attr'] == 'enabled':
if not isinstance(val, bool):
raise SynapseError(400, "Value for 'enabled' must be boolean")
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
self.hs.get_datastore().set_push_rule_enabled(
user_name, namespaced_rule_id, val
)
else:
raise UnrecognizedRequestError()
def get_rule_attr(self, user_name, namespaced_rule_id, attr):
if attr == 'enabled':
return self.hs.get_datastore().get_push_rule_enabled_by_user_rule_id(
user_name, namespaced_rule_id
)
else:
raise UnrecognizedRequestError()
def _rule_spec_from_path(path): def _rule_spec_from_path(path):
if len(path) < 2: if len(path) < 2:
@ -226,6 +255,12 @@ def _rule_spec_from_path(path):
} }
if device: if device:
spec['profile_tag'] = device spec['profile_tag'] = device
path = path[1:]
if len(path) > 0 and len(path[0]) > 0:
spec['attr'] = path[0]
return spec return spec
@ -275,7 +310,7 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj, device=None
for a in actions: for a in actions:
if a in ['notify', 'dont_notify', 'coalesce']: if a in ['notify', 'dont_notify', 'coalesce']:
pass pass
elif isinstance(a, dict) and 'set_sound' in a: elif isinstance(a, dict) and 'set_tweak' in a:
pass pass
else: else:
raise InvalidRuleException("Unrecognised action") raise InvalidRuleException("Unrecognised action")
@ -319,10 +354,23 @@ def _filter_ruleset_with_path(ruleset, path):
if path[0] == '': if path[0] == '':
return ruleset[template_kind] return ruleset[template_kind]
rule_id = path[0] rule_id = path[0]
the_rule = None
for r in ruleset[template_kind]: for r in ruleset[template_kind]:
if r['rule_id'] == rule_id: if r['rule_id'] == rule_id:
return r the_rule = r
raise NotFoundError if the_rule is None:
raise NotFoundError
path = path[1:]
if len(path) == 0:
return the_rule
attr = path[0]
if attr in the_rule:
return the_rule[attr]
else:
raise UnrecognizedRequestError()
def _priority_class_from_spec(spec): def _priority_class_from_spec(spec):
@ -339,7 +387,7 @@ def _priority_class_from_spec(spec):
def _priority_class_to_template_name(pc): def _priority_class_to_template_name(pc):
if pc > PRIORITY_CLASS_MAP['override']: if pc > PRIORITY_CLASS_MAP['override']:
# per-device # per-device
prio_class_index = pc - len(PushRuleRestServlet.PRIORITY_CLASS_MAP) prio_class_index = pc - len(PRIORITY_CLASS_MAP)
return PRIORITY_CLASS_INVERSE_MAP[prio_class_index] return PRIORITY_CLASS_INVERSE_MAP[prio_class_index]
else: else:
return PRIORITY_CLASS_INVERSE_MAP[pc] return PRIORITY_CLASS_INVERSE_MAP[pc]
@ -399,9 +447,6 @@ class InvalidRuleException(Exception):
def _parse_json(request): def _parse_json(request):
try: try:
content = json.loads(request.content.read()) content = json.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.",
errcode=Codes.NOT_JSON)
return content return content
except ValueError: except ValueError:
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)

View file

@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
class PushRuleStore(SQLBaseStore): class PushRuleStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_push_rules_for_user_name(self, user_name): def get_push_rules_for_user(self, user_name):
sql = ( sql = (
"SELECT "+",".join(PushRuleTable.fields)+" " "SELECT "+",".join(PushRuleTable.fields)+" "
"FROM "+PushRuleTable.table_name+" " "FROM "+PushRuleTable.table_name+" "
@ -45,6 +45,28 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(dicts) defer.returnValue(dicts)
@defer.inlineCallbacks
def get_push_rules_enabled_for_user(self, user_name):
results = yield self._simple_select_list(
PushRuleEnableTable.table_name,
{'user_name': user_name},
PushRuleEnableTable.fields
)
defer.returnValue(
{r['rule_id']: False if r['enabled'] == 0 else True for r in results}
)
@defer.inlineCallbacks
def get_push_rule_enabled_by_user_rule_id(self, user_name, rule_id):
results = yield self._simple_select_list(
PushRuleEnableTable.table_name,
{'user_name': user_name, 'rule_id': rule_id},
['enabled']
)
if not results:
defer.returnValue(True)
defer.returnValue(results[0])
@defer.inlineCallbacks @defer.inlineCallbacks
def add_push_rule(self, before, after, **kwargs): def add_push_rule(self, before, after, **kwargs):
vals = copy.copy(kwargs) vals = copy.copy(kwargs)
@ -193,6 +215,20 @@ class PushRuleStore(SQLBaseStore):
{'user_name': user_name, 'rule_id': rule_id} {'user_name': user_name, 'rule_id': rule_id}
) )
@defer.inlineCallbacks
def set_push_rule_enabled(self, user_name, rule_id, enabled):
if enabled:
yield self._simple_delete_one(
PushRuleEnableTable.table_name,
{'user_name': user_name, 'rule_id': rule_id}
)
else:
yield self._simple_upsert(
PushRuleEnableTable.table_name,
{'user_name': user_name, 'rule_id': rule_id},
{'enabled': False}
)
class RuleNotFoundException(Exception): class RuleNotFoundException(Exception):
pass pass
@ -216,3 +252,13 @@ class PushRuleTable(Table):
] ]
EntryType = collections.namedtuple("PushRuleEntry", fields) EntryType = collections.namedtuple("PushRuleEntry", fields)
class PushRuleEnableTable(Table):
table_name = "push_rules_enable"
fields = [
"user_name",
"rule_id",
"enabled"
]

View file

@ -0,0 +1,9 @@
CREATE TABLE IF NOT EXISTS push_rules_enable (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_name TEXT NOT NULL,
rule_id TEXT NOT NULL,
enabled TINYINT,
UNIQUE(user_name, rule_id)
);
CREATE INDEX IF NOT EXISTS push_rules_enable_user_name on push_rules_enable (user_name);