Merge branch 'release-v0.27.0' of https://github.com/matrix-org/synapse

This commit is contained in:
Neil Johnson 2018-04-11 10:59:00 +01:00
commit 427e6c4059
51 changed files with 957 additions and 279 deletions

View file

@ -1,3 +1,66 @@
Changes in synapse v0.27.3 (2018-04-11)
======================================
Bug fixes:
* URL quote path segments over federation (#3082)
Changes in synapse v0.27.3-rc2 (2018-04-09)
==========================================
v0.27.3-rc1 used a stale version of the develop branch so the changelog overstates
the functionality. v0.27.3-rc2 is up to date, rc1 should be ignored.
Changes in synapse v0.27.3-rc1 (2018-04-09)
=======================================
Notable changes include API support for joinability of groups. Also new metrics
and phone home stats. Phone home stats include better visibility of system usage
so we can tweak synpase to work better for all users rather than our own experience
with matrix.org. Also, recording 'r30' stat which is the measure we use to track
overal growth of the Matrix ecosystem. It is defined as:-
Counts the number of native 30 day retained users, defined as:-
* Users who have created their accounts more than 30 days
* Where last seen at most 30 days ago
* Where account creation and last_seen are > 30 days"
Features:
* Add joinability for groups (PR #3045)
* Implement group join API (PR #3046)
* Add counter metrics for calculating state delta (PR #3033)
* R30 stats (PR #3041)
* Measure time it takes to calculate state group ID (PR #3043)
* Add basic performance statistics to phone home (PR #3044)
* Add response size metrics (PR #3071)
* phone home cache size configurations (PR #3063)
Changes:
* Add a blurb explaining the main synapse worker (PR #2886) Thanks to @turt2live!
* Replace old style error catching with 'as' keyword (PR #3000) Thanks to @NotAFile!
* Use .iter* to avoid copies in StateHandler (PR #3006)
* Linearize calls to _generate_user_id (PR #3029)
* Remove last usage of ujson (PR #3030)
* Use simplejson throughout (PR #3048)
* Use static JSONEncoders (PR #3049)
* Remove uses of events.content (PR #3060)
* Improve database cache performance (PR #3068)
Bug fixes:
* Add room_id to the response of `rooms/{roomId}/join` (PR #2986) Thanks to @jplatte!
* Fix replication after switch to simplejson (PR #3015)
* Fix replication after switch to simplejson (PR #3015)
* 404 correctly on missing paths via NoResource (PR #3022)
* Fix error when claiming e2e keys from offline servers (PR #3034)
* fix tests/storage/test_user_directory.py (PR #3042)
* use PUT instead of POST for federating groups/m.join_policy (PR #3070) Thanks to @krombel!
* postgres port script: fix state_groups_pkey error (PR #3072)
Changes in synapse v0.27.2 (2018-03-26) Changes in synapse v0.27.2 (2018-03-26)
======================================= =======================================
@ -59,7 +122,7 @@ Features:
Changes: Changes:
* Continue to factor out processing from main process and into worker processes. See updated `docs/workers.rst <docs/metrics-howto.rst>`_ (PR #2892 - #2904, #2913, #2920 - #2926, #2947, #2847, #2854, #2872, #2873, #2874, #2928, #2929, #2934, #2856, #2976 - #2984, #2987 - #2989, #2991 - #2993, #2995, #2784) * Continue to factor out processing from main process and into worker processes. See updated `docs/workers.rst <docs/workers.rst>`_ (PR #2892 - #2904, #2913, #2920 - #2926, #2947, #2847, #2854, #2872, #2873, #2874, #2928, #2929, #2934, #2856, #2976 - #2984, #2987 - #2989, #2991 - #2993, #2995, #2784)
* Ensure state cache is used when persisting events (PR #2864, #2871, #2802, #2835, #2836, #2841, #2842, #2849) * Ensure state cache is used when persisting events (PR #2864, #2871, #2802, #2835, #2836, #2841, #2842, #2849)
* Change the default config to bind on both IPv4 and IPv6 on all platforms (PR #2435) Thanks to @silkeh! * Change the default config to bind on both IPv4 and IPv6 on all platforms (PR #2435) Thanks to @silkeh!
* No longer require a specific version of saml2 (PR #2695) Thanks to @okurz! * No longer require a specific version of saml2 (PR #2695) Thanks to @okurz!

View file

@ -48,6 +48,18 @@ returned by the Client-Server API:
# configured on port 443. # configured on port 443.
curl -kv https://<host.name>/_matrix/client/versions 2>&1 | grep "Server:" curl -kv https://<host.name>/_matrix/client/versions 2>&1 | grep "Server:"
Upgrading to $NEXT_VERSION
====================
This release expands the anonymous usage stats sent if the opt-in
``report_stats`` configuration is set to ``true``. We now capture RSS memory
and cpu use at a very coarse level. This requires administrators to install
the optional ``psutil`` python module.
We would appreciate it if you could assist by ensuring this module is available
and ``report_stats`` is enabled. This will let us see if performance changes to
synapse are having an impact to the general community.
Upgrading to v0.15.0 Upgrading to v0.15.0
==================== ====================

View file

@ -55,7 +55,12 @@ synapse process.)
You then create a set of configs for the various worker processes. These You then create a set of configs for the various worker processes. These
should be worker configuration files, and should be stored in a dedicated should be worker configuration files, and should be stored in a dedicated
subdirectory, to allow synctl to manipulate them. subdirectory, to allow synctl to manipulate them. An additional configuration
for the master synapse process will need to be created because the process will
not be started automatically. That configuration should look like this::
worker_app: synapse.app.homeserver
daemonize: true
Each worker configuration file inherits the configuration of the main homeserver Each worker configuration file inherits the configuration of the main homeserver
configuration file. You can then override configuration specific to that worker, configuration file. You can then override configuration specific to that worker,
@ -230,9 +235,11 @@ file. For example::
``synapse.app.event_creator`` ``synapse.app.event_creator``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Handles non-state event creation. It can handle REST endpoints matching:: Handles some event creation. It can handle REST endpoints matching::
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/send ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/send
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/(join|invite|leave|ban|unban|kick)$
^/_matrix/client/(api/v1|r0|unstable)/join/
It will create events locally and then send them on to the main synapse It will create events locally and then send them on to the main synapse
instance to be persisted and handled. instance to be persisted and handled.

View file

@ -1,6 +1,7 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -250,6 +251,12 @@ class Porter(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_table(self, table, postgres_size, table_size, forward_chunk, def handle_table(self, table, postgres_size, table_size, forward_chunk,
backward_chunk): backward_chunk):
logger.info(
"Table %s: %i/%i (rows %i-%i) already ported",
table, postgres_size, table_size,
backward_chunk+1, forward_chunk-1,
)
if not table_size: if not table_size:
return return
@ -467,31 +474,10 @@ class Porter(object):
self.progress.set_state("Preparing PostgreSQL") self.progress.set_state("Preparing PostgreSQL")
self.setup_db(postgres_config, postgres_engine) self.setup_db(postgres_config, postgres_engine)
# Step 2. Get tables. self.progress.set_state("Creating port tables")
self.progress.set_state("Fetching tables")
sqlite_tables = yield self.sqlite_store._simple_select_onecol(
table="sqlite_master",
keyvalues={
"type": "table",
},
retcol="name",
)
postgres_tables = yield self.postgres_store._simple_select_onecol(
table="information_schema.tables",
keyvalues={},
retcol="distinct table_name",
)
tables = set(sqlite_tables) & set(postgres_tables)
self.progress.set_state("Creating tables")
logger.info("Found %d tables", len(tables))
def create_port_table(txn): def create_port_table(txn):
txn.execute( txn.execute(
"CREATE TABLE port_from_sqlite3 (" "CREATE TABLE IF NOT EXISTS port_from_sqlite3 ("
" table_name varchar(100) NOT NULL UNIQUE," " table_name varchar(100) NOT NULL UNIQUE,"
" forward_rowid bigint NOT NULL," " forward_rowid bigint NOT NULL,"
" backward_rowid bigint NOT NULL" " backward_rowid bigint NOT NULL"
@ -517,18 +503,33 @@ class Porter(object):
"alter_table", alter_table "alter_table", alter_table
) )
except Exception as e: except Exception as e:
logger.info("Failed to create port table: %s", e) pass
try:
yield self.postgres_store.runInteraction( yield self.postgres_store.runInteraction(
"create_port_table", create_port_table "create_port_table", create_port_table
) )
except Exception as e:
logger.info("Failed to create port table: %s", e)
self.progress.set_state("Setting up") # Step 2. Get tables.
self.progress.set_state("Fetching tables")
sqlite_tables = yield self.sqlite_store._simple_select_onecol(
table="sqlite_master",
keyvalues={
"type": "table",
},
retcol="name",
)
# Set up tables. postgres_tables = yield self.postgres_store._simple_select_onecol(
table="information_schema.tables",
keyvalues={},
retcol="distinct table_name",
)
tables = set(sqlite_tables) & set(postgres_tables)
logger.info("Found %d tables", len(tables))
# Step 3. Figure out what still needs copying
self.progress.set_state("Checking on port progress")
setup_res = yield defer.gatherResults( setup_res = yield defer.gatherResults(
[ [
self.setup_table(table) self.setup_table(table)
@ -539,7 +540,8 @@ class Porter(object):
consumeErrors=True, consumeErrors=True,
) )
# Process tables. # Step 4. Do the copying.
self.progress.set_state("Copying to postgres")
yield defer.gatherResults( yield defer.gatherResults(
[ [
self.handle_table(*res) self.handle_table(*res)
@ -548,6 +550,9 @@ class Porter(object):
consumeErrors=True, consumeErrors=True,
) )
# Step 5. Do final post-processing
yield self._setup_state_group_id_seq()
self.progress.done() self.progress.done()
except: except:
global end_error_exec_info global end_error_exec_info
@ -707,6 +712,16 @@ class Porter(object):
defer.returnValue((done, remaining + done)) defer.returnValue((done, remaining + done))
def _setup_state_group_id_seq(self):
def r(txn):
txn.execute("SELECT MAX(id) FROM state_groups")
next_id = txn.fetchone()[0]+1
txn.execute(
"ALTER SEQUENCE state_group_id_seq RESTART WITH %s",
(next_id,),
)
return self.postgres_store.runInteraction("setup_state_group_id_seq", r)
############################################## ##############################################
###### The following is simply UI stuff ###### ###### The following is simply UI stuff ######

View file

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.27.2" __version__ = "0.27.3"

View file

@ -15,9 +15,10 @@
"""Contains exceptions and error codes.""" """Contains exceptions and error codes."""
import json
import logging import logging
import simplejson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View file

@ -36,7 +36,7 @@ from synapse.util.logcontext import LoggingContext, preserve_fn
from synapse.util.manhole import manhole from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
from twisted.internet import reactor from twisted.internet import reactor
from twisted.web.resource import Resource from twisted.web.resource import NoResource
logger = logging.getLogger("synapse.app.appservice") logger = logging.getLogger("synapse.app.appservice")
@ -64,7 +64,7 @@ class AppserviceServer(HomeServer):
if name == "metrics": if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self) resources[METRICS_PREFIX] = MetricsResource(self)
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, NoResource())
_base.listen_tcp( _base.listen_tcp(
bind_addresses, bind_addresses,

View file

@ -44,7 +44,7 @@ from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
from twisted.internet import reactor from twisted.internet import reactor
from twisted.web.resource import Resource from twisted.web.resource import NoResource
logger = logging.getLogger("synapse.app.client_reader") logger = logging.getLogger("synapse.app.client_reader")
@ -88,7 +88,7 @@ class ClientReaderServer(HomeServer):
"/_matrix/client/api/v1": resource, "/_matrix/client/api/v1": resource,
}) })
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, NoResource())
_base.listen_tcp( _base.listen_tcp(
bind_addresses, bind_addresses,

View file

@ -52,7 +52,7 @@ from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
from twisted.internet import reactor from twisted.internet import reactor
from twisted.web.resource import Resource from twisted.web.resource import NoResource
logger = logging.getLogger("synapse.app.event_creator") logger = logging.getLogger("synapse.app.event_creator")
@ -104,7 +104,7 @@ class EventCreatorServer(HomeServer):
"/_matrix/client/api/v1": resource, "/_matrix/client/api/v1": resource,
}) })
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, NoResource())
_base.listen_tcp( _base.listen_tcp(
bind_addresses, bind_addresses,

View file

@ -41,7 +41,7 @@ from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
from twisted.internet import reactor from twisted.internet import reactor
from twisted.web.resource import Resource from twisted.web.resource import NoResource
logger = logging.getLogger("synapse.app.federation_reader") logger = logging.getLogger("synapse.app.federation_reader")
@ -77,7 +77,7 @@ class FederationReaderServer(HomeServer):
FEDERATION_PREFIX: TransportLayerServer(self), FEDERATION_PREFIX: TransportLayerServer(self),
}) })
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, NoResource())
_base.listen_tcp( _base.listen_tcp(
bind_addresses, bind_addresses,

View file

@ -42,7 +42,7 @@ from synapse.util.logcontext import LoggingContext, preserve_fn
from synapse.util.manhole import manhole from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.web.resource import Resource from twisted.web.resource import NoResource
logger = logging.getLogger("synapse.app.federation_sender") logger = logging.getLogger("synapse.app.federation_sender")
@ -91,7 +91,7 @@ class FederationSenderServer(HomeServer):
if name == "metrics": if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self) resources[METRICS_PREFIX] = MetricsResource(self)
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, NoResource())
_base.listen_tcp( _base.listen_tcp(
bind_addresses, bind_addresses,

View file

@ -44,7 +44,7 @@ from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.web.resource import Resource from twisted.web.resource import NoResource
logger = logging.getLogger("synapse.app.frontend_proxy") logger = logging.getLogger("synapse.app.frontend_proxy")
@ -142,7 +142,7 @@ class FrontendProxyServer(HomeServer):
"/_matrix/client/api/v1": resource, "/_matrix/client/api/v1": resource,
}) })
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, NoResource())
_base.listen_tcp( _base.listen_tcp(
bind_addresses, bind_addresses,

View file

@ -48,6 +48,7 @@ from synapse.server import HomeServer
from synapse.storage import are_all_users_on_domain from synapse.storage import are_all_users_on_domain
from synapse.storage.engines import IncorrectDatabaseSetup, create_engine from synapse.storage.engines import IncorrectDatabaseSetup, create_engine
from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database
from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole from synapse.util.manhole import manhole
@ -56,7 +57,7 @@ from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
from twisted.application import service from twisted.application import service
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.web.resource import EncodingResourceWrapper, Resource from twisted.web.resource import EncodingResourceWrapper, NoResource
from twisted.web.server import GzipEncoderFactory from twisted.web.server import GzipEncoderFactory
from twisted.web.static import File from twisted.web.static import File
@ -126,7 +127,7 @@ class SynapseHomeServer(HomeServer):
if WEB_CLIENT_PREFIX in resources: if WEB_CLIENT_PREFIX in resources:
root_resource = RootRedirect(WEB_CLIENT_PREFIX) root_resource = RootRedirect(WEB_CLIENT_PREFIX)
else: else:
root_resource = Resource() root_resource = NoResource()
root_resource = create_resource_tree(resources, root_resource) root_resource = create_resource_tree(resources, root_resource)
@ -402,6 +403,10 @@ def run(hs):
stats = {} stats = {}
# Contains the list of processes we will be monitoring
# currently either 0 or 1
stats_process = []
@defer.inlineCallbacks @defer.inlineCallbacks
def phone_stats_home(): def phone_stats_home():
logger.info("Gathering stats for reporting") logger.info("Gathering stats for reporting")
@ -425,8 +430,21 @@ def run(hs):
stats["daily_active_rooms"] = yield hs.get_datastore().count_daily_active_rooms() stats["daily_active_rooms"] = yield hs.get_datastore().count_daily_active_rooms()
stats["daily_messages"] = yield hs.get_datastore().count_daily_messages() stats["daily_messages"] = yield hs.get_datastore().count_daily_messages()
r30_results = yield hs.get_datastore().count_r30_users()
for name, count in r30_results.iteritems():
stats["r30_users_" + name] = count
daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages() daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages()
stats["daily_sent_messages"] = daily_sent_messages stats["daily_sent_messages"] = daily_sent_messages
stats["cache_factor"] = CACHE_SIZE_FACTOR
stats["event_cache_size"] = hs.config.event_cache_size
if len(stats_process) > 0:
stats["memory_rss"] = 0
stats["cpu_average"] = 0
for process in stats_process:
stats["memory_rss"] += process.memory_info().rss
stats["cpu_average"] += int(process.cpu_percent(interval=None))
logger.info("Reporting stats to matrix.org: %s" % (stats,)) logger.info("Reporting stats to matrix.org: %s" % (stats,))
try: try:
@ -437,10 +455,32 @@ def run(hs):
except Exception as e: except Exception as e:
logger.warn("Error reporting stats: %s", e) logger.warn("Error reporting stats: %s", e)
def performance_stats_init():
try:
import psutil
process = psutil.Process()
# Ensure we can fetch both, and make the initial request for cpu_percent
# so the next request will use this as the initial point.
process.memory_info().rss
process.cpu_percent(interval=None)
logger.info("report_stats can use psutil")
stats_process.append(process)
except (ImportError, AttributeError):
logger.warn(
"report_stats enabled but psutil is not installed or incorrect version."
" Disabling reporting of memory/cpu stats."
" Ensuring psutil is available will help matrix.org track performance"
" changes across releases."
)
if hs.config.report_stats: if hs.config.report_stats:
logger.info("Scheduling stats reporting for 3 hour intervals") logger.info("Scheduling stats reporting for 3 hour intervals")
clock.looping_call(phone_stats_home, 3 * 60 * 60 * 1000) clock.looping_call(phone_stats_home, 3 * 60 * 60 * 1000)
# We need to defer this init for the cases that we daemonize
# otherwise the process ID we get is that of the non-daemon process
clock.call_later(0, performance_stats_init)
# We wait 5 minutes to send the first set of stats as the server can # We wait 5 minutes to send the first set of stats as the server can
# be quite busy the first few minutes # be quite busy the first few minutes
clock.call_later(5 * 60, phone_stats_home) clock.call_later(5 * 60, phone_stats_home)

View file

@ -43,7 +43,7 @@ from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
from twisted.internet import reactor from twisted.internet import reactor
from twisted.web.resource import Resource from twisted.web.resource import NoResource
logger = logging.getLogger("synapse.app.media_repository") logger = logging.getLogger("synapse.app.media_repository")
@ -84,7 +84,7 @@ class MediaRepositoryServer(HomeServer):
), ),
}) })
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, NoResource())
_base.listen_tcp( _base.listen_tcp(
bind_addresses, bind_addresses,

View file

@ -37,7 +37,7 @@ from synapse.util.logcontext import LoggingContext, preserve_fn
from synapse.util.manhole import manhole from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.web.resource import Resource from twisted.web.resource import NoResource
logger = logging.getLogger("synapse.app.pusher") logger = logging.getLogger("synapse.app.pusher")
@ -94,7 +94,7 @@ class PusherServer(HomeServer):
if name == "metrics": if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self) resources[METRICS_PREFIX] = MetricsResource(self)
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, NoResource())
_base.listen_tcp( _base.listen_tcp(
bind_addresses, bind_addresses,

View file

@ -56,7 +56,7 @@ from synapse.util.manhole import manhole
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.web.resource import Resource from twisted.web.resource import NoResource
logger = logging.getLogger("synapse.app.synchrotron") logger = logging.getLogger("synapse.app.synchrotron")
@ -269,7 +269,7 @@ class SynchrotronServer(HomeServer):
"/_matrix/client/api/v1": resource, "/_matrix/client/api/v1": resource,
}) })
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, NoResource())
_base.listen_tcp( _base.listen_tcp(
bind_addresses, bind_addresses,

View file

@ -38,7 +38,7 @@ def pid_running(pid):
try: try:
os.kill(pid, 0) os.kill(pid, 0)
return True return True
except OSError, err: except OSError as err:
if err.errno == errno.EPERM: if err.errno == errno.EPERM:
return True return True
return False return False
@ -98,7 +98,7 @@ def stop(pidfile, app):
try: try:
os.kill(pid, signal.SIGTERM) os.kill(pid, signal.SIGTERM)
write("stopped %s" % (app,), colour=GREEN) write("stopped %s" % (app,), colour=GREEN)
except OSError, err: except OSError as err:
if err.errno == errno.ESRCH: if err.errno == errno.ESRCH:
write("%s not running" % (app,), colour=YELLOW) write("%s not running" % (app,), colour=YELLOW)
elif err.errno == errno.EPERM: elif err.errno == errno.EPERM:
@ -252,6 +252,7 @@ def main():
for running_pid in running_pids: for running_pid in running_pids:
while pid_running(running_pid): while pid_running(running_pid):
time.sleep(0.2) time.sleep(0.2)
write("All processes exited; now restarting...")
if action == "start" or action == "restart": if action == "start" or action == "restart":
if start_stop_synapse: if start_stop_synapse:

View file

@ -43,7 +43,7 @@ from synapse.util.logcontext import LoggingContext, preserve_fn
from synapse.util.manhole import manhole from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
from twisted.internet import reactor from twisted.internet import reactor
from twisted.web.resource import Resource from twisted.web.resource import NoResource
logger = logging.getLogger("synapse.app.user_dir") logger = logging.getLogger("synapse.app.user_dir")
@ -116,7 +116,7 @@ class UserDirectoryServer(HomeServer):
"/_matrix/client/api/v1": resource, "/_matrix/client/api/v1": resource,
}) })
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, NoResource())
_base.listen_tcp( _base.listen_tcp(
bind_addresses, bind_addresses,

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -20,6 +21,7 @@ from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
import logging import logging
import urllib
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -49,7 +51,7 @@ class TransportLayerClient(object):
logger.debug("get_room_state dest=%s, room=%s", logger.debug("get_room_state dest=%s, room=%s",
destination, room_id) destination, room_id)
path = PREFIX + "/state/%s/" % room_id path = _create_path(PREFIX, "/state/%s/", room_id)
return self.client.get_json( return self.client.get_json(
destination, path=path, args={"event_id": event_id}, destination, path=path, args={"event_id": event_id},
) )
@ -71,7 +73,7 @@ class TransportLayerClient(object):
logger.debug("get_room_state_ids dest=%s, room=%s", logger.debug("get_room_state_ids dest=%s, room=%s",
destination, room_id) destination, room_id)
path = PREFIX + "/state_ids/%s/" % room_id path = _create_path(PREFIX, "/state_ids/%s/", room_id)
return self.client.get_json( return self.client.get_json(
destination, path=path, args={"event_id": event_id}, destination, path=path, args={"event_id": event_id},
) )
@ -93,7 +95,7 @@ class TransportLayerClient(object):
logger.debug("get_pdu dest=%s, event_id=%s", logger.debug("get_pdu dest=%s, event_id=%s",
destination, event_id) destination, event_id)
path = PREFIX + "/event/%s/" % (event_id, ) path = _create_path(PREFIX, "/event/%s/", event_id)
return self.client.get_json(destination, path=path, timeout=timeout) return self.client.get_json(destination, path=path, timeout=timeout)
@log_function @log_function
@ -119,7 +121,7 @@ class TransportLayerClient(object):
# TODO: raise? # TODO: raise?
return return
path = PREFIX + "/backfill/%s/" % (room_id,) path = _create_path(PREFIX, "/backfill/%s/", room_id)
args = { args = {
"v": event_tuples, "v": event_tuples,
@ -157,9 +159,11 @@ class TransportLayerClient(object):
# generated by the json_data_callback. # generated by the json_data_callback.
json_data = transaction.get_dict() json_data = transaction.get_dict()
path = _create_path(PREFIX, "/send/%s/", transaction.transaction_id)
response = yield self.client.put_json( response = yield self.client.put_json(
transaction.destination, transaction.destination,
path=PREFIX + "/send/%s/" % transaction.transaction_id, path=path,
data=json_data, data=json_data,
json_data_callback=json_data_callback, json_data_callback=json_data_callback,
long_retries=True, long_retries=True,
@ -177,7 +181,7 @@ class TransportLayerClient(object):
@log_function @log_function
def make_query(self, destination, query_type, args, retry_on_dns_fail, def make_query(self, destination, query_type, args, retry_on_dns_fail,
ignore_backoff=False): ignore_backoff=False):
path = PREFIX + "/query/%s" % query_type path = _create_path(PREFIX, "/query/%s", query_type)
content = yield self.client.get_json( content = yield self.client.get_json(
destination=destination, destination=destination,
@ -222,7 +226,7 @@ class TransportLayerClient(object):
"make_membership_event called with membership='%s', must be one of %s" % "make_membership_event called with membership='%s', must be one of %s" %
(membership, ",".join(valid_memberships)) (membership, ",".join(valid_memberships))
) )
path = PREFIX + "/make_%s/%s/%s" % (membership, room_id, user_id) path = _create_path(PREFIX, "/make_%s/%s/%s", membership, room_id, user_id)
ignore_backoff = False ignore_backoff = False
retry_on_dns_fail = False retry_on_dns_fail = False
@ -248,7 +252,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def send_join(self, destination, room_id, event_id, content): def send_join(self, destination, room_id, event_id, content):
path = PREFIX + "/send_join/%s/%s" % (room_id, event_id) path = _create_path(PREFIX, "/send_join/%s/%s", room_id, event_id)
response = yield self.client.put_json( response = yield self.client.put_json(
destination=destination, destination=destination,
@ -261,7 +265,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def send_leave(self, destination, room_id, event_id, content): def send_leave(self, destination, room_id, event_id, content):
path = PREFIX + "/send_leave/%s/%s" % (room_id, event_id) path = _create_path(PREFIX, "/send_leave/%s/%s", room_id, event_id)
response = yield self.client.put_json( response = yield self.client.put_json(
destination=destination, destination=destination,
@ -280,7 +284,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def send_invite(self, destination, room_id, event_id, content): def send_invite(self, destination, room_id, event_id, content):
path = PREFIX + "/invite/%s/%s" % (room_id, event_id) path = _create_path(PREFIX, "/invite/%s/%s", room_id, event_id)
response = yield self.client.put_json( response = yield self.client.put_json(
destination=destination, destination=destination,
@ -322,7 +326,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def exchange_third_party_invite(self, destination, room_id, event_dict): def exchange_third_party_invite(self, destination, room_id, event_dict):
path = PREFIX + "/exchange_third_party_invite/%s" % (room_id,) path = _create_path(PREFIX, "/exchange_third_party_invite/%s", room_id,)
response = yield self.client.put_json( response = yield self.client.put_json(
destination=destination, destination=destination,
@ -335,7 +339,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_event_auth(self, destination, room_id, event_id): def get_event_auth(self, destination, room_id, event_id):
path = PREFIX + "/event_auth/%s/%s" % (room_id, event_id) path = _create_path(PREFIX, "/event_auth/%s/%s", room_id, event_id)
content = yield self.client.get_json( content = yield self.client.get_json(
destination=destination, destination=destination,
@ -347,7 +351,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def send_query_auth(self, destination, room_id, event_id, content): def send_query_auth(self, destination, room_id, event_id, content):
path = PREFIX + "/query_auth/%s/%s" % (room_id, event_id) path = _create_path(PREFIX, "/query_auth/%s/%s", room_id, event_id)
content = yield self.client.post_json( content = yield self.client.post_json(
destination=destination, destination=destination,
@ -409,7 +413,7 @@ class TransportLayerClient(object):
Returns: Returns:
A dict containg the device keys. A dict containg the device keys.
""" """
path = PREFIX + "/user/devices/" + user_id path = _create_path(PREFIX, "/user/devices/%s", user_id)
content = yield self.client.get_json( content = yield self.client.get_json(
destination=destination, destination=destination,
@ -459,7 +463,7 @@ class TransportLayerClient(object):
@log_function @log_function
def get_missing_events(self, destination, room_id, earliest_events, def get_missing_events(self, destination, room_id, earliest_events,
latest_events, limit, min_depth, timeout): latest_events, limit, min_depth, timeout):
path = PREFIX + "/get_missing_events/%s" % (room_id,) path = _create_path(PREFIX, "/get_missing_events/%s", room_id,)
content = yield self.client.post_json( content = yield self.client.post_json(
destination=destination, destination=destination,
@ -479,7 +483,7 @@ class TransportLayerClient(object):
def get_group_profile(self, destination, group_id, requester_user_id): def get_group_profile(self, destination, group_id, requester_user_id):
"""Get a group profile """Get a group profile
""" """
path = PREFIX + "/groups/%s/profile" % (group_id,) path = _create_path(PREFIX, "/groups/%s/profile", group_id,)
return self.client.get_json( return self.client.get_json(
destination=destination, destination=destination,
@ -498,7 +502,7 @@ class TransportLayerClient(object):
requester_user_id (str) requester_user_id (str)
content (dict): The new profile of the group content (dict): The new profile of the group
""" """
path = PREFIX + "/groups/%s/profile" % (group_id,) path = _create_path(PREFIX, "/groups/%s/profile", group_id,)
return self.client.post_json( return self.client.post_json(
destination=destination, destination=destination,
@ -512,7 +516,7 @@ class TransportLayerClient(object):
def get_group_summary(self, destination, group_id, requester_user_id): def get_group_summary(self, destination, group_id, requester_user_id):
"""Get a group summary """Get a group summary
""" """
path = PREFIX + "/groups/%s/summary" % (group_id,) path = _create_path(PREFIX, "/groups/%s/summary", group_id,)
return self.client.get_json( return self.client.get_json(
destination=destination, destination=destination,
@ -525,7 +529,7 @@ class TransportLayerClient(object):
def get_rooms_in_group(self, destination, group_id, requester_user_id): def get_rooms_in_group(self, destination, group_id, requester_user_id):
"""Get all rooms in a group """Get all rooms in a group
""" """
path = PREFIX + "/groups/%s/rooms" % (group_id,) path = _create_path(PREFIX, "/groups/%s/rooms", group_id,)
return self.client.get_json( return self.client.get_json(
destination=destination, destination=destination,
@ -538,7 +542,7 @@ class TransportLayerClient(object):
content): content):
"""Add a room to a group """Add a room to a group
""" """
path = PREFIX + "/groups/%s/room/%s" % (group_id, room_id,) path = _create_path(PREFIX, "/groups/%s/room/%s", group_id, room_id,)
return self.client.post_json( return self.client.post_json(
destination=destination, destination=destination,
@ -552,7 +556,10 @@ class TransportLayerClient(object):
config_key, content): config_key, content):
"""Update room in group """Update room in group
""" """
path = PREFIX + "/groups/%s/room/%s/config/%s" % (group_id, room_id, config_key,) path = _create_path(
PREFIX, "/groups/%s/room/%s/config/%s",
group_id, room_id, config_key,
)
return self.client.post_json( return self.client.post_json(
destination=destination, destination=destination,
@ -565,7 +572,7 @@ class TransportLayerClient(object):
def remove_room_from_group(self, destination, group_id, requester_user_id, room_id): def remove_room_from_group(self, destination, group_id, requester_user_id, room_id):
"""Remove a room from a group """Remove a room from a group
""" """
path = PREFIX + "/groups/%s/room/%s" % (group_id, room_id,) path = _create_path(PREFIX, "/groups/%s/room/%s", group_id, room_id,)
return self.client.delete_json( return self.client.delete_json(
destination=destination, destination=destination,
@ -578,7 +585,7 @@ class TransportLayerClient(object):
def get_users_in_group(self, destination, group_id, requester_user_id): def get_users_in_group(self, destination, group_id, requester_user_id):
"""Get users in a group """Get users in a group
""" """
path = PREFIX + "/groups/%s/users" % (group_id,) path = _create_path(PREFIX, "/groups/%s/users", group_id,)
return self.client.get_json( return self.client.get_json(
destination=destination, destination=destination,
@ -591,7 +598,7 @@ class TransportLayerClient(object):
def get_invited_users_in_group(self, destination, group_id, requester_user_id): def get_invited_users_in_group(self, destination, group_id, requester_user_id):
"""Get users that have been invited to a group """Get users that have been invited to a group
""" """
path = PREFIX + "/groups/%s/invited_users" % (group_id,) path = _create_path(PREFIX, "/groups/%s/invited_users", group_id,)
return self.client.get_json( return self.client.get_json(
destination=destination, destination=destination,
@ -604,7 +611,23 @@ class TransportLayerClient(object):
def accept_group_invite(self, destination, group_id, user_id, content): def accept_group_invite(self, destination, group_id, user_id, content):
"""Accept a group invite """Accept a group invite
""" """
path = PREFIX + "/groups/%s/users/%s/accept_invite" % (group_id, user_id) path = _create_path(
PREFIX, "/groups/%s/users/%s/accept_invite",
group_id, user_id,
)
return self.client.post_json(
destination=destination,
path=path,
data=content,
ignore_backoff=True,
)
@log_function
def join_group(self, destination, group_id, user_id, content):
"""Attempts to join a group
"""
path = _create_path(PREFIX, "/groups/%s/users/%s/join", group_id, user_id)
return self.client.post_json( return self.client.post_json(
destination=destination, destination=destination,
@ -617,7 +640,7 @@ class TransportLayerClient(object):
def invite_to_group(self, destination, group_id, user_id, requester_user_id, content): def invite_to_group(self, destination, group_id, user_id, requester_user_id, content):
"""Invite a user to a group """Invite a user to a group
""" """
path = PREFIX + "/groups/%s/users/%s/invite" % (group_id, user_id) path = _create_path(PREFIX, "/groups/%s/users/%s/invite", group_id, user_id)
return self.client.post_json( return self.client.post_json(
destination=destination, destination=destination,
@ -633,7 +656,7 @@ class TransportLayerClient(object):
invited. invited.
""" """
path = PREFIX + "/groups/local/%s/users/%s/invite" % (group_id, user_id) path = _create_path(PREFIX, "/groups/local/%s/users/%s/invite", group_id, user_id)
return self.client.post_json( return self.client.post_json(
destination=destination, destination=destination,
@ -647,7 +670,7 @@ class TransportLayerClient(object):
user_id, content): user_id, content):
"""Remove a user fron a group """Remove a user fron a group
""" """
path = PREFIX + "/groups/%s/users/%s/remove" % (group_id, user_id) path = _create_path(PREFIX, "/groups/%s/users/%s/remove", group_id, user_id)
return self.client.post_json( return self.client.post_json(
destination=destination, destination=destination,
@ -664,7 +687,7 @@ class TransportLayerClient(object):
kicked from the group. kicked from the group.
""" """
path = PREFIX + "/groups/local/%s/users/%s/remove" % (group_id, user_id) path = _create_path(PREFIX, "/groups/local/%s/users/%s/remove", group_id, user_id)
return self.client.post_json( return self.client.post_json(
destination=destination, destination=destination,
@ -679,7 +702,7 @@ class TransportLayerClient(object):
the attestations the attestations
""" """
path = PREFIX + "/groups/%s/renew_attestation/%s" % (group_id, user_id) path = _create_path(PREFIX, "/groups/%s/renew_attestation/%s", group_id, user_id)
return self.client.post_json( return self.client.post_json(
destination=destination, destination=destination,
@ -694,11 +717,12 @@ class TransportLayerClient(object):
"""Update a room entry in a group summary """Update a room entry in a group summary
""" """
if category_id: if category_id:
path = PREFIX + "/groups/%s/summary/categories/%s/rooms/%s" % ( path = _create_path(
PREFIX, "/groups/%s/summary/categories/%s/rooms/%s",
group_id, category_id, room_id, group_id, category_id, room_id,
) )
else: else:
path = PREFIX + "/groups/%s/summary/rooms/%s" % (group_id, room_id,) path = _create_path(PREFIX, "/groups/%s/summary/rooms/%s", group_id, room_id,)
return self.client.post_json( return self.client.post_json(
destination=destination, destination=destination,
@ -714,11 +738,12 @@ class TransportLayerClient(object):
"""Delete a room entry in a group summary """Delete a room entry in a group summary
""" """
if category_id: if category_id:
path = PREFIX + "/groups/%s/summary/categories/%s/rooms/%s" % ( path = _create_path(
PREFIX + "/groups/%s/summary/categories/%s/rooms/%s",
group_id, category_id, room_id, group_id, category_id, room_id,
) )
else: else:
path = PREFIX + "/groups/%s/summary/rooms/%s" % (group_id, room_id,) path = _create_path(PREFIX, "/groups/%s/summary/rooms/%s", group_id, room_id,)
return self.client.delete_json( return self.client.delete_json(
destination=destination, destination=destination,
@ -731,7 +756,7 @@ class TransportLayerClient(object):
def get_group_categories(self, destination, group_id, requester_user_id): def get_group_categories(self, destination, group_id, requester_user_id):
"""Get all categories in a group """Get all categories in a group
""" """
path = PREFIX + "/groups/%s/categories" % (group_id,) path = _create_path(PREFIX, "/groups/%s/categories", group_id,)
return self.client.get_json( return self.client.get_json(
destination=destination, destination=destination,
@ -744,7 +769,7 @@ class TransportLayerClient(object):
def get_group_category(self, destination, group_id, requester_user_id, category_id): def get_group_category(self, destination, group_id, requester_user_id, category_id):
"""Get category info in a group """Get category info in a group
""" """
path = PREFIX + "/groups/%s/categories/%s" % (group_id, category_id,) path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,)
return self.client.get_json( return self.client.get_json(
destination=destination, destination=destination,
@ -758,7 +783,7 @@ class TransportLayerClient(object):
content): content):
"""Update a category in a group """Update a category in a group
""" """
path = PREFIX + "/groups/%s/categories/%s" % (group_id, category_id,) path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,)
return self.client.post_json( return self.client.post_json(
destination=destination, destination=destination,
@ -773,7 +798,7 @@ class TransportLayerClient(object):
category_id): category_id):
"""Delete a category in a group """Delete a category in a group
""" """
path = PREFIX + "/groups/%s/categories/%s" % (group_id, category_id,) path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,)
return self.client.delete_json( return self.client.delete_json(
destination=destination, destination=destination,
@ -786,7 +811,7 @@ class TransportLayerClient(object):
def get_group_roles(self, destination, group_id, requester_user_id): def get_group_roles(self, destination, group_id, requester_user_id):
"""Get all roles in a group """Get all roles in a group
""" """
path = PREFIX + "/groups/%s/roles" % (group_id,) path = _create_path(PREFIX, "/groups/%s/roles", group_id,)
return self.client.get_json( return self.client.get_json(
destination=destination, destination=destination,
@ -799,7 +824,7 @@ class TransportLayerClient(object):
def get_group_role(self, destination, group_id, requester_user_id, role_id): def get_group_role(self, destination, group_id, requester_user_id, role_id):
"""Get a roles info """Get a roles info
""" """
path = PREFIX + "/groups/%s/roles/%s" % (group_id, role_id,) path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,)
return self.client.get_json( return self.client.get_json(
destination=destination, destination=destination,
@ -813,7 +838,7 @@ class TransportLayerClient(object):
content): content):
"""Update a role in a group """Update a role in a group
""" """
path = PREFIX + "/groups/%s/roles/%s" % (group_id, role_id,) path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,)
return self.client.post_json( return self.client.post_json(
destination=destination, destination=destination,
@ -827,7 +852,7 @@ class TransportLayerClient(object):
def delete_group_role(self, destination, group_id, requester_user_id, role_id): def delete_group_role(self, destination, group_id, requester_user_id, role_id):
"""Delete a role in a group """Delete a role in a group
""" """
path = PREFIX + "/groups/%s/roles/%s" % (group_id, role_id,) path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,)
return self.client.delete_json( return self.client.delete_json(
destination=destination, destination=destination,
@ -842,11 +867,12 @@ class TransportLayerClient(object):
"""Update a users entry in a group """Update a users entry in a group
""" """
if role_id: if role_id:
path = PREFIX + "/groups/%s/summary/roles/%s/users/%s" % ( path = _create_path(
PREFIX, "/groups/%s/summary/roles/%s/users/%s",
group_id, role_id, user_id, group_id, role_id, user_id,
) )
else: else:
path = PREFIX + "/groups/%s/summary/users/%s" % (group_id, user_id,) path = _create_path(PREFIX, "/groups/%s/summary/users/%s", group_id, user_id,)
return self.client.post_json( return self.client.post_json(
destination=destination, destination=destination,
@ -856,17 +882,33 @@ class TransportLayerClient(object):
ignore_backoff=True, ignore_backoff=True,
) )
@log_function
def set_group_join_policy(self, destination, group_id, requester_user_id,
content):
"""Sets the join policy for a group
"""
path = _create_path(PREFIX, "/groups/%s/settings/m.join_policy", group_id,)
return self.client.put_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
data=content,
ignore_backoff=True,
)
@log_function @log_function
def delete_group_summary_user(self, destination, group_id, requester_user_id, def delete_group_summary_user(self, destination, group_id, requester_user_id,
user_id, role_id): user_id, role_id):
"""Delete a users entry in a group """Delete a users entry in a group
""" """
if role_id: if role_id:
path = PREFIX + "/groups/%s/summary/roles/%s/users/%s" % ( path = _create_path(
PREFIX, "/groups/%s/summary/roles/%s/users/%s",
group_id, role_id, user_id, group_id, role_id, user_id,
) )
else: else:
path = PREFIX + "/groups/%s/summary/users/%s" % (group_id, user_id,) path = _create_path(PREFIX, "/groups/%s/summary/users/%s", group_id, user_id,)
return self.client.delete_json( return self.client.delete_json(
destination=destination, destination=destination,
@ -889,3 +931,22 @@ class TransportLayerClient(object):
data=content, data=content,
ignore_backoff=True, ignore_backoff=True,
) )
def _create_path(prefix, path, *args):
"""Creates a path from the prefix, path template and args. Ensures that
all args are url encoded.
Example:
_create_path(PREFIX, "/event/%s/", event_id)
Args:
prefix (str)
path (str): String template for the path
args: ([str]): Args to insert into path. Each arg will be url encoded
Returns:
str
"""
return prefix + path % tuple(urllib.quote(arg, "") for arg in args)

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -802,6 +803,23 @@ class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
defer.returnValue((200, new_content)) defer.returnValue((200, new_content))
class FederationGroupsJoinServlet(BaseFederationServlet):
"""Attempt to join a group
"""
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join$"
@defer.inlineCallbacks
def on_POST(self, origin, content, query, group_id, user_id):
if get_domain_from_id(user_id) != origin:
raise SynapseError(403, "user_id doesn't match origin")
new_content = yield self.handler.join_group(
group_id, user_id, content,
)
defer.returnValue((200, new_content))
class FederationGroupsRemoveUserServlet(BaseFederationServlet): class FederationGroupsRemoveUserServlet(BaseFederationServlet):
"""Leave or kick a user from the group """Leave or kick a user from the group
""" """
@ -1124,6 +1142,24 @@ class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
defer.returnValue((200, resp)) defer.returnValue((200, resp))
class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet):
"""Sets whether a group is joinable without an invite or knock
"""
PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy$"
@defer.inlineCallbacks
def on_PUT(self, origin, content, query, group_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
new_content = yield self.handler.set_group_join_policy(
group_id, requester_user_id, content
)
defer.returnValue((200, new_content))
FEDERATION_SERVLET_CLASSES = ( FEDERATION_SERVLET_CLASSES = (
FederationSendServlet, FederationSendServlet,
FederationPullServlet, FederationPullServlet,
@ -1163,6 +1199,7 @@ GROUP_SERVER_SERVLET_CLASSES = (
FederationGroupsInvitedUsersServlet, FederationGroupsInvitedUsersServlet,
FederationGroupsInviteServlet, FederationGroupsInviteServlet,
FederationGroupsAcceptInviteServlet, FederationGroupsAcceptInviteServlet,
FederationGroupsJoinServlet,
FederationGroupsRemoveUserServlet, FederationGroupsRemoveUserServlet,
FederationGroupsSummaryRoomsServlet, FederationGroupsSummaryRoomsServlet,
FederationGroupsCategoriesServlet, FederationGroupsCategoriesServlet,
@ -1172,6 +1209,7 @@ GROUP_SERVER_SERVLET_CLASSES = (
FederationGroupsSummaryUsersServlet, FederationGroupsSummaryUsersServlet,
FederationGroupsAddRoomsServlet, FederationGroupsAddRoomsServlet,
FederationGroupsAddRoomsConfigServlet, FederationGroupsAddRoomsConfigServlet,
FederationGroupsSettingJoinPolicyServlet,
) )

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd # Copyright 2017 Vector Creations Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -205,6 +206,28 @@ class GroupsServerHandler(object):
defer.returnValue({}) defer.returnValue({})
@defer.inlineCallbacks
def set_group_join_policy(self, group_id, requester_user_id, content):
"""Sets the group join policy.
Currently supported policies are:
- "invite": an invite must be received and accepted in order to join.
- "open": anyone can join.
"""
yield self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
join_policy = _parse_join_policy_from_contents(content)
if join_policy is None:
raise SynapseError(
400, "No value specified for 'm.join_policy'"
)
yield self.store.set_group_join_policy(group_id, join_policy=join_policy)
defer.returnValue({})
@defer.inlineCallbacks @defer.inlineCallbacks
def get_group_categories(self, group_id, requester_user_id): def get_group_categories(self, group_id, requester_user_id):
"""Get all categories in a group (as seen by user) """Get all categories in a group (as seen by user)
@ -381,9 +404,16 @@ class GroupsServerHandler(object):
yield self.check_group_is_ours(group_id, requester_user_id) yield self.check_group_is_ours(group_id, requester_user_id)
group_description = yield self.store.get_group(group_id) group = yield self.store.get_group(group_id)
if group:
cols = [
"name", "short_description", "long_description",
"avatar_url", "is_public",
]
group_description = {key: group[key] for key in cols}
group_description["is_openly_joinable"] = group["join_policy"] == "open"
if group_description:
defer.returnValue(group_description) defer.returnValue(group_description)
else: else:
raise SynapseError(404, "Unknown group") raise SynapseError(404, "Unknown group")
@ -654,6 +684,40 @@ class GroupsServerHandler(object):
else: else:
raise SynapseError(502, "Unknown state returned by HS") raise SynapseError(502, "Unknown state returned by HS")
@defer.inlineCallbacks
def _add_user(self, group_id, user_id, content):
"""Add a user to a group based on a content dict.
See accept_invite, join_group.
"""
if not self.hs.is_mine_id(user_id):
local_attestation = self.attestations.create_attestation(
group_id, user_id,
)
remote_attestation = content["attestation"]
yield self.attestations.verify_attestation(
remote_attestation,
user_id=user_id,
group_id=group_id,
)
else:
local_attestation = None
remote_attestation = None
is_public = _parse_visibility_from_contents(content)
yield self.store.add_user_to_group(
group_id, user_id,
is_admin=False,
is_public=is_public,
local_attestation=local_attestation,
remote_attestation=remote_attestation,
)
defer.returnValue(local_attestation)
@defer.inlineCallbacks @defer.inlineCallbacks
def accept_invite(self, group_id, requester_user_id, content): def accept_invite(self, group_id, requester_user_id, content):
"""User tries to accept an invite to the group. """User tries to accept an invite to the group.
@ -670,30 +734,27 @@ class GroupsServerHandler(object):
if not is_invited: if not is_invited:
raise SynapseError(403, "User not invited to group") raise SynapseError(403, "User not invited to group")
if not self.hs.is_mine_id(requester_user_id): local_attestation = yield self._add_user(group_id, requester_user_id, content)
local_attestation = self.attestations.create_attestation(
group_id, requester_user_id,
)
remote_attestation = content["attestation"]
yield self.attestations.verify_attestation( defer.returnValue({
remote_attestation, "state": "join",
user_id=requester_user_id, "attestation": local_attestation,
group_id=group_id, })
)
else:
local_attestation = None
remote_attestation = None
is_public = _parse_visibility_from_contents(content) @defer.inlineCallbacks
def join_group(self, group_id, requester_user_id, content):
"""User tries to join the group.
yield self.store.add_user_to_group( This will error if the group requires an invite/knock to join
group_id, requester_user_id, """
is_admin=False,
is_public=is_public, group_info = yield self.check_group_is_ours(
local_attestation=local_attestation, group_id, requester_user_id, and_exists=True
remote_attestation=remote_attestation,
) )
if group_info['join_policy'] != "open":
raise SynapseError(403, "Group is not publicly joinable")
local_attestation = yield self._add_user(group_id, requester_user_id, content)
defer.returnValue({ defer.returnValue({
"state": "join", "state": "join",
@ -835,6 +896,31 @@ class GroupsServerHandler(object):
}) })
def _parse_join_policy_from_contents(content):
"""Given a content for a request, return the specified join policy or None
"""
join_policy_dict = content.get("m.join_policy")
if join_policy_dict:
return _parse_join_policy_dict(join_policy_dict)
else:
return None
def _parse_join_policy_dict(join_policy_dict):
"""Given a dict for the "m.join_policy" config return the join policy specified
"""
join_policy_type = join_policy_dict.get("type")
if not join_policy_type:
return "invite"
if join_policy_type not in ("invite", "open"):
raise SynapseError(
400, "Synapse only supports 'invite'/'open' join rule"
)
return join_policy_type
def _parse_visibility_from_contents(content): def _parse_visibility_from_contents(content):
"""Given a content for a request parse out whether the entity should be """Given a content for a request parse out whether the entity should be
public or not public or not

View file

@ -155,7 +155,7 @@ class DeviceHandler(BaseHandler):
try: try:
yield self.store.delete_device(user_id, device_id) yield self.store.delete_device(user_id, device_id)
except errors.StoreError, e: except errors.StoreError as e:
if e.code == 404: if e.code == 404:
# no match # no match
pass pass
@ -204,7 +204,7 @@ class DeviceHandler(BaseHandler):
try: try:
yield self.store.delete_devices(user_id, device_ids) yield self.store.delete_devices(user_id, device_ids)
except errors.StoreError, e: except errors.StoreError as e:
if e.code == 404: if e.code == 404:
# no match # no match
pass pass
@ -243,7 +243,7 @@ class DeviceHandler(BaseHandler):
new_display_name=content.get("display_name") new_display_name=content.get("display_name")
) )
yield self.notify_device_update(user_id, [device_id]) yield self.notify_device_update(user_id, [device_id])
except errors.StoreError, e: except errors.StoreError as e:
if e.code == 404: if e.code == 404:
raise errors.NotFoundError() raise errors.NotFoundError()
else: else:

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd # Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -134,23 +135,8 @@ class E2eKeysHandler(object):
if user_id in destination_query: if user_id in destination_query:
results[user_id] = keys results[user_id] = keys
except CodeMessageException as e:
failures[destination] = {
"status": e.code, "message": e.message
}
except NotRetryingDestination as e:
failures[destination] = {
"status": 503, "message": "Not ready for retry",
}
except FederationDeniedError as e:
failures[destination] = {
"status": 403, "message": "Federation Denied",
}
except Exception as e: except Exception as e:
# include ConnectionRefused and other errors failures[destination] = _exception_to_failure(e)
failures[destination] = {
"status": 503, "message": e.message
}
yield make_deferred_yieldable(defer.gatherResults([ yield make_deferred_yieldable(defer.gatherResults([
preserve_fn(do_remote_query)(destination) preserve_fn(do_remote_query)(destination)
@ -252,19 +238,8 @@ class E2eKeysHandler(object):
for user_id, keys in remote_result["one_time_keys"].items(): for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys: if user_id in device_keys:
json_result[user_id] = keys json_result[user_id] = keys
except CodeMessageException as e:
failures[destination] = {
"status": e.code, "message": e.message
}
except NotRetryingDestination as e:
failures[destination] = {
"status": 503, "message": "Not ready for retry",
}
except Exception as e: except Exception as e:
# include ConnectionRefused and other errors failures[destination] = _exception_to_failure(e)
failures[destination] = {
"status": 503, "message": e.message
}
yield make_deferred_yieldable(defer.gatherResults([ yield make_deferred_yieldable(defer.gatherResults([
preserve_fn(claim_client_keys)(destination) preserve_fn(claim_client_keys)(destination)
@ -362,6 +337,31 @@ class E2eKeysHandler(object):
) )
def _exception_to_failure(e):
if isinstance(e, CodeMessageException):
return {
"status": e.code, "message": e.message,
}
if isinstance(e, NotRetryingDestination):
return {
"status": 503, "message": "Not ready for retry",
}
if isinstance(e, FederationDeniedError):
return {
"status": 403, "message": "Federation Denied",
}
# include ConnectionRefused and other errors
#
# Note that some Exceptions (notably twisted's ResponseFailed etc) don't
# give a string for e.message, which simplejson then fails to serialize.
return {
"status": 503, "message": str(e.message),
}
def _one_time_keys_match(old_key_json, new_key): def _one_time_keys_match(old_key_json, new_key):
old_key = json.loads(old_key_json) old_key = json.loads(old_key_json)

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd # Copyright 2017 Vector Creations Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -90,6 +91,8 @@ class GroupsLocalHandler(object):
get_group_role = _create_rerouter("get_group_role") get_group_role = _create_rerouter("get_group_role")
get_group_roles = _create_rerouter("get_group_roles") get_group_roles = _create_rerouter("get_group_roles")
set_group_join_policy = _create_rerouter("set_group_join_policy")
@defer.inlineCallbacks @defer.inlineCallbacks
def get_group_summary(self, group_id, requester_user_id): def get_group_summary(self, group_id, requester_user_id):
"""Get the group summary for a group. """Get the group summary for a group.
@ -226,7 +229,45 @@ class GroupsLocalHandler(object):
def join_group(self, group_id, user_id, content): def join_group(self, group_id, user_id, content):
"""Request to join a group """Request to join a group
""" """
raise NotImplementedError() # TODO if self.is_mine_id(group_id):
yield self.groups_server_handler.join_group(
group_id, user_id, content
)
local_attestation = None
remote_attestation = None
else:
local_attestation = self.attestations.create_attestation(group_id, user_id)
content["attestation"] = local_attestation
res = yield self.transport_client.join_group(
get_domain_from_id(group_id), group_id, user_id, content,
)
remote_attestation = res["attestation"]
yield self.attestations.verify_attestation(
remote_attestation,
group_id=group_id,
user_id=user_id,
server_name=get_domain_from_id(group_id),
)
# TODO: Check that the group is public and we're being added publically
is_publicised = content.get("publicise", False)
token = yield self.store.register_user_group_membership(
group_id, user_id,
membership="join",
is_admin=False,
local_attestation=local_attestation,
remote_attestation=remote_attestation,
is_publicised=is_publicised,
)
self.notifier.on_new_event(
"groups_key", token, users=[user_id],
)
defer.returnValue({})
@defer.inlineCallbacks @defer.inlineCallbacks
def accept_invite(self, group_id, user_id, content): def accept_invite(self, group_id, user_id, content):

View file

@ -15,6 +15,11 @@
# limitations under the License. # limitations under the License.
"""Utilities for interacting with Identity Servers""" """Utilities for interacting with Identity Servers"""
import logging
import simplejson as json
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import ( from synapse.api.errors import (
@ -24,9 +29,6 @@ from ._base import BaseHandler
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
import json
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View file

@ -27,7 +27,7 @@ from synapse.types import (
from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter
from synapse.util.logcontext import preserve_fn, run_in_background from synapse.util.logcontext import preserve_fn, run_in_background
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.util.frozenutils import unfreeze from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
from synapse.replication.http.send_event import send_event_to_master from synapse.replication.http.send_event import send_event_to_master
@ -678,7 +678,7 @@ class EventCreationHandler(object):
# Ensure that we can round trip before trying to persist in db # Ensure that we can round trip before trying to persist in db
try: try:
dump = simplejson.dumps(unfreeze(event.content)) dump = frozendict_json_encoder.encode(event.content)
simplejson.loads(dump) simplejson.loads(dump)
except Exception: except Exception:
logger.exception("Failed to encode content: %r", event.content) logger.exception("Failed to encode content: %r", event.content)

View file

@ -24,7 +24,7 @@ from synapse.api.errors import (
from synapse.http.client import CaptchaServerHttpClient from synapse.http.client import CaptchaServerHttpClient
from synapse import types from synapse import types
from synapse.types import UserID from synapse.types import UserID
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor, Linearizer
from synapse.util.threepids import check_3pid_allowed from synapse.util.threepids import check_3pid_allowed
from ._base import BaseHandler from ._base import BaseHandler
@ -46,6 +46,10 @@ class RegistrationHandler(BaseHandler):
self.macaroon_gen = hs.get_macaroon_generator() self.macaroon_gen = hs.get_macaroon_generator()
self._generate_user_id_linearizer = Linearizer(
name="_generate_user_id_linearizer",
)
@defer.inlineCallbacks @defer.inlineCallbacks
def check_username(self, localpart, guest_access_token=None, def check_username(self, localpart, guest_access_token=None,
assigned_user_id=None): assigned_user_id=None):
@ -344,6 +348,8 @@ class RegistrationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _generate_user_id(self, reseed=False): def _generate_user_id(self, reseed=False):
if reseed or self._next_generated_user_id is None:
with (yield self._generate_user_id_linearizer.queue(())):
if reseed or self._next_generated_user_id is None: if reseed or self._next_generated_user_id is None:
self._next_generated_user_id = ( self._next_generated_user_id = (
yield self.store.find_next_generated_user_id_localpart() yield self.store.find_next_generated_user_id_localpart()

View file

@ -286,7 +286,8 @@ class MatrixFederationHttpClient(object):
headers_dict[b"Authorization"] = auth_headers headers_dict[b"Authorization"] = auth_headers
@defer.inlineCallbacks @defer.inlineCallbacks
def put_json(self, destination, path, data={}, json_data_callback=None, def put_json(self, destination, path, args={}, data={},
json_data_callback=None,
long_retries=False, timeout=None, long_retries=False, timeout=None,
ignore_backoff=False, ignore_backoff=False,
backoff_on_404=False): backoff_on_404=False):
@ -296,6 +297,7 @@ class MatrixFederationHttpClient(object):
destination (str): The remote server to send the HTTP request destination (str): The remote server to send the HTTP request
to. to.
path (str): The HTTP path. path (str): The HTTP path.
args (dict): query params
data (dict): A dict containing the data that will be used as data (dict): A dict containing the data that will be used as
the request body. This will be encoded as JSON. the request body. This will be encoded as JSON.
json_data_callback (callable): A callable returning the dict to json_data_callback (callable): A callable returning the dict to
@ -342,6 +344,7 @@ class MatrixFederationHttpClient(object):
path, path,
body_callback=body_callback, body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]}, headers_dict={"Content-Type": ["application/json"]},
query_bytes=encode_query_args(args),
long_retries=long_retries, long_retries=long_retries,
timeout=timeout, timeout=timeout,
ignore_backoff=ignore_backoff, ignore_backoff=ignore_backoff,
@ -373,6 +376,7 @@ class MatrixFederationHttpClient(object):
giving up. None indicates no timeout. giving up. None indicates no timeout.
ignore_backoff (bool): true to ignore the historical backoff data and ignore_backoff (bool): true to ignore the historical backoff data and
try the request anyway. try the request anyway.
args (dict): query params
Returns: Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body. will be the decoded JSON body.

View file

@ -113,6 +113,11 @@ response_db_sched_duration = metrics.register_counter(
"response_db_sched_duration_seconds", labels=["method", "servlet", "tag"] "response_db_sched_duration_seconds", labels=["method", "servlet", "tag"]
) )
# size in bytes of the response written
response_size = metrics.register_counter(
"response_size", labels=["method", "servlet", "tag"]
)
_next_request_id = 0 _next_request_id = 0
@ -426,6 +431,8 @@ class RequestMetrics(object):
context.db_sched_duration_ms / 1000., request.method, self.name, tag context.db_sched_duration_ms / 1000., request.method, self.name, tag
) )
response_size.inc_by(request.sentLength, request.method, self.name, tag)
class RootRedirect(resource.Resource): class RootRedirect(resource.Resource):
"""Redirects the root '/' path to another path.""" """Redirects the root '/' path to another path."""
@ -488,6 +495,7 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
request.setHeader(b"Content-Type", b"application/json") request.setHeader(b"Content-Type", b"application/json")
request.setHeader(b"Server", version_string) request.setHeader(b"Server", version_string)
request.setHeader(b"Content-Length", b"%d" % (len(json_bytes),)) request.setHeader(b"Content-Length", b"%d" % (len(json_bytes),))
request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate")
if send_cors: if send_cors:
set_cors_headers(request) set_cors_headers(request)

View file

@ -34,7 +34,6 @@ REQUIREMENTS = {
"bcrypt": ["bcrypt>=3.1.0"], "bcrypt": ["bcrypt>=3.1.0"],
"pillow": ["PIL"], "pillow": ["PIL"],
"pydenticon": ["pydenticon"], "pydenticon": ["pydenticon"],
"ujson": ["ujson"],
"blist": ["blist"], "blist": ["blist"],
"pysaml2>=3.0.0": ["saml2>=3.0.0"], "pysaml2>=3.0.0": ["saml2>=3.0.0"],
"pymacaroons-pynacl": ["pymacaroons"], "pymacaroons-pynacl": ["pymacaroons"],

View file

@ -24,6 +24,8 @@ import simplejson
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_json_encoder = simplejson.JSONEncoder(namedtuple_as_object=False)
class Command(object): class Command(object):
"""The base command class. """The base command class.
@ -107,7 +109,7 @@ class RdataCommand(Command):
return " ".join(( return " ".join((
self.stream_name, self.stream_name,
str(self.token) if self.token is not None else "batch", str(self.token) if self.token is not None else "batch",
simplejson.dumps(self.row, namedtuple_as_object=False), _json_encoder.encode(self.row),
)) ))
@ -302,7 +304,7 @@ class InvalidateCacheCommand(Command):
def to_line(self): def to_line(self):
return " ".join(( return " ".join((
self.cache_func, simplejson.dumps(self.keys, namedtuple_as_object=False) self.cache_func, _json_encoder.encode(self.keys),
)) ))
@ -334,7 +336,7 @@ class UserIpCommand(Command):
) )
def to_line(self): def to_line(self):
return self.user_id + " " + simplejson.dumps(( return self.user_id + " " + _json_encoder.encode((
self.access_token, self.ip, self.user_agent, self.device_id, self.access_token, self.ip, self.user_agent, self.device_id,
self.last_seen, self.last_seen,
)) ))

View file

@ -655,7 +655,12 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
content=event_content, content=event_content,
) )
defer.returnValue((200, {})) return_value = {}
if membership_action == "join":
return_value["room_id"] = room_id
defer.returnValue((200, return_value))
def _has_3pid_invite_keys(self, content): def _has_3pid_invite_keys(self, content):
for key in {"id_server", "medium", "address"}: for key in {"id_server", "medium", "address"}:

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd # Copyright 2017 Vector Creations Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -401,6 +402,32 @@ class GroupInvitedUsersServlet(RestServlet):
defer.returnValue((200, result)) defer.returnValue((200, result))
class GroupSettingJoinPolicyServlet(RestServlet):
"""Set group join policy
"""
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$")
def __init__(self, hs):
super(GroupSettingJoinPolicyServlet, self).__init__()
self.auth = hs.get_auth()
self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks
def on_PUT(self, request, group_id):
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
result = yield self.groups_handler.set_group_join_policy(
group_id,
requester_user_id,
content,
)
defer.returnValue((200, result))
class GroupCreateServlet(RestServlet): class GroupCreateServlet(RestServlet):
"""Create a group """Create a group
""" """
@ -738,6 +765,7 @@ def register_servlets(hs, http_server):
GroupInvitedUsersServlet(hs).register(http_server) GroupInvitedUsersServlet(hs).register(http_server)
GroupUsersServlet(hs).register(http_server) GroupUsersServlet(hs).register(http_server)
GroupRoomServlet(hs).register(http_server) GroupRoomServlet(hs).register(http_server)
GroupSettingJoinPolicyServlet(hs).register(http_server)
GroupCreateServlet(hs).register(http_server) GroupCreateServlet(hs).register(http_server)
GroupAdminRoomsServlet(hs).register(http_server) GroupAdminRoomsServlet(hs).register(http_server)
GroupAdminRoomsConfigServlet(hs).register(http_server) GroupAdminRoomsConfigServlet(hs).register(http_server)

View file

@ -132,7 +132,7 @@ class StateHandler(object):
state_map = yield self.store.get_events(state.values(), get_prev_content=False) state_map = yield self.store.get_events(state.values(), get_prev_content=False)
state = { state = {
key: state_map[e_id] for key, e_id in state.items() if e_id in state_map key: state_map[e_id] for key, e_id in state.iteritems() if e_id in state_map
} }
defer.returnValue(state) defer.returnValue(state)
@ -378,7 +378,7 @@ class StateHandler(object):
new_state = resolve_events_with_state_map(state_set_ids, state_map) new_state = resolve_events_with_state_map(state_set_ids, state_map)
new_state = { new_state = {
key: state_map[ev_id] for key, ev_id in new_state.items() key: state_map[ev_id] for key, ev_id in new_state.iteritems()
} }
return new_state return new_state
@ -458,15 +458,15 @@ class StateResolutionHandler(object):
# build a map from state key to the event_ids which set that state. # build a map from state key to the event_ids which set that state.
# dict[(str, str), set[str]) # dict[(str, str), set[str])
state = {} state = {}
for st in state_groups_ids.values(): for st in state_groups_ids.itervalues():
for key, e_id in st.items(): for key, e_id in st.iteritems():
state.setdefault(key, set()).add(e_id) state.setdefault(key, set()).add(e_id)
# build a map from state key to the event_ids which set that state, # build a map from state key to the event_ids which set that state,
# including only those where there are state keys in conflict. # including only those where there are state keys in conflict.
conflicted_state = { conflicted_state = {
k: list(v) k: list(v)
for k, v in state.items() for k, v in state.iteritems()
if len(v) > 1 if len(v) > 1
} }
@ -480,16 +480,17 @@ class StateResolutionHandler(object):
) )
else: else:
new_state = { new_state = {
key: e_ids.pop() for key, e_ids in state.items() key: e_ids.pop() for key, e_ids in state.iteritems()
} }
with Measure(self.clock, "state.create_group_ids"):
# if the new state matches any of the input state groups, we can # if the new state matches any of the input state groups, we can
# use that state group again. Otherwise we will generate a state_id # use that state group again. Otherwise we will generate a state_id
# which will be used as a cache key for future resolutions, but # which will be used as a cache key for future resolutions, but
# not get persisted. # not get persisted.
state_group = None state_group = None
new_state_event_ids = frozenset(new_state.values()) new_state_event_ids = frozenset(new_state.itervalues())
for sg, events in state_groups_ids.items(): for sg, events in state_groups_ids.iteritems():
if new_state_event_ids == frozenset(e_id for e_id in events): if new_state_event_ids == frozenset(e_id for e_id in events):
state_group = sg state_group = sg
break break
@ -702,7 +703,7 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_
auth_events = { auth_events = {
key: state_map[ev_id] key: state_map[ev_id]
for key, ev_id in auth_event_ids.items() for key, ev_id in auth_event_ids.iteritems()
if ev_id in state_map if ev_id in state_map
} }
@ -740,7 +741,7 @@ def _resolve_state_events(conflicted_state, auth_events):
auth_events.update(resolved_state) auth_events.update(resolved_state)
for key, events in conflicted_state.items(): for key, events in conflicted_state.iteritems():
if key[0] == EventTypes.JoinRules: if key[0] == EventTypes.JoinRules:
logger.debug("Resolving conflicted join rules %r", events) logger.debug("Resolving conflicted join rules %r", events)
resolved_state[key] = _resolve_auth_events( resolved_state[key] = _resolve_auth_events(
@ -750,7 +751,7 @@ def _resolve_state_events(conflicted_state, auth_events):
auth_events.update(resolved_state) auth_events.update(resolved_state)
for key, events in conflicted_state.items(): for key, events in conflicted_state.iteritems():
if key[0] == EventTypes.Member: if key[0] == EventTypes.Member:
logger.debug("Resolving conflicted member lists %r", events) logger.debug("Resolving conflicted member lists %r", events)
resolved_state[key] = _resolve_auth_events( resolved_state[key] = _resolve_auth_events(
@ -760,7 +761,7 @@ def _resolve_state_events(conflicted_state, auth_events):
auth_events.update(resolved_state) auth_events.update(resolved_state)
for key, events in conflicted_state.items(): for key, events in conflicted_state.iteritems():
if key not in resolved_state: if key not in resolved_state:
logger.debug("Resolving conflicted state %r:%r", key, events) logger.debug("Resolving conflicted state %r:%r", key, events)
resolved_state[key] = _resolve_normal_events( resolved_state[key] = _resolve_normal_events(

View file

@ -14,8 +14,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.storage.devices import DeviceStore from synapse.storage.devices import DeviceStore
from .appservice import ( from .appservice import (
ApplicationServiceStore, ApplicationServiceTransactionStore ApplicationServiceStore, ApplicationServiceTransactionStore
@ -244,13 +242,12 @@ class DataStore(RoomMemberStore, RoomStore,
return [UserPresenceState(**row) for row in rows] return [UserPresenceState(**row) for row in rows]
@defer.inlineCallbacks
def count_daily_users(self): def count_daily_users(self):
""" """
Counts the number of users who used this homeserver in the last 24 hours. Counts the number of users who used this homeserver in the last 24 hours.
""" """
def _count_users(txn): def _count_users(txn):
yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24), yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
sql = """ sql = """
SELECT COALESCE(count(*), 0) FROM ( SELECT COALESCE(count(*), 0) FROM (
@ -264,8 +261,91 @@ class DataStore(RoomMemberStore, RoomStore,
count, = txn.fetchone() count, = txn.fetchone()
return count return count
ret = yield self.runInteraction("count_users", _count_users) return self.runInteraction("count_users", _count_users)
defer.returnValue(ret)
def count_r30_users(self):
"""
Counts the number of 30 day retained users, defined as:-
* Users who have created their accounts more than 30 days
* Where last seen at most 30 days ago
* Where account creation and last_seen are > 30 days
Returns counts globaly for a given user as well as breaking
by platform
"""
def _count_r30_users(txn):
thirty_days_in_secs = 86400 * 30
now = int(self._clock.time())
thirty_days_ago_in_secs = now - thirty_days_in_secs
sql = """
SELECT platform, COALESCE(count(*), 0) FROM (
SELECT
users.name, platform, users.creation_ts * 1000,
MAX(uip.last_seen)
FROM users
INNER JOIN (
SELECT
user_id,
last_seen,
CASE
WHEN user_agent LIKE '%%Android%%' THEN 'android'
WHEN user_agent LIKE '%%iOS%%' THEN 'ios'
WHEN user_agent LIKE '%%Electron%%' THEN 'electron'
WHEN user_agent LIKE '%%Mozilla%%' THEN 'web'
WHEN user_agent LIKE '%%Gecko%%' THEN 'web'
ELSE 'unknown'
END
AS platform
FROM user_ips
) uip
ON users.name = uip.user_id
AND users.appservice_id is NULL
AND users.creation_ts < ?
AND uip.last_seen/1000 > ?
AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
GROUP BY users.name, platform, users.creation_ts
) u GROUP BY platform
"""
results = {}
txn.execute(sql, (thirty_days_ago_in_secs,
thirty_days_ago_in_secs))
for row in txn:
if row[0] is 'unknown':
pass
results[row[0]] = row[1]
sql = """
SELECT COALESCE(count(*), 0) FROM (
SELECT users.name, users.creation_ts * 1000,
MAX(uip.last_seen)
FROM users
INNER JOIN (
SELECT
user_id,
last_seen
FROM user_ips
) uip
ON users.name = uip.user_id
AND appservice_id is NULL
AND users.creation_ts < ?
AND uip.last_seen/1000 > ?
AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
GROUP BY users.name, users.creation_ts
) u
"""
txn.execute(sql, (thirty_days_ago_in_secs,
thirty_days_ago_in_secs))
count, = txn.fetchone()
results['all'] = count
return results
return self.runInteraction("count_r30_users", _count_r30_users)
def get_users(self): def get_users(self):
"""Function to reterive a list of users in users table. """Function to reterive a list of users in users table.

View file

@ -48,6 +48,13 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
columns=["user_id", "device_id", "last_seen"], columns=["user_id", "device_id", "last_seen"],
) )
self.register_background_index_update(
"user_ips_last_seen_index",
index_name="user_ips_last_seen",
table="user_ips",
columns=["user_id", "last_seen"],
)
# (user_id, access_token, ip) -> (user_agent, device_id, last_seen) # (user_id, access_token, ip) -> (user_agent, device_id, last_seen)
self._batch_row_update = {} self._batch_row_update = {}

View file

@ -14,15 +14,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.storage.events_worker import EventsWorkerStore from collections import OrderedDict, deque, namedtuple
from functools import wraps
import logging
import simplejson as json
from twisted.internet import defer from twisted.internet import defer
from synapse.events import USE_FROZEN_DICTS
from synapse.storage.events_worker import EventsWorkerStore
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.logcontext import ( from synapse.util.logcontext import (
PreserveLoggingContext, make_deferred_yieldable PreserveLoggingContext, make_deferred_yieldable,
) )
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -30,16 +34,8 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
from canonicaljson import encode_canonical_json
from collections import deque, namedtuple, OrderedDict
from functools import wraps
import synapse.metrics import synapse.metrics
import logging
import simplejson as json
# these are only included to make the type annotations work # these are only included to make the type annotations work
from synapse.events import EventBase # noqa: F401 from synapse.events import EventBase # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401
@ -53,12 +49,25 @@ event_counter = metrics.register_counter(
"persisted_events_sep", labels=["type", "origin_type", "origin_entity"] "persisted_events_sep", labels=["type", "origin_type", "origin_entity"]
) )
# The number of times we are recalculating the current state
state_delta_counter = metrics.register_counter(
"state_delta",
)
# The number of times we are recalculating state when there is only a
# single forward extremity
state_delta_single_event_counter = metrics.register_counter(
"state_delta_single_event",
)
# The number of times we are reculating state when we could have resonably
# calculated the delta when we calculated the state for an event we were
# persisting.
state_delta_reuse_delta_counter = metrics.register_counter(
"state_delta_reuse_delta",
)
def encode_json(json_object): def encode_json(json_object):
if USE_FROZEN_DICTS: return frozendict_json_encoder.encode(json_object)
return encode_canonical_json(json_object)
else:
return json.dumps(json_object, ensure_ascii=False)
class _EventPeristenceQueue(object): class _EventPeristenceQueue(object):
@ -368,7 +377,8 @@ class EventsStore(EventsWorkerStore):
room_id, ev_ctx_rm, latest_event_ids room_id, ev_ctx_rm, latest_event_ids
) )
if new_latest_event_ids == set(latest_event_ids): latest_event_ids = set(latest_event_ids)
if new_latest_event_ids == latest_event_ids:
# No change in extremities, so no change in state # No change in extremities, so no change in state
continue continue
@ -389,6 +399,26 @@ class EventsStore(EventsWorkerStore):
if all_single_prev_not_state: if all_single_prev_not_state:
continue continue
state_delta_counter.inc()
if len(new_latest_event_ids) == 1:
state_delta_single_event_counter.inc()
# This is a fairly handwavey check to see if we could
# have guessed what the delta would have been when
# processing one of these events.
# What we're interested in is if the latest extremities
# were the same when we created the event as they are
# now. When this server creates a new event (as opposed
# to receiving it over federation) it will use the
# forward extremities as the prev_events, so we can
# guess this by looking at the prev_events and checking
# if they match the current forward extremities.
for ev, _ in ev_ctx_rm:
prev_event_ids = set(e for e, _ in ev.prev_events)
if latest_event_ids == prev_event_ids:
state_delta_reuse_delta_counter.inc()
break
logger.info( logger.info(
"Calculating state delta for room %s", room_id, "Calculating state delta for room %s", room_id,
) )

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd # Copyright 2017 Vector Creations Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -19,7 +20,7 @@ from synapse.api.errors import SynapseError
from ._base import SQLBaseStore from ._base import SQLBaseStore
import ujson as json import simplejson as json
# The category ID for the "default" category. We don't store as null in the # The category ID for the "default" category. We don't store as null in the
@ -29,6 +30,24 @@ _DEFAULT_ROLE_ID = ""
class GroupServerStore(SQLBaseStore): class GroupServerStore(SQLBaseStore):
def set_group_join_policy(self, group_id, join_policy):
"""Set the join policy of a group.
join_policy can be one of:
* "invite"
* "open"
"""
return self._simple_update_one(
table="groups",
keyvalues={
"group_id": group_id,
},
updatevalues={
"join_policy": join_policy,
},
desc="set_group_join_policy",
)
def get_group(self, group_id): def get_group(self, group_id):
return self._simple_select_one( return self._simple_select_one(
table="groups", table="groups",
@ -36,10 +55,11 @@ class GroupServerStore(SQLBaseStore):
"group_id": group_id, "group_id": group_id,
}, },
retcols=( retcols=(
"name", "short_description", "long_description", "avatar_url", "is_public" "name", "short_description", "long_description",
"avatar_url", "is_public", "join_policy",
), ),
allow_none=True, allow_none=True,
desc="is_user_in_group", desc="get_group",
) )
def get_users_in_group(self, group_id, include_private=False): def get_users_in_group(self, group_id, include_private=False):

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd # Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -25,7 +26,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 47 SCHEMA_VERSION = 48
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))

View file

@ -460,14 +460,12 @@ class RegistrationStore(RegistrationWorkerStore,
""" """
def _find_next_generated_user_id(txn): def _find_next_generated_user_id(txn):
txn.execute("SELECT name FROM users") txn.execute("SELECT name FROM users")
rows = self.cursor_to_dict(txn)
regex = re.compile("^@(\d+):") regex = re.compile("^@(\d+):")
found = set() found = set()
for r in rows: for user_id, in txn:
user_id = r["name"]
match = regex.search(user_id) match = regex.search(user_id)
if match: if match:
found.add(int(match.group(1))) found.add(int(match.group(1)))

View file

@ -594,7 +594,8 @@ class RoomStore(RoomWorkerStore, SearchStore):
while next_token: while next_token:
sql = """ sql = """
SELECT stream_ordering, content FROM events SELECT stream_ordering, json FROM events
JOIN event_json USING (event_id)
WHERE room_id = ? WHERE room_id = ?
AND stream_ordering < ? AND stream_ordering < ?
AND contains_url = ? AND outlier = ? AND contains_url = ? AND outlier = ?
@ -606,8 +607,8 @@ class RoomStore(RoomWorkerStore, SearchStore):
next_token = None next_token = None
for stream_ordering, content_json in txn: for stream_ordering, content_json in txn:
next_token = stream_ordering next_token = stream_ordering
content = json.loads(content_json) event_json = json.loads(content_json)
content = event_json["content"]
content_url = content.get("url") content_url = content.get("url")
thumbnail_url = content.get("info", {}).get("thumbnail_url") thumbnail_url = content.get("info", {}).get("thumbnail_url")

View file

@ -645,8 +645,9 @@ class RoomMemberStore(RoomMemberWorkerStore):
def add_membership_profile_txn(txn): def add_membership_profile_txn(txn):
sql = (""" sql = ("""
SELECT stream_ordering, event_id, events.room_id, content SELECT stream_ordering, event_id, events.room_id, event_json.json
FROM events FROM events
INNER JOIN event_json USING (event_id)
INNER JOIN room_memberships USING (event_id) INNER JOIN room_memberships USING (event_id)
WHERE ? <= stream_ordering AND stream_ordering < ? WHERE ? <= stream_ordering AND stream_ordering < ?
AND type = 'm.room.member' AND type = 'm.room.member'
@ -667,7 +668,8 @@ class RoomMemberStore(RoomMemberWorkerStore):
event_id = row["event_id"] event_id = row["event_id"]
room_id = row["room_id"] room_id = row["room_id"]
try: try:
content = json.loads(row["content"]) event_json = json.loads(row["json"])
content = event_json['content']
except Exception: except Exception:
continue continue

View file

@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import logging import logging
import simplejson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View file

@ -0,0 +1,17 @@
/* Copyright 2018 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
INSERT into background_updates (update_name, progress_json)
VALUES ('user_ips_last_seen_index', '{}');

View file

@ -0,0 +1,22 @@
/* Copyright 2018 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* This isn't a real ENUM because sqlite doesn't support it
* and we use a default of NULL for inserted rows and interpret
* NULL at the python store level as necessary so that existing
* rows are given the correct default policy.
*/
ALTER TABLE groups ADD COLUMN join_policy TEXT NOT NULL DEFAULT 'invite';

View file

@ -75,8 +75,9 @@ class SearchStore(BackgroundUpdateStore):
def reindex_search_txn(txn): def reindex_search_txn(txn):
sql = ( sql = (
"SELECT stream_ordering, event_id, room_id, type, content, " "SELECT stream_ordering, event_id, room_id, type, json, "
" origin_server_ts FROM events" " origin_server_ts FROM events"
" JOIN event_json USING (event_id)"
" WHERE ? <= stream_ordering AND stream_ordering < ?" " WHERE ? <= stream_ordering AND stream_ordering < ?"
" AND (%s)" " AND (%s)"
" ORDER BY stream_ordering DESC" " ORDER BY stream_ordering DESC"
@ -104,7 +105,8 @@ class SearchStore(BackgroundUpdateStore):
stream_ordering = row["stream_ordering"] stream_ordering = row["stream_ordering"]
origin_server_ts = row["origin_server_ts"] origin_server_ts = row["origin_server_ts"]
try: try:
content = json.loads(row["content"]) event_json = json.loads(row["json"])
content = event_json["content"]
except Exception: except Exception:
continue continue

View file

@ -667,7 +667,7 @@ class UserDirectoryStore(SQLBaseStore):
# The array of numbers are the weights for the various part of the # The array of numbers are the weights for the various part of the
# search: (domain, _, display name, localpart) # search: (domain, _, display name, localpart)
sql = """ sql = """
SELECT d.user_id, display_name, avatar_url SELECT d.user_id AS user_id, display_name, avatar_url
FROM user_directory_search FROM user_directory_search
INNER JOIN user_directory AS d USING (user_id) INNER JOIN user_directory AS d USING (user_id)
%s %s
@ -702,7 +702,7 @@ class UserDirectoryStore(SQLBaseStore):
search_query = _parse_query_sqlite(search_term) search_query = _parse_query_sqlite(search_term)
sql = """ sql = """
SELECT d.user_id, display_name, avatar_url SELECT d.user_id AS user_id, display_name, avatar_url
FROM user_directory_search FROM user_directory_search
INNER JOIN user_directory AS d USING (user_id) INNER JOIN user_directory AS d USING (user_id)
%s %s

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -39,12 +40,11 @@ _CacheSentinel = object()
class CacheEntry(object): class CacheEntry(object):
__slots__ = [ __slots__ = [
"deferred", "sequence", "callbacks", "invalidated" "deferred", "callbacks", "invalidated"
] ]
def __init__(self, deferred, sequence, callbacks): def __init__(self, deferred, callbacks):
self.deferred = deferred self.deferred = deferred
self.sequence = sequence
self.callbacks = set(callbacks) self.callbacks = set(callbacks)
self.invalidated = False self.invalidated = False
@ -62,7 +62,6 @@ class Cache(object):
"max_entries", "max_entries",
"name", "name",
"keylen", "keylen",
"sequence",
"thread", "thread",
"metrics", "metrics",
"_pending_deferred_cache", "_pending_deferred_cache",
@ -80,7 +79,6 @@ class Cache(object):
self.name = name self.name = name
self.keylen = keylen self.keylen = keylen
self.sequence = 0
self.thread = None self.thread = None
self.metrics = register_cache(name, self.cache) self.metrics = register_cache(name, self.cache)
@ -113,7 +111,6 @@ class Cache(object):
callbacks = [callback] if callback else [] callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _CacheSentinel) val = self._pending_deferred_cache.get(key, _CacheSentinel)
if val is not _CacheSentinel: if val is not _CacheSentinel:
if val.sequence == self.sequence:
val.callbacks.update(callbacks) val.callbacks.update(callbacks)
if update_metrics: if update_metrics:
self.metrics.inc_hits() self.metrics.inc_hits()
@ -137,12 +134,9 @@ class Cache(object):
self.check_thread() self.check_thread()
entry = CacheEntry( entry = CacheEntry(
deferred=value, deferred=value,
sequence=self.sequence,
callbacks=callbacks, callbacks=callbacks,
) )
entry.callbacks.update(callbacks)
existing_entry = self._pending_deferred_cache.pop(key, None) existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry: if existing_entry:
existing_entry.invalidate() existing_entry.invalidate()
@ -150,13 +144,25 @@ class Cache(object):
self._pending_deferred_cache[key] = entry self._pending_deferred_cache[key] = entry
def shuffle(result): def shuffle(result):
if self.sequence == entry.sequence:
existing_entry = self._pending_deferred_cache.pop(key, None) existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry is entry: if existing_entry is entry:
self.cache.set(key, result, entry.callbacks) self.cache.set(key, result, entry.callbacks)
else: else:
entry.invalidate() # oops, the _pending_deferred_cache has been updated since
else: # we started our query, so we are out of date.
#
# Better put back whatever we took out. (We do it this way
# round, rather than peeking into the _pending_deferred_cache
# and then removing on a match, to make the common case faster)
if existing_entry is not None:
self._pending_deferred_cache[key] = existing_entry
# we're not going to put this entry into the cache, so need
# to make sure that the invalidation callbacks are called.
# That was probably done when _pending_deferred_cache was
# updated, but it's possible that `set` was called without
# `invalidate` being previously called, in which case it may
# not have been. Either way, let's double-check now.
entry.invalidate() entry.invalidate()
return result return result
@ -168,25 +174,29 @@ class Cache(object):
def invalidate(self, key): def invalidate(self, key):
self.check_thread() self.check_thread()
self.cache.pop(key, None)
# Increment the sequence number so that any SELECT statements that # if we have a pending lookup for this key, remove it from the
# raced with the INSERT don't update the cache (SYN-369) # _pending_deferred_cache, which will (a) stop it being returned
self.sequence += 1 # for future queries and (b) stop it being persisted as a proper entry
# in self.cache.
entry = self._pending_deferred_cache.pop(key, None) entry = self._pending_deferred_cache.pop(key, None)
# run the invalidation callbacks now, rather than waiting for the
# deferred to resolve.
if entry: if entry:
entry.invalidate() entry.invalidate()
self.cache.pop(key, None)
def invalidate_many(self, key): def invalidate_many(self, key):
self.check_thread() self.check_thread()
if not isinstance(key, tuple): if not isinstance(key, tuple):
raise TypeError( raise TypeError(
"The cache key must be a tuple not %r" % (type(key),) "The cache key must be a tuple not %r" % (type(key),)
) )
self.sequence += 1
self.cache.del_multi(key) self.cache.del_multi(key)
# if we have a pending lookup for this key, remove it from the
# _pending_deferred_cache, as above
entry_dict = self._pending_deferred_cache.pop(key, None) entry_dict = self._pending_deferred_cache.pop(key, None)
if entry_dict is not None: if entry_dict is not None:
for entry in iterate_tree_cache_entry(entry_dict): for entry in iterate_tree_cache_entry(entry_dict):
@ -194,8 +204,10 @@ class Cache(object):
def invalidate_all(self): def invalidate_all(self):
self.check_thread() self.check_thread()
self.sequence += 1
self.cache.clear() self.cache.clear()
for entry in self._pending_deferred_cache.itervalues():
entry.invalidate()
self._pending_deferred_cache.clear()
class _CacheDescriptorBase(object): class _CacheDescriptorBase(object):

View file

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
from frozendict import frozendict from frozendict import frozendict
import simplejson as json
def freeze(o): def freeze(o):
@ -49,3 +50,21 @@ def unfreeze(o):
pass pass
return o return o
def _handle_frozendict(obj):
"""Helper for EventEncoder. Makes frozendicts serializable by returning
the underlying dict
"""
if type(obj) is frozendict:
# fishing the protected dict out of the object is a bit nasty,
# but we don't really want the overhead of copying the dict.
return obj._dict
raise TypeError('Object of type %s is not JSON serializable' %
obj.__class__.__name__)
# A JSONEncoder which is capable of encoding frozendics without barfing
frozendict_json_encoder = json.JSONEncoder(
default=_handle_frozendict,
)

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.web.resource import Resource from twisted.web.resource import NoResource
import logging import logging
@ -45,7 +45,7 @@ def create_resource_tree(desired_tree, root_resource):
for path_seg in full_path.split('/')[1:-1]: for path_seg in full_path.split('/')[1:-1]:
if path_seg not in last_resource.listNames(): if path_seg not in last_resource.listNames():
# resource doesn't exist, so make a "dummy resource" # resource doesn't exist, so make a "dummy resource"
child_resource = Resource() child_resource = NoResource()
last_resource.putChild(path_seg, child_resource) last_resource.putChild(path_seg, child_resource)
res_id = _resource_id(last_resource, path_seg) res_id = _resource_id(last_resource, path_seg)
resource_mappings[res_id] = child_resource resource_mappings[res_id] = child_resource

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd # Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial
import logging import logging
import mock import mock
@ -25,6 +27,50 @@ from tests import unittest
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CacheTestCase(unittest.TestCase):
def test_invalidate_all(self):
cache = descriptors.Cache("testcache")
callback_record = [False, False]
def record_callback(idx):
callback_record[idx] = True
# add a couple of pending entries
d1 = defer.Deferred()
cache.set("key1", d1, partial(record_callback, 0))
d2 = defer.Deferred()
cache.set("key2", d2, partial(record_callback, 1))
# lookup should return the deferreds
self.assertIs(cache.get("key1"), d1)
self.assertIs(cache.get("key2"), d2)
# let one of the lookups complete
d2.callback("result2")
self.assertEqual(cache.get("key2"), "result2")
# now do the invalidation
cache.invalidate_all()
# lookup should return none
self.assertIsNone(cache.get("key1", None))
self.assertIsNone(cache.get("key2", None))
# both callbacks should have been callbacked
self.assertTrue(
callback_record[0], "Invalidation callback for key1 not called",
)
self.assertTrue(
callback_record[1], "Invalidation callback for key2 not called",
)
# letting the other lookup complete should do nothing
d1.callback("result1")
self.assertIsNone(cache.get("key1", None))
class DescriptorTestCase(unittest.TestCase): class DescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_cache(self): def test_cache(self):