0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-10-03 15:09:12 +02:00

Merge branch 'release-v0.25.0' of github.com:matrix-org/synapse

This commit is contained in:
Erik Johnston 2017-11-15 11:32:24 +00:00
commit 552f123bea
85 changed files with 1777 additions and 683 deletions

View file

@ -1,3 +1,61 @@
Changes in synapse v0.25.0 (2017-11-15)
=======================================
Bug fixes:
* Fix port script (PR #2673)
Changes in synapse v0.25.0-rc1 (2017-11-14)
===========================================
Features:
* Add is_public to groups table to allow for private groups (PR #2582)
* Add a route for determining who you are (PR #2668) Thanks to @turt2live!
* Add more features to the password providers (PR #2608, #2610, #2620, #2622,
#2623, #2624, #2626, #2628, #2629)
* Add a hook for custom rest endpoints (PR #2627)
* Add API to update group room visibility (PR #2651)
Changes:
* Ignore <noscript> tags when generating URL preview descriptions (PR #2576)
Thanks to @maximevaillancourt!
* Register some /unstable endpoints in /r0 as well (PR #2579) Thanks to
@krombel!
* Support /keys/upload on /r0 as well as /unstable (PR #2585)
* Front-end proxy: pass through auth header (PR #2586)
* Allow ASes to deactivate their own users (PR #2589)
* Remove refresh tokens (PR #2613)
* Automatically set default displayname on register (PR #2617)
* Log login requests (PR #2618)
* Always return `is_public` in the `/groups/:group_id/rooms` API (PR #2630)
* Avoid no-op media deletes (PR #2637) Thanks to @spantaleev!
* Fix various embarrassing typos around user_directory and add some doc. (PR
#2643)
* Return whether a user is an admin within a group (PR #2647)
* Namespace visibility options for groups (PR #2657)
* Downcase UserIDs on registration (PR #2662)
* Cache failures when fetching URL previews (PR #2669)
Bug fixes:
* Fix port script (PR #2577)
* Fix error when running synapse with no logfile (PR #2581)
* Fix UI auth when deleting devices (PR #2591)
* Fix typo when checking if user is invited to group (PR #2599)
* Fix the port script to drop NUL values in all tables (PR #2611)
* Fix appservices being backlogged and not receiving new events due to a bug in
notify_interested_services (PR #2631) Thanks to @xyzz!
* Fix updating rooms avatar/display name when modified by admin (PR #2636)
Thanks to @farialima!
* Fix bug in state group storage (PR #2649)
* Fix 500 on invalid utf-8 in request (PR #2663)
Changes in synapse v0.24.1 (2017-10-24) Changes in synapse v0.24.1 (2017-10-24)
======================================= =======================================

View file

@ -823,7 +823,9 @@ spidering 'internal' URLs on your network. At the very least we recommend that
your loopback and RFC1918 IP addresses are blacklisted. your loopback and RFC1918 IP addresses are blacklisted.
This also requires the optional lxml and netaddr python dependencies to be This also requires the optional lxml and netaddr python dependencies to be
installed. installed. This in turn requires the libxml2 library to be available - on
Debian/Ubuntu this means ``apt-get install libxml2-dev``, or equivalent for
your OS.
Password reset Password reset

View file

@ -1,26 +1,13 @@
Basically, PEP8 - Everything should comply with PEP8. Code should pass
``pep8 --max-line-length=100`` without any warnings.
- **Indenting**:
- NEVER tabs. 4 spaces to indent. - NEVER tabs. 4 spaces to indent.
- Max line width: 79 chars (with flexibility to overflow by a "few chars" if
the overflowing content is not semantically significant and avoids an - follow PEP8; either hanging indent or multiline-visual indent depending
explosion of vertical whitespace). on the size and shape of the arguments and what makes more sense to the
- Use camel case for class and type names author. In other words, both this::
- Use underscores for functions and variables.
- Use double quotes.
- Use parentheses instead of '\\' for line continuation where ever possible
(which is pretty much everywhere)
- There should be max a single new line between:
- statements
- functions in a class
- There should be two new lines between:
- definitions in a module (e.g., between different classes)
- There should be spaces where spaces should be and not where there shouldn't be:
- a single space after a comma
- a single space before and after for '=' when used as assignment
- no spaces before and after for '=' for default values and keyword arguments.
- Indenting must follow PEP8; either hanging indent or multiline-visual indent
depending on the size and shape of the arguments and what makes more sense to
the author. In other words, both this::
print("I am a fish %s" % "moo") print("I am a fish %s" % "moo")
@ -33,20 +20,100 @@ Basically, PEP8
print( print(
"I am a fish %s" % "I am a fish %s" %
"moo" "moo",
) )
...are valid, although given each one takes up 2x more vertical space than ...are valid, although given each one takes up 2x more vertical space than
the previous, it's up to the author's discretion as to which layout makes most the previous, it's up to the author's discretion as to which layout makes
sense for their function invocation. (e.g. if they want to add comments most sense for their function invocation. (e.g. if they want to add
per-argument, or put expressions in the arguments, or group related arguments comments per-argument, or put expressions in the arguments, or group
together, or want to deliberately extend or preserve vertical/horizontal related arguments together, or want to deliberately extend or preserve
space) vertical/horizontal space)
Comments should follow the `google code style <http://google.github.io/styleguide/pyguide.html?showone=Comments#Comments>`_. - **Line length**:
This is so that we can generate documentation with
`sphinx <http://sphinxcontrib-napoleon.readthedocs.org/en/latest/>`_. See the Max line length is 79 chars (with flexibility to overflow by a "few chars" if
`examples <http://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html>`_ the overflowing content is not semantically significant and avoids an
explosion of vertical whitespace).
Use parentheses instead of ``\`` for line continuation where ever possible
(which is pretty much everywhere).
- **Naming**:
- Use camel case for class and type names
- Use underscores for functions and variables.
- Use double quotes ``"foo"`` rather than single quotes ``'foo'``.
- **Blank lines**:
- There should be max a single new line between:
- statements
- functions in a class
- There should be two new lines between:
- definitions in a module (e.g., between different classes)
- **Whitespace**:
There should be spaces where spaces should be and not where there shouldn't
be:
- a single space after a comma
- a single space before and after for '=' when used as assignment
- no spaces before and after for '=' for default values and keyword arguments.
- **Comments**: should follow the `google code style
<http://google.github.io/styleguide/pyguide.html?showone=Comments#Comments>`_.
This is so that we can generate documentation with `sphinx
<http://sphinxcontrib-napoleon.readthedocs.org/en/latest/>`_. See the
`examples
<http://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html>`_
in the sphinx documentation. in the sphinx documentation.
Code should pass pep8 --max-line-length=100 without any warnings. - **Imports**:
- Prefer to import classes and functions than packages or modules.
Example::
from synapse.types import UserID
...
user_id = UserID(local, server)
is preferred over::
from synapse import types
...
user_id = types.UserID(local, server)
(or any other variant).
This goes against the advice in the Google style guide, but it means that
errors in the name are caught early (at import time).
- Multiple imports from the same package can be combined onto one line::
from synapse.types import GroupID, RoomID, UserID
An effort should be made to keep the individual imports in alphabetical
order.
If the list becomes long, wrap it with parentheses and split it over
multiple lines.
- As per `PEP-8 <https://www.python.org/dev/peps/pep-0008/#imports>`_,
imports should be grouped in the following order, with a blank line between
each group:
1. standard library imports
2. related third party imports
3. local application/library specific imports
- Imports within each group should be sorted alphabetically by module name.
- Avoid wildcard imports (``from synapse.types import *``) and relative
imports (``from .types import UserID``).

View file

@ -0,0 +1,99 @@
Password auth provider modules
==============================
Password auth providers offer a way for server administrators to integrate
their Synapse installation with an existing authentication system.
A password auth provider is a Python class which is dynamically loaded into
Synapse, and provides a number of methods by which it can integrate with the
authentication system.
This document serves as a reference for those looking to implement their own
password auth providers.
Required methods
----------------
Password auth provider classes must provide the following methods:
*class* ``SomeProvider.parse_config``\(*config*)
This method is passed the ``config`` object for this module from the
homeserver configuration file.
It should perform any appropriate sanity checks on the provided
configuration, and return an object which is then passed into ``__init__``.
*class* ``SomeProvider``\(*config*, *account_handler*)
The constructor is passed the config object returned by ``parse_config``,
and a ``synapse.module_api.ModuleApi`` object which allows the
password provider to check if accounts exist and/or create new ones.
Optional methods
----------------
Password auth provider classes may optionally provide the following methods.
*class* ``SomeProvider.get_db_schema_files``\()
This method, if implemented, should return an Iterable of ``(name,
stream)`` pairs of database schema files. Each file is applied in turn at
initialisation, and a record is then made in the database so that it is
not re-applied on the next start.
``someprovider.get_supported_login_types``\()
This method, if implemented, should return a ``dict`` mapping from a login
type identifier (such as ``m.login.password``) to an iterable giving the
fields which must be provided by the user in the submission to the
``/login`` api. These fields are passed in the ``login_dict`` dictionary
to ``check_auth``.
For example, if a password auth provider wants to implement a custom login
type of ``com.example.custom_login``, where the client is expected to pass
the fields ``secret1`` and ``secret2``, the provider should implement this
method and return the following dict::
{"com.example.custom_login": ("secret1", "secret2")}
``someprovider.check_auth``\(*username*, *login_type*, *login_dict*)
This method is the one that does the real work. If implemented, it will be
called for each login attempt where the login type matches one of the keys
returned by ``get_supported_login_types``.
It is passed the (possibly UNqualified) ``user`` provided by the client,
the login type, and a dictionary of login secrets passed by the client.
The method should return a Twisted ``Deferred`` object, which resolves to
the canonical ``@localpart:domain`` user id if authentication is successful,
and ``None`` if not.
Alternatively, the ``Deferred`` can resolve to a ``(str, func)`` tuple, in
which case the second field is a callback which will be called with the
result from the ``/login`` call (including ``access_token``, ``device_id``,
etc.)
``someprovider.check_password``\(*user_id*, *password*)
This method provides a simpler interface than ``get_supported_login_types``
and ``check_auth`` for password auth providers that just want to provide a
mechanism for validating ``m.login.password`` logins.
Iif implemented, it will be called to check logins with an
``m.login.password`` login type. It is passed a qualified
``@localpart:domain`` user id, and the password provided by the user.
The method should return a Twisted ``Deferred`` object, which resolves to
``True`` if authentication is successful, and ``False`` if not.
``someprovider.on_logged_out``\(*user_id*, *device_id*, *access_token*)
This method, if implemented, is called when a user logs out. It is passed
the qualified user ID, the ID of the deactivated device (if any: access
tokens are occasionally created without an associated device ID), and the
(now deactivated) access token.
It may return a Twisted ``Deferred`` object; the logout request will wait
for the deferred to complete but the result is ignored.

View file

@ -56,6 +56,7 @@ As a first cut, let's do #2 and have the receiver hit the API to calculate its o
API API
--- ---
```
GET /_matrix/media/r0/preview_url?url=http://wherever.com GET /_matrix/media/r0/preview_url?url=http://wherever.com
200 OK 200 OK
{ {
@ -66,6 +67,7 @@ GET /_matrix/media/r0/preview_url?url=http://wherever.com
"og:description" : "“Synapse 0.12 is out! Lots of polishing, performance &amp;amp; bugfixes: /sync API, /r0 prefix, fulltext search, 3PID invites https://t.co/5alhXLLEGP”" "og:description" : "“Synapse 0.12 is out! Lots of polishing, performance &amp;amp; bugfixes: /sync API, /r0 prefix, fulltext search, 3PID invites https://t.co/5alhXLLEGP”"
"og:site_name" : "Twitter" "og:site_name" : "Twitter"
} }
```
* Downloads the URL * Downloads the URL
* If HTML, just stores it in RAM and parses it for OG meta tags * If HTML, just stores it in RAM and parses it for OG meta tags

17
docs/user_directory.md Normal file
View file

@ -0,0 +1,17 @@
User Directory API Implementation
=================================
The user directory is currently maintained based on the 'visible' users
on this particular server - i.e. ones which your account shares a room with, or
who are present in a publicly viewable room present on the server.
The directory info is stored in various tables, which can (typically after
DB corruption) get stale or out of sync. If this happens, for now the
quickest solution to fix it is:
```
UPDATE user_directory_stream_pos SET stream_id = NULL;
```
and restart the synapse, which should then start a background task to
flush the current tables and regenerate the directory.

View file

@ -42,6 +42,14 @@ BOOLEAN_COLUMNS = {
"public_room_list_stream": ["visibility"], "public_room_list_stream": ["visibility"],
"device_lists_outbound_pokes": ["sent"], "device_lists_outbound_pokes": ["sent"],
"users_who_share_rooms": ["share_private"], "users_who_share_rooms": ["share_private"],
"groups": ["is_public"],
"group_rooms": ["is_public"],
"group_users": ["is_public", "is_admin"],
"group_summary_rooms": ["is_public"],
"group_room_categories": ["is_public"],
"group_summary_users": ["is_public"],
"group_roles": ["is_public"],
"local_group_membership": ["is_publicised", "is_admin"],
} }
@ -112,6 +120,7 @@ class Store(object):
_simple_update_one = SQLBaseStore.__dict__["_simple_update_one"] _simple_update_one = SQLBaseStore.__dict__["_simple_update_one"]
_simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"] _simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"]
_simple_update_txn = SQLBaseStore.__dict__["_simple_update_txn"]
def runInteraction(self, desc, func, *args, **kwargs): def runInteraction(self, desc, func, *args, **kwargs):
def r(conn): def r(conn):
@ -318,7 +327,7 @@ class Porter(object):
backward_chunk = min(row[0] for row in brows) - 1 backward_chunk = min(row[0] for row in brows) - 1
rows = frows + brows rows = frows + brows
self._convert_rows(table, headers, rows) rows = self._convert_rows(table, headers, rows)
def insert(txn): def insert(txn):
self.postgres_store.insert_many_txn( self.postgres_store.insert_many_txn(
@ -554,17 +563,29 @@ class Porter(object):
i for i, h in enumerate(headers) if h in bool_col_names i for i, h in enumerate(headers) if h in bool_col_names
] ]
class BadValueException(Exception):
pass
def conv(j, col): def conv(j, col):
if j in bool_cols: if j in bool_cols:
return bool(col) return bool(col)
elif isinstance(col, basestring) and "\0" in col:
logger.warn("DROPPING ROW: NUL value in table %s col %s: %r", table, headers[j], col)
raise BadValueException();
return col return col
outrows = []
for i, row in enumerate(rows): for i, row in enumerate(rows):
rows[i] = tuple( try:
outrows.append(tuple(
conv(j, col) conv(j, col)
for j, col in enumerate(row) for j, col in enumerate(row)
if j > 0 if j > 0
) ))
except BadValueException:
pass
return outrows
@defer.inlineCallbacks @defer.inlineCallbacks
def _setup_sent_transactions(self): def _setup_sent_transactions(self):
@ -592,7 +613,7 @@ class Porter(object):
"select", r, "select", r,
) )
self._convert_rows("sent_transactions", headers, rows) rows = self._convert_rows("sent_transactions", headers, rows)
inserted_rows = len(rows) inserted_rows = len(rows)
if inserted_rows: if inserted_rows:

View file

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.24.1" __version__ = "0.25.0"

View file

@ -50,8 +50,7 @@ logger = logging.getLogger("synapse.app.frontend_proxy")
class KeyUploadServlet(RestServlet): class KeyUploadServlet(RestServlet):
PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$", PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
releases=())
def __init__(self, hs): def __init__(self, hs):
""" """
@ -89,9 +88,16 @@ class KeyUploadServlet(RestServlet):
if body: if body:
# They're actually trying to upload something, proxy to main synapse. # They're actually trying to upload something, proxy to main synapse.
# Pass through the auth headers, if any, in case the access token
# is there.
auth_headers = request.requestHeaders.getRawHeaders("Authorization", [])
headers = {
"Authorization": auth_headers,
}
result = yield self.http_client.post_json_get_json( result = yield self.http_client.post_json_get_json(
self.main_uri + request.uri, self.main_uri + request.uri,
body, body,
headers=headers,
) )
defer.returnValue((200, result)) defer.returnValue((200, result))

View file

@ -30,6 +30,8 @@ from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory from synapse.crypto import context_factory
from synapse.federation.transport.server import TransportLayerServer from synapse.federation.transport.server import TransportLayerServer
from synapse.module_api import ModuleApi
from synapse.http.additional_resource import AdditionalResource
from synapse.http.server import RootRedirect from synapse.http.server import RootRedirect
from synapse.http.site import SynapseSite from synapse.http.site import SynapseSite
from synapse.metrics import register_memory_metrics from synapse.metrics import register_memory_metrics
@ -49,6 +51,7 @@ from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_d
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole from synapse.util.manhole import manhole
from synapse.util.module_loader import load_module
from synapse.util.rlimit import change_resource_limit from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
from twisted.application import service from twisted.application import service
@ -107,9 +110,68 @@ class SynapseHomeServer(HomeServer):
resources = {} resources = {}
for res in listener_config["resources"]: for res in listener_config["resources"]:
for name in res["names"]: for name in res["names"]:
resources.update(self._configure_named_resource(
name, res.get("compress", False),
))
additional_resources = listener_config.get("additional_resources", {})
logger.debug("Configuring additional resources: %r",
additional_resources)
module_api = ModuleApi(self, self.get_auth_handler())
for path, resmodule in additional_resources.items():
handler_cls, config = load_module(resmodule)
handler = handler_cls(config, module_api)
resources[path] = AdditionalResource(self, handler.handle_request)
if WEB_CLIENT_PREFIX in resources:
root_resource = RootRedirect(WEB_CLIENT_PREFIX)
else:
root_resource = Resource()
root_resource = create_resource_tree(resources, root_resource)
if tls:
for address in bind_addresses:
reactor.listenSSL(
port,
SynapseSite(
"synapse.access.https.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
self.tls_server_context_factory,
interface=address
)
else:
for address in bind_addresses:
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=address
)
logger.info("Synapse now listening on port %d", port)
def _configure_named_resource(self, name, compress=False):
"""Build a resource map for a named resource
Args:
name (str): named resource: one of "client", "federation", etc
compress (bool): whether to enable gzip compression for this
resource
Returns:
dict[str, Resource]: map from path to HTTP resource
"""
resources = {}
if name == "client": if name == "client":
client_resource = ClientRestResource(self) client_resource = ClientRestResource(self)
if res["compress"]: if compress:
client_resource = gz_wrap(client_resource) client_resource = gz_wrap(client_resource)
resources.update({ resources.update({
@ -154,39 +216,7 @@ class SynapseHomeServer(HomeServer):
if name == "metrics" and self.get_config().enable_metrics: if name == "metrics" and self.get_config().enable_metrics:
resources[METRICS_PREFIX] = MetricsResource(self) resources[METRICS_PREFIX] = MetricsResource(self)
if WEB_CLIENT_PREFIX in resources: return resources
root_resource = RootRedirect(WEB_CLIENT_PREFIX)
else:
root_resource = Resource()
root_resource = create_resource_tree(resources, root_resource)
if tls:
for address in bind_addresses:
reactor.listenSSL(
port,
SynapseSite(
"synapse.access.https.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
self.tls_server_context_factory,
interface=address
)
else:
for address in bind_addresses:
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=address
)
logger.info("Synapse now listening on port %d", port)
def start_listening(self): def start_listening(self):
config = self.get_config() config = self.get_config()

View file

@ -18,6 +18,7 @@ from synapse.api.constants import ThirdPartyEntityKind
from synapse.api.errors import CodeMessageException from synapse.api.errors import CodeMessageException
from synapse.http.client import SimpleHttpClient from synapse.http.client import SimpleHttpClient
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from synapse.types import ThirdPartyInstanceID from synapse.types import ThirdPartyInstanceID
@ -192,9 +193,12 @@ class ApplicationServiceApi(SimpleHttpClient):
defer.returnValue(None) defer.returnValue(None)
key = (service.id, protocol) key = (service.id, protocol)
return self.protocol_meta_cache.get(key) or ( result = self.protocol_meta_cache.get(key)
self.protocol_meta_cache.set(key, _get()) if not result:
result = self.protocol_meta_cache.set(
key, preserve_fn(_get)()
) )
return make_deferred_yieldable(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def push_bulk(self, service, events, txn_id=None): def push_bulk(self, service, events, txn_id=None):

View file

@ -41,7 +41,7 @@ class CasConfig(Config):
#cas_config: #cas_config:
# enabled: true # enabled: true
# server_url: "https://cas-server.com" # server_url: "https://cas-server.com"
# service_url: "https://homesever.domain.com:8448" # service_url: "https://homeserver.domain.com:8448"
# #required_attributes: # #required_attributes:
# # name: value # # name: value
""" """

View file

@ -148,8 +148,8 @@ def setup_logging(config, use_worker_options=False):
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s" "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
" - %(message)s" " - %(message)s"
) )
if log_config is None:
if log_config is None:
level = logging.INFO level = logging.INFO
level_for_storage = logging.INFO level_for_storage = logging.INFO
if config.verbosity: if config.verbosity:
@ -176,6 +176,10 @@ def setup_logging(config, use_worker_options=False):
logger.info("Opened new log file due to SIGHUP") logger.info("Opened new log file due to SIGHUP")
else: else:
handler = logging.StreamHandler() handler = logging.StreamHandler()
def sighup(signum, stack):
pass
handler.setFormatter(formatter) handler.setFormatter(formatter)
handler.addFilter(LoggingContextFilter(request="")) handler.addFilter(LoggingContextFilter(request=""))

View file

@ -13,41 +13,40 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import Config, ConfigError from ._base import Config
from synapse.util.module_loader import load_module from synapse.util.module_loader import load_module
LDAP_PROVIDER = 'ldap_auth_provider.LdapAuthProvider'
class PasswordAuthProviderConfig(Config): class PasswordAuthProviderConfig(Config):
def read_config(self, config): def read_config(self, config):
self.password_providers = [] self.password_providers = []
providers = []
provider_config = None
# We want to be backwards compatible with the old `ldap_config` # We want to be backwards compatible with the old `ldap_config`
# param. # param.
ldap_config = config.get("ldap_config", {}) ldap_config = config.get("ldap_config", {})
self.ldap_enabled = ldap_config.get("enabled", False) if ldap_config.get("enabled", False):
if self.ldap_enabled: providers.append[{
from ldap_auth_provider import LdapAuthProvider 'module': LDAP_PROVIDER,
parsed_config = LdapAuthProvider.parse_config(ldap_config) 'config': ldap_config,
self.password_providers.append((LdapAuthProvider, parsed_config)) }]
providers = config.get("password_providers", []) providers.extend(config.get("password_providers", []))
for provider in providers: for provider in providers:
mod_name = provider['module']
# This is for backwards compat when the ldap auth provider resided # This is for backwards compat when the ldap auth provider resided
# in this package. # in this package.
if provider['module'] == "synapse.util.ldap_auth_provider.LdapAuthProvider": if mod_name == "synapse.util.ldap_auth_provider.LdapAuthProvider":
from ldap_auth_provider import LdapAuthProvider mod_name = LDAP_PROVIDER
provider_class = LdapAuthProvider
try: (provider_class, provider_config) = load_module({
provider_config = provider_class.parse_config(provider["config"]) "module": mod_name,
except Exception as e: "config": provider['config'],
raise ConfigError( })
"Failed to parse config for %r: %r" % (provider['module'], e)
)
else:
(provider_class, provider_config) = load_module(provider)
self.password_providers.append((provider_class, provider_config)) self.password_providers.append((provider_class, provider_config))

View file

@ -247,6 +247,13 @@ class ServerConfig(Config):
- names: [federation] # Federation APIs - names: [federation] # Federation APIs
compress: false compress: false
# optional list of additional endpoints which can be loaded via
# dynamic modules
# additional_resources:
# "/_matrix/my/custom/endpoint":
# module: my_module.CustomRequestHandler
# config: {}
# Unsecure HTTP listener, # Unsecure HTTP listener,
# For when matrix traffic passes through loadbalancer that unwraps TLS. # For when matrix traffic passes through loadbalancer that unwraps TLS.
- port: %(unsecure_port)s - port: %(unsecure_port)s

View file

@ -109,6 +109,12 @@ class TlsConfig(Config):
# key. It may be necessary to publish the fingerprints of a new # key. It may be necessary to publish the fingerprints of a new
# certificate and wait until the "valid_until_ts" of the previous key # certificate and wait until the "valid_until_ts" of the previous key
# responses have passed before deploying it. # responses have passed before deploying it.
#
# You can calculate a fingerprint from a given TLS listener via:
# openssl s_client -connect $host:$port < /dev/null 2> /dev/null |
# openssl x509 -outform DER | openssl sha256 -binary | base64 | tr -d '='
# or by checking matrix.org/federationtester/api/report?server_name=$host
#
tls_fingerprints: [] tls_fingerprints: []
# tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}] # tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}]
""" % locals() """ % locals()

View file

@ -18,6 +18,7 @@ from .federation_base import FederationBase
from .units import Transaction, Edu from .units import Transaction, Edu
from synapse.util import async from synapse.util import async
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
@ -253,12 +254,13 @@ class FederationServer(FederationBase):
result = self._state_resp_cache.get((room_id, event_id)) result = self._state_resp_cache.get((room_id, event_id))
if not result: if not result:
with (yield self._server_linearizer.queue((origin, room_id))): with (yield self._server_linearizer.queue((origin, room_id))):
resp = yield self._state_resp_cache.set( d = self._state_resp_cache.set(
(room_id, event_id), (room_id, event_id),
self._on_context_state_request_compute(room_id, event_id) preserve_fn(self._on_context_state_request_compute)(room_id, event_id)
) )
resp = yield make_deferred_yieldable(d)
else: else:
resp = yield result resp = yield make_deferred_yieldable(result)
defer.returnValue((200, resp)) defer.returnValue((200, resp))

View file

@ -545,6 +545,20 @@ class TransportLayerClient(object):
ignore_backoff=True, ignore_backoff=True,
) )
def update_room_in_group(self, destination, group_id, requester_user_id, room_id,
config_key, content):
"""Update room in group
"""
path = PREFIX + "/groups/%s/room/%s/config/%s" % (group_id, room_id, config_key,)
return self.client.post_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
data=content,
ignore_backoff=True,
)
def remove_room_from_group(self, destination, group_id, requester_user_id, room_id): def remove_room_from_group(self, destination, group_id, requester_user_id, room_id):
"""Remove a room from a group """Remove a room from a group
""" """

View file

@ -676,7 +676,7 @@ class FederationGroupsRoomsServlet(BaseFederationServlet):
class FederationGroupsAddRoomsServlet(BaseFederationServlet): class FederationGroupsAddRoomsServlet(BaseFederationServlet):
"""Add/remove room from group """Add/remove room from group
""" """
PATH = "/groups/(?P<group_id>[^/]*)/room/(?<room_id>)$" PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)$"
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, origin, content, query, group_id, room_id): def on_POST(self, origin, content, query, group_id, room_id):
@ -703,6 +703,27 @@ class FederationGroupsAddRoomsServlet(BaseFederationServlet):
defer.returnValue((200, new_content)) defer.returnValue((200, new_content))
class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
"""Update room config in group
"""
PATH = (
"/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
"/config/(?P<config_key>[^/]*)$"
)
@defer.inlineCallbacks
def on_POST(self, origin, content, query, group_id, room_id, config_key):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
result = yield self.groups_handler.update_room_in_group(
group_id, requester_user_id, room_id, config_key, content,
)
defer.returnValue((200, result))
class FederationGroupsUsersServlet(BaseFederationServlet): class FederationGroupsUsersServlet(BaseFederationServlet):
"""Get the users in a group on behalf of a user """Get the users in a group on behalf of a user
""" """
@ -1142,6 +1163,8 @@ GROUP_SERVER_SERVLET_CLASSES = (
FederationGroupsRolesServlet, FederationGroupsRolesServlet,
FederationGroupsRoleServlet, FederationGroupsRoleServlet,
FederationGroupsSummaryUsersServlet, FederationGroupsSummaryUsersServlet,
FederationGroupsAddRoomsServlet,
FederationGroupsAddRoomsConfigServlet,
) )

View file

@ -13,6 +13,31 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Attestations ensure that users and groups can't lie about their memberships.
When a user joins a group the HS and GS swap attestations, which allow them
both to independently prove to third parties their membership.These
attestations have a validity period so need to be periodically renewed.
If a user leaves (or gets kicked out of) a group, either side can still use
their attestation to "prove" their membership, until the attestation expires.
Therefore attestations shouldn't be relied on to prove membership in important
cases, but can for less important situtations, e.g. showing a users membership
of groups on their profile, showing flairs, etc.abs
An attestsation is a signed blob of json that looks like:
{
"user_id": "@foo:a.example.com",
"group_id": "+bar:b.example.com",
"valid_until_ms": 1507994728530,
"signatures":{"matrix.org":{"ed25519:auto":"..."}}
}
"""
import logging
import random
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
@ -22,9 +47,17 @@ from synapse.util.logcontext import preserve_fn
from signedjson.sign import sign_json from signedjson.sign import sign_json
logger = logging.getLogger(__name__)
# Default validity duration for new attestations we create # Default validity duration for new attestations we create
DEFAULT_ATTESTATION_LENGTH_MS = 3 * 24 * 60 * 60 * 1000 DEFAULT_ATTESTATION_LENGTH_MS = 3 * 24 * 60 * 60 * 1000
# We add some jitter to the validity duration of attestations so that if we
# add lots of users at once we don't need to renew them all at once.
# The jitter is a multiplier picked randomly between the first and second number
DEFAULT_ATTESTATION_JITTER = (0.9, 1.3)
# Start trying to update our attestations when they come this close to expiring # Start trying to update our attestations when they come this close to expiring
UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000 UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000
@ -73,10 +106,14 @@ class GroupAttestationSigning(object):
"""Create an attestation for the group_id and user_id with default """Create an attestation for the group_id and user_id with default
validity length. validity length.
""" """
validity_period = DEFAULT_ATTESTATION_LENGTH_MS
validity_period *= random.uniform(*DEFAULT_ATTESTATION_JITTER)
valid_until_ms = int(self.clock.time_msec() + validity_period)
return sign_json({ return sign_json({
"group_id": group_id, "group_id": group_id,
"user_id": user_id, "user_id": user_id,
"valid_until_ms": self.clock.time_msec() + DEFAULT_ATTESTATION_LENGTH_MS, "valid_until_ms": valid_until_ms,
}, self.server_name, self.signing_key) }, self.server_name, self.signing_key)
@ -128,12 +165,19 @@ class GroupAttestionRenewer(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _renew_attestation(group_id, user_id): def _renew_attestation(group_id, user_id):
attestation = self.attestations.create_attestation(group_id, user_id) if not self.is_mine_id(group_id):
destination = get_domain_from_id(group_id)
if self.is_mine_id(group_id): elif not self.is_mine_id(user_id):
destination = get_domain_from_id(user_id) destination = get_domain_from_id(user_id)
else: else:
destination = get_domain_from_id(group_id) logger.warn(
"Incorrectly trying to do attestations for user: %r in %r",
user_id, group_id,
)
yield self.store.remove_attestation_renewal(group_id, user_id)
return
attestation = self.attestations.create_attestation(group_id, user_id)
yield self.transport_client.renew_group_attestation( yield self.transport_client.renew_group_attestation(
destination, group_id, user_id, destination, group_id, user_id,

View file

@ -49,7 +49,8 @@ class GroupsServerHandler(object):
hs.get_groups_attestation_renewer() hs.get_groups_attestation_renewer()
@defer.inlineCallbacks @defer.inlineCallbacks
def check_group_is_ours(self, group_id, and_exists=False, and_is_admin=None): def check_group_is_ours(self, group_id, requester_user_id,
and_exists=False, and_is_admin=None):
"""Check that the group is ours, and optionally if it exists. """Check that the group is ours, and optionally if it exists.
If group does exist then return group. If group does exist then return group.
@ -67,6 +68,10 @@ class GroupsServerHandler(object):
if and_exists and not group: if and_exists and not group:
raise SynapseError(404, "Unknown group") raise SynapseError(404, "Unknown group")
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
if group and not is_user_in_group and not group["is_public"]:
raise SynapseError(404, "Unknown group")
if and_is_admin: if and_is_admin:
is_admin = yield self.store.is_user_admin_in_group(group_id, and_is_admin) is_admin = yield self.store.is_user_admin_in_group(group_id, and_is_admin)
if not is_admin: if not is_admin:
@ -84,7 +89,7 @@ class GroupsServerHandler(object):
A user/room may appear in multiple roles/categories. A user/room may appear in multiple roles/categories.
""" """
yield self.check_group_is_ours(group_id, and_exists=True) yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
@ -153,10 +158,16 @@ class GroupsServerHandler(object):
}) })
@defer.inlineCallbacks @defer.inlineCallbacks
def update_group_summary_room(self, group_id, user_id, room_id, category_id, content): def update_group_summary_room(self, group_id, requester_user_id,
room_id, category_id, content):
"""Add/update a room to the group summary """Add/update a room to the group summary
""" """
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id) yield self.check_group_is_ours(
group_id,
requester_user_id,
and_exists=True,
and_is_admin=requester_user_id,
)
RoomID.from_string(room_id) # Ensure valid room id RoomID.from_string(room_id) # Ensure valid room id
@ -175,10 +186,16 @@ class GroupsServerHandler(object):
defer.returnValue({}) defer.returnValue({})
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_group_summary_room(self, group_id, user_id, room_id, category_id): def delete_group_summary_room(self, group_id, requester_user_id,
room_id, category_id):
"""Remove a room from the summary """Remove a room from the summary
""" """
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id) yield self.check_group_is_ours(
group_id,
requester_user_id,
and_exists=True,
and_is_admin=requester_user_id,
)
yield self.store.remove_room_from_summary( yield self.store.remove_room_from_summary(
group_id=group_id, group_id=group_id,
@ -189,10 +206,10 @@ class GroupsServerHandler(object):
defer.returnValue({}) defer.returnValue({})
@defer.inlineCallbacks @defer.inlineCallbacks
def get_group_categories(self, group_id, user_id): def get_group_categories(self, group_id, requester_user_id):
"""Get all categories in a group (as seen by user) """Get all categories in a group (as seen by user)
""" """
yield self.check_group_is_ours(group_id, and_exists=True) yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
categories = yield self.store.get_group_categories( categories = yield self.store.get_group_categories(
group_id=group_id, group_id=group_id,
@ -200,10 +217,10 @@ class GroupsServerHandler(object):
defer.returnValue({"categories": categories}) defer.returnValue({"categories": categories})
@defer.inlineCallbacks @defer.inlineCallbacks
def get_group_category(self, group_id, user_id, category_id): def get_group_category(self, group_id, requester_user_id, category_id):
"""Get a specific category in a group (as seen by user) """Get a specific category in a group (as seen by user)
""" """
yield self.check_group_is_ours(group_id, and_exists=True) yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
res = yield self.store.get_group_category( res = yield self.store.get_group_category(
group_id=group_id, group_id=group_id,
@ -213,10 +230,15 @@ class GroupsServerHandler(object):
defer.returnValue(res) defer.returnValue(res)
@defer.inlineCallbacks @defer.inlineCallbacks
def update_group_category(self, group_id, user_id, category_id, content): def update_group_category(self, group_id, requester_user_id, category_id, content):
"""Add/Update a group category """Add/Update a group category
""" """
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id) yield self.check_group_is_ours(
group_id,
requester_user_id,
and_exists=True,
and_is_admin=requester_user_id,
)
is_public = _parse_visibility_from_contents(content) is_public = _parse_visibility_from_contents(content)
profile = content.get("profile") profile = content.get("profile")
@ -231,10 +253,15 @@ class GroupsServerHandler(object):
defer.returnValue({}) defer.returnValue({})
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_group_category(self, group_id, user_id, category_id): def delete_group_category(self, group_id, requester_user_id, category_id):
"""Delete a group category """Delete a group category
""" """
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id) yield self.check_group_is_ours(
group_id,
requester_user_id,
and_exists=True,
and_is_admin=requester_user_id
)
yield self.store.remove_group_category( yield self.store.remove_group_category(
group_id=group_id, group_id=group_id,
@ -244,10 +271,10 @@ class GroupsServerHandler(object):
defer.returnValue({}) defer.returnValue({})
@defer.inlineCallbacks @defer.inlineCallbacks
def get_group_roles(self, group_id, user_id): def get_group_roles(self, group_id, requester_user_id):
"""Get all roles in a group (as seen by user) """Get all roles in a group (as seen by user)
""" """
yield self.check_group_is_ours(group_id, and_exists=True) yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
roles = yield self.store.get_group_roles( roles = yield self.store.get_group_roles(
group_id=group_id, group_id=group_id,
@ -255,10 +282,10 @@ class GroupsServerHandler(object):
defer.returnValue({"roles": roles}) defer.returnValue({"roles": roles})
@defer.inlineCallbacks @defer.inlineCallbacks
def get_group_role(self, group_id, user_id, role_id): def get_group_role(self, group_id, requester_user_id, role_id):
"""Get a specific role in a group (as seen by user) """Get a specific role in a group (as seen by user)
""" """
yield self.check_group_is_ours(group_id, and_exists=True) yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
res = yield self.store.get_group_role( res = yield self.store.get_group_role(
group_id=group_id, group_id=group_id,
@ -267,10 +294,15 @@ class GroupsServerHandler(object):
defer.returnValue(res) defer.returnValue(res)
@defer.inlineCallbacks @defer.inlineCallbacks
def update_group_role(self, group_id, user_id, role_id, content): def update_group_role(self, group_id, requester_user_id, role_id, content):
"""Add/update a role in a group """Add/update a role in a group
""" """
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id) yield self.check_group_is_ours(
group_id,
requester_user_id,
and_exists=True,
and_is_admin=requester_user_id,
)
is_public = _parse_visibility_from_contents(content) is_public = _parse_visibility_from_contents(content)
@ -286,10 +318,15 @@ class GroupsServerHandler(object):
defer.returnValue({}) defer.returnValue({})
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_group_role(self, group_id, user_id, role_id): def delete_group_role(self, group_id, requester_user_id, role_id):
"""Remove role from group """Remove role from group
""" """
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id) yield self.check_group_is_ours(
group_id,
requester_user_id,
and_exists=True,
and_is_admin=requester_user_id,
)
yield self.store.remove_group_role( yield self.store.remove_group_role(
group_id=group_id, group_id=group_id,
@ -304,7 +341,7 @@ class GroupsServerHandler(object):
"""Add/update a users entry in the group summary """Add/update a users entry in the group summary
""" """
yield self.check_group_is_ours( yield self.check_group_is_ours(
group_id, and_exists=True, and_is_admin=requester_user_id, group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id,
) )
order = content.get("order", None) order = content.get("order", None)
@ -326,7 +363,7 @@ class GroupsServerHandler(object):
"""Remove a user from the group summary """Remove a user from the group summary
""" """
yield self.check_group_is_ours( yield self.check_group_is_ours(
group_id, and_exists=True, and_is_admin=requester_user_id, group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id,
) )
yield self.store.remove_user_from_summary( yield self.store.remove_user_from_summary(
@ -342,7 +379,7 @@ class GroupsServerHandler(object):
"""Get the group profile as seen by requester_user_id """Get the group profile as seen by requester_user_id
""" """
yield self.check_group_is_ours(group_id) yield self.check_group_is_ours(group_id, requester_user_id)
group_description = yield self.store.get_group(group_id) group_description = yield self.store.get_group(group_id)
@ -356,7 +393,7 @@ class GroupsServerHandler(object):
"""Update the group profile """Update the group profile
""" """
yield self.check_group_is_ours( yield self.check_group_is_ours(
group_id, and_exists=True, and_is_admin=requester_user_id, group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id,
) )
profile = {} profile = {}
@ -377,7 +414,7 @@ class GroupsServerHandler(object):
The ordering is arbitrary at the moment The ordering is arbitrary at the moment
""" """
yield self.check_group_is_ours(group_id, and_exists=True) yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
@ -389,14 +426,15 @@ class GroupsServerHandler(object):
for user_result in user_results: for user_result in user_results:
g_user_id = user_result["user_id"] g_user_id = user_result["user_id"]
is_public = user_result["is_public"] is_public = user_result["is_public"]
is_privileged = user_result["is_admin"]
entry = {"user_id": g_user_id} entry = {"user_id": g_user_id}
profile = yield self.profile_handler.get_profile_from_cache(g_user_id) profile = yield self.profile_handler.get_profile_from_cache(g_user_id)
entry.update(profile) entry.update(profile)
if not is_public: entry["is_public"] = bool(is_public)
entry["is_public"] = False entry["is_privileged"] = bool(is_privileged)
if not self.is_mine_id(g_user_id): if not self.is_mine_id(g_user_id):
attestation = yield self.store.get_remote_attestation(group_id, g_user_id) attestation = yield self.store.get_remote_attestation(group_id, g_user_id)
@ -425,7 +463,7 @@ class GroupsServerHandler(object):
The ordering is arbitrary at the moment The ordering is arbitrary at the moment
""" """
yield self.check_group_is_ours(group_id, and_exists=True) yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
@ -459,7 +497,7 @@ class GroupsServerHandler(object):
This returns rooms in order of decreasing number of joined users This returns rooms in order of decreasing number of joined users
""" """
yield self.check_group_is_ours(group_id, and_exists=True) yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
@ -470,7 +508,6 @@ class GroupsServerHandler(object):
chunk = [] chunk = []
for room_result in room_results: for room_result in room_results:
room_id = room_result["room_id"] room_id = room_result["room_id"]
is_public = room_result["is_public"]
joined_users = yield self.store.get_users_in_room(room_id) joined_users = yield self.store.get_users_in_room(room_id)
entry = yield self.room_list_handler.generate_room_entry( entry = yield self.room_list_handler.generate_room_entry(
@ -481,8 +518,7 @@ class GroupsServerHandler(object):
if not entry: if not entry:
continue continue
if not is_public: entry["is_public"] = bool(room_result["is_public"])
entry["is_public"] = False
chunk.append(entry) chunk.append(entry)
@ -500,7 +536,7 @@ class GroupsServerHandler(object):
RoomID.from_string(room_id) # Ensure valid room id RoomID.from_string(room_id) # Ensure valid room id
yield self.check_group_is_ours( yield self.check_group_is_ours(
group_id, and_exists=True, and_is_admin=requester_user_id group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
) )
is_public = _parse_visibility_from_contents(content) is_public = _parse_visibility_from_contents(content)
@ -509,12 +545,35 @@ class GroupsServerHandler(object):
defer.returnValue({}) defer.returnValue({})
@defer.inlineCallbacks
def update_room_in_group(self, group_id, requester_user_id, room_id, config_key,
content):
"""Update room in group
"""
RoomID.from_string(room_id) # Ensure valid room id
yield self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
if config_key == "m.visibility":
is_public = _parse_visibility_dict(content)
yield self.store.update_room_in_group_visibility(
group_id, room_id,
is_public=is_public,
)
else:
raise SynapseError(400, "Uknown config option")
defer.returnValue({})
@defer.inlineCallbacks @defer.inlineCallbacks
def remove_room_from_group(self, group_id, requester_user_id, room_id): def remove_room_from_group(self, group_id, requester_user_id, room_id):
"""Remove room from group """Remove room from group
""" """
yield self.check_group_is_ours( yield self.check_group_is_ours(
group_id, and_exists=True, and_is_admin=requester_user_id group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
) )
yield self.store.remove_room_from_group(group_id, room_id) yield self.store.remove_room_from_group(group_id, room_id)
@ -527,7 +586,7 @@ class GroupsServerHandler(object):
""" """
group = yield self.check_group_is_ours( group = yield self.check_group_is_ours(
group_id, and_exists=True, and_is_admin=requester_user_id group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
) )
# TODO: Check if user knocked # TODO: Check if user knocked
@ -596,35 +655,40 @@ class GroupsServerHandler(object):
raise SynapseError(502, "Unknown state returned by HS") raise SynapseError(502, "Unknown state returned by HS")
@defer.inlineCallbacks @defer.inlineCallbacks
def accept_invite(self, group_id, user_id, content): def accept_invite(self, group_id, requester_user_id, content):
"""User tries to accept an invite to the group. """User tries to accept an invite to the group.
This is different from them asking to join, and so should error if no This is different from them asking to join, and so should error if no
invite exists (and they're not a member of the group) invite exists (and they're not a member of the group)
""" """
yield self.check_group_is_ours(group_id, and_exists=True) yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
if not self.store.is_user_invited_to_local_group(group_id, user_id): is_invited = yield self.store.is_user_invited_to_local_group(
group_id, requester_user_id,
)
if not is_invited:
raise SynapseError(403, "User not invited to group") raise SynapseError(403, "User not invited to group")
if not self.hs.is_mine_id(user_id): if not self.hs.is_mine_id(requester_user_id):
local_attestation = self.attestations.create_attestation(
group_id, requester_user_id,
)
remote_attestation = content["attestation"] remote_attestation = content["attestation"]
yield self.attestations.verify_attestation( yield self.attestations.verify_attestation(
remote_attestation, remote_attestation,
user_id=user_id, user_id=requester_user_id,
group_id=group_id, group_id=group_id,
) )
else: else:
local_attestation = None
remote_attestation = None remote_attestation = None
local_attestation = self.attestations.create_attestation(group_id, user_id)
is_public = _parse_visibility_from_contents(content) is_public = _parse_visibility_from_contents(content)
yield self.store.add_user_to_group( yield self.store.add_user_to_group(
group_id, user_id, group_id, requester_user_id,
is_admin=False, is_admin=False,
is_public=is_public, is_public=is_public,
local_attestation=local_attestation, local_attestation=local_attestation,
@ -637,31 +701,31 @@ class GroupsServerHandler(object):
}) })
@defer.inlineCallbacks @defer.inlineCallbacks
def knock(self, group_id, user_id, content): def knock(self, group_id, requester_user_id, content):
"""A user requests becoming a member of the group """A user requests becoming a member of the group
""" """
yield self.check_group_is_ours(group_id, and_exists=True) yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
raise NotImplementedError() raise NotImplementedError()
@defer.inlineCallbacks @defer.inlineCallbacks
def accept_knock(self, group_id, user_id, content): def accept_knock(self, group_id, requester_user_id, content):
"""Accept a users knock to the room. """Accept a users knock to the room.
Errors if the user hasn't knocked, rather than inviting them. Errors if the user hasn't knocked, rather than inviting them.
""" """
yield self.check_group_is_ours(group_id, and_exists=True) yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
raise NotImplementedError() raise NotImplementedError()
@defer.inlineCallbacks @defer.inlineCallbacks
def remove_user_from_group(self, group_id, user_id, requester_user_id, content): def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
"""Remove a user from the group; either a user is leaving or and admin """Remove a user from the group; either a user is leaving or an admin
kicked htem. kicked them.
""" """
yield self.check_group_is_ours(group_id, and_exists=True) yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
is_kick = False is_kick = False
if requester_user_id != user_id: if requester_user_id != user_id:
@ -692,8 +756,8 @@ class GroupsServerHandler(object):
defer.returnValue({}) defer.returnValue({})
@defer.inlineCallbacks @defer.inlineCallbacks
def create_group(self, group_id, user_id, content): def create_group(self, group_id, requester_user_id, content):
group = yield self.check_group_is_ours(group_id) group = yield self.check_group_is_ours(group_id, requester_user_id)
logger.info("Attempting to create group with ID: %r", group_id) logger.info("Attempting to create group with ID: %r", group_id)
@ -703,11 +767,11 @@ class GroupsServerHandler(object):
if group: if group:
raise SynapseError(400, "Group already exists") raise SynapseError(400, "Group already exists")
is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id)) is_admin = yield self.auth.is_server_admin(UserID.from_string(requester_user_id))
if not is_admin: if not is_admin:
if not self.hs.config.enable_group_creation: if not self.hs.config.enable_group_creation:
raise SynapseError( raise SynapseError(
403, "Only server admin can create group on this server", 403, "Only a server admin can create groups on this server",
) )
localpart = group_id_obj.localpart localpart = group_id_obj.localpart
if not localpart.startswith(self.hs.config.group_creation_prefix): if not localpart.startswith(self.hs.config.group_creation_prefix):
@ -727,38 +791,41 @@ class GroupsServerHandler(object):
yield self.store.create_group( yield self.store.create_group(
group_id, group_id,
user_id, requester_user_id,
name=name, name=name,
avatar_url=avatar_url, avatar_url=avatar_url,
short_description=short_description, short_description=short_description,
long_description=long_description, long_description=long_description,
) )
if not self.hs.is_mine_id(user_id): if not self.hs.is_mine_id(requester_user_id):
remote_attestation = content["attestation"] remote_attestation = content["attestation"]
yield self.attestations.verify_attestation( yield self.attestations.verify_attestation(
remote_attestation, remote_attestation,
user_id=user_id, user_id=requester_user_id,
group_id=group_id, group_id=group_id,
) )
local_attestation = self.attestations.create_attestation(group_id, user_id) local_attestation = self.attestations.create_attestation(
group_id,
requester_user_id,
)
else: else:
local_attestation = None local_attestation = None
remote_attestation = None remote_attestation = None
yield self.store.add_user_to_group( yield self.store.add_user_to_group(
group_id, user_id, group_id, requester_user_id,
is_admin=True, is_admin=True,
is_public=True, # TODO is_public=True, # TODO
local_attestation=local_attestation, local_attestation=local_attestation,
remote_attestation=remote_attestation, remote_attestation=remote_attestation,
) )
if not self.hs.is_mine_id(user_id): if not self.hs.is_mine_id(requester_user_id):
yield self.store.add_remote_profile_cache( yield self.store.add_remote_profile_cache(
user_id, requester_user_id,
displayname=user_profile.get("displayname"), displayname=user_profile.get("displayname"),
avatar_url=user_profile.get("avatar_url"), avatar_url=user_profile.get("avatar_url"),
) )
@ -773,15 +840,25 @@ def _parse_visibility_from_contents(content):
public or not public or not
""" """
visibility = content.get("visibility") visibility = content.get("m.visibility")
if visibility: if visibility:
vis_type = visibility["type"] return _parse_visibility_dict(visibility)
if vis_type not in ("public", "private"):
raise SynapseError(
400, "Synapse only supports 'public'/'private' visibility"
)
is_public = vis_type == "public"
else: else:
is_public = True is_public = True
return is_public return is_public
def _parse_visibility_dict(visibility):
"""Given a dict for the "m.visibility" config return if the entity should
be public or not
"""
vis_type = visibility.get("type")
if not vis_type:
return True
if vis_type not in ("public", "private"):
raise SynapseError(
400, "Synapse only supports 'public'/'private' visibility"
)
return vis_type == "public"

View file

@ -70,11 +70,10 @@ class ApplicationServicesHandler(object):
with Measure(self.clock, "notify_interested_services"): with Measure(self.clock, "notify_interested_services"):
self.is_processing = True self.is_processing = True
try: try:
upper_bound = self.current_max
limit = 100 limit = 100
while True: while True:
upper_bound, events = yield self.store.get_new_events_for_appservice( upper_bound, events = yield self.store.get_new_events_for_appservice(
upper_bound, limit self.current_max, limit
) )
if not events: if not events:
@ -105,9 +104,6 @@ class ApplicationServicesHandler(object):
) )
yield self.store.set_appservice_last_pos(upper_bound) yield self.store.set_appservice_last_pos(upper_bound)
if len(events) < limit:
break
finally: finally:
self.is_processing = False self.is_processing = False

View file

@ -13,13 +13,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.types import UserID
from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
from synapse.module_api import ModuleApi
from synapse.types import UserID
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
@ -63,10 +63,7 @@ class AuthHandler(BaseHandler):
reset_expiry_on_get=True, reset_expiry_on_get=True,
) )
account_handler = _AccountHandler( account_handler = ModuleApi(hs, self)
hs, check_user_exists=self.check_user_exists
)
self.password_providers = [ self.password_providers = [
module(config=config, account_handler=account_handler) module(config=config, account_handler=account_handler)
for module, config in hs.config.password_providers for module, config in hs.config.password_providers
@ -75,14 +72,24 @@ class AuthHandler(BaseHandler):
logger.info("Extra password_providers: %r", self.password_providers) logger.info("Extra password_providers: %r", self.password_providers)
self.hs = hs # FIXME better possibility to access registrationHandler later? self.hs = hs # FIXME better possibility to access registrationHandler later?
self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator() self.macaroon_gen = hs.get_macaroon_generator()
self._password_enabled = hs.config.password_enabled
login_types = set()
if self._password_enabled:
login_types.add(LoginType.PASSWORD)
for provider in self.password_providers:
if hasattr(provider, "get_supported_login_types"):
login_types.update(
provider.get_supported_login_types().keys()
)
self._supported_login_types = frozenset(login_types)
@defer.inlineCallbacks @defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip): def check_auth(self, flows, clientdict, clientip):
""" """
Takes a dictionary sent by the client in the login / registration Takes a dictionary sent by the client in the login / registration
protocol and handles the login flow. protocol and handles the User-Interactive Auth flow.
As a side effect, this function fills in the 'creds' key on the user's As a side effect, this function fills in the 'creds' key on the user's
session with a map, which maps each auth-type (str) to the relevant session with a map, which maps each auth-type (str) to the relevant
@ -260,16 +267,19 @@ class AuthHandler(BaseHandler):
sess = self._get_session_info(session_id) sess = self._get_session_info(session_id)
return sess.setdefault('serverdict', {}).get(key, default) return sess.setdefault('serverdict', {}).get(key, default)
@defer.inlineCallbacks
def _check_password_auth(self, authdict, _): def _check_password_auth(self, authdict, _):
if "user" not in authdict or "password" not in authdict: if "user" not in authdict or "password" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM) raise LoginError(400, "", Codes.MISSING_PARAM)
user_id = authdict["user"] user_id = authdict["user"]
password = authdict["password"] password = authdict["password"]
if not user_id.startswith('@'):
user_id = UserID(user_id, self.hs.hostname).to_string()
return self._check_password(user_id, password) (canonical_id, callback) = yield self.validate_login(user_id, {
"type": LoginType.PASSWORD,
"password": password,
})
defer.returnValue(canonical_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_recaptcha(self, authdict, clientip): def _check_recaptcha(self, authdict, clientip):
@ -398,26 +408,8 @@ class AuthHandler(BaseHandler):
return self.sessions[session_id] return self.sessions[session_id]
def validate_password_login(self, user_id, password):
"""
Authenticates the user with their username and password.
Used only by the v1 login API.
Args:
user_id (str): complete @user:id
password (str): Password
Returns:
defer.Deferred: (str) canonical user id
Raises:
StoreError if there was a problem accessing the database
LoginError if there was an authentication problem.
"""
return self._check_password(user_id, password)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_access_token_for_user_id(self, user_id, device_id=None, def get_access_token_for_user_id(self, user_id, device_id=None):
initial_display_name=None):
""" """
Creates a new access token for the user with the given user ID. Creates a new access token for the user with the given user ID.
@ -431,13 +423,10 @@ class AuthHandler(BaseHandler):
device_id (str|None): the device ID to associate with the tokens. device_id (str|None): the device ID to associate with the tokens.
None to leave the tokens unassociated with a device (deprecated: None to leave the tokens unassociated with a device (deprecated:
we should always have a device ID) we should always have a device ID)
initial_display_name (str): display name to associate with the
device if it needs re-registering
Returns: Returns:
The access token for the user's session. The access token for the user's session.
Raises: Raises:
StoreError if there was a problem storing the token. StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
""" """
logger.info("Logging in user %s on device %s", user_id, device_id) logger.info("Logging in user %s on device %s", user_id, device_id)
access_token = yield self.issue_access_token(user_id, device_id) access_token = yield self.issue_access_token(user_id, device_id)
@ -447,9 +436,11 @@ class AuthHandler(BaseHandler):
# really don't want is active access_tokens without a record of the # really don't want is active access_tokens without a record of the
# device, so we double-check it here. # device, so we double-check it here.
if device_id is not None: if device_id is not None:
yield self.device_handler.check_device_registered( try:
user_id, device_id, initial_display_name yield self.store.get_device(user_id, device_id)
) except StoreError:
yield self.store.delete_access_token(access_token)
raise StoreError(400, "Login raced against device deletion")
defer.returnValue(access_token) defer.returnValue(access_token)
@ -501,29 +492,115 @@ class AuthHandler(BaseHandler):
) )
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks def get_supported_login_types(self):
def _check_password(self, user_id, password): """Get a the login types supported for the /login API
"""Authenticate a user against the LDAP and local databases.
user_id is checked case insensitively against the local database, but By default this is just 'm.login.password' (unless password_enabled is
will throw if there are multiple inexact matches. False in the config file), but password auth providers can provide
other login types.
Returns:
Iterable[str]: login types
"""
return self._supported_login_types
@defer.inlineCallbacks
def validate_login(self, username, login_submission):
"""Authenticates the user for the /login API
Also used by the user-interactive auth flow to validate
m.login.password auth types.
Args: Args:
user_id (str): complete @user:id username (str): username supplied by the user
login_submission (dict): the whole of the login submission
(including 'type' and other relevant fields)
Returns: Returns:
(str) the canonical_user_id Deferred[str, func]: canonical user id, and optional callback
to be called once the access token and device id are issued
Raises: Raises:
LoginError if login fails StoreError if there was a problem accessing the database
SynapseError if there was a problem with the request
LoginError if there was an authentication problem.
""" """
for provider in self.password_providers:
is_valid = yield provider.check_password(user_id, password)
if is_valid:
defer.returnValue(user_id)
canonical_user_id = yield self._check_local_password(user_id, password) if username.startswith('@'):
qualified_user_id = username
else:
qualified_user_id = UserID(
username, self.hs.hostname
).to_string()
login_type = login_submission.get("type")
known_login_type = False
# special case to check for "password" for the check_password interface
# for the auth providers
password = login_submission.get("password")
if login_type == LoginType.PASSWORD:
if not self._password_enabled:
raise SynapseError(400, "Password login has been disabled.")
if not password:
raise SynapseError(400, "Missing parameter: password")
for provider in self.password_providers:
if (hasattr(provider, "check_password")
and login_type == LoginType.PASSWORD):
known_login_type = True
is_valid = yield provider.check_password(
qualified_user_id, password,
)
if is_valid:
defer.returnValue(qualified_user_id)
if (not hasattr(provider, "get_supported_login_types")
or not hasattr(provider, "check_auth")):
# this password provider doesn't understand custom login types
continue
supported_login_types = provider.get_supported_login_types()
if login_type not in supported_login_types:
# this password provider doesn't understand this login type
continue
known_login_type = True
login_fields = supported_login_types[login_type]
missing_fields = []
login_dict = {}
for f in login_fields:
if f not in login_submission:
missing_fields.append(f)
else:
login_dict[f] = login_submission[f]
if missing_fields:
raise SynapseError(
400, "Missing parameters for login type %s: %s" % (
login_type,
missing_fields,
),
)
result = yield provider.check_auth(
username, login_type, login_dict,
)
if result:
if isinstance(result, str):
result = (result, None)
defer.returnValue(result)
if login_type == LoginType.PASSWORD:
known_login_type = True
canonical_user_id = yield self._check_local_password(
qualified_user_id, password,
)
if canonical_user_id: if canonical_user_id:
defer.returnValue(canonical_user_id) defer.returnValue((canonical_user_id, None))
if not known_login_type:
raise SynapseError(400, "Unknown login type %s" % login_type)
# unknown username or invalid password. We raise a 403 here, but note # unknown username or invalid password. We raise a 403 here, but note
# that if we're doing user-interactive login, it turns all LoginErrors # that if we're doing user-interactive login, it turns all LoginErrors
@ -584,13 +661,80 @@ class AuthHandler(BaseHandler):
if e.code == 404: if e.code == 404:
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND) raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
raise e raise e
yield self.store.user_delete_access_tokens( yield self.delete_access_tokens_for_user(
user_id, except_access_token_id user_id, except_token_id=except_access_token_id,
) )
yield self.hs.get_pusherpool().remove_pushers_by_user( yield self.hs.get_pusherpool().remove_pushers_by_user(
user_id, except_access_token_id user_id, except_access_token_id
) )
@defer.inlineCallbacks
def deactivate_account(self, user_id):
"""Deactivate a user's account
Args:
user_id (str): ID of user to be deactivated
Returns:
Deferred
"""
# FIXME: Theoretically there is a race here wherein user resets
# password using threepid.
yield self.delete_access_tokens_for_user(user_id)
yield self.store.user_delete_threepids(user_id)
yield self.store.user_set_password_hash(user_id, None)
@defer.inlineCallbacks
def delete_access_token(self, access_token):
"""Invalidate a single access token
Args:
access_token (str): access token to be deleted
Returns:
Deferred
"""
user_info = yield self.auth.get_user_by_access_token(access_token)
yield self.store.delete_access_token(access_token)
# see if any of our auth providers want to know about this
for provider in self.password_providers:
if hasattr(provider, "on_logged_out"):
yield provider.on_logged_out(
user_id=str(user_info["user"]),
device_id=user_info["device_id"],
access_token=access_token,
)
@defer.inlineCallbacks
def delete_access_tokens_for_user(self, user_id, except_token_id=None,
device_id=None):
"""Invalidate access tokens belonging to a user
Args:
user_id (str): ID of user the tokens belong to
except_token_id (str|None): access_token ID which should *not* be
deleted
device_id (str|None): ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will
be deleted
Returns:
Deferred
"""
tokens_and_devices = yield self.store.user_delete_access_tokens(
user_id, except_token_id=except_token_id, device_id=device_id,
)
# see if any of our auth providers want to know about this
for provider in self.password_providers:
if hasattr(provider, "on_logged_out"):
for token, device_id in tokens_and_devices:
yield provider.on_logged_out(
user_id=user_id,
device_id=device_id,
access_token=token,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_threepid(self, user_id, medium, address, validated_at): def add_threepid(self, user_id, medium, address, validated_at):
# 'Canonicalise' email addresses down to lower case. # 'Canonicalise' email addresses down to lower case.
@ -696,30 +840,3 @@ class MacaroonGeneartor(object):
macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon return macaroon
class _AccountHandler(object):
"""A proxy object that gets passed to password auth providers so they
can register new users etc if necessary.
"""
def __init__(self, hs, check_user_exists):
self.hs = hs
self._check_user_exists = check_user_exists
def check_user_exists(self, user_id):
"""Check if user exissts.
Returns:
Deferred(bool)
"""
return self._check_user_exists(user_id)
def register(self, localpart):
"""Registers a new user with given localpart
Returns:
Deferred: a 2-tuple of (user_id, access_token)
"""
reg = self.hs.get_handlers().registration_handler
return reg.register(localpart=localpart)

View file

@ -34,6 +34,7 @@ class DeviceHandler(BaseHandler):
self.hs = hs self.hs = hs
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self._auth_handler = hs.get_auth_handler()
self.federation_sender = hs.get_federation_sender() self.federation_sender = hs.get_federation_sender()
self.federation = hs.get_replication_layer() self.federation = hs.get_replication_layer()
@ -159,9 +160,8 @@ class DeviceHandler(BaseHandler):
else: else:
raise raise
yield self.store.user_delete_access_tokens( yield self._auth_handler.delete_access_tokens_for_user(
user_id, device_id=device_id, user_id, device_id=device_id,
delete_refresh_tokens=True,
) )
yield self.store.delete_e2e_keys_by_device( yield self.store.delete_e2e_keys_by_device(
@ -194,9 +194,8 @@ class DeviceHandler(BaseHandler):
# Delete access tokens and e2e keys for each device. Not optimised as it is not # Delete access tokens and e2e keys for each device. Not optimised as it is not
# considered as part of a critical path. # considered as part of a critical path.
for device_id in device_ids: for device_id in device_ids:
yield self.store.user_delete_access_tokens( yield self._auth_handler.delete_access_tokens_for_user(
user_id, device_id=device_id, user_id, device_id=device_id,
delete_refresh_tokens=True,
) )
yield self.store.delete_e2e_keys_by_device( yield self.store.delete_e2e_keys_by_device(
user_id=user_id, device_id=device_id user_id=user_id, device_id=device_id

View file

@ -1706,6 +1706,17 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def do_auth(self, origin, event, context, auth_events): def do_auth(self, origin, event, context, auth_events):
"""
Args:
origin (str):
event (synapse.events.FrozenEvent):
context (synapse.events.snapshot.EventContext):
auth_events (dict[(str, str)->str]):
Returns:
defer.Deferred[None]
"""
# Check if we have all the auth events. # Check if we have all the auth events.
current_state = set(e.event_id for e in auth_events.values()) current_state = set(e.event_id for e in auth_events.values())
event_auth_events = set(e_id for e_id, _ in event.auth_events) event_auth_events = set(e_id for e_id, _ in event.auth_events)
@ -1817,16 +1828,9 @@ class FederationHandler(BaseHandler):
current_state = set(e.event_id for e in auth_events.values()) current_state = set(e.event_id for e in auth_events.values())
different_auth = event_auth_events - current_state different_auth = event_auth_events - current_state
context.current_state_ids = dict(context.current_state_ids) self._update_context_for_auth_events(
context.current_state_ids.update({ context, auth_events, event_key,
k: a.event_id for k, a in auth_events.items() )
if k != event_key
})
context.prev_state_ids = dict(context.prev_state_ids)
context.prev_state_ids.update({
k: a.event_id for k, a in auth_events.items()
})
context.state_group = self.store.get_next_state_group()
if different_auth and not event.internal_metadata.is_outlier(): if different_auth and not event.internal_metadata.is_outlier():
logger.info("Different auth after resolution: %s", different_auth) logger.info("Different auth after resolution: %s", different_auth)
@ -1906,16 +1910,9 @@ class FederationHandler(BaseHandler):
# 4. Look at rejects and their proofs. # 4. Look at rejects and their proofs.
# TODO. # TODO.
context.current_state_ids = dict(context.current_state_ids) self._update_context_for_auth_events(
context.current_state_ids.update({ context, auth_events, event_key,
k: a.event_id for k, a in auth_events.items() )
if k != event_key
})
context.prev_state_ids = dict(context.prev_state_ids)
context.prev_state_ids.update({
k: a.event_id for k, a in auth_events.items()
})
context.state_group = self.store.get_next_state_group()
try: try:
self.auth.check(event, auth_events=auth_events) self.auth.check(event, auth_events=auth_events)
@ -1923,6 +1920,35 @@ class FederationHandler(BaseHandler):
logger.warn("Failed auth resolution for %r because %s", event, e) logger.warn("Failed auth resolution for %r because %s", event, e)
raise e raise e
def _update_context_for_auth_events(self, context, auth_events,
event_key):
"""Update the state_ids in an event context after auth event resolution
Args:
context (synapse.events.snapshot.EventContext): event context
to be updated
auth_events (dict[(str, str)->str]): Events to update in the event
context.
event_key ((str, str)): (type, state_key) for the current event.
this will not be included in the current_state in the context.
"""
state_updates = {
k: a.event_id for k, a in auth_events.iteritems()
if k != event_key
}
context.current_state_ids = dict(context.current_state_ids)
context.current_state_ids.update(state_updates)
if context.delta_ids is not None:
context.delta_ids = dict(context.delta_ids)
context.delta_ids.update(state_updates)
context.prev_state_ids = dict(context.prev_state_ids)
context.prev_state_ids.update({
k: a.event_id for k, a in auth_events.iteritems()
})
context.state_group = self.store.get_next_state_group()
@defer.inlineCallbacks @defer.inlineCallbacks
def construct_auth_difference(self, local_auth, remote_auth): def construct_auth_difference(self, local_auth, remote_auth):
""" Given a local and remote auth chain, find the differences. This """ Given a local and remote auth chain, find the differences. This

View file

@ -71,6 +71,7 @@ class GroupsLocalHandler(object):
get_invited_users_in_group = _create_rerouter("get_invited_users_in_group") get_invited_users_in_group = _create_rerouter("get_invited_users_in_group")
add_room_to_group = _create_rerouter("add_room_to_group") add_room_to_group = _create_rerouter("add_room_to_group")
update_room_in_group = _create_rerouter("update_room_in_group")
remove_room_from_group = _create_rerouter("remove_room_from_group") remove_room_from_group = _create_rerouter("remove_room_from_group")
update_group_summary_room = _create_rerouter("update_group_summary_room") update_group_summary_room = _create_rerouter("update_group_summary_room")

View file

@ -17,7 +17,6 @@ import logging
from twisted.internet import defer from twisted.internet import defer
import synapse.types
from synapse.api.errors import SynapseError, AuthError, CodeMessageException from synapse.api.errors import SynapseError, AuthError, CodeMessageException
from synapse.types import UserID, get_domain_from_id from synapse.types import UserID, get_domain_from_id
from ._base import BaseHandler from ._base import BaseHandler
@ -140,7 +139,7 @@ class ProfileHandler(BaseHandler):
target_user.localpart, new_displayname target_user.localpart, new_displayname
) )
yield self._update_join_states(requester) yield self._update_join_states(requester, target_user)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_avatar_url(self, target_user): def get_avatar_url(self, target_user):
@ -184,7 +183,7 @@ class ProfileHandler(BaseHandler):
target_user.localpart, new_avatar_url target_user.localpart, new_avatar_url
) )
yield self._update_join_states(requester) yield self._update_join_states(requester, target_user)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_profile_query(self, args): def on_profile_query(self, args):
@ -209,28 +208,24 @@ class ProfileHandler(BaseHandler):
defer.returnValue(response) defer.returnValue(response)
@defer.inlineCallbacks @defer.inlineCallbacks
def _update_join_states(self, requester): def _update_join_states(self, requester, target_user):
user = requester.user if not self.hs.is_mine(target_user):
if not self.hs.is_mine(user):
return return
yield self.ratelimit(requester) yield self.ratelimit(requester)
room_ids = yield self.store.get_rooms_for_user( room_ids = yield self.store.get_rooms_for_user(
user.to_string(), target_user.to_string(),
) )
for room_id in room_ids: for room_id in room_ids:
handler = self.hs.get_handlers().room_member_handler handler = self.hs.get_handlers().room_member_handler
try: try:
# Assume the user isn't a guest because we don't let guests set # Assume the target_user isn't a guest,
# profile or avatar data. # because we don't let guests set profile or avatar data.
# XXX why are we recreating `requester` here for each room?
# what was wrong with the `requester` we were passed?
requester = synapse.types.create_requester(user)
yield handler.update_membership( yield handler.update_membership(
requester, requester,
user, target_user,
room_id, room_id,
"join", # We treat a profile update like a join. "join", # We treat a profile update like a join.
ratelimit=False, # Try to hide that these events aren't atomic. ratelimit=False, # Try to hide that these events aren't atomic.

View file

@ -36,6 +36,7 @@ class RegistrationHandler(BaseHandler):
super(RegistrationHandler, self).__init__(hs) super(RegistrationHandler, self).__init__(hs)
self.auth = hs.get_auth() self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
self.captcha_client = CaptchaServerHttpClient(hs) self.captcha_client = CaptchaServerHttpClient(hs)
@ -416,7 +417,7 @@ class RegistrationHandler(BaseHandler):
create_profile_with_localpart=user.localpart, create_profile_with_localpart=user.localpart,
) )
else: else:
yield self.store.user_delete_access_tokens(user_id=user_id) yield self._auth_handler.delete_access_tokens_for_user(user_id)
yield self.store.add_access_token_to_user(user_id=user_id, token=token) yield self.store.add_access_token_to_user(user_id=user_id, token=token)
if displayname is not None: if displayname is not None:

View file

@ -20,6 +20,7 @@ from ._base import BaseHandler
from synapse.api.constants import ( from synapse.api.constants import (
EventTypes, JoinRules, EventTypes, JoinRules,
) )
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.util.async import concurrently_execute from synapse.util.async import concurrently_execute
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
@ -70,6 +71,7 @@ class RoomListHandler(BaseHandler):
if search_filter: if search_filter:
# We explicitly don't bother caching searches or requests for # We explicitly don't bother caching searches or requests for
# appservice specific lists. # appservice specific lists.
logger.info("Bypassing cache as search request.")
return self._get_public_room_list( return self._get_public_room_list(
limit, since_token, search_filter, network_tuple=network_tuple, limit, since_token, search_filter, network_tuple=network_tuple,
) )
@ -77,13 +79,16 @@ class RoomListHandler(BaseHandler):
key = (limit, since_token, network_tuple) key = (limit, since_token, network_tuple)
result = self.response_cache.get(key) result = self.response_cache.get(key)
if not result: if not result:
logger.info("No cached result, calculating one.")
result = self.response_cache.set( result = self.response_cache.set(
key, key,
self._get_public_room_list( preserve_fn(self._get_public_room_list)(
limit, since_token, network_tuple=network_tuple limit, since_token, network_tuple=network_tuple
) )
) )
return result else:
logger.info("Using cached deferred result.")
return make_deferred_yieldable(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_public_room_list(self, limit=None, since_token=None, def _get_public_room_list(self, limit=None, since_token=None,

View file

@ -15,7 +15,7 @@
from synapse.api.constants import Membership, EventTypes from synapse.api.constants import Membership, EventTypes
from synapse.util.async import concurrently_execute from synapse.util.async import concurrently_execute
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext, make_deferred_yieldable, preserve_fn
from synapse.util.metrics import Measure, measure_func from synapse.util.metrics import Measure, measure_func
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from synapse.push.clientformat import format_push_rules_for_user from synapse.push.clientformat import format_push_rules_for_user
@ -184,11 +184,11 @@ class SyncHandler(object):
if not result: if not result:
result = self.response_cache.set( result = self.response_cache.set(
sync_config.request_key, sync_config.request_key,
self._wait_for_sync_for_user( preserve_fn(self._wait_for_sync_for_user)(
sync_config, since_token, timeout, full_state sync_config, since_token, timeout, full_state
) )
) )
return result return make_deferred_yieldable(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def _wait_for_sync_for_user(self, sync_config, since_token, timeout, def _wait_for_sync_for_user(self, sync_config, since_token, timeout,

View file

@ -152,7 +152,7 @@ class UserDirectoyHandler(object):
for room_id in room_ids: for room_id in room_ids:
logger.info("Handling room %d/%d", num_processed_rooms, len(room_ids)) logger.info("Handling room %d/%d", num_processed_rooms, len(room_ids))
yield self._handle_intial_room(room_id) yield self._handle_initial_room(room_id)
num_processed_rooms += 1 num_processed_rooms += 1
yield sleep(self.INITIAL_SLEEP_MS / 1000.) yield sleep(self.INITIAL_SLEEP_MS / 1000.)
@ -166,7 +166,7 @@ class UserDirectoyHandler(object):
yield self.store.update_user_directory_stream_pos(new_pos) yield self.store.update_user_directory_stream_pos(new_pos)
@defer.inlineCallbacks @defer.inlineCallbacks
def _handle_intial_room(self, room_id): def _handle_initial_room(self, room_id):
"""Called when we initially fill out user_directory one room at a time """Called when we initially fill out user_directory one room at a time
""" """
is_in_room = yield self.store.is_host_joined(room_id, self.server_name) is_in_room = yield self.store.is_host_joined(room_id, self.server_name)

View file

@ -0,0 +1,55 @@
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.http.server import wrap_request_handler
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
class AdditionalResource(Resource):
"""Resource wrapper for additional_resources
If the user has configured additional_resources, we need to wrap the
handler class with a Resource so that we can map it into the resource tree.
This class is also where we wrap the request handler with logging, metrics,
and exception handling.
"""
def __init__(self, hs, handler):
"""Initialise AdditionalResource
The ``handler`` should return a deferred which completes when it has
done handling the request. It should write a response with
``request.write()``, and call ``request.finish()``.
Args:
hs (synapse.server.HomeServer): homeserver
handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred):
function to be called to handle the request.
"""
Resource.__init__(self)
self._handler = handler
# these are required by the request_handler wrapper
self.version_string = hs.version_string
self.clock = hs.get_clock()
def render(self, request):
self._async_render(request)
return NOT_DONE_YET
@wrap_request_handler
def _async_render(self, request):
return self._handler(request)

View file

@ -18,7 +18,7 @@ from OpenSSL.SSL import VERIFY_NONE
from synapse.api.errors import ( from synapse.api.errors import (
CodeMessageException, MatrixCodeMessageException, SynapseError, Codes, CodeMessageException, MatrixCodeMessageException, SynapseError, Codes,
) )
from synapse.util.logcontext import preserve_context_over_fn from synapse.util.logcontext import make_deferred_yieldable
from synapse.util import logcontext from synapse.util import logcontext
import synapse.metrics import synapse.metrics
from synapse.http.endpoint import SpiderEndpoint from synapse.http.endpoint import SpiderEndpoint
@ -114,43 +114,73 @@ class SimpleHttpClient(object):
raise e raise e
@defer.inlineCallbacks @defer.inlineCallbacks
def post_urlencoded_get_json(self, uri, args={}): def post_urlencoded_get_json(self, uri, args={}, headers=None):
"""
Args:
uri (str):
args (dict[str, str|List[str]]): query params
headers (dict[str, List[str]]|None): If not None, a map from
header name to a list of values for that header
Returns:
Deferred[object]: parsed json
"""
# TODO: Do we ever want to log message contents? # TODO: Do we ever want to log message contents?
logger.debug("post_urlencoded_get_json args: %s", args) logger.debug("post_urlencoded_get_json args: %s", args)
query_bytes = urllib.urlencode(encode_urlencode_args(args), True) query_bytes = urllib.urlencode(encode_urlencode_args(args), True)
actual_headers = {
b"Content-Type": [b"application/x-www-form-urlencoded"],
b"User-Agent": [self.user_agent],
}
if headers:
actual_headers.update(headers)
response = yield self.request( response = yield self.request(
"POST", "POST",
uri.encode("ascii"), uri.encode("ascii"),
headers=Headers({ headers=Headers(actual_headers),
b"Content-Type": [b"application/x-www-form-urlencoded"],
b"User-Agent": [self.user_agent],
}),
bodyProducer=FileBodyProducer(StringIO(query_bytes)) bodyProducer=FileBodyProducer(StringIO(query_bytes))
) )
body = yield preserve_context_over_fn(readBody, response) body = yield make_deferred_yieldable(readBody(response))
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@defer.inlineCallbacks @defer.inlineCallbacks
def post_json_get_json(self, uri, post_json): def post_json_get_json(self, uri, post_json, headers=None):
"""
Args:
uri (str):
post_json (object):
headers (dict[str, List[str]]|None): If not None, a map from
header name to a list of values for that header
Returns:
Deferred[object]: parsed json
"""
json_str = encode_canonical_json(post_json) json_str = encode_canonical_json(post_json)
logger.debug("HTTP POST %s -> %s", json_str, uri) logger.debug("HTTP POST %s -> %s", json_str, uri)
actual_headers = {
b"Content-Type": [b"application/json"],
b"User-Agent": [self.user_agent],
}
if headers:
actual_headers.update(headers)
response = yield self.request( response = yield self.request(
"POST", "POST",
uri.encode("ascii"), uri.encode("ascii"),
headers=Headers({ headers=Headers(actual_headers),
b"Content-Type": [b"application/json"],
b"User-Agent": [self.user_agent],
}),
bodyProducer=FileBodyProducer(StringIO(json_str)) bodyProducer=FileBodyProducer(StringIO(json_str))
) )
body = yield preserve_context_over_fn(readBody, response) body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300: if 200 <= response.code < 300:
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@ -160,7 +190,7 @@ class SimpleHttpClient(object):
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_json(self, uri, args={}): def get_json(self, uri, args={}, headers=None):
""" Gets some json from the given URI. """ Gets some json from the given URI.
Args: Args:
@ -169,6 +199,8 @@ class SimpleHttpClient(object):
None. None.
**Note**: The value of each key is assumed to be an iterable **Note**: The value of each key is assumed to be an iterable
and *not* a string. and *not* a string.
headers (dict[str, List[str]]|None): If not None, a map from
header name to a list of values for that header
Returns: Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON. HTTP body as JSON.
@ -177,13 +209,13 @@ class SimpleHttpClient(object):
error message. error message.
""" """
try: try:
body = yield self.get_raw(uri, args) body = yield self.get_raw(uri, args, headers=headers)
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
except CodeMessageException as e: except CodeMessageException as e:
raise self._exceptionFromFailedRequest(e.code, e.msg) raise self._exceptionFromFailedRequest(e.code, e.msg)
@defer.inlineCallbacks @defer.inlineCallbacks
def put_json(self, uri, json_body, args={}): def put_json(self, uri, json_body, args={}, headers=None):
""" Puts some json to the given URI. """ Puts some json to the given URI.
Args: Args:
@ -193,6 +225,8 @@ class SimpleHttpClient(object):
None. None.
**Note**: The value of each key is assumed to be an iterable **Note**: The value of each key is assumed to be an iterable
and *not* a string. and *not* a string.
headers (dict[str, List[str]]|None): If not None, a map from
header name to a list of values for that header
Returns: Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON. HTTP body as JSON.
@ -205,17 +239,21 @@ class SimpleHttpClient(object):
json_str = encode_canonical_json(json_body) json_str = encode_canonical_json(json_body)
actual_headers = {
b"Content-Type": [b"application/json"],
b"User-Agent": [self.user_agent],
}
if headers:
actual_headers.update(headers)
response = yield self.request( response = yield self.request(
"PUT", "PUT",
uri.encode("ascii"), uri.encode("ascii"),
headers=Headers({ headers=Headers(actual_headers),
b"User-Agent": [self.user_agent],
"Content-Type": ["application/json"]
}),
bodyProducer=FileBodyProducer(StringIO(json_str)) bodyProducer=FileBodyProducer(StringIO(json_str))
) )
body = yield preserve_context_over_fn(readBody, response) body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300: if 200 <= response.code < 300:
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@ -226,7 +264,7 @@ class SimpleHttpClient(object):
raise CodeMessageException(response.code, body) raise CodeMessageException(response.code, body)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_raw(self, uri, args={}): def get_raw(self, uri, args={}, headers=None):
""" Gets raw text from the given URI. """ Gets raw text from the given URI.
Args: Args:
@ -235,6 +273,8 @@ class SimpleHttpClient(object):
None. None.
**Note**: The value of each key is assumed to be an iterable **Note**: The value of each key is assumed to be an iterable
and *not* a string. and *not* a string.
headers (dict[str, List[str]]|None): If not None, a map from
header name to a list of values for that header
Returns: Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body at text. HTTP body at text.
@ -246,15 +286,19 @@ class SimpleHttpClient(object):
query_bytes = urllib.urlencode(args, True) query_bytes = urllib.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes) uri = "%s?%s" % (uri, query_bytes)
actual_headers = {
b"User-Agent": [self.user_agent],
}
if headers:
actual_headers.update(headers)
response = yield self.request( response = yield self.request(
"GET", "GET",
uri.encode("ascii"), uri.encode("ascii"),
headers=Headers({ headers=Headers(actual_headers),
b"User-Agent": [self.user_agent],
})
) )
body = yield preserve_context_over_fn(readBody, response) body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300: if 200 <= response.code < 300:
defer.returnValue(body) defer.returnValue(body)
@ -274,27 +318,33 @@ class SimpleHttpClient(object):
# The two should be factored out. # The two should be factored out.
@defer.inlineCallbacks @defer.inlineCallbacks
def get_file(self, url, output_stream, max_size=None): def get_file(self, url, output_stream, max_size=None, headers=None):
"""GETs a file from a given URL """GETs a file from a given URL
Args: Args:
url (str): The URL to GET url (str): The URL to GET
output_stream (file): File to write the response body to. output_stream (file): File to write the response body to.
headers (dict[str, List[str]]|None): If not None, a map from
header name to a list of values for that header
Returns: Returns:
A (int,dict,string,int) tuple of the file length, dict of the response A (int,dict,string,int) tuple of the file length, dict of the response
headers, absolute URI of the response and HTTP response code. headers, absolute URI of the response and HTTP response code.
""" """
actual_headers = {
b"User-Agent": [self.user_agent],
}
if headers:
actual_headers.update(headers)
response = yield self.request( response = yield self.request(
"GET", "GET",
url.encode("ascii"), url.encode("ascii"),
headers=Headers({ headers=Headers(actual_headers),
b"User-Agent": [self.user_agent],
})
) )
headers = dict(response.headers.getAllRawHeaders()) resp_headers = dict(response.headers.getAllRawHeaders())
if 'Content-Length' in headers and headers['Content-Length'] > max_size: if 'Content-Length' in resp_headers and resp_headers['Content-Length'] > max_size:
logger.warn("Requested URL is too large > %r bytes" % (self.max_size,)) logger.warn("Requested URL is too large > %r bytes" % (self.max_size,))
raise SynapseError( raise SynapseError(
502, 502,
@ -315,10 +365,9 @@ class SimpleHttpClient(object):
# straight back in again # straight back in again
try: try:
length = yield preserve_context_over_fn( length = yield make_deferred_yieldable(_readBodyToFile(
_readBodyToFile, response, output_stream, max_size,
response, output_stream, max_size ))
)
except Exception as e: except Exception as e:
logger.exception("Failed to download body") logger.exception("Failed to download body")
raise SynapseError( raise SynapseError(
@ -327,7 +376,9 @@ class SimpleHttpClient(object):
Codes.UNKNOWN, Codes.UNKNOWN,
) )
defer.returnValue((length, headers, response.request.absoluteURI, response.code)) defer.returnValue(
(length, resp_headers, response.request.absoluteURI, response.code),
)
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
@ -395,7 +446,7 @@ class CaptchaServerHttpClient(SimpleHttpClient):
) )
try: try:
body = yield preserve_context_over_fn(readBody, response) body = yield make_deferred_yieldable(readBody(response))
defer.returnValue(body) defer.returnValue(body)
except PartialDownloadError as e: except PartialDownloadError as e:
# twisted dislikes google's response, no content length. # twisted dislikes google's response, no content length.

View file

@ -167,7 +167,8 @@ def parse_json_value_from_request(request):
try: try:
content = simplejson.loads(content_bytes) content = simplejson.loads(content_bytes)
except simplejson.JSONDecodeError: except Exception as e:
logger.warn("Unable to parse JSON: %s", e)
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
return content return content

View file

@ -0,0 +1,113 @@
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.types import UserID
class ModuleApi(object):
"""A proxy object that gets passed to password auth providers so they
can register new users etc if necessary.
"""
def __init__(self, hs, auth_handler):
self.hs = hs
self._store = hs.get_datastore()
self._auth = hs.get_auth()
self._auth_handler = auth_handler
def get_user_by_req(self, req, allow_guest=False):
"""Check the access_token provided for a request
Args:
req (twisted.web.server.Request): Incoming HTTP request
allow_guest (bool): True if guest users should be allowed. If this
is False, and the access token is for a guest user, an
AuthError will be thrown
Returns:
twisted.internet.defer.Deferred[synapse.types.Requester]:
the requester for this request
Raises:
synapse.api.errors.AuthError: if no user by that token exists,
or the token is invalid.
"""
return self._auth.get_user_by_req(req, allow_guest)
def get_qualified_user_id(self, username):
"""Qualify a user id, if necessary
Takes a user id provided by the user and adds the @ and :domain to
qualify it, if necessary
Args:
username (str): provided user id
Returns:
str: qualified @user:id
"""
if username.startswith('@'):
return username
return UserID(username, self.hs.hostname).to_string()
def check_user_exists(self, user_id):
"""Check if user exists.
Args:
user_id (str): Complete @user:id
Returns:
Deferred[str|None]: Canonical (case-corrected) user_id, or None
if the user is not registered.
"""
return self._auth_handler.check_user_exists(user_id)
def register(self, localpart):
"""Registers a new user with given localpart
Returns:
Deferred: a 2-tuple of (user_id, access_token)
"""
reg = self.hs.get_handlers().registration_handler
return reg.register(localpart=localpart)
def invalidate_access_token(self, access_token):
"""Invalidate an access token for a user
Args:
access_token(str): access token
Returns:
twisted.internet.defer.Deferred - resolves once the access token
has been removed.
Raises:
synapse.api.errors.AuthError: the access token is invalid
"""
return self._auth_handler.delete_access_token(access_token)
def run_db_interaction(self, desc, func, *args, **kwargs):
"""Run a function with a database connection
Args:
desc (str): description for the transaction, for metrics etc
func (func): function to be run. Passed a database cursor object
as well as *args and **kwargs
*args: positional args to be passed to func
**kwargs: named args to be passed to func
Returns:
Deferred[object]: result of func
"""
return self._store.runInteraction(desc, func, *args, **kwargs)

View file

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
class BaseSlavedStore(SQLBaseStore): class BaseSlavedStore(SQLBaseStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(BaseSlavedStore, self).__init__(hs) super(BaseSlavedStore, self).__init__(db_conn, hs)
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = SlavedIdTracker( self._cache_id_gen = SlavedIdTracker(
db_conn, "cache_invalidation_stream", "stream_id", db_conn, "cache_invalidation_stream", "stream_id",

View file

@ -137,7 +137,7 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)") PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)")
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self._auth_handler = hs.get_auth_handler()
super(DeactivateAccountRestServlet, self).__init__(hs) super(DeactivateAccountRestServlet, self).__init__(hs)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -149,12 +149,7 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
if not is_admin: if not is_admin:
raise AuthError(403, "You are not a server admin") raise AuthError(403, "You are not a server admin")
# FIXME: Theoretically there is a race here wherein user resets password yield self._auth_handler.deactivate_account(target_user_id)
# using threepid.
yield self.store.user_delete_access_tokens(target_user_id)
yield self.store.user_delete_threepids(target_user_id)
yield self.store.user_set_password_hash(target_user_id, None)
defer.returnValue((200, {})) defer.returnValue((200, {}))

View file

@ -85,7 +85,6 @@ def login_id_thirdparty_from_phone(identifier):
class LoginRestServlet(ClientV1RestServlet): class LoginRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login$") PATTERNS = client_path_patterns("/login$")
PASS_TYPE = "m.login.password"
SAML2_TYPE = "m.login.saml2" SAML2_TYPE = "m.login.saml2"
CAS_TYPE = "m.login.cas" CAS_TYPE = "m.login.cas"
TOKEN_TYPE = "m.login.token" TOKEN_TYPE = "m.login.token"
@ -94,7 +93,6 @@ class LoginRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(LoginRestServlet, self).__init__(hs) super(LoginRestServlet, self).__init__(hs)
self.idp_redirect_url = hs.config.saml2_idp_redirect_url self.idp_redirect_url = hs.config.saml2_idp_redirect_url
self.password_enabled = hs.config.password_enabled
self.saml2_enabled = hs.config.saml2_enabled self.saml2_enabled = hs.config.saml2_enabled
self.jwt_enabled = hs.config.jwt_enabled self.jwt_enabled = hs.config.jwt_enabled
self.jwt_secret = hs.config.jwt_secret self.jwt_secret = hs.config.jwt_secret
@ -121,8 +119,10 @@ class LoginRestServlet(ClientV1RestServlet):
# fall back to the fallback API if they don't understand one of the # fall back to the fallback API if they don't understand one of the
# login flow types returned. # login flow types returned.
flows.append({"type": LoginRestServlet.TOKEN_TYPE}) flows.append({"type": LoginRestServlet.TOKEN_TYPE})
if self.password_enabled:
flows.append({"type": LoginRestServlet.PASS_TYPE}) flows.extend((
{"type": t} for t in self.auth_handler.get_supported_login_types()
))
return (200, {"flows": flows}) return (200, {"flows": flows})
@ -133,13 +133,7 @@ class LoginRestServlet(ClientV1RestServlet):
def on_POST(self, request): def on_POST(self, request):
login_submission = parse_json_object_from_request(request) login_submission = parse_json_object_from_request(request)
try: try:
if login_submission["type"] == LoginRestServlet.PASS_TYPE: if self.saml2_enabled and (login_submission["type"] ==
if not self.password_enabled:
raise SynapseError(400, "Password login has been disabled.")
result = yield self.do_password_login(login_submission)
defer.returnValue(result)
elif self.saml2_enabled and (login_submission["type"] ==
LoginRestServlet.SAML2_TYPE): LoginRestServlet.SAML2_TYPE):
relay_state = "" relay_state = ""
if "relay_state" in login_submission: if "relay_state" in login_submission:
@ -157,15 +151,31 @@ class LoginRestServlet(ClientV1RestServlet):
result = yield self.do_token_login(login_submission) result = yield self.do_token_login(login_submission)
defer.returnValue(result) defer.returnValue(result)
else: else:
raise SynapseError(400, "Bad login type.") result = yield self._do_other_login(login_submission)
defer.returnValue(result)
except KeyError: except KeyError:
raise SynapseError(400, "Missing JSON keys.") raise SynapseError(400, "Missing JSON keys.")
@defer.inlineCallbacks @defer.inlineCallbacks
def do_password_login(self, login_submission): def _do_other_login(self, login_submission):
if "password" not in login_submission: """Handle non-token/saml/jwt logins
raise SynapseError(400, "Missing parameter: password")
Args:
login_submission:
Returns:
(int, object): HTTP code/response
"""
# Log the request we got, but only certain fields to minimise the chance of
# logging someone's password (even if they accidentally put it in the wrong
# field)
logger.info(
"Got login request with identifier: %r, medium: %r, address: %r, user: %r",
login_submission.get('identifier'),
login_submission.get('medium'),
login_submission.get('address'),
login_submission.get('user'),
)
login_submission_legacy_convert(login_submission) login_submission_legacy_convert(login_submission)
if "identifier" not in login_submission: if "identifier" not in login_submission:
@ -208,30 +218,29 @@ class LoginRestServlet(ClientV1RestServlet):
if "user" not in identifier: if "user" not in identifier:
raise SynapseError(400, "User identifier is missing 'user' key") raise SynapseError(400, "User identifier is missing 'user' key")
user_id = identifier["user"]
if not user_id.startswith('@'):
user_id = UserID(
user_id, self.hs.hostname
).to_string()
auth_handler = self.auth_handler auth_handler = self.auth_handler
user_id = yield auth_handler.validate_password_login( canonical_user_id, callback = yield auth_handler.validate_login(
user_id=user_id, identifier["user"],
password=login_submission["password"], login_submission,
)
device_id = yield self._register_device(
canonical_user_id, login_submission,
) )
device_id = yield self._register_device(user_id, login_submission)
access_token = yield auth_handler.get_access_token_for_user_id( access_token = yield auth_handler.get_access_token_for_user_id(
user_id, device_id, canonical_user_id, device_id,
login_submission.get("initial_device_display_name"),
) )
result = { result = {
"user_id": user_id, # may have changed "user_id": canonical_user_id,
"access_token": access_token, "access_token": access_token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
"device_id": device_id, "device_id": device_id,
} }
if callback is not None:
yield callback(result)
defer.returnValue((200, result)) defer.returnValue((200, result))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -244,7 +253,6 @@ class LoginRestServlet(ClientV1RestServlet):
device_id = yield self._register_device(user_id, login_submission) device_id = yield self._register_device(user_id, login_submission)
access_token = yield auth_handler.get_access_token_for_user_id( access_token = yield auth_handler.get_access_token_for_user_id(
user_id, device_id, user_id, device_id,
login_submission.get("initial_device_display_name"),
) )
result = { result = {
"user_id": user_id, # may have changed "user_id": user_id, # may have changed
@ -287,7 +295,6 @@ class LoginRestServlet(ClientV1RestServlet):
) )
access_token = yield auth_handler.get_access_token_for_user_id( access_token = yield auth_handler.get_access_token_for_user_id(
registered_user_id, device_id, registered_user_id, device_id,
login_submission.get("initial_device_display_name"),
) )
result = { result = {

View file

@ -30,7 +30,7 @@ class LogoutRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(LogoutRestServlet, self).__init__(hs) super(LogoutRestServlet, self).__init__(hs)
self.store = hs.get_datastore() self._auth_handler = hs.get_auth_handler()
def on_OPTIONS(self, request): def on_OPTIONS(self, request):
return (200, {}) return (200, {})
@ -38,7 +38,7 @@ class LogoutRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
access_token = get_access_token_from_request(request) access_token = get_access_token_from_request(request)
yield self.store.delete_access_token(access_token) yield self._auth_handler.delete_access_token(access_token)
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -47,8 +47,8 @@ class LogoutAllRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(LogoutAllRestServlet, self).__init__(hs) super(LogoutAllRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
def on_OPTIONS(self, request): def on_OPTIONS(self, request):
return (200, {}) return (200, {})
@ -57,7 +57,7 @@ class LogoutAllRestServlet(ClientV1RestServlet):
def on_POST(self, request): def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()
yield self.store.user_delete_access_tokens(user_id) yield self._auth_handler.delete_access_tokens_for_user(user_id)
defer.returnValue((200, {})) defer.returnValue((200, {}))

View file

@ -359,7 +359,7 @@ class RegisterRestServlet(ClientV1RestServlet):
if compare_digest(want_mac, got_mac): if compare_digest(want_mac, got_mac):
handler = self.handlers.registration_handler handler = self.handlers.registration_handler
user_id, token = yield handler.register( user_id, token = yield handler.register(
localpart=user, localpart=user.lower(),
password=password, password=password,
admin=bool(admin), admin=bool(admin),
) )

View file

@ -13,22 +13,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api.auth import has_access_token
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import LoginError, SynapseError, Codes from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, parse_json_object_from_request, assert_params_in_request RestServlet, assert_params_in_request,
parse_json_object_from_request,
) )
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.msisdn import phone_number_to_msisdn
from ._base import client_v2_patterns from ._base import client_v2_patterns
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -163,7 +162,6 @@ class DeactivateAccountRestServlet(RestServlet):
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
super(DeactivateAccountRestServlet, self).__init__() super(DeactivateAccountRestServlet, self).__init__()
@ -172,6 +170,20 @@ class DeactivateAccountRestServlet(RestServlet):
def on_POST(self, request): def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
# if the caller provides an access token, it ought to be valid.
requester = None
if has_access_token(request):
requester = yield self.auth.get_user_by_req(
request,
) # type: synapse.types.Requester
# allow ASes to dectivate their own users
if requester and requester.app_service:
yield self.auth_handler.deactivate_account(
requester.user.to_string()
)
defer.returnValue((200, {}))
authed, result, params, _ = yield self.auth_handler.check_auth([ authed, result, params, _ = yield self.auth_handler.check_auth([
[LoginType.PASSWORD], [LoginType.PASSWORD],
], body, self.hs.get_ip_from_request(request)) ], body, self.hs.get_ip_from_request(request))
@ -179,25 +191,22 @@ class DeactivateAccountRestServlet(RestServlet):
if not authed: if not authed:
defer.returnValue((401, result)) defer.returnValue((401, result))
user_id = None
requester = None
if LoginType.PASSWORD in result: if LoginType.PASSWORD in result:
user_id = result[LoginType.PASSWORD]
# if using password, they should also be logged in # if using password, they should also be logged in
requester = yield self.auth.get_user_by_req(request) if requester is None:
user_id = requester.user.to_string() raise SynapseError(
if user_id != result[LoginType.PASSWORD]: 400,
"Deactivate account requires an access_token",
errcode=Codes.MISSING_TOKEN
)
if requester.user.to_string() != user_id:
raise LoginError(400, "", Codes.UNKNOWN) raise LoginError(400, "", Codes.UNKNOWN)
else: else:
logger.error("Auth succeeded but no known type!", result.keys()) logger.error("Auth succeeded but no known type!", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN) raise SynapseError(500, "", Codes.UNKNOWN)
# FIXME: Theoretically there is a race here wherein user resets password yield self.auth_handler.deactivate_account(user_id)
# using threepid.
yield self.store.user_delete_access_tokens(user_id)
yield self.store.user_delete_threepids(user_id)
yield self.store.user_set_password_hash(user_id, None)
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -373,6 +382,20 @@ class ThreepidDeleteRestServlet(RestServlet):
defer.returnValue((200, {})) defer.returnValue((200, {}))
class WhoamiRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/whoami$")
def __init__(self, hs):
super(WhoamiRestServlet, self).__init__()
self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request)
defer.returnValue((200, {'user_id': requester.user.to_string()}))
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
EmailPasswordRequestTokenRestServlet(hs).register(http_server) EmailPasswordRequestTokenRestServlet(hs).register(http_server)
MsisdnPasswordRequestTokenRestServlet(hs).register(http_server) MsisdnPasswordRequestTokenRestServlet(hs).register(http_server)
@ -382,3 +405,4 @@ def register_servlets(hs, http_server):
MsisdnThreepidRequestTokenRestServlet(hs).register(http_server) MsisdnThreepidRequestTokenRestServlet(hs).register(http_server)
ThreepidRestServlet(hs).register(http_server) ThreepidRestServlet(hs).register(http_server)
ThreepidDeleteRestServlet(hs).register(http_server) ThreepidDeleteRestServlet(hs).register(http_server)
WhoamiRestServlet(hs).register(http_server)

View file

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
class DevicesRestServlet(servlet.RestServlet): class DevicesRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False) PATTERNS = client_v2_patterns("/devices$", v2_alpha=False)
def __init__(self, hs): def __init__(self, hs):
""" """
@ -51,7 +51,7 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
API for bulk deletion of devices. Accepts a JSON object with a devices API for bulk deletion of devices. Accepts a JSON object with a devices
key which lists the device_ids to delete. Requires user interactive auth. key which lists the device_ids to delete. Requires user interactive auth.
""" """
PATTERNS = client_v2_patterns("/delete_devices", releases=[], v2_alpha=False) PATTERNS = client_v2_patterns("/delete_devices", v2_alpha=False)
def __init__(self, hs): def __init__(self, hs):
super(DeleteDevicesRestServlet, self).__init__() super(DeleteDevicesRestServlet, self).__init__()
@ -93,8 +93,7 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
class DeviceRestServlet(servlet.RestServlet): class DeviceRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", v2_alpha=False)
releases=[], v2_alpha=False)
def __init__(self, hs): def __init__(self, hs):
""" """
@ -118,6 +117,8 @@ class DeviceRestServlet(servlet.RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_DELETE(self, request, device_id): def on_DELETE(self, request, device_id):
requester = yield self.auth.get_user_by_req(request)
try: try:
body = servlet.parse_json_object_from_request(request) body = servlet.parse_json_object_from_request(request)
@ -136,11 +137,12 @@ class DeviceRestServlet(servlet.RestServlet):
if not authed: if not authed:
defer.returnValue((401, result)) defer.returnValue((401, result))
requester = yield self.auth.get_user_by_req(request) # check that the UI auth matched the access token
yield self.device_handler.delete_device( user_id = result[constants.LoginType.PASSWORD]
requester.user.to_string(), if user_id != requester.user.to_string():
device_id, raise errors.AuthError(403, "Invalid auth")
)
yield self.device_handler.delete_device(user_id, device_id)
defer.returnValue((200, {})) defer.returnValue((200, {}))
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -39,20 +39,23 @@ class GroupServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, group_id): def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
group_description = yield self.groups_handler.get_group_profile(group_id, user_id) group_description = yield self.groups_handler.get_group_profile(
group_id,
requester_user_id,
)
defer.returnValue((200, group_description)) defer.returnValue((200, group_description))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, group_id): def on_POST(self, request, group_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
yield self.groups_handler.update_group_profile( yield self.groups_handler.update_group_profile(
group_id, user_id, content, group_id, requester_user_id, content,
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -72,9 +75,12 @@ class GroupSummaryServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, group_id): def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
get_group_summary = yield self.groups_handler.get_group_summary(group_id, user_id) get_group_summary = yield self.groups_handler.get_group_summary(
group_id,
requester_user_id,
)
defer.returnValue((200, get_group_summary)) defer.returnValue((200, get_group_summary))
@ -101,11 +107,11 @@ class GroupSummaryRoomsCatServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, group_id, category_id, room_id): def on_PUT(self, request, group_id, category_id, room_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
resp = yield self.groups_handler.update_group_summary_room( resp = yield self.groups_handler.update_group_summary_room(
group_id, user_id, group_id, requester_user_id,
room_id=room_id, room_id=room_id,
category_id=category_id, category_id=category_id,
content=content, content=content,
@ -116,10 +122,10 @@ class GroupSummaryRoomsCatServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_DELETE(self, request, group_id, category_id, room_id): def on_DELETE(self, request, group_id, category_id, room_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
resp = yield self.groups_handler.delete_group_summary_room( resp = yield self.groups_handler.delete_group_summary_room(
group_id, user_id, group_id, requester_user_id,
room_id=room_id, room_id=room_id,
category_id=category_id, category_id=category_id,
) )
@ -143,10 +149,10 @@ class GroupCategoryServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, group_id, category_id): def on_GET(self, request, group_id, category_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
category = yield self.groups_handler.get_group_category( category = yield self.groups_handler.get_group_category(
group_id, user_id, group_id, requester_user_id,
category_id=category_id, category_id=category_id,
) )
@ -155,11 +161,11 @@ class GroupCategoryServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, group_id, category_id): def on_PUT(self, request, group_id, category_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
resp = yield self.groups_handler.update_group_category( resp = yield self.groups_handler.update_group_category(
group_id, user_id, group_id, requester_user_id,
category_id=category_id, category_id=category_id,
content=content, content=content,
) )
@ -169,10 +175,10 @@ class GroupCategoryServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_DELETE(self, request, group_id, category_id): def on_DELETE(self, request, group_id, category_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
resp = yield self.groups_handler.delete_group_category( resp = yield self.groups_handler.delete_group_category(
group_id, user_id, group_id, requester_user_id,
category_id=category_id, category_id=category_id,
) )
@ -195,10 +201,10 @@ class GroupCategoriesServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, group_id): def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
category = yield self.groups_handler.get_group_categories( category = yield self.groups_handler.get_group_categories(
group_id, user_id, group_id, requester_user_id,
) )
defer.returnValue((200, category)) defer.returnValue((200, category))
@ -220,10 +226,10 @@ class GroupRoleServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, group_id, role_id): def on_GET(self, request, group_id, role_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
category = yield self.groups_handler.get_group_role( category = yield self.groups_handler.get_group_role(
group_id, user_id, group_id, requester_user_id,
role_id=role_id, role_id=role_id,
) )
@ -232,11 +238,11 @@ class GroupRoleServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, group_id, role_id): def on_PUT(self, request, group_id, role_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
resp = yield self.groups_handler.update_group_role( resp = yield self.groups_handler.update_group_role(
group_id, user_id, group_id, requester_user_id,
role_id=role_id, role_id=role_id,
content=content, content=content,
) )
@ -246,10 +252,10 @@ class GroupRoleServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_DELETE(self, request, group_id, role_id): def on_DELETE(self, request, group_id, role_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
resp = yield self.groups_handler.delete_group_role( resp = yield self.groups_handler.delete_group_role(
group_id, user_id, group_id, requester_user_id,
role_id=role_id, role_id=role_id,
) )
@ -272,10 +278,10 @@ class GroupRolesServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, group_id): def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
category = yield self.groups_handler.get_group_roles( category = yield self.groups_handler.get_group_roles(
group_id, user_id, group_id, requester_user_id,
) )
defer.returnValue((200, category)) defer.returnValue((200, category))
@ -343,9 +349,9 @@ class GroupRoomServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, group_id): def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
result = yield self.groups_handler.get_rooms_in_group(group_id, user_id) result = yield self.groups_handler.get_rooms_in_group(group_id, requester_user_id)
defer.returnValue((200, result)) defer.returnValue((200, result))
@ -364,9 +370,9 @@ class GroupUsersServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, group_id): def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
result = yield self.groups_handler.get_users_in_group(group_id, user_id) result = yield self.groups_handler.get_users_in_group(group_id, requester_user_id)
defer.returnValue((200, result)) defer.returnValue((200, result))
@ -385,9 +391,12 @@ class GroupInvitedUsersServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, group_id): def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
result = yield self.groups_handler.get_invited_users_in_group(group_id, user_id) result = yield self.groups_handler.get_invited_users_in_group(
group_id,
requester_user_id,
)
defer.returnValue((200, result)) defer.returnValue((200, result))
@ -407,14 +416,18 @@ class GroupCreateServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
# TODO: Create group on remote server # TODO: Create group on remote server
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
localpart = content.pop("localpart") localpart = content.pop("localpart")
group_id = GroupID(localpart, self.server_name).to_string() group_id = GroupID(localpart, self.server_name).to_string()
result = yield self.groups_handler.create_group(group_id, user_id, content) result = yield self.groups_handler.create_group(
group_id,
requester_user_id,
content,
)
defer.returnValue((200, result)) defer.returnValue((200, result))
@ -435,11 +448,11 @@ class GroupAdminRoomsServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, group_id, room_id): def on_PUT(self, request, group_id, room_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
result = yield self.groups_handler.add_room_to_group( result = yield self.groups_handler.add_room_to_group(
group_id, user_id, room_id, content, group_id, requester_user_id, room_id, content,
) )
defer.returnValue((200, result)) defer.returnValue((200, result))
@ -447,10 +460,37 @@ class GroupAdminRoomsServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_DELETE(self, request, group_id, room_id): def on_DELETE(self, request, group_id, room_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
result = yield self.groups_handler.remove_room_from_group( result = yield self.groups_handler.remove_room_from_group(
group_id, user_id, room_id, group_id, requester_user_id, room_id,
)
defer.returnValue((200, result))
class GroupAdminRoomsConfigServlet(RestServlet):
"""Update the config of a room in a group
"""
PATTERNS = client_v2_patterns(
"/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)"
"/config/(?P<config_key>[^/]*)$"
)
def __init__(self, hs):
super(GroupAdminRoomsConfigServlet, self).__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks
def on_PUT(self, request, group_id, room_id, config_key):
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
result = yield self.groups_handler.update_room_in_group(
group_id, requester_user_id, room_id, config_key, content,
) )
defer.returnValue((200, result)) defer.returnValue((200, result))
@ -685,9 +725,9 @@ class GroupsForUserServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
result = yield self.groups_handler.get_joined_groups(user_id) result = yield self.groups_handler.get_joined_groups(requester_user_id)
defer.returnValue((200, result)) defer.returnValue((200, result))
@ -700,6 +740,7 @@ def register_servlets(hs, http_server):
GroupRoomServlet(hs).register(http_server) GroupRoomServlet(hs).register(http_server)
GroupCreateServlet(hs).register(http_server) GroupCreateServlet(hs).register(http_server)
GroupAdminRoomsServlet(hs).register(http_server) GroupAdminRoomsServlet(hs).register(http_server)
GroupAdminRoomsConfigServlet(hs).register(http_server)
GroupAdminUsersInviteServlet(hs).register(http_server) GroupAdminUsersInviteServlet(hs).register(http_server)
GroupAdminUsersKickServlet(hs).register(http_server) GroupAdminUsersKickServlet(hs).register(http_server)
GroupSelfLeaveServlet(hs).register(http_server) GroupSelfLeaveServlet(hs).register(http_server)

View file

@ -53,8 +53,7 @@ class KeyUploadServlet(RestServlet):
}, },
} }
""" """
PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$", PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
releases=())
def __init__(self, hs): def __init__(self, hs):
""" """
@ -128,10 +127,7 @@ class KeyQueryServlet(RestServlet):
} } } } } } } } } } } }
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_v2_patterns("/keys/query$")
"/keys/query$",
releases=()
)
def __init__(self, hs): def __init__(self, hs):
""" """
@ -160,10 +156,7 @@ class KeyChangesServlet(RestServlet):
200 OK 200 OK
{ "changed": ["@foo:example.com"] } { "changed": ["@foo:example.com"] }
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_v2_patterns("/keys/changes$")
"/keys/changes$",
releases=()
)
def __init__(self, hs): def __init__(self, hs):
""" """
@ -213,10 +206,7 @@ class OneTimeKeyServlet(RestServlet):
} } } } } } } }
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_v2_patterns("/keys/claim$")
"/keys/claim$",
releases=()
)
def __init__(self, hs): def __init__(self, hs):
super(OneTimeKeyServlet, self).__init__() super(OneTimeKeyServlet, self).__init__()

View file

@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
class NotificationsServlet(RestServlet): class NotificationsServlet(RestServlet):
PATTERNS = client_v2_patterns("/notifications$", releases=()) PATTERNS = client_v2_patterns("/notifications$")
def __init__(self, hs): def __init__(self, hs):
super(NotificationsServlet, self).__init__() super(NotificationsServlet, self).__init__()

View file

@ -224,6 +224,12 @@ class RegisterRestServlet(RestServlet):
# 'user' key not 'username'). Since this is a new addition, we'll # 'user' key not 'username'). Since this is a new addition, we'll
# fallback to 'username' if they gave one. # fallback to 'username' if they gave one.
desired_username = body.get("user", desired_username) desired_username = body.get("user", desired_username)
# XXX we should check that desired_username is valid. Currently
# we give appservices carte blanche for any insanity in mxids,
# because the IRC bridges rely on being able to register stupid
# IDs.
access_token = get_access_token_from_request(request) access_token = get_access_token_from_request(request)
if isinstance(desired_username, basestring): if isinstance(desired_username, basestring):
@ -233,6 +239,15 @@ class RegisterRestServlet(RestServlet):
defer.returnValue((200, result)) # we throw for non 200 responses defer.returnValue((200, result)) # we throw for non 200 responses
return return
# for either shared secret or regular registration, downcase the
# provided username before attempting to register it. This should mean
# that people who try to register with upper-case in their usernames
# don't get a nasty surprise. (Note that we treat username
# case-insenstively in login, so they are free to carry on imagining
# that their username is CrAzYh4cKeR if that keeps them happy)
if desired_username is not None:
desired_username = desired_username.lower()
# == Shared Secret Registration == (e.g. create new user scripts) # == Shared Secret Registration == (e.g. create new user scripts)
if 'mac' in body: if 'mac' in body:
# FIXME: Should we really be determining if this is shared secret # FIXME: Should we really be determining if this is shared secret
@ -336,6 +351,9 @@ class RegisterRestServlet(RestServlet):
new_password = params.get("password", None) new_password = params.get("password", None)
guest_access_token = params.get("guest_access_token", None) guest_access_token = params.get("guest_access_token", None)
if desired_username is not None:
desired_username = desired_username.lower()
(registered_user_id, _) = yield self.registration_handler.register( (registered_user_id, _) = yield self.registration_handler.register(
localpart=desired_username, localpart=desired_username,
password=new_password, password=new_password,
@ -417,13 +435,22 @@ class RegisterRestServlet(RestServlet):
def _do_shared_secret_registration(self, username, password, body): def _do_shared_secret_registration(self, username, password, body):
if not self.hs.config.registration_shared_secret: if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled") raise SynapseError(400, "Shared secret registration is not enabled")
if not username:
raise SynapseError(
400, "username must be specified", errcode=Codes.BAD_JSON,
)
user = username.encode("utf-8") # use the username from the original request rather than the
# downcased one in `username` for the mac calculation
user = body["username"].encode("utf-8")
# str() because otherwise hmac complains that 'unicode' does not # str() because otherwise hmac complains that 'unicode' does not
# have the buffer interface # have the buffer interface
got_mac = str(body["mac"]) got_mac = str(body["mac"])
# FIXME this is different to the /v1/register endpoint, which
# includes the password and admin flag in the hashed text. Why are
# these different?
want_mac = hmac.new( want_mac = hmac.new(
key=self.hs.config.registration_shared_secret, key=self.hs.config.registration_shared_secret,
msg=user, msg=user,
@ -557,25 +584,28 @@ class RegisterRestServlet(RestServlet):
Args: Args:
(str) user_id: full canonical @user:id (str) user_id: full canonical @user:id
(object) params: registration parameters, from which we pull (object) params: registration parameters, from which we pull
device_id and initial_device_name device_id, initial_device_name and inhibit_login
Returns: Returns:
defer.Deferred: (object) dictionary for response from /register defer.Deferred: (object) dictionary for response from /register
""" """
result = {
"user_id": user_id,
"home_server": self.hs.hostname,
}
if not params.get("inhibit_login", False):
device_id = yield self._register_device(user_id, params) device_id = yield self._register_device(user_id, params)
access_token = ( access_token = (
yield self.auth_handler.get_access_token_for_user_id( yield self.auth_handler.get_access_token_for_user_id(
user_id, device_id=device_id, user_id, device_id=device_id,
initial_display_name=params.get("initial_device_display_name")
) )
) )
defer.returnValue({ result.update({
"user_id": user_id,
"access_token": access_token, "access_token": access_token,
"home_server": self.hs.hostname,
"device_id": device_id, "device_id": device_id,
}) })
defer.returnValue(result)
def _register_device(self, user_id, params): def _register_device(self, user_id, params):
"""Register a device for a user. """Register a device for a user.

View file

@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class SendToDeviceRestServlet(servlet.RestServlet): class SendToDeviceRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns( PATTERNS = client_v2_patterns(
"/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$", "/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$",
releases=[], v2_alpha=False v2_alpha=False
) )
def __init__(self, hs): def __init__(self, hs):

View file

@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
class ThirdPartyProtocolsServlet(RestServlet): class ThirdPartyProtocolsServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/protocols", releases=()) PATTERNS = client_v2_patterns("/thirdparty/protocols")
def __init__(self, hs): def __init__(self, hs):
super(ThirdPartyProtocolsServlet, self).__init__() super(ThirdPartyProtocolsServlet, self).__init__()
@ -43,8 +43,7 @@ class ThirdPartyProtocolsServlet(RestServlet):
class ThirdPartyProtocolServlet(RestServlet): class ThirdPartyProtocolServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$", PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$")
releases=())
def __init__(self, hs): def __init__(self, hs):
super(ThirdPartyProtocolServlet, self).__init__() super(ThirdPartyProtocolServlet, self).__init__()
@ -66,8 +65,7 @@ class ThirdPartyProtocolServlet(RestServlet):
class ThirdPartyUserServlet(RestServlet): class ThirdPartyUserServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$", PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$")
releases=())
def __init__(self, hs): def __init__(self, hs):
super(ThirdPartyUserServlet, self).__init__() super(ThirdPartyUserServlet, self).__init__()
@ -90,8 +88,7 @@ class ThirdPartyUserServlet(RestServlet):
class ThirdPartyLocationServlet(RestServlet): class ThirdPartyLocationServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$", PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$")
releases=())
def __init__(self, hs): def __init__(self, hs):
super(ThirdPartyLocationServlet, self).__init__() super(ThirdPartyLocationServlet, self).__init__()

View file

@ -20,6 +20,7 @@ from twisted.web.resource import Resource
from synapse.api.errors import ( from synapse.api.errors import (
SynapseError, Codes, SynapseError, Codes,
) )
from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.http.client import SpiderHttpClient from synapse.http.client import SpiderHttpClient
@ -63,16 +64,15 @@ class PreviewUrlResource(Resource):
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
# simple memory cache mapping urls to OG metadata # memory cache mapping urls to an ObservableDeferred returning
self.cache = ExpiringCache( # JSON-encoded OG metadata
self._cache = ExpiringCache(
cache_name="url_previews", cache_name="url_previews",
clock=self.clock, clock=self.clock,
# don't spider URLs more often than once an hour # don't spider URLs more often than once an hour
expiry_ms=60 * 60 * 1000, expiry_ms=60 * 60 * 1000,
) )
self.cache.start() self._cache.start()
self.downloads = {}
self._cleaner_loop = self.clock.looping_call( self._cleaner_loop = self.clock.looping_call(
self._expire_url_cache_data, 10 * 1000 self._expire_url_cache_data, 10 * 1000
@ -94,6 +94,7 @@ class PreviewUrlResource(Resource):
else: else:
ts = self.clock.time_msec() ts = self.clock.time_msec()
# XXX: we could move this into _do_preview if we wanted.
url_tuple = urlparse.urlsplit(url) url_tuple = urlparse.urlsplit(url)
for entry in self.url_preview_url_blacklist: for entry in self.url_preview_url_blacklist:
match = True match = True
@ -126,14 +127,42 @@ class PreviewUrlResource(Resource):
Codes.UNKNOWN Codes.UNKNOWN
) )
# first check the memory cache - good to handle all the clients on this # the in-memory cache:
# HS thundering away to preview the same URL at the same time. # * ensures that only one request is active at a time
og = self.cache.get(url) # * takes load off the DB for the thundering herds
if og: # * also caches any failures (unlike the DB) so we don't keep
respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True) # requesting the same endpoint
return
# then check the URL cache in the DB (which will also provide us with observable = self._cache.get(url)
if not observable:
download = preserve_fn(self._do_preview)(
url, requester.user, ts,
)
observable = ObservableDeferred(
download,
consumeErrors=True
)
self._cache[url] = observable
else:
logger.info("Returning cached response")
og = yield make_deferred_yieldable(observable.observe())
respond_with_json_bytes(request, 200, og, send_cors=True)
@defer.inlineCallbacks
def _do_preview(self, url, user, ts):
"""Check the db, and download the URL and build a preview
Args:
url (str):
user (str):
ts (int):
Returns:
Deferred[str]: json-encoded og data
"""
# check the URL cache in the DB (which will also provide us with
# historical previews, if we have any) # historical previews, if we have any)
cache_result = yield self.store.get_url_cache(url, ts) cache_result = yield self.store.get_url_cache(url, ts)
if ( if (
@ -141,32 +170,10 @@ class PreviewUrlResource(Resource):
cache_result["expires_ts"] > ts and cache_result["expires_ts"] > ts and
cache_result["response_code"] / 100 == 2 cache_result["response_code"] / 100 == 2
): ):
respond_with_json_bytes( defer.returnValue(cache_result["og"])
request, 200, cache_result["og"].encode('utf-8'),
send_cors=True
)
return return
# Ensure only one download for a given URL is active at a time media_info = yield self._download_url(url, user)
download = self.downloads.get(url)
if download is None:
download = self._download_url(url, requester.user)
download = ObservableDeferred(
download,
consumeErrors=True
)
self.downloads[url] = download
@download.addBoth
def callback(media_info):
del self.downloads[url]
return media_info
media_info = yield download.observe()
# FIXME: we should probably update our cache now anyway, so that
# even if the OG calculation raises, we don't keep hammering on the
# remote server. For now, leave it uncached to aid debugging OG
# calculation problems
logger.debug("got media_info of '%s'" % media_info) logger.debug("got media_info of '%s'" % media_info)
@ -212,7 +219,7 @@ class PreviewUrlResource(Resource):
# just rely on the caching on the master request to speed things up. # just rely on the caching on the master request to speed things up.
if 'og:image' in og and og['og:image']: if 'og:image' in og and og['og:image']:
image_info = yield self._download_url( image_info = yield self._download_url(
_rebase_url(og['og:image'], media_info['uri']), requester.user _rebase_url(og['og:image'], media_info['uri']), user
) )
if _is_media(image_info['media_type']): if _is_media(image_info['media_type']):
@ -239,8 +246,7 @@ class PreviewUrlResource(Resource):
logger.debug("Calculated OG for %s as %s" % (url, og)) logger.debug("Calculated OG for %s as %s" % (url, og))
# store OG in ephemeral in-memory cache jsonog = json.dumps(og)
self.cache[url] = og
# store OG in history-aware DB cache # store OG in history-aware DB cache
yield self.store.store_url_cache( yield self.store.store_url_cache(
@ -248,12 +254,12 @@ class PreviewUrlResource(Resource):
media_info["response_code"], media_info["response_code"],
media_info["etag"], media_info["etag"],
media_info["expires"] + media_info["created_ts"], media_info["expires"] + media_info["created_ts"],
json.dumps(og), jsonog,
media_info["filesystem_id"], media_info["filesystem_id"],
media_info["created_ts"], media_info["created_ts"],
) )
respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True) defer.returnValue(jsonog)
@defer.inlineCallbacks @defer.inlineCallbacks
def _download_url(self, url, user): def _download_url(self, url, user):
@ -520,7 +526,14 @@ def _calc_og(tree, media_uri):
from lxml import etree from lxml import etree
TAGS_TO_REMOVE = ( TAGS_TO_REMOVE = (
"header", "nav", "aside", "footer", "script", "style", etree.Comment "header",
"nav",
"aside",
"footer",
"script",
"noscript",
"style",
etree.Comment
) )
# Split all the text nodes into paragraphs (by splitting on new # Split all the text nodes into paragraphs (by splitting on new

View file

@ -268,7 +268,7 @@ class DataStore(RoomMemberStore, RoomStore,
self._stream_order_on_start = self.get_room_max_stream_ordering() self._stream_order_on_start = self.get_room_max_stream_ordering()
self._min_stream_order_on_start = self.get_room_min_stream_ordering() self._min_stream_order_on_start = self.get_room_min_stream_ordering()
super(DataStore, self).__init__(hs) super(DataStore, self).__init__(db_conn, hs)
def take_presence_startup_info(self): def take_presence_startup_info(self):
active_on_startup = self._presence_on_startup active_on_startup = self._presence_on_startup

View file

@ -162,7 +162,7 @@ class PerformanceCounters(object):
class SQLBaseStore(object): class SQLBaseStore(object):
_TXN_ID = 0 _TXN_ID = 0
def __init__(self, hs): def __init__(self, db_conn, hs):
self.hs = hs self.hs = hs
self._clock = hs.get_clock() self._clock = hs.get_clock()
self._db_pool = hs.get_db_pool() self._db_pool = hs.get_db_pool()

View file

@ -63,7 +63,7 @@ class AccountDataStore(SQLBaseStore):
"get_account_data_for_user", get_account_data_for_user_txn "get_account_data_for_user", get_account_data_for_user_txn
) )
@cachedInlineCallbacks(num_args=2) @cachedInlineCallbacks(num_args=2, max_entries=5000)
def get_global_account_data_by_type_for_user(self, data_type, user_id): def get_global_account_data_by_type_for_user(self, data_type, user_id):
""" """
Returns: Returns:

View file

@ -48,8 +48,8 @@ def _make_exclusive_regex(services_cache):
class ApplicationServiceStore(SQLBaseStore): class ApplicationServiceStore(SQLBaseStore):
def __init__(self, hs): def __init__(self, db_conn, hs):
super(ApplicationServiceStore, self).__init__(hs) super(ApplicationServiceStore, self).__init__(db_conn, hs)
self.hostname = hs.hostname self.hostname = hs.hostname
self.services_cache = load_appservices( self.services_cache = load_appservices(
hs.hostname, hs.hostname,
@ -173,8 +173,8 @@ class ApplicationServiceStore(SQLBaseStore):
class ApplicationServiceTransactionStore(SQLBaseStore): class ApplicationServiceTransactionStore(SQLBaseStore):
def __init__(self, hs): def __init__(self, db_conn, hs):
super(ApplicationServiceTransactionStore, self).__init__(hs) super(ApplicationServiceTransactionStore, self).__init__(db_conn, hs)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_appservices_by_state(self, state): def get_appservices_by_state(self, state):

View file

@ -80,8 +80,8 @@ class BackgroundUpdateStore(SQLBaseStore):
BACKGROUND_UPDATE_INTERVAL_MS = 1000 BACKGROUND_UPDATE_INTERVAL_MS = 1000
BACKGROUND_UPDATE_DURATION_MS = 100 BACKGROUND_UPDATE_DURATION_MS = 100
def __init__(self, hs): def __init__(self, db_conn, hs):
super(BackgroundUpdateStore, self).__init__(hs) super(BackgroundUpdateStore, self).__init__(db_conn, hs)
self._background_update_performance = {} self._background_update_performance = {}
self._background_update_queue = [] self._background_update_queue = []
self._background_update_handlers = {} self._background_update_handlers = {}

View file

@ -32,14 +32,14 @@ LAST_SEEN_GRANULARITY = 120 * 1000
class ClientIpStore(background_updates.BackgroundUpdateStore): class ClientIpStore(background_updates.BackgroundUpdateStore):
def __init__(self, hs): def __init__(self, db_conn, hs):
self.client_ip_last_seen = Cache( self.client_ip_last_seen = Cache(
name="client_ip_last_seen", name="client_ip_last_seen",
keylen=4, keylen=4,
max_entries=50000 * CACHE_SIZE_FACTOR, max_entries=50000 * CACHE_SIZE_FACTOR,
) )
super(ClientIpStore, self).__init__(hs) super(ClientIpStore, self).__init__(db_conn, hs)
self.register_background_index_update( self.register_background_index_update(
"user_ips_device_index", "user_ips_device_index",

View file

@ -29,8 +29,8 @@ logger = logging.getLogger(__name__)
class DeviceInboxStore(BackgroundUpdateStore): class DeviceInboxStore(BackgroundUpdateStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, hs): def __init__(self, db_conn, hs):
super(DeviceInboxStore, self).__init__(hs) super(DeviceInboxStore, self).__init__(db_conn, hs)
self.register_background_index_update( self.register_background_index_update(
"device_inbox_stream_index", "device_inbox_stream_index",

View file

@ -26,8 +26,8 @@ logger = logging.getLogger(__name__)
class DeviceStore(SQLBaseStore): class DeviceStore(SQLBaseStore):
def __init__(self, hs): def __init__(self, db_conn, hs):
super(DeviceStore, self).__init__(hs) super(DeviceStore, self).__init__(db_conn, hs)
# Map of (user_id, device_id) -> bool. If there is an entry that implies # Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists. # the device exists.

View file

@ -39,8 +39,8 @@ class EventFederationStore(SQLBaseStore):
EVENT_AUTH_STATE_ONLY = "event_auth_state_only" EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
def __init__(self, hs): def __init__(self, db_conn, hs):
super(EventFederationStore, self).__init__(hs) super(EventFederationStore, self).__init__(db_conn, hs)
self.register_background_update_handler( self.register_background_update_handler(
self.EVENT_AUTH_STATE_ONLY, self.EVENT_AUTH_STATE_ONLY,

View file

@ -65,8 +65,8 @@ def _deserialize_action(actions, is_highlight):
class EventPushActionsStore(SQLBaseStore): class EventPushActionsStore(SQLBaseStore):
EPA_HIGHLIGHT_INDEX = "epa_highlight_index" EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
def __init__(self, hs): def __init__(self, db_conn, hs):
super(EventPushActionsStore, self).__init__(hs) super(EventPushActionsStore, self).__init__(db_conn, hs)
self.register_background_index_update( self.register_background_index_update(
self.EPA_HIGHLIGHT_INDEX, self.EPA_HIGHLIGHT_INDEX,

View file

@ -197,8 +197,8 @@ class EventsStore(SQLBaseStore):
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
def __init__(self, hs): def __init__(self, db_conn, hs):
super(EventsStore, self).__init__(hs) super(EventsStore, self).__init__(db_conn, hs)
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.register_background_update_handler( self.register_background_update_handler(
self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts

View file

@ -35,7 +35,9 @@ class GroupServerStore(SQLBaseStore):
keyvalues={ keyvalues={
"group_id": group_id, "group_id": group_id,
}, },
retcols=("name", "short_description", "long_description", "avatar_url",), retcols=(
"name", "short_description", "long_description", "avatar_url", "is_public"
),
allow_none=True, allow_none=True,
desc="is_user_in_group", desc="is_user_in_group",
) )
@ -52,7 +54,7 @@ class GroupServerStore(SQLBaseStore):
return self._simple_select_list( return self._simple_select_list(
table="group_users", table="group_users",
keyvalues=keyvalues, keyvalues=keyvalues,
retcols=("user_id", "is_public",), retcols=("user_id", "is_public", "is_admin",),
desc="get_users_in_group", desc="get_users_in_group",
) )
@ -855,6 +857,19 @@ class GroupServerStore(SQLBaseStore):
desc="add_room_to_group", desc="add_room_to_group",
) )
def update_room_in_group_visibility(self, group_id, room_id, is_public):
return self._simple_update(
table="group_rooms",
keyvalues={
"group_id": group_id,
"room_id": room_id,
},
updatevalues={
"is_public": is_public,
},
desc="update_room_in_group_visibility",
)
def remove_room_from_group(self, group_id, room_id): def remove_room_from_group(self, group_id, room_id):
def _remove_room_from_group_txn(txn): def _remove_room_from_group_txn(txn):
self._simple_delete_txn( self._simple_delete_txn(
@ -1026,6 +1041,7 @@ class GroupServerStore(SQLBaseStore):
"avatar_url": avatar_url, "avatar_url": avatar_url,
"short_description": short_description, "short_description": short_description,
"long_description": long_description, "long_description": long_description,
"is_public": True,
}, },
desc="create_group", desc="create_group",
) )
@ -1086,6 +1102,24 @@ class GroupServerStore(SQLBaseStore):
desc="update_remote_attestion", desc="update_remote_attestion",
) )
def remove_attestation_renewal(self, group_id, user_id):
"""Remove an attestation that we thought we should renew, but actually
shouldn't. Ideally this would never get called as we would never
incorrectly try and do attestations for local users on local groups.
Args:
group_id (str)
user_id (str)
"""
return self._simple_delete(
table="group_attestations_renewals",
keyvalues={
"group_id": group_id,
"user_id": user_id,
},
desc="remove_attestation_renewal",
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_remote_attestation(self, group_id, user_id): def get_remote_attestation(self, group_id, user_id):
"""Get the attestation that proves the remote agrees that the user is """Get the attestation that proves the remote agrees that the user is

View file

@ -254,6 +254,9 @@ class MediaRepositoryStore(SQLBaseStore):
return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn) return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn)
def delete_url_cache(self, media_ids): def delete_url_cache(self, media_ids):
if len(media_ids) == 0:
return
sql = ( sql = (
"DELETE FROM local_media_repository_url_cache" "DELETE FROM local_media_repository_url_cache"
" WHERE media_id = ?" " WHERE media_id = ?"
@ -281,6 +284,9 @@ class MediaRepositoryStore(SQLBaseStore):
) )
def delete_url_cache_media(self, media_ids): def delete_url_cache_media(self, media_ids):
if len(media_ids) == 0:
return
def _delete_url_cache_media_txn(txn): def _delete_url_cache_media_txn(txn):
sql = ( sql = (
"DELETE FROM local_media_repository" "DELETE FROM local_media_repository"

View file

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 45 SCHEMA_VERSION = 46
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))
@ -44,6 +44,13 @@ def prepare_database(db_conn, database_engine, config):
If `config` is None then prepare_database will assert that no upgrade is If `config` is None then prepare_database will assert that no upgrade is
necessary, *or* will create a fresh database if the database is empty. necessary, *or* will create a fresh database if the database is empty.
Args:
db_conn:
database_engine:
config (synapse.config.homeserver.HomeServerConfig|None):
application config, or None if we are connecting to an existing
database which we expect to be configured already
""" """
try: try:
cur = db_conn.cursor() cur = db_conn.cursor()
@ -64,6 +71,10 @@ def prepare_database(db_conn, database_engine, config):
else: else:
_setup_new_database(cur, database_engine) _setup_new_database(cur, database_engine)
# check if any of our configured dynamic modules want a database
if config is not None:
_apply_module_schemas(cur, database_engine, config)
cur.close() cur.close()
db_conn.commit() db_conn.commit()
except Exception: except Exception:
@ -283,6 +294,65 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
) )
def _apply_module_schemas(txn, database_engine, config):
"""Apply the module schemas for the dynamic modules, if any
Args:
cur: database cursor
database_engine: synapse database engine class
config (synapse.config.homeserver.HomeServerConfig):
application config
"""
for (mod, _config) in config.password_providers:
if not hasattr(mod, 'get_db_schema_files'):
continue
modname = ".".join((mod.__module__, mod.__name__))
_apply_module_schema_files(
txn, database_engine, modname, mod.get_db_schema_files(),
)
def _apply_module_schema_files(cur, database_engine, modname, names_and_streams):
"""Apply the module schemas for a single module
Args:
cur: database cursor
database_engine: synapse database engine class
modname (str): fully qualified name of the module
names_and_streams (Iterable[(str, file)]): the names and streams of
schemas to be applied
"""
cur.execute(
database_engine.convert_param_style(
"SELECT file FROM applied_module_schemas WHERE module_name = ?"
),
(modname,)
)
applied_deltas = set(d for d, in cur)
for (name, stream) in names_and_streams:
if name in applied_deltas:
continue
root_name, ext = os.path.splitext(name)
if ext != '.sql':
raise PrepareDatabaseException(
"only .sql files are currently supported for module schemas",
)
logger.info("applying schema %s for %s", name, modname)
for statement in get_statements(stream):
cur.execute(statement)
# Mark as done.
cur.execute(
database_engine.convert_param_style(
"INSERT INTO applied_module_schemas (module_name, file)"
" VALUES (?,?)",
),
(modname, name)
)
def get_statements(f): def get_statements(f):
statement_buffer = "" statement_buffer = ""
in_comment = False # If we're in a /* ... */ style comment in_comment = False # If we're in a /* ... */ style comment

View file

@ -27,8 +27,8 @@ logger = logging.getLogger(__name__)
class ReceiptsStore(SQLBaseStore): class ReceiptsStore(SQLBaseStore):
def __init__(self, hs): def __init__(self, db_conn, hs):
super(ReceiptsStore, self).__init__(hs) super(ReceiptsStore, self).__init__(db_conn, hs)
self._receipts_stream_cache = StreamChangeCache( self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token() "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()

View file

@ -24,8 +24,8 @@ from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
class RegistrationStore(background_updates.BackgroundUpdateStore): class RegistrationStore(background_updates.BackgroundUpdateStore):
def __init__(self, hs): def __init__(self, db_conn, hs):
super(RegistrationStore, self).__init__(hs) super(RegistrationStore, self).__init__(db_conn, hs)
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -36,12 +36,15 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
columns=["user_id", "device_id"], columns=["user_id", "device_id"],
) )
self.register_background_index_update( # we no longer use refresh tokens, but it's possible that some people
"refresh_tokens_device_index", # might have a background update queued to build this index. Just
index_name="refresh_tokens_device_id", # clear the background update.
table="refresh_tokens", @defer.inlineCallbacks
columns=["user_id", "device_id"], def noop_update(progress, batch_size):
) yield self._end_background_update("refresh_tokens_device_index")
defer.returnValue(1)
self.register_background_update_handler(
"refresh_tokens_device_index", noop_update)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_access_token_to_user(self, user_id, token, device_id=None): def add_access_token_to_user(self, user_id, token, device_id=None):
@ -177,9 +180,11 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
) )
if create_profile_with_localpart: if create_profile_with_localpart:
# set a default displayname serverside to avoid ugly race
# between auto-joins and clients trying to set displaynames
txn.execute( txn.execute(
"INSERT INTO profiles(user_id) VALUES (?)", "INSERT INTO profiles(user_id, displayname) VALUES (?,?)",
(create_profile_with_localpart,) (create_profile_with_localpart, create_profile_with_localpart)
) )
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream(
@ -236,12 +241,10 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
"user_set_password_hash", user_set_password_hash_txn "user_set_password_hash", user_set_password_hash_txn
) )
@defer.inlineCallbacks
def user_delete_access_tokens(self, user_id, except_token_id=None, def user_delete_access_tokens(self, user_id, except_token_id=None,
device_id=None, device_id=None):
delete_refresh_tokens=False):
""" """
Invalidate access/refresh tokens belonging to a user Invalidate access tokens belonging to a user
Args: Args:
user_id (str): ID of user the tokens belong to user_id (str): ID of user the tokens belong to
@ -250,10 +253,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
device_id (str|None): ID of device the tokens are associated with. device_id (str|None): ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will If None, tokens associated with any device (or no device) will
be deleted be deleted
delete_refresh_tokens (bool): True to delete refresh tokens as
well as access tokens.
Returns: Returns:
defer.Deferred: defer.Deferred[list[str, str|None]]: a list of the deleted tokens
and device IDs
""" """
def f(txn): def f(txn):
keyvalues = { keyvalues = {
@ -262,13 +264,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
if device_id is not None: if device_id is not None:
keyvalues["device_id"] = device_id keyvalues["device_id"] = device_id
if delete_refresh_tokens:
self._simple_delete_txn(
txn,
table="refresh_tokens",
keyvalues=keyvalues,
)
items = keyvalues.items() items = keyvalues.items()
where_clause = " AND ".join(k + " = ?" for k, _ in items) where_clause = " AND ".join(k + " = ?" for k, _ in items)
values = [v for _, v in items] values = [v for _, v in items]
@ -277,14 +272,14 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
values.append(except_token_id) values.append(except_token_id)
txn.execute( txn.execute(
"SELECT token FROM access_tokens WHERE %s" % where_clause, "SELECT token, device_id FROM access_tokens WHERE %s" % where_clause,
values values
) )
rows = self.cursor_to_dict(txn) tokens_and_devices = [(r[0], r[1]) for r in txn]
for row in rows: for token, _ in tokens_and_devices:
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream(
txn, self.get_user_by_access_token, (row["token"],) txn, self.get_user_by_access_token, (token,)
) )
txn.execute( txn.execute(
@ -292,7 +287,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
values values
) )
yield self.runInteraction( return tokens_and_devices
return self.runInteraction(
"user_delete_access_tokens", f, "user_delete_access_tokens", f,
) )

View file

@ -49,8 +49,8 @@ _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
class RoomMemberStore(SQLBaseStore): class RoomMemberStore(SQLBaseStore):
def __init__(self, hs): def __init__(self, db_conn, hs):
super(RoomMemberStore, self).__init__(hs) super(RoomMemberStore, self).__init__(db_conn, hs)
self.register_background_update_handler( self.register_background_update_handler(
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
) )

View file

@ -1,17 +0,0 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
INSERT INTO background_updates (update_name, progress_json) VALUES
('refresh_tokens_device_index', '{}');

View file

@ -29,5 +29,5 @@ CREATE INDEX users_who_share_rooms_r_idx ON users_who_share_rooms(room_id);
CREATE INDEX users_who_share_rooms_o_idx ON users_who_share_rooms(other_user_id); CREATE INDEX users_who_share_rooms_o_idx ON users_who_share_rooms(other_user_id);
-- Make sure that we popualte the table initially -- Make sure that we populate the table initially
UPDATE user_directory_stream_pos SET stream_id = NULL; UPDATE user_directory_stream_pos SET stream_id = NULL;

View file

@ -1,4 +1,4 @@
/* Copyright 2016 OpenMarket Ltd /* Copyright 2017 New Vector Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -13,4 +13,5 @@
* limitations under the License. * limitations under the License.
*/ */
ALTER TABLE refresh_tokens ADD COLUMN device_id TEXT; /* we no longer use (or create) the refresh_tokens table */
DROP TABLE IF EXISTS refresh_tokens;

View file

@ -0,0 +1,32 @@
/* Copyright 2017 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE groups_new (
group_id TEXT NOT NULL,
name TEXT, -- the display name of the room
avatar_url TEXT,
short_description TEXT,
long_description TEXT,
is_public BOOL NOT NULL -- whether non-members can access group APIs
);
-- NB: awful hack to get the default to be true on postgres and 1 on sqlite
INSERT INTO groups_new
SELECT group_id, name, avatar_url, short_description, long_description, (1=1) FROM groups;
DROP TABLE groups;
ALTER TABLE groups_new RENAME TO groups;
CREATE UNIQUE INDEX groups_idx ON groups(group_id);

View file

@ -1,4 +1,4 @@
/* Copyright 2015, 2016 OpenMarket Ltd /* Copyright 2017 New Vector Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -13,9 +13,12 @@
* limitations under the License. * limitations under the License.
*/ */
CREATE TABLE IF NOT EXISTS refresh_tokens( -- this is just embarassing :|
id INTEGER PRIMARY KEY, ALTER TABLE users_in_pubic_room RENAME TO users_in_public_rooms;
token TEXT NOT NULL,
user_id TEXT NOT NULL, -- this is only 300K rows on matrix.org and takes ~3s to generate the index,
UNIQUE (token) -- so is hopefully not going to block anyone else for that long...
); CREATE INDEX users_in_public_rooms_room_idx ON users_in_public_rooms(room_id);
CREATE UNIQUE INDEX users_in_public_rooms_user_idx ON users_in_public_rooms(user_id);
DROP INDEX users_in_pubic_room_room_idx;
DROP INDEX users_in_pubic_room_user_idx;

View file

@ -25,3 +25,10 @@ CREATE TABLE IF NOT EXISTS applied_schema_deltas(
file TEXT NOT NULL, file TEXT NOT NULL,
UNIQUE(version, file) UNIQUE(version, file)
); );
-- a list of schema files we have loaded on behalf of dynamic modules
CREATE TABLE IF NOT EXISTS applied_module_schemas(
module_name TEXT NOT NULL,
file TEXT NOT NULL,
UNIQUE(module_name, file)
);

View file

@ -33,8 +33,8 @@ class SearchStore(BackgroundUpdateStore):
EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order" EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist" EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
def __init__(self, hs): def __init__(self, db_conn, hs):
super(SearchStore, self).__init__(hs) super(SearchStore, self).__init__(db_conn, hs)
self.register_background_update_handler( self.register_background_update_handler(
self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search
) )

View file

@ -63,8 +63,8 @@ class StateStore(SQLBaseStore):
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
def __init__(self, hs): def __init__(self, db_conn, hs):
super(StateStore, self).__init__(hs) super(StateStore, self).__init__(db_conn, hs)
self.register_background_update_handler( self.register_background_update_handler(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
self._background_deduplicate_state, self._background_deduplicate_state,

View file

@ -46,8 +46,8 @@ class TransactionStore(SQLBaseStore):
"""A collection of queries for handling PDUs. """A collection of queries for handling PDUs.
""" """
def __init__(self, hs): def __init__(self, db_conn, hs):
super(TransactionStore, self).__init__(hs) super(TransactionStore, self).__init__(db_conn, hs)
self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000) self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000)

View file

@ -63,7 +63,7 @@ class UserDirectoryStore(SQLBaseStore):
user_ids (list(str)): Users to add user_ids (list(str)): Users to add
""" """
yield self._simple_insert_many( yield self._simple_insert_many(
table="users_in_pubic_room", table="users_in_public_rooms",
values=[ values=[
{ {
"user_id": user_id, "user_id": user_id,
@ -219,7 +219,7 @@ class UserDirectoryStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def update_user_in_public_user_list(self, user_id, room_id): def update_user_in_public_user_list(self, user_id, room_id):
yield self._simple_update_one( yield self._simple_update_one(
table="users_in_pubic_room", table="users_in_public_rooms",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
updatevalues={"room_id": room_id}, updatevalues={"room_id": room_id},
desc="update_user_in_public_user_list", desc="update_user_in_public_user_list",
@ -240,7 +240,7 @@ class UserDirectoryStore(SQLBaseStore):
) )
self._simple_delete_txn( self._simple_delete_txn(
txn, txn,
table="users_in_pubic_room", table="users_in_public_rooms",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
) )
txn.call_after( txn.call_after(
@ -256,7 +256,7 @@ class UserDirectoryStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def remove_from_user_in_public_room(self, user_id): def remove_from_user_in_public_room(self, user_id):
yield self._simple_delete( yield self._simple_delete(
table="users_in_pubic_room", table="users_in_public_rooms",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
desc="remove_from_user_in_public_room", desc="remove_from_user_in_public_room",
) )
@ -267,7 +267,7 @@ class UserDirectoryStore(SQLBaseStore):
in the given room_id in the given room_id
""" """
return self._simple_select_onecol( return self._simple_select_onecol(
table="users_in_pubic_room", table="users_in_public_rooms",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
retcol="user_id", retcol="user_id",
desc="get_users_in_public_due_to_room", desc="get_users_in_public_due_to_room",
@ -286,7 +286,7 @@ class UserDirectoryStore(SQLBaseStore):
) )
user_ids_pub = yield self._simple_select_onecol( user_ids_pub = yield self._simple_select_onecol(
table="users_in_pubic_room", table="users_in_public_rooms",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
retcol="user_id", retcol="user_id",
desc="get_users_in_dir_due_to_room", desc="get_users_in_dir_due_to_room",
@ -514,7 +514,7 @@ class UserDirectoryStore(SQLBaseStore):
def _delete_all_from_user_dir_txn(txn): def _delete_all_from_user_dir_txn(txn):
txn.execute("DELETE FROM user_directory") txn.execute("DELETE FROM user_directory")
txn.execute("DELETE FROM user_directory_search") txn.execute("DELETE FROM user_directory_search")
txn.execute("DELETE FROM users_in_pubic_room") txn.execute("DELETE FROM users_in_public_rooms")
txn.execute("DELETE FROM users_who_share_rooms") txn.execute("DELETE FROM users_who_share_rooms")
txn.call_after(self.get_user_in_directory.invalidate_all) txn.call_after(self.get_user_in_directory.invalidate_all)
txn.call_after(self.get_user_in_public_room.invalidate_all) txn.call_after(self.get_user_in_public_room.invalidate_all)
@ -537,7 +537,7 @@ class UserDirectoryStore(SQLBaseStore):
@cached() @cached()
def get_user_in_public_room(self, user_id): def get_user_in_public_room(self, user_id):
return self._simple_select_one( return self._simple_select_one(
table="users_in_pubic_room", table="users_in_public_rooms",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcols=("room_id",), retcols=("room_id",),
allow_none=True, allow_none=True,
@ -641,7 +641,7 @@ class UserDirectoryStore(SQLBaseStore):
SELECT d.user_id, display_name, avatar_url SELECT d.user_id, display_name, avatar_url
FROM user_directory_search FROM user_directory_search
INNER JOIN user_directory AS d USING (user_id) INNER JOIN user_directory AS d USING (user_id)
LEFT JOIN users_in_pubic_room AS p USING (user_id) LEFT JOIN users_in_public_rooms AS p USING (user_id)
LEFT JOIN ( LEFT JOIN (
SELECT other_user_id AS user_id FROM users_who_share_rooms SELECT other_user_id AS user_id FROM users_who_share_rooms
WHERE user_id = ? AND share_private WHERE user_id = ? AND share_private
@ -680,7 +680,7 @@ class UserDirectoryStore(SQLBaseStore):
SELECT d.user_id, display_name, avatar_url SELECT d.user_id, display_name, avatar_url
FROM user_directory_search FROM user_directory_search
INNER JOIN user_directory AS d USING (user_id) INNER JOIN user_directory AS d USING (user_id)
LEFT JOIN users_in_pubic_room AS p USING (user_id) LEFT JOIN users_in_public_rooms AS p USING (user_id)
LEFT JOIN ( LEFT JOIN (
SELECT other_user_id AS user_id FROM users_who_share_rooms SELECT other_user_id AS user_id FROM users_who_share_rooms
WHERE user_id = ? AND share_private WHERE user_id = ? AND share_private

View file

@ -278,8 +278,13 @@ class Limiter(object):
if entry[0] >= self.max_count: if entry[0] >= self.max_count:
new_defer = defer.Deferred() new_defer = defer.Deferred()
entry[1].append(new_defer) entry[1].append(new_defer)
logger.info("Waiting to acquire limiter lock for key %r", key)
with PreserveLoggingContext(): with PreserveLoggingContext():
yield new_defer yield new_defer
logger.info("Acquired limiter lock for key %r", key)
else:
logger.info("Acquired uncontended limiter lock for key %r", key)
entry[0] += 1 entry[0] += 1
@ -288,16 +293,21 @@ class Limiter(object):
try: try:
yield yield
finally: finally:
logger.info("Releasing limiter lock for key %r", key)
# We've finished executing so check if there are any things # We've finished executing so check if there are any things
# blocked waiting to execute and start one of them # blocked waiting to execute and start one of them
entry[0] -= 1 entry[0] -= 1
try:
entry[1].pop(0).callback(None) if entry[1]:
except IndexError: next_def = entry[1].pop(0)
# If nothing else is executing for this key then remove it
# from the map with PreserveLoggingContext():
if entry[0] == 0: next_def.callback(None)
self.key_to_defer.pop(key, None) elif entry[0] == 0:
# We were the last thing for this key: remove it from the
# map.
del self.key_to_defer[key]
defer.returnValue(_ctx_manager()) defer.returnValue(_ctx_manager())

View file

@ -53,7 +53,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
type="m.room.message", type="m.room.message",
room_id="!foo:bar" room_id="!foo:bar"
) )
self.mock_store.get_new_events_for_appservice.return_value = (0, [event]) self.mock_store.get_new_events_for_appservice.side_effect = [
(0, [event]),
(0, [])
]
self.mock_as_api.push = Mock() self.mock_as_api.push = Mock()
yield self.handler.notify_interested_services(0) yield self.handler.notify_interested_services(0)
self.mock_scheduler.submit_event_for_as.assert_called_once_with( self.mock_scheduler.submit_event_for_as.assert_called_once_with(
@ -75,7 +78,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
) )
self.mock_as_api.push = Mock() self.mock_as_api.push = Mock()
self.mock_as_api.query_user = Mock() self.mock_as_api.query_user = Mock()
self.mock_store.get_new_events_for_appservice.return_value = (0, [event]) self.mock_store.get_new_events_for_appservice.side_effect = [
(0, [event]),
(0, [])
]
yield self.handler.notify_interested_services(0) yield self.handler.notify_interested_services(0)
self.mock_as_api.query_user.assert_called_once_with( self.mock_as_api.query_user.assert_called_once_with(
services[0], user_id services[0], user_id
@ -98,7 +104,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
) )
self.mock_as_api.push = Mock() self.mock_as_api.push = Mock()
self.mock_as_api.query_user = Mock() self.mock_as_api.query_user = Mock()
self.mock_store.get_new_events_for_appservice.return_value = (0, [event]) self.mock_store.get_new_events_for_appservice.side_effect = [
(0, [event]),
(0, [])
]
yield self.handler.notify_interested_services(0) yield self.handler.notify_interested_services(0)
self.assertFalse( self.assertFalse(
self.mock_as_api.query_user.called, self.mock_as_api.query_user.called,

View file

@ -58,7 +58,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob") self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob") self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
# must be done after inserts # must be done after inserts
self.store = ApplicationServiceStore(hs) self.store = ApplicationServiceStore(None, hs)
def tearDown(self): def tearDown(self):
# TODO: suboptimal that we need to create files for tests! # TODO: suboptimal that we need to create files for tests!
@ -150,7 +150,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
self.as_yaml_files = [] self.as_yaml_files = []
self.store = TestTransactionStore(hs) self.store = TestTransactionStore(None, hs)
def _add_service(self, url, as_token, id): def _add_service(self, url, as_token, id):
as_yaml = dict(url=url, as_token=as_token, hs_token="something", as_yaml = dict(url=url, as_token=as_token, hs_token="something",
@ -420,8 +420,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
class TestTransactionStore(ApplicationServiceTransactionStore, class TestTransactionStore(ApplicationServiceTransactionStore,
ApplicationServiceStore): ApplicationServiceStore):
def __init__(self, hs): def __init__(self, db_conn, hs):
super(TestTransactionStore, self).__init__(hs) super(TestTransactionStore, self).__init__(db_conn, hs)
class ApplicationServiceStoreConfigTestCase(unittest.TestCase): class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
@ -458,7 +458,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
replication_layer=Mock(), replication_layer=Mock(),
) )
ApplicationServiceStore(hs) ApplicationServiceStore(None, hs)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_duplicate_ids(self): def test_duplicate_ids(self):
@ -477,7 +477,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
) )
with self.assertRaises(ConfigError) as cm: with self.assertRaises(ConfigError) as cm:
ApplicationServiceStore(hs) ApplicationServiceStore(None, hs)
e = cm.exception e = cm.exception
self.assertIn(f1, e.message) self.assertIn(f1, e.message)
@ -501,7 +501,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
) )
with self.assertRaises(ConfigError) as cm: with self.assertRaises(ConfigError) as cm:
ApplicationServiceStore(hs) ApplicationServiceStore(None, hs)
e = cm.exception e = cm.exception
self.assertIn(f1, e.message) self.assertIn(f1, e.message)

View file

@ -56,7 +56,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
database_engine=create_engine(config.database_config), database_engine=create_engine(config.database_config),
) )
self.datastore = SQLBaseStore(hs) self.datastore = SQLBaseStore(None, hs)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_insert_1col(self): def test_insert_1col(self):

View file

@ -29,7 +29,7 @@ class DirectoryStoreTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
hs = yield setup_test_homeserver() hs = yield setup_test_homeserver()
self.store = DirectoryStore(hs) self.store = DirectoryStore(None, hs)
self.room = RoomID.from_string("!abcde:test") self.room = RoomID.from_string("!abcde:test")
self.alias = RoomAlias.from_string("#my-room:test") self.alias = RoomAlias.from_string("#my-room:test")

View file

@ -29,7 +29,7 @@ class PresenceStoreTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
hs = yield setup_test_homeserver(clock=MockClock()) hs = yield setup_test_homeserver(clock=MockClock())
self.store = PresenceStore(hs) self.store = PresenceStore(None, hs)
self.u_apple = UserID.from_string("@apple:test") self.u_apple = UserID.from_string("@apple:test")
self.u_banana = UserID.from_string("@banana:test") self.u_banana = UserID.from_string("@banana:test")

View file

@ -29,7 +29,7 @@ class ProfileStoreTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
hs = yield setup_test_homeserver() hs = yield setup_test_homeserver()
self.store = ProfileStore(hs) self.store = ProfileStore(None, hs)
self.u_frank = UserID.from_string("@frank:test") self.u_frank = UserID.from_string("@frank:test")

View file

@ -86,7 +86,8 @@ class RegistrationStoreTestCase(unittest.TestCase):
# now delete some # now delete some
yield self.store.user_delete_access_tokens( yield self.store.user_delete_access_tokens(
self.user_id, device_id=self.device_id, delete_refresh_tokens=True) self.user_id, device_id=self.device_id,
)
# check they were deleted # check they were deleted
user = yield self.store.get_user_by_access_token(self.tokens[1]) user = yield self.store.get_user_by_access_token(self.tokens[1])
@ -97,8 +98,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
self.assertEqual(self.user_id, user["name"]) self.assertEqual(self.user_id, user["name"])
# now delete the rest # now delete the rest
yield self.store.user_delete_access_tokens( yield self.store.user_delete_access_tokens(self.user_id)
self.user_id, delete_refresh_tokens=True)
user = yield self.store.get_user_by_access_token(self.tokens[0]) user = yield self.store.get_user_by_access_token(self.tokens[0])
self.assertIsNone(user, self.assertIsNone(user,

View file

@ -310,6 +310,7 @@ class SQLiteMemoryDbPool(ConnectionPool, object):
) )
self.config = Mock() self.config = Mock()
self.config.password_providers = []
self.config.database_config = {"name": "sqlite3"} self.config.database_config = {"name": "sqlite3"}
def prepare(self): def prepare(self):