0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-07-04 02:18:25 +02:00

Merge branch 'release-v0.25.0' of github.com:matrix-org/synapse

This commit is contained in:
Erik Johnston 2017-11-15 11:32:24 +00:00
commit 552f123bea
85 changed files with 1777 additions and 683 deletions

View file

@ -1,3 +1,61 @@
Changes in synapse v0.25.0 (2017-11-15)
=======================================
Bug fixes:
* Fix port script (PR #2673)
Changes in synapse v0.25.0-rc1 (2017-11-14)
===========================================
Features:
* Add is_public to groups table to allow for private groups (PR #2582)
* Add a route for determining who you are (PR #2668) Thanks to @turt2live!
* Add more features to the password providers (PR #2608, #2610, #2620, #2622,
#2623, #2624, #2626, #2628, #2629)
* Add a hook for custom rest endpoints (PR #2627)
* Add API to update group room visibility (PR #2651)
Changes:
* Ignore <noscript> tags when generating URL preview descriptions (PR #2576)
Thanks to @maximevaillancourt!
* Register some /unstable endpoints in /r0 as well (PR #2579) Thanks to
@krombel!
* Support /keys/upload on /r0 as well as /unstable (PR #2585)
* Front-end proxy: pass through auth header (PR #2586)
* Allow ASes to deactivate their own users (PR #2589)
* Remove refresh tokens (PR #2613)
* Automatically set default displayname on register (PR #2617)
* Log login requests (PR #2618)
* Always return `is_public` in the `/groups/:group_id/rooms` API (PR #2630)
* Avoid no-op media deletes (PR #2637) Thanks to @spantaleev!
* Fix various embarrassing typos around user_directory and add some doc. (PR
#2643)
* Return whether a user is an admin within a group (PR #2647)
* Namespace visibility options for groups (PR #2657)
* Downcase UserIDs on registration (PR #2662)
* Cache failures when fetching URL previews (PR #2669)
Bug fixes:
* Fix port script (PR #2577)
* Fix error when running synapse with no logfile (PR #2581)
* Fix UI auth when deleting devices (PR #2591)
* Fix typo when checking if user is invited to group (PR #2599)
* Fix the port script to drop NUL values in all tables (PR #2611)
* Fix appservices being backlogged and not receiving new events due to a bug in
notify_interested_services (PR #2631) Thanks to @xyzz!
* Fix updating rooms avatar/display name when modified by admin (PR #2636)
Thanks to @farialima!
* Fix bug in state group storage (PR #2649)
* Fix 500 on invalid utf-8 in request (PR #2663)
Changes in synapse v0.24.1 (2017-10-24)
=======================================

View file

@ -823,7 +823,9 @@ spidering 'internal' URLs on your network. At the very least we recommend that
your loopback and RFC1918 IP addresses are blacklisted.
This also requires the optional lxml and netaddr python dependencies to be
installed.
installed. This in turn requires the libxml2 library to be available - on
Debian/Ubuntu this means ``apt-get install libxml2-dev``, or equivalent for
your OS.
Password reset

View file

@ -1,52 +1,119 @@
Basically, PEP8
- Everything should comply with PEP8. Code should pass
``pep8 --max-line-length=100`` without any warnings.
- NEVER tabs. 4 spaces to indent.
- Max line width: 79 chars (with flexibility to overflow by a "few chars" if
- **Indenting**:
- NEVER tabs. 4 spaces to indent.
- follow PEP8; either hanging indent or multiline-visual indent depending
on the size and shape of the arguments and what makes more sense to the
author. In other words, both this::
print("I am a fish %s" % "moo")
and this::
print("I am a fish %s" %
"moo")
and this::
print(
"I am a fish %s" %
"moo",
)
...are valid, although given each one takes up 2x more vertical space than
the previous, it's up to the author's discretion as to which layout makes
most sense for their function invocation. (e.g. if they want to add
comments per-argument, or put expressions in the arguments, or group
related arguments together, or want to deliberately extend or preserve
vertical/horizontal space)
- **Line length**:
Max line length is 79 chars (with flexibility to overflow by a "few chars" if
the overflowing content is not semantically significant and avoids an
explosion of vertical whitespace).
- Use camel case for class and type names
- Use underscores for functions and variables.
- Use double quotes.
- Use parentheses instead of '\\' for line continuation where ever possible
(which is pretty much everywhere)
- There should be max a single new line between:
Use parentheses instead of ``\`` for line continuation where ever possible
(which is pretty much everywhere).
- **Naming**:
- Use camel case for class and type names
- Use underscores for functions and variables.
- Use double quotes ``"foo"`` rather than single quotes ``'foo'``.
- **Blank lines**:
- There should be max a single new line between:
- statements
- functions in a class
- There should be two new lines between:
- There should be two new lines between:
- definitions in a module (e.g., between different classes)
- There should be spaces where spaces should be and not where there shouldn't be:
- a single space after a comma
- a single space before and after for '=' when used as assignment
- no spaces before and after for '=' for default values and keyword arguments.
- Indenting must follow PEP8; either hanging indent or multiline-visual indent
depending on the size and shape of the arguments and what makes more sense to
the author. In other words, both this::
print("I am a fish %s" % "moo")
- **Whitespace**:
and this::
There should be spaces where spaces should be and not where there shouldn't
be:
print("I am a fish %s" %
"moo")
- a single space after a comma
- a single space before and after for '=' when used as assignment
- no spaces before and after for '=' for default values and keyword arguments.
and this::
- **Comments**: should follow the `google code style
<http://google.github.io/styleguide/pyguide.html?showone=Comments#Comments>`_.
This is so that we can generate documentation with `sphinx
<http://sphinxcontrib-napoleon.readthedocs.org/en/latest/>`_. See the
`examples
<http://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html>`_
in the sphinx documentation.
print(
"I am a fish %s" %
"moo"
)
- **Imports**:
...are valid, although given each one takes up 2x more vertical space than
the previous, it's up to the author's discretion as to which layout makes most
sense for their function invocation. (e.g. if they want to add comments
per-argument, or put expressions in the arguments, or group related arguments
together, or want to deliberately extend or preserve vertical/horizontal
space)
- Prefer to import classes and functions than packages or modules.
Comments should follow the `google code style <http://google.github.io/styleguide/pyguide.html?showone=Comments#Comments>`_.
This is so that we can generate documentation with
`sphinx <http://sphinxcontrib-napoleon.readthedocs.org/en/latest/>`_. See the
`examples <http://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html>`_
in the sphinx documentation.
Example::
Code should pass pep8 --max-line-length=100 without any warnings.
from synapse.types import UserID
...
user_id = UserID(local, server)
is preferred over::
from synapse import types
...
user_id = types.UserID(local, server)
(or any other variant).
This goes against the advice in the Google style guide, but it means that
errors in the name are caught early (at import time).
- Multiple imports from the same package can be combined onto one line::
from synapse.types import GroupID, RoomID, UserID
An effort should be made to keep the individual imports in alphabetical
order.
If the list becomes long, wrap it with parentheses and split it over
multiple lines.
- As per `PEP-8 <https://www.python.org/dev/peps/pep-0008/#imports>`_,
imports should be grouped in the following order, with a blank line between
each group:
1. standard library imports
2. related third party imports
3. local application/library specific imports
- Imports within each group should be sorted alphabetically by module name.
- Avoid wildcard imports (``from synapse.types import *``) and relative
imports (``from .types import UserID``).

View file

@ -0,0 +1,99 @@
Password auth provider modules
==============================
Password auth providers offer a way for server administrators to integrate
their Synapse installation with an existing authentication system.
A password auth provider is a Python class which is dynamically loaded into
Synapse, and provides a number of methods by which it can integrate with the
authentication system.
This document serves as a reference for those looking to implement their own
password auth providers.
Required methods
----------------
Password auth provider classes must provide the following methods:
*class* ``SomeProvider.parse_config``\(*config*)
This method is passed the ``config`` object for this module from the
homeserver configuration file.
It should perform any appropriate sanity checks on the provided
configuration, and return an object which is then passed into ``__init__``.
*class* ``SomeProvider``\(*config*, *account_handler*)
The constructor is passed the config object returned by ``parse_config``,
and a ``synapse.module_api.ModuleApi`` object which allows the
password provider to check if accounts exist and/or create new ones.
Optional methods
----------------
Password auth provider classes may optionally provide the following methods.
*class* ``SomeProvider.get_db_schema_files``\()
This method, if implemented, should return an Iterable of ``(name,
stream)`` pairs of database schema files. Each file is applied in turn at
initialisation, and a record is then made in the database so that it is
not re-applied on the next start.
``someprovider.get_supported_login_types``\()
This method, if implemented, should return a ``dict`` mapping from a login
type identifier (such as ``m.login.password``) to an iterable giving the
fields which must be provided by the user in the submission to the
``/login`` api. These fields are passed in the ``login_dict`` dictionary
to ``check_auth``.
For example, if a password auth provider wants to implement a custom login
type of ``com.example.custom_login``, where the client is expected to pass
the fields ``secret1`` and ``secret2``, the provider should implement this
method and return the following dict::
{"com.example.custom_login": ("secret1", "secret2")}
``someprovider.check_auth``\(*username*, *login_type*, *login_dict*)
This method is the one that does the real work. If implemented, it will be
called for each login attempt where the login type matches one of the keys
returned by ``get_supported_login_types``.
It is passed the (possibly UNqualified) ``user`` provided by the client,
the login type, and a dictionary of login secrets passed by the client.
The method should return a Twisted ``Deferred`` object, which resolves to
the canonical ``@localpart:domain`` user id if authentication is successful,
and ``None`` if not.
Alternatively, the ``Deferred`` can resolve to a ``(str, func)`` tuple, in
which case the second field is a callback which will be called with the
result from the ``/login`` call (including ``access_token``, ``device_id``,
etc.)
``someprovider.check_password``\(*user_id*, *password*)
This method provides a simpler interface than ``get_supported_login_types``
and ``check_auth`` for password auth providers that just want to provide a
mechanism for validating ``m.login.password`` logins.
Iif implemented, it will be called to check logins with an
``m.login.password`` login type. It is passed a qualified
``@localpart:domain`` user id, and the password provided by the user.
The method should return a Twisted ``Deferred`` object, which resolves to
``True`` if authentication is successful, and ``False`` if not.
``someprovider.on_logged_out``\(*user_id*, *device_id*, *access_token*)
This method, if implemented, is called when a user logs out. It is passed
the qualified user ID, the ID of the deactivated device (if any: access
tokens are occasionally created without an associated device ID), and the
(now deactivated) access token.
It may return a Twisted ``Deferred`` object; the logout request will wait
for the deferred to complete but the result is ignored.

View file

@ -56,6 +56,7 @@ As a first cut, let's do #2 and have the receiver hit the API to calculate its o
API
---
```
GET /_matrix/media/r0/preview_url?url=http://wherever.com
200 OK
{
@ -66,6 +67,7 @@ GET /_matrix/media/r0/preview_url?url=http://wherever.com
"og:description" : "“Synapse 0.12 is out! Lots of polishing, performance &amp;amp; bugfixes: /sync API, /r0 prefix, fulltext search, 3PID invites https://t.co/5alhXLLEGP”"
"og:site_name" : "Twitter"
}
```
* Downloads the URL
* If HTML, just stores it in RAM and parses it for OG meta tags

17
docs/user_directory.md Normal file
View file

@ -0,0 +1,17 @@
User Directory API Implementation
=================================
The user directory is currently maintained based on the 'visible' users
on this particular server - i.e. ones which your account shares a room with, or
who are present in a publicly viewable room present on the server.
The directory info is stored in various tables, which can (typically after
DB corruption) get stale or out of sync. If this happens, for now the
quickest solution to fix it is:
```
UPDATE user_directory_stream_pos SET stream_id = NULL;
```
and restart the synapse, which should then start a background task to
flush the current tables and regenerate the directory.

View file

@ -42,6 +42,14 @@ BOOLEAN_COLUMNS = {
"public_room_list_stream": ["visibility"],
"device_lists_outbound_pokes": ["sent"],
"users_who_share_rooms": ["share_private"],
"groups": ["is_public"],
"group_rooms": ["is_public"],
"group_users": ["is_public", "is_admin"],
"group_summary_rooms": ["is_public"],
"group_room_categories": ["is_public"],
"group_summary_users": ["is_public"],
"group_roles": ["is_public"],
"local_group_membership": ["is_publicised", "is_admin"],
}
@ -112,6 +120,7 @@ class Store(object):
_simple_update_one = SQLBaseStore.__dict__["_simple_update_one"]
_simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"]
_simple_update_txn = SQLBaseStore.__dict__["_simple_update_txn"]
def runInteraction(self, desc, func, *args, **kwargs):
def r(conn):
@ -318,7 +327,7 @@ class Porter(object):
backward_chunk = min(row[0] for row in brows) - 1
rows = frows + brows
self._convert_rows(table, headers, rows)
rows = self._convert_rows(table, headers, rows)
def insert(txn):
self.postgres_store.insert_many_txn(
@ -554,17 +563,29 @@ class Porter(object):
i for i, h in enumerate(headers) if h in bool_col_names
]
class BadValueException(Exception):
pass
def conv(j, col):
if j in bool_cols:
return bool(col)
elif isinstance(col, basestring) and "\0" in col:
logger.warn("DROPPING ROW: NUL value in table %s col %s: %r", table, headers[j], col)
raise BadValueException();
return col
outrows = []
for i, row in enumerate(rows):
rows[i] = tuple(
conv(j, col)
for j, col in enumerate(row)
if j > 0
)
try:
outrows.append(tuple(
conv(j, col)
for j, col in enumerate(row)
if j > 0
))
except BadValueException:
pass
return outrows
@defer.inlineCallbacks
def _setup_sent_transactions(self):
@ -592,7 +613,7 @@ class Porter(object):
"select", r,
)
self._convert_rows("sent_transactions", headers, rows)
rows = self._convert_rows("sent_transactions", headers, rows)
inserted_rows = len(rows)
if inserted_rows:

View file

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server.
"""
__version__ = "0.24.1"
__version__ = "0.25.0"

View file

@ -50,8 +50,7 @@ logger = logging.getLogger("synapse.app.frontend_proxy")
class KeyUploadServlet(RestServlet):
PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$",
releases=())
PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
def __init__(self, hs):
"""
@ -89,9 +88,16 @@ class KeyUploadServlet(RestServlet):
if body:
# They're actually trying to upload something, proxy to main synapse.
# Pass through the auth headers, if any, in case the access token
# is there.
auth_headers = request.requestHeaders.getRawHeaders("Authorization", [])
headers = {
"Authorization": auth_headers,
}
result = yield self.http_client.post_json_get_json(
self.main_uri + request.uri,
body,
headers=headers,
)
defer.returnValue((200, result))

View file

@ -30,6 +30,8 @@ from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory
from synapse.federation.transport.server import TransportLayerServer
from synapse.module_api import ModuleApi
from synapse.http.additional_resource import AdditionalResource
from synapse.http.server import RootRedirect
from synapse.http.site import SynapseSite
from synapse.metrics import register_memory_metrics
@ -49,6 +51,7 @@ from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_d
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.module_loader import load_module
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from twisted.application import service
@ -107,52 +110,18 @@ class SynapseHomeServer(HomeServer):
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "client":
client_resource = ClientRestResource(self)
if res["compress"]:
client_resource = gz_wrap(client_resource)
resources.update(self._configure_named_resource(
name, res.get("compress", False),
))
resources.update({
"/_matrix/client/api/v1": client_resource,
"/_matrix/client/r0": client_resource,
"/_matrix/client/unstable": client_resource,
"/_matrix/client/v2_alpha": client_resource,
"/_matrix/client/versions": client_resource,
})
if name == "federation":
resources.update({
FEDERATION_PREFIX: TransportLayerServer(self),
})
if name in ["static", "client"]:
resources.update({
STATIC_PREFIX: File(
os.path.join(os.path.dirname(synapse.__file__), "static")
),
})
if name in ["media", "federation", "client"]:
media_repo = MediaRepositoryResource(self)
resources.update({
MEDIA_PREFIX: media_repo,
LEGACY_MEDIA_PREFIX: media_repo,
CONTENT_REPO_PREFIX: ContentRepoResource(
self, self.config.uploads_path
),
})
if name in ["keys", "federation"]:
resources.update({
SERVER_KEY_PREFIX: LocalKey(self),
SERVER_KEY_V2_PREFIX: KeyApiV2Resource(self),
})
if name == "webclient":
resources[WEB_CLIENT_PREFIX] = build_resource_for_web_client(self)
if name == "metrics" and self.get_config().enable_metrics:
resources[METRICS_PREFIX] = MetricsResource(self)
additional_resources = listener_config.get("additional_resources", {})
logger.debug("Configuring additional resources: %r",
additional_resources)
module_api = ModuleApi(self, self.get_auth_handler())
for path, resmodule in additional_resources.items():
handler_cls, config = load_module(resmodule)
handler = handler_cls(config, module_api)
resources[path] = AdditionalResource(self, handler.handle_request)
if WEB_CLIENT_PREFIX in resources:
root_resource = RootRedirect(WEB_CLIENT_PREFIX)
@ -188,6 +157,67 @@ class SynapseHomeServer(HomeServer):
)
logger.info("Synapse now listening on port %d", port)
def _configure_named_resource(self, name, compress=False):
"""Build a resource map for a named resource
Args:
name (str): named resource: one of "client", "federation", etc
compress (bool): whether to enable gzip compression for this
resource
Returns:
dict[str, Resource]: map from path to HTTP resource
"""
resources = {}
if name == "client":
client_resource = ClientRestResource(self)
if compress:
client_resource = gz_wrap(client_resource)
resources.update({
"/_matrix/client/api/v1": client_resource,
"/_matrix/client/r0": client_resource,
"/_matrix/client/unstable": client_resource,
"/_matrix/client/v2_alpha": client_resource,
"/_matrix/client/versions": client_resource,
})
if name == "federation":
resources.update({
FEDERATION_PREFIX: TransportLayerServer(self),
})
if name in ["static", "client"]:
resources.update({
STATIC_PREFIX: File(
os.path.join(os.path.dirname(synapse.__file__), "static")
),
})
if name in ["media", "federation", "client"]:
media_repo = MediaRepositoryResource(self)
resources.update({
MEDIA_PREFIX: media_repo,
LEGACY_MEDIA_PREFIX: media_repo,
CONTENT_REPO_PREFIX: ContentRepoResource(
self, self.config.uploads_path
),
})
if name in ["keys", "federation"]:
resources.update({
SERVER_KEY_PREFIX: LocalKey(self),
SERVER_KEY_V2_PREFIX: KeyApiV2Resource(self),
})
if name == "webclient":
resources[WEB_CLIENT_PREFIX] = build_resource_for_web_client(self)
if name == "metrics" and self.get_config().enable_metrics:
resources[METRICS_PREFIX] = MetricsResource(self)
return resources
def start_listening(self):
config = self.get_config()

View file

@ -18,6 +18,7 @@ from synapse.api.constants import ThirdPartyEntityKind
from synapse.api.errors import CodeMessageException
from synapse.http.client import SimpleHttpClient
from synapse.events.utils import serialize_event
from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
from synapse.util.caches.response_cache import ResponseCache
from synapse.types import ThirdPartyInstanceID
@ -192,9 +193,12 @@ class ApplicationServiceApi(SimpleHttpClient):
defer.returnValue(None)
key = (service.id, protocol)
return self.protocol_meta_cache.get(key) or (
self.protocol_meta_cache.set(key, _get())
)
result = self.protocol_meta_cache.get(key)
if not result:
result = self.protocol_meta_cache.set(
key, preserve_fn(_get)()
)
return make_deferred_yieldable(result)
@defer.inlineCallbacks
def push_bulk(self, service, events, txn_id=None):

View file

@ -41,7 +41,7 @@ class CasConfig(Config):
#cas_config:
# enabled: true
# server_url: "https://cas-server.com"
# service_url: "https://homesever.domain.com:8448"
# service_url: "https://homeserver.domain.com:8448"
# #required_attributes:
# # name: value
"""

View file

@ -148,8 +148,8 @@ def setup_logging(config, use_worker_options=False):
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
" - %(message)s"
)
if log_config is None:
if log_config is None:
level = logging.INFO
level_for_storage = logging.INFO
if config.verbosity:
@ -176,6 +176,10 @@ def setup_logging(config, use_worker_options=False):
logger.info("Opened new log file due to SIGHUP")
else:
handler = logging.StreamHandler()
def sighup(signum, stack):
pass
handler.setFormatter(formatter)
handler.addFilter(LoggingContextFilter(request=""))

View file

@ -13,41 +13,40 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config, ConfigError
from ._base import Config
from synapse.util.module_loader import load_module
LDAP_PROVIDER = 'ldap_auth_provider.LdapAuthProvider'
class PasswordAuthProviderConfig(Config):
def read_config(self, config):
self.password_providers = []
provider_config = None
providers = []
# We want to be backwards compatible with the old `ldap_config`
# param.
ldap_config = config.get("ldap_config", {})
self.ldap_enabled = ldap_config.get("enabled", False)
if self.ldap_enabled:
from ldap_auth_provider import LdapAuthProvider
parsed_config = LdapAuthProvider.parse_config(ldap_config)
self.password_providers.append((LdapAuthProvider, parsed_config))
if ldap_config.get("enabled", False):
providers.append[{
'module': LDAP_PROVIDER,
'config': ldap_config,
}]
providers = config.get("password_providers", [])
providers.extend(config.get("password_providers", []))
for provider in providers:
mod_name = provider['module']
# This is for backwards compat when the ldap auth provider resided
# in this package.
if provider['module'] == "synapse.util.ldap_auth_provider.LdapAuthProvider":
from ldap_auth_provider import LdapAuthProvider
provider_class = LdapAuthProvider
try:
provider_config = provider_class.parse_config(provider["config"])
except Exception as e:
raise ConfigError(
"Failed to parse config for %r: %r" % (provider['module'], e)
)
else:
(provider_class, provider_config) = load_module(provider)
if mod_name == "synapse.util.ldap_auth_provider.LdapAuthProvider":
mod_name = LDAP_PROVIDER
(provider_class, provider_config) = load_module({
"module": mod_name,
"config": provider['config'],
})
self.password_providers.append((provider_class, provider_config))

View file

@ -247,6 +247,13 @@ class ServerConfig(Config):
- names: [federation] # Federation APIs
compress: false
# optional list of additional endpoints which can be loaded via
# dynamic modules
# additional_resources:
# "/_matrix/my/custom/endpoint":
# module: my_module.CustomRequestHandler
# config: {}
# Unsecure HTTP listener,
# For when matrix traffic passes through loadbalancer that unwraps TLS.
- port: %(unsecure_port)s

View file

@ -109,6 +109,12 @@ class TlsConfig(Config):
# key. It may be necessary to publish the fingerprints of a new
# certificate and wait until the "valid_until_ts" of the previous key
# responses have passed before deploying it.
#
# You can calculate a fingerprint from a given TLS listener via:
# openssl s_client -connect $host:$port < /dev/null 2> /dev/null |
# openssl x509 -outform DER | openssl sha256 -binary | base64 | tr -d '='
# or by checking matrix.org/federationtester/api/report?server_name=$host
#
tls_fingerprints: []
# tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}]
""" % locals()

View file

@ -18,6 +18,7 @@ from .federation_base import FederationBase
from .units import Transaction, Edu
from synapse.util import async
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.util.logutils import log_function
from synapse.util.caches.response_cache import ResponseCache
from synapse.events import FrozenEvent
@ -253,12 +254,13 @@ class FederationServer(FederationBase):
result = self._state_resp_cache.get((room_id, event_id))
if not result:
with (yield self._server_linearizer.queue((origin, room_id))):
resp = yield self._state_resp_cache.set(
d = self._state_resp_cache.set(
(room_id, event_id),
self._on_context_state_request_compute(room_id, event_id)
preserve_fn(self._on_context_state_request_compute)(room_id, event_id)
)
resp = yield make_deferred_yieldable(d)
else:
resp = yield result
resp = yield make_deferred_yieldable(result)
defer.returnValue((200, resp))

View file

@ -545,6 +545,20 @@ class TransportLayerClient(object):
ignore_backoff=True,
)
def update_room_in_group(self, destination, group_id, requester_user_id, room_id,
config_key, content):
"""Update room in group
"""
path = PREFIX + "/groups/%s/room/%s/config/%s" % (group_id, room_id, config_key,)
return self.client.post_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
data=content,
ignore_backoff=True,
)
def remove_room_from_group(self, destination, group_id, requester_user_id, room_id):
"""Remove a room from a group
"""

View file

@ -676,7 +676,7 @@ class FederationGroupsRoomsServlet(BaseFederationServlet):
class FederationGroupsAddRoomsServlet(BaseFederationServlet):
"""Add/remove room from group
"""
PATH = "/groups/(?P<group_id>[^/]*)/room/(?<room_id>)$"
PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)$"
@defer.inlineCallbacks
def on_POST(self, origin, content, query, group_id, room_id):
@ -703,6 +703,27 @@ class FederationGroupsAddRoomsServlet(BaseFederationServlet):
defer.returnValue((200, new_content))
class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
"""Update room config in group
"""
PATH = (
"/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
"/config/(?P<config_key>[^/]*)$"
)
@defer.inlineCallbacks
def on_POST(self, origin, content, query, group_id, room_id, config_key):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
result = yield self.groups_handler.update_room_in_group(
group_id, requester_user_id, room_id, config_key, content,
)
defer.returnValue((200, result))
class FederationGroupsUsersServlet(BaseFederationServlet):
"""Get the users in a group on behalf of a user
"""
@ -1142,6 +1163,8 @@ GROUP_SERVER_SERVLET_CLASSES = (
FederationGroupsRolesServlet,
FederationGroupsRoleServlet,
FederationGroupsSummaryUsersServlet,
FederationGroupsAddRoomsServlet,
FederationGroupsAddRoomsConfigServlet,
)

View file

@ -13,6 +13,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Attestations ensure that users and groups can't lie about their memberships.
When a user joins a group the HS and GS swap attestations, which allow them
both to independently prove to third parties their membership.These
attestations have a validity period so need to be periodically renewed.
If a user leaves (or gets kicked out of) a group, either side can still use
their attestation to "prove" their membership, until the attestation expires.
Therefore attestations shouldn't be relied on to prove membership in important
cases, but can for less important situtations, e.g. showing a users membership
of groups on their profile, showing flairs, etc.abs
An attestsation is a signed blob of json that looks like:
{
"user_id": "@foo:a.example.com",
"group_id": "+bar:b.example.com",
"valid_until_ms": 1507994728530,
"signatures":{"matrix.org":{"ed25519:auto":"..."}}
}
"""
import logging
import random
from twisted.internet import defer
from synapse.api.errors import SynapseError
@ -22,9 +47,17 @@ from synapse.util.logcontext import preserve_fn
from signedjson.sign import sign_json
logger = logging.getLogger(__name__)
# Default validity duration for new attestations we create
DEFAULT_ATTESTATION_LENGTH_MS = 3 * 24 * 60 * 60 * 1000
# We add some jitter to the validity duration of attestations so that if we
# add lots of users at once we don't need to renew them all at once.
# The jitter is a multiplier picked randomly between the first and second number
DEFAULT_ATTESTATION_JITTER = (0.9, 1.3)
# Start trying to update our attestations when they come this close to expiring
UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000
@ -73,10 +106,14 @@ class GroupAttestationSigning(object):
"""Create an attestation for the group_id and user_id with default
validity length.
"""
validity_period = DEFAULT_ATTESTATION_LENGTH_MS
validity_period *= random.uniform(*DEFAULT_ATTESTATION_JITTER)
valid_until_ms = int(self.clock.time_msec() + validity_period)
return sign_json({
"group_id": group_id,
"user_id": user_id,
"valid_until_ms": self.clock.time_msec() + DEFAULT_ATTESTATION_LENGTH_MS,
"valid_until_ms": valid_until_ms,
}, self.server_name, self.signing_key)
@ -128,12 +165,19 @@ class GroupAttestionRenewer(object):
@defer.inlineCallbacks
def _renew_attestation(group_id, user_id):
attestation = self.attestations.create_attestation(group_id, user_id)
if self.is_mine_id(group_id):
if not self.is_mine_id(group_id):
destination = get_domain_from_id(group_id)
elif not self.is_mine_id(user_id):
destination = get_domain_from_id(user_id)
else:
destination = get_domain_from_id(group_id)
logger.warn(
"Incorrectly trying to do attestations for user: %r in %r",
user_id, group_id,
)
yield self.store.remove_attestation_renewal(group_id, user_id)
return
attestation = self.attestations.create_attestation(group_id, user_id)
yield self.transport_client.renew_group_attestation(
destination, group_id, user_id,

View file

@ -49,7 +49,8 @@ class GroupsServerHandler(object):
hs.get_groups_attestation_renewer()
@defer.inlineCallbacks
def check_group_is_ours(self, group_id, and_exists=False, and_is_admin=None):
def check_group_is_ours(self, group_id, requester_user_id,
and_exists=False, and_is_admin=None):
"""Check that the group is ours, and optionally if it exists.
If group does exist then return group.
@ -67,6 +68,10 @@ class GroupsServerHandler(object):
if and_exists and not group:
raise SynapseError(404, "Unknown group")
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
if group and not is_user_in_group and not group["is_public"]:
raise SynapseError(404, "Unknown group")
if and_is_admin:
is_admin = yield self.store.is_user_admin_in_group(group_id, and_is_admin)
if not is_admin:
@ -84,7 +89,7 @@ class GroupsServerHandler(object):
A user/room may appear in multiple roles/categories.
"""
yield self.check_group_is_ours(group_id, and_exists=True)
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
@ -153,10 +158,16 @@ class GroupsServerHandler(object):
})
@defer.inlineCallbacks
def update_group_summary_room(self, group_id, user_id, room_id, category_id, content):
def update_group_summary_room(self, group_id, requester_user_id,
room_id, category_id, content):
"""Add/update a room to the group summary
"""
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
yield self.check_group_is_ours(
group_id,
requester_user_id,
and_exists=True,
and_is_admin=requester_user_id,
)
RoomID.from_string(room_id) # Ensure valid room id
@ -175,10 +186,16 @@ class GroupsServerHandler(object):
defer.returnValue({})
@defer.inlineCallbacks
def delete_group_summary_room(self, group_id, user_id, room_id, category_id):
def delete_group_summary_room(self, group_id, requester_user_id,
room_id, category_id):
"""Remove a room from the summary
"""
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
yield self.check_group_is_ours(
group_id,
requester_user_id,
and_exists=True,
and_is_admin=requester_user_id,
)
yield self.store.remove_room_from_summary(
group_id=group_id,
@ -189,10 +206,10 @@ class GroupsServerHandler(object):
defer.returnValue({})
@defer.inlineCallbacks
def get_group_categories(self, group_id, user_id):
def get_group_categories(self, group_id, requester_user_id):
"""Get all categories in a group (as seen by user)
"""
yield self.check_group_is_ours(group_id, and_exists=True)
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
categories = yield self.store.get_group_categories(
group_id=group_id,
@ -200,10 +217,10 @@ class GroupsServerHandler(object):
defer.returnValue({"categories": categories})
@defer.inlineCallbacks
def get_group_category(self, group_id, user_id, category_id):
def get_group_category(self, group_id, requester_user_id, category_id):
"""Get a specific category in a group (as seen by user)
"""
yield self.check_group_is_ours(group_id, and_exists=True)
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
res = yield self.store.get_group_category(
group_id=group_id,
@ -213,10 +230,15 @@ class GroupsServerHandler(object):
defer.returnValue(res)
@defer.inlineCallbacks
def update_group_category(self, group_id, user_id, category_id, content):
def update_group_category(self, group_id, requester_user_id, category_id, content):
"""Add/Update a group category
"""
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
yield self.check_group_is_ours(
group_id,
requester_user_id,
and_exists=True,
and_is_admin=requester_user_id,
)
is_public = _parse_visibility_from_contents(content)
profile = content.get("profile")
@ -231,10 +253,15 @@ class GroupsServerHandler(object):
defer.returnValue({})
@defer.inlineCallbacks
def delete_group_category(self, group_id, user_id, category_id):
def delete_group_category(self, group_id, requester_user_id, category_id):
"""Delete a group category
"""
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
yield self.check_group_is_ours(
group_id,
requester_user_id,
and_exists=True,
and_is_admin=requester_user_id
)
yield self.store.remove_group_category(
group_id=group_id,
@ -244,10 +271,10 @@ class GroupsServerHandler(object):
defer.returnValue({})
@defer.inlineCallbacks
def get_group_roles(self, group_id, user_id):
def get_group_roles(self, group_id, requester_user_id):
"""Get all roles in a group (as seen by user)
"""
yield self.check_group_is_ours(group_id, and_exists=True)
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
roles = yield self.store.get_group_roles(
group_id=group_id,
@ -255,10 +282,10 @@ class GroupsServerHandler(object):
defer.returnValue({"roles": roles})
@defer.inlineCallbacks
def get_group_role(self, group_id, user_id, role_id):
def get_group_role(self, group_id, requester_user_id, role_id):
"""Get a specific role in a group (as seen by user)
"""
yield self.check_group_is_ours(group_id, and_exists=True)
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
res = yield self.store.get_group_role(
group_id=group_id,
@ -267,10 +294,15 @@ class GroupsServerHandler(object):
defer.returnValue(res)
@defer.inlineCallbacks
def update_group_role(self, group_id, user_id, role_id, content):
def update_group_role(self, group_id, requester_user_id, role_id, content):
"""Add/update a role in a group
"""
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
yield self.check_group_is_ours(
group_id,
requester_user_id,
and_exists=True,
and_is_admin=requester_user_id,
)
is_public = _parse_visibility_from_contents(content)
@ -286,10 +318,15 @@ class GroupsServerHandler(object):
defer.returnValue({})
@defer.inlineCallbacks
def delete_group_role(self, group_id, user_id, role_id):
def delete_group_role(self, group_id, requester_user_id, role_id):
"""Remove role from group
"""
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
yield self.check_group_is_ours(
group_id,
requester_user_id,
and_exists=True,
and_is_admin=requester_user_id,
)
yield self.store.remove_group_role(
group_id=group_id,
@ -304,7 +341,7 @@ class GroupsServerHandler(object):
"""Add/update a users entry in the group summary
"""
yield self.check_group_is_ours(
group_id, and_exists=True, and_is_admin=requester_user_id,
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id,
)
order = content.get("order", None)
@ -326,7 +363,7 @@ class GroupsServerHandler(object):
"""Remove a user from the group summary
"""
yield self.check_group_is_ours(
group_id, and_exists=True, and_is_admin=requester_user_id,
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id,
)
yield self.store.remove_user_from_summary(
@ -342,7 +379,7 @@ class GroupsServerHandler(object):
"""Get the group profile as seen by requester_user_id
"""
yield self.check_group_is_ours(group_id)
yield self.check_group_is_ours(group_id, requester_user_id)
group_description = yield self.store.get_group(group_id)
@ -356,7 +393,7 @@ class GroupsServerHandler(object):
"""Update the group profile
"""
yield self.check_group_is_ours(
group_id, and_exists=True, and_is_admin=requester_user_id,
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id,
)
profile = {}
@ -377,7 +414,7 @@ class GroupsServerHandler(object):
The ordering is arbitrary at the moment
"""
yield self.check_group_is_ours(group_id, and_exists=True)
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
@ -389,14 +426,15 @@ class GroupsServerHandler(object):
for user_result in user_results:
g_user_id = user_result["user_id"]
is_public = user_result["is_public"]
is_privileged = user_result["is_admin"]
entry = {"user_id": g_user_id}
profile = yield self.profile_handler.get_profile_from_cache(g_user_id)
entry.update(profile)
if not is_public:
entry["is_public"] = False
entry["is_public"] = bool(is_public)
entry["is_privileged"] = bool(is_privileged)
if not self.is_mine_id(g_user_id):
attestation = yield self.store.get_remote_attestation(group_id, g_user_id)
@ -425,7 +463,7 @@ class GroupsServerHandler(object):
The ordering is arbitrary at the moment
"""
yield self.check_group_is_ours(group_id, and_exists=True)
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
@ -459,7 +497,7 @@ class GroupsServerHandler(object):
This returns rooms in order of decreasing number of joined users
"""
yield self.check_group_is_ours(group_id, and_exists=True)
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
@ -470,7 +508,6 @@ class GroupsServerHandler(object):
chunk = []
for room_result in room_results:
room_id = room_result["room_id"]
is_public = room_result["is_public"]
joined_users = yield self.store.get_users_in_room(room_id)
entry = yield self.room_list_handler.generate_room_entry(
@ -481,8 +518,7 @@ class GroupsServerHandler(object):
if not entry:
continue
if not is_public:
entry["is_public"] = False
entry["is_public"] = bool(room_result["is_public"])
chunk.append(entry)
@ -500,7 +536,7 @@ class GroupsServerHandler(object):
RoomID.from_string(room_id) # Ensure valid room id
yield self.check_group_is_ours(
group_id, and_exists=True, and_is_admin=requester_user_id
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
is_public = _parse_visibility_from_contents(content)
@ -509,12 +545,35 @@ class GroupsServerHandler(object):
defer.returnValue({})
@defer.inlineCallbacks
def update_room_in_group(self, group_id, requester_user_id, room_id, config_key,
content):
"""Update room in group
"""
RoomID.from_string(room_id) # Ensure valid room id
yield self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
if config_key == "m.visibility":
is_public = _parse_visibility_dict(content)
yield self.store.update_room_in_group_visibility(
group_id, room_id,
is_public=is_public,
)
else:
raise SynapseError(400, "Uknown config option")
defer.returnValue({})
@defer.inlineCallbacks
def remove_room_from_group(self, group_id, requester_user_id, room_id):
"""Remove room from group
"""
yield self.check_group_is_ours(
group_id, and_exists=True, and_is_admin=requester_user_id
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
yield self.store.remove_room_from_group(group_id, room_id)
@ -527,7 +586,7 @@ class GroupsServerHandler(object):
"""
group = yield self.check_group_is_ours(
group_id, and_exists=True, and_is_admin=requester_user_id
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
# TODO: Check if user knocked
@ -596,35 +655,40 @@ class GroupsServerHandler(object):
raise SynapseError(502, "Unknown state returned by HS")
@defer.inlineCallbacks
def accept_invite(self, group_id, user_id, content):
def accept_invite(self, group_id, requester_user_id, content):
"""User tries to accept an invite to the group.
This is different from them asking to join, and so should error if no
invite exists (and they're not a member of the group)
"""
yield self.check_group_is_ours(group_id, and_exists=True)
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
if not self.store.is_user_invited_to_local_group(group_id, user_id):
is_invited = yield self.store.is_user_invited_to_local_group(
group_id, requester_user_id,
)
if not is_invited:
raise SynapseError(403, "User not invited to group")
if not self.hs.is_mine_id(user_id):
if not self.hs.is_mine_id(requester_user_id):
local_attestation = self.attestations.create_attestation(
group_id, requester_user_id,
)
remote_attestation = content["attestation"]
yield self.attestations.verify_attestation(
remote_attestation,
user_id=user_id,
user_id=requester_user_id,
group_id=group_id,
)
else:
local_attestation = None
remote_attestation = None
local_attestation = self.attestations.create_attestation(group_id, user_id)
is_public = _parse_visibility_from_contents(content)
yield self.store.add_user_to_group(
group_id, user_id,
group_id, requester_user_id,
is_admin=False,
is_public=is_public,
local_attestation=local_attestation,
@ -637,31 +701,31 @@ class GroupsServerHandler(object):
})
@defer.inlineCallbacks
def knock(self, group_id, user_id, content):
def knock(self, group_id, requester_user_id, content):
"""A user requests becoming a member of the group
"""
yield self.check_group_is_ours(group_id, and_exists=True)
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
raise NotImplementedError()
@defer.inlineCallbacks
def accept_knock(self, group_id, user_id, content):
def accept_knock(self, group_id, requester_user_id, content):
"""Accept a users knock to the room.
Errors if the user hasn't knocked, rather than inviting them.
"""
yield self.check_group_is_ours(group_id, and_exists=True)
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
raise NotImplementedError()
@defer.inlineCallbacks
def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
"""Remove a user from the group; either a user is leaving or and admin
kicked htem.
"""Remove a user from the group; either a user is leaving or an admin
kicked them.
"""
yield self.check_group_is_ours(group_id, and_exists=True)
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
is_kick = False
if requester_user_id != user_id:
@ -692,8 +756,8 @@ class GroupsServerHandler(object):
defer.returnValue({})
@defer.inlineCallbacks
def create_group(self, group_id, user_id, content):
group = yield self.check_group_is_ours(group_id)
def create_group(self, group_id, requester_user_id, content):
group = yield self.check_group_is_ours(group_id, requester_user_id)
logger.info("Attempting to create group with ID: %r", group_id)
@ -703,11 +767,11 @@ class GroupsServerHandler(object):
if group:
raise SynapseError(400, "Group already exists")
is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id))
is_admin = yield self.auth.is_server_admin(UserID.from_string(requester_user_id))
if not is_admin:
if not self.hs.config.enable_group_creation:
raise SynapseError(
403, "Only server admin can create group on this server",
403, "Only a server admin can create groups on this server",
)
localpart = group_id_obj.localpart
if not localpart.startswith(self.hs.config.group_creation_prefix):
@ -727,38 +791,41 @@ class GroupsServerHandler(object):
yield self.store.create_group(
group_id,
user_id,
requester_user_id,
name=name,
avatar_url=avatar_url,
short_description=short_description,
long_description=long_description,
)
if not self.hs.is_mine_id(user_id):
if not self.hs.is_mine_id(requester_user_id):
remote_attestation = content["attestation"]
yield self.attestations.verify_attestation(
remote_attestation,
user_id=user_id,
user_id=requester_user_id,
group_id=group_id,
)
local_attestation = self.attestations.create_attestation(group_id, user_id)
local_attestation = self.attestations.create_attestation(
group_id,
requester_user_id,
)
else:
local_attestation = None
remote_attestation = None
yield self.store.add_user_to_group(
group_id, user_id,
group_id, requester_user_id,
is_admin=True,
is_public=True, # TODO
local_attestation=local_attestation,
remote_attestation=remote_attestation,
)
if not self.hs.is_mine_id(user_id):
if not self.hs.is_mine_id(requester_user_id):
yield self.store.add_remote_profile_cache(
user_id,
requester_user_id,
displayname=user_profile.get("displayname"),
avatar_url=user_profile.get("avatar_url"),
)
@ -773,15 +840,25 @@ def _parse_visibility_from_contents(content):
public or not
"""
visibility = content.get("visibility")
visibility = content.get("m.visibility")
if visibility:
vis_type = visibility["type"]
if vis_type not in ("public", "private"):
raise SynapseError(
400, "Synapse only supports 'public'/'private' visibility"
)
is_public = vis_type == "public"
return _parse_visibility_dict(visibility)
else:
is_public = True
return is_public
def _parse_visibility_dict(visibility):
"""Given a dict for the "m.visibility" config return if the entity should
be public or not
"""
vis_type = visibility.get("type")
if not vis_type:
return True
if vis_type not in ("public", "private"):
raise SynapseError(
400, "Synapse only supports 'public'/'private' visibility"
)
return vis_type == "public"

View file

@ -70,11 +70,10 @@ class ApplicationServicesHandler(object):
with Measure(self.clock, "notify_interested_services"):
self.is_processing = True
try:
upper_bound = self.current_max
limit = 100
while True:
upper_bound, events = yield self.store.get_new_events_for_appservice(
upper_bound, limit
self.current_max, limit
)
if not events:
@ -105,9 +104,6 @@ class ApplicationServicesHandler(object):
)
yield self.store.set_appservice_last_pos(upper_bound)
if len(events) < limit:
break
finally:
self.is_processing = False

View file

@ -13,13 +13,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.constants import LoginType
from synapse.types import UserID
from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
from synapse.module_api import ModuleApi
from synapse.types import UserID
from synapse.util.async import run_on_reactor
from synapse.util.caches.expiringcache import ExpiringCache
@ -63,10 +63,7 @@ class AuthHandler(BaseHandler):
reset_expiry_on_get=True,
)
account_handler = _AccountHandler(
hs, check_user_exists=self.check_user_exists
)
account_handler = ModuleApi(hs, self)
self.password_providers = [
module(config=config, account_handler=account_handler)
for module, config in hs.config.password_providers
@ -75,14 +72,24 @@ class AuthHandler(BaseHandler):
logger.info("Extra password_providers: %r", self.password_providers)
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator()
self._password_enabled = hs.config.password_enabled
login_types = set()
if self._password_enabled:
login_types.add(LoginType.PASSWORD)
for provider in self.password_providers:
if hasattr(provider, "get_supported_login_types"):
login_types.update(
provider.get_supported_login_types().keys()
)
self._supported_login_types = frozenset(login_types)
@defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip):
"""
Takes a dictionary sent by the client in the login / registration
protocol and handles the login flow.
protocol and handles the User-Interactive Auth flow.
As a side effect, this function fills in the 'creds' key on the user's
session with a map, which maps each auth-type (str) to the relevant
@ -260,16 +267,19 @@ class AuthHandler(BaseHandler):
sess = self._get_session_info(session_id)
return sess.setdefault('serverdict', {}).get(key, default)
@defer.inlineCallbacks
def _check_password_auth(self, authdict, _):
if "user" not in authdict or "password" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM)
user_id = authdict["user"]
password = authdict["password"]
if not user_id.startswith('@'):
user_id = UserID(user_id, self.hs.hostname).to_string()
return self._check_password(user_id, password)
(canonical_id, callback) = yield self.validate_login(user_id, {
"type": LoginType.PASSWORD,
"password": password,
})
defer.returnValue(canonical_id)
@defer.inlineCallbacks
def _check_recaptcha(self, authdict, clientip):
@ -398,26 +408,8 @@ class AuthHandler(BaseHandler):
return self.sessions[session_id]
def validate_password_login(self, user_id, password):
"""
Authenticates the user with their username and password.
Used only by the v1 login API.
Args:
user_id (str): complete @user:id
password (str): Password
Returns:
defer.Deferred: (str) canonical user id
Raises:
StoreError if there was a problem accessing the database
LoginError if there was an authentication problem.
"""
return self._check_password(user_id, password)
@defer.inlineCallbacks
def get_access_token_for_user_id(self, user_id, device_id=None,
initial_display_name=None):
def get_access_token_for_user_id(self, user_id, device_id=None):
"""
Creates a new access token for the user with the given user ID.
@ -431,13 +423,10 @@ class AuthHandler(BaseHandler):
device_id (str|None): the device ID to associate with the tokens.
None to leave the tokens unassociated with a device (deprecated:
we should always have a device ID)
initial_display_name (str): display name to associate with the
device if it needs re-registering
Returns:
The access token for the user's session.
Raises:
StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
"""
logger.info("Logging in user %s on device %s", user_id, device_id)
access_token = yield self.issue_access_token(user_id, device_id)
@ -447,9 +436,11 @@ class AuthHandler(BaseHandler):
# really don't want is active access_tokens without a record of the
# device, so we double-check it here.
if device_id is not None:
yield self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
try:
yield self.store.get_device(user_id, device_id)
except StoreError:
yield self.store.delete_access_token(access_token)
raise StoreError(400, "Login raced against device deletion")
defer.returnValue(access_token)
@ -501,29 +492,115 @@ class AuthHandler(BaseHandler):
)
defer.returnValue(result)
@defer.inlineCallbacks
def _check_password(self, user_id, password):
"""Authenticate a user against the LDAP and local databases.
def get_supported_login_types(self):
"""Get a the login types supported for the /login API
user_id is checked case insensitively against the local database, but
will throw if there are multiple inexact matches.
By default this is just 'm.login.password' (unless password_enabled is
False in the config file), but password auth providers can provide
other login types.
Returns:
Iterable[str]: login types
"""
return self._supported_login_types
@defer.inlineCallbacks
def validate_login(self, username, login_submission):
"""Authenticates the user for the /login API
Also used by the user-interactive auth flow to validate
m.login.password auth types.
Args:
user_id (str): complete @user:id
username (str): username supplied by the user
login_submission (dict): the whole of the login submission
(including 'type' and other relevant fields)
Returns:
(str) the canonical_user_id
Deferred[str, func]: canonical user id, and optional callback
to be called once the access token and device id are issued
Raises:
LoginError if login fails
StoreError if there was a problem accessing the database
SynapseError if there was a problem with the request
LoginError if there was an authentication problem.
"""
if username.startswith('@'):
qualified_user_id = username
else:
qualified_user_id = UserID(
username, self.hs.hostname
).to_string()
login_type = login_submission.get("type")
known_login_type = False
# special case to check for "password" for the check_password interface
# for the auth providers
password = login_submission.get("password")
if login_type == LoginType.PASSWORD:
if not self._password_enabled:
raise SynapseError(400, "Password login has been disabled.")
if not password:
raise SynapseError(400, "Missing parameter: password")
for provider in self.password_providers:
is_valid = yield provider.check_password(user_id, password)
if is_valid:
defer.returnValue(user_id)
if (hasattr(provider, "check_password")
and login_type == LoginType.PASSWORD):
known_login_type = True
is_valid = yield provider.check_password(
qualified_user_id, password,
)
if is_valid:
defer.returnValue(qualified_user_id)
canonical_user_id = yield self._check_local_password(user_id, password)
if (not hasattr(provider, "get_supported_login_types")
or not hasattr(provider, "check_auth")):
# this password provider doesn't understand custom login types
continue
if canonical_user_id:
defer.returnValue(canonical_user_id)
supported_login_types = provider.get_supported_login_types()
if login_type not in supported_login_types:
# this password provider doesn't understand this login type
continue
known_login_type = True
login_fields = supported_login_types[login_type]
missing_fields = []
login_dict = {}
for f in login_fields:
if f not in login_submission:
missing_fields.append(f)
else:
login_dict[f] = login_submission[f]
if missing_fields:
raise SynapseError(
400, "Missing parameters for login type %s: %s" % (
login_type,
missing_fields,
),
)
result = yield provider.check_auth(
username, login_type, login_dict,
)
if result:
if isinstance(result, str):
result = (result, None)
defer.returnValue(result)
if login_type == LoginType.PASSWORD:
known_login_type = True
canonical_user_id = yield self._check_local_password(
qualified_user_id, password,
)
if canonical_user_id:
defer.returnValue((canonical_user_id, None))
if not known_login_type:
raise SynapseError(400, "Unknown login type %s" % login_type)
# unknown username or invalid password. We raise a 403 here, but note
# that if we're doing user-interactive login, it turns all LoginErrors
@ -584,13 +661,80 @@ class AuthHandler(BaseHandler):
if e.code == 404:
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
raise e
yield self.store.user_delete_access_tokens(
user_id, except_access_token_id
yield self.delete_access_tokens_for_user(
user_id, except_token_id=except_access_token_id,
)
yield self.hs.get_pusherpool().remove_pushers_by_user(
user_id, except_access_token_id
)
@defer.inlineCallbacks
def deactivate_account(self, user_id):
"""Deactivate a user's account
Args:
user_id (str): ID of user to be deactivated
Returns:
Deferred
"""
# FIXME: Theoretically there is a race here wherein user resets
# password using threepid.
yield self.delete_access_tokens_for_user(user_id)
yield self.store.user_delete_threepids(user_id)
yield self.store.user_set_password_hash(user_id, None)
@defer.inlineCallbacks
def delete_access_token(self, access_token):
"""Invalidate a single access token
Args:
access_token (str): access token to be deleted
Returns:
Deferred
"""
user_info = yield self.auth.get_user_by_access_token(access_token)
yield self.store.delete_access_token(access_token)
# see if any of our auth providers want to know about this
for provider in self.password_providers:
if hasattr(provider, "on_logged_out"):
yield provider.on_logged_out(
user_id=str(user_info["user"]),
device_id=user_info["device_id"],
access_token=access_token,
)
@defer.inlineCallbacks
def delete_access_tokens_for_user(self, user_id, except_token_id=None,
device_id=None):
"""Invalidate access tokens belonging to a user
Args:
user_id (str): ID of user the tokens belong to
except_token_id (str|None): access_token ID which should *not* be
deleted
device_id (str|None): ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will
be deleted
Returns:
Deferred
"""
tokens_and_devices = yield self.store.user_delete_access_tokens(
user_id, except_token_id=except_token_id, device_id=device_id,
)
# see if any of our auth providers want to know about this
for provider in self.password_providers:
if hasattr(provider, "on_logged_out"):
for token, device_id in tokens_and_devices:
yield provider.on_logged_out(
user_id=user_id,
device_id=device_id,
access_token=token,
)
@defer.inlineCallbacks
def add_threepid(self, user_id, medium, address, validated_at):
# 'Canonicalise' email addresses down to lower case.
@ -696,30 +840,3 @@ class MacaroonGeneartor(object):
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon
class _AccountHandler(object):
"""A proxy object that gets passed to password auth providers so they
can register new users etc if necessary.
"""
def __init__(self, hs, check_user_exists):
self.hs = hs
self._check_user_exists = check_user_exists
def check_user_exists(self, user_id):
"""Check if user exissts.
Returns:
Deferred(bool)
"""
return self._check_user_exists(user_id)
def register(self, localpart):
"""Registers a new user with given localpart
Returns:
Deferred: a 2-tuple of (user_id, access_token)
"""
reg = self.hs.get_handlers().registration_handler
return reg.register(localpart=localpart)

View file

@ -34,6 +34,7 @@ class DeviceHandler(BaseHandler):
self.hs = hs
self.state = hs.get_state_handler()
self._auth_handler = hs.get_auth_handler()
self.federation_sender = hs.get_federation_sender()
self.federation = hs.get_replication_layer()
@ -159,9 +160,8 @@ class DeviceHandler(BaseHandler):
else:
raise
yield self.store.user_delete_access_tokens(
yield self._auth_handler.delete_access_tokens_for_user(
user_id, device_id=device_id,
delete_refresh_tokens=True,
)
yield self.store.delete_e2e_keys_by_device(
@ -194,9 +194,8 @@ class DeviceHandler(BaseHandler):
# Delete access tokens and e2e keys for each device. Not optimised as it is not
# considered as part of a critical path.
for device_id in device_ids:
yield self.store.user_delete_access_tokens(
yield self._auth_handler.delete_access_tokens_for_user(
user_id, device_id=device_id,
delete_refresh_tokens=True,
)
yield self.store.delete_e2e_keys_by_device(
user_id=user_id, device_id=device_id

View file

@ -1706,6 +1706,17 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
@log_function
def do_auth(self, origin, event, context, auth_events):
"""
Args:
origin (str):
event (synapse.events.FrozenEvent):
context (synapse.events.snapshot.EventContext):
auth_events (dict[(str, str)->str]):
Returns:
defer.Deferred[None]
"""
# Check if we have all the auth events.
current_state = set(e.event_id for e in auth_events.values())
event_auth_events = set(e_id for e_id, _ in event.auth_events)
@ -1817,16 +1828,9 @@ class FederationHandler(BaseHandler):
current_state = set(e.event_id for e in auth_events.values())
different_auth = event_auth_events - current_state
context.current_state_ids = dict(context.current_state_ids)
context.current_state_ids.update({
k: a.event_id for k, a in auth_events.items()
if k != event_key
})
context.prev_state_ids = dict(context.prev_state_ids)
context.prev_state_ids.update({
k: a.event_id for k, a in auth_events.items()
})
context.state_group = self.store.get_next_state_group()
self._update_context_for_auth_events(
context, auth_events, event_key,
)
if different_auth and not event.internal_metadata.is_outlier():
logger.info("Different auth after resolution: %s", different_auth)
@ -1906,16 +1910,9 @@ class FederationHandler(BaseHandler):
# 4. Look at rejects and their proofs.
# TODO.
context.current_state_ids = dict(context.current_state_ids)
context.current_state_ids.update({
k: a.event_id for k, a in auth_events.items()
if k != event_key
})
context.prev_state_ids = dict(context.prev_state_ids)
context.prev_state_ids.update({
k: a.event_id for k, a in auth_events.items()
})
context.state_group = self.store.get_next_state_group()
self._update_context_for_auth_events(
context, auth_events, event_key,
)
try:
self.auth.check(event, auth_events=auth_events)
@ -1923,6 +1920,35 @@ class FederationHandler(BaseHandler):
logger.warn("Failed auth resolution for %r because %s", event, e)
raise e
def _update_context_for_auth_events(self, context, auth_events,
event_key):
"""Update the state_ids in an event context after auth event resolution
Args:
context (synapse.events.snapshot.EventContext): event context
to be updated
auth_events (dict[(str, str)->str]): Events to update in the event
context.
event_key ((str, str)): (type, state_key) for the current event.
this will not be included in the current_state in the context.
"""
state_updates = {
k: a.event_id for k, a in auth_events.iteritems()
if k != event_key
}
context.current_state_ids = dict(context.current_state_ids)
context.current_state_ids.update(state_updates)
if context.delta_ids is not None:
context.delta_ids = dict(context.delta_ids)
context.delta_ids.update(state_updates)
context.prev_state_ids = dict(context.prev_state_ids)
context.prev_state_ids.update({
k: a.event_id for k, a in auth_events.iteritems()
})
context.state_group = self.store.get_next_state_group()
@defer.inlineCallbacks
def construct_auth_difference(self, local_auth, remote_auth):
""" Given a local and remote auth chain, find the differences. This

View file

@ -71,6 +71,7 @@ class GroupsLocalHandler(object):
get_invited_users_in_group = _create_rerouter("get_invited_users_in_group")
add_room_to_group = _create_rerouter("add_room_to_group")
update_room_in_group = _create_rerouter("update_room_in_group")
remove_room_from_group = _create_rerouter("remove_room_from_group")
update_group_summary_room = _create_rerouter("update_group_summary_room")

View file

@ -17,7 +17,6 @@ import logging
from twisted.internet import defer
import synapse.types
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
from synapse.types import UserID, get_domain_from_id
from ._base import BaseHandler
@ -140,7 +139,7 @@ class ProfileHandler(BaseHandler):
target_user.localpart, new_displayname
)
yield self._update_join_states(requester)
yield self._update_join_states(requester, target_user)
@defer.inlineCallbacks
def get_avatar_url(self, target_user):
@ -184,7 +183,7 @@ class ProfileHandler(BaseHandler):
target_user.localpart, new_avatar_url
)
yield self._update_join_states(requester)
yield self._update_join_states(requester, target_user)
@defer.inlineCallbacks
def on_profile_query(self, args):
@ -209,28 +208,24 @@ class ProfileHandler(BaseHandler):
defer.returnValue(response)
@defer.inlineCallbacks
def _update_join_states(self, requester):
user = requester.user
if not self.hs.is_mine(user):
def _update_join_states(self, requester, target_user):
if not self.hs.is_mine(target_user):
return
yield self.ratelimit(requester)
room_ids = yield self.store.get_rooms_for_user(
user.to_string(),
target_user.to_string(),
)
for room_id in room_ids:
handler = self.hs.get_handlers().room_member_handler
try:
# Assume the user isn't a guest because we don't let guests set
# profile or avatar data.
# XXX why are we recreating `requester` here for each room?
# what was wrong with the `requester` we were passed?
requester = synapse.types.create_requester(user)
# Assume the target_user isn't a guest,
# because we don't let guests set profile or avatar data.
yield handler.update_membership(
requester,
user,
target_user,
room_id,
"join", # We treat a profile update like a join.
ratelimit=False, # Try to hide that these events aren't atomic.

View file

@ -36,6 +36,7 @@ class RegistrationHandler(BaseHandler):
super(RegistrationHandler, self).__init__(hs)
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
self.profile_handler = hs.get_profile_handler()
self.captcha_client = CaptchaServerHttpClient(hs)
@ -416,7 +417,7 @@ class RegistrationHandler(BaseHandler):
create_profile_with_localpart=user.localpart,
)
else:
yield self.store.user_delete_access_tokens(user_id=user_id)
yield self._auth_handler.delete_access_tokens_for_user(user_id)
yield self.store.add_access_token_to_user(user_id=user_id, token=token)
if displayname is not None:

View file

@ -20,6 +20,7 @@ from ._base import BaseHandler
from synapse.api.constants import (
EventTypes, JoinRules,
)
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.util.async import concurrently_execute
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.caches.response_cache import ResponseCache
@ -70,6 +71,7 @@ class RoomListHandler(BaseHandler):
if search_filter:
# We explicitly don't bother caching searches or requests for
# appservice specific lists.
logger.info("Bypassing cache as search request.")
return self._get_public_room_list(
limit, since_token, search_filter, network_tuple=network_tuple,
)
@ -77,13 +79,16 @@ class RoomListHandler(BaseHandler):
key = (limit, since_token, network_tuple)
result = self.response_cache.get(key)
if not result:
logger.info("No cached result, calculating one.")
result = self.response_cache.set(
key,
self._get_public_room_list(
preserve_fn(self._get_public_room_list)(
limit, since_token, network_tuple=network_tuple
)
)
return result
else:
logger.info("Using cached deferred result.")
return make_deferred_yieldable(result)
@defer.inlineCallbacks
def _get_public_room_list(self, limit=None, since_token=None,

View file

@ -15,7 +15,7 @@
from synapse.api.constants import Membership, EventTypes
from synapse.util.async import concurrently_execute
from synapse.util.logcontext import LoggingContext
from synapse.util.logcontext import LoggingContext, make_deferred_yieldable, preserve_fn
from synapse.util.metrics import Measure, measure_func
from synapse.util.caches.response_cache import ResponseCache
from synapse.push.clientformat import format_push_rules_for_user
@ -184,11 +184,11 @@ class SyncHandler(object):
if not result:
result = self.response_cache.set(
sync_config.request_key,
self._wait_for_sync_for_user(
preserve_fn(self._wait_for_sync_for_user)(
sync_config, since_token, timeout, full_state
)
)
return result
return make_deferred_yieldable(result)
@defer.inlineCallbacks
def _wait_for_sync_for_user(self, sync_config, since_token, timeout,

View file

@ -152,7 +152,7 @@ class UserDirectoyHandler(object):
for room_id in room_ids:
logger.info("Handling room %d/%d", num_processed_rooms, len(room_ids))
yield self._handle_intial_room(room_id)
yield self._handle_initial_room(room_id)
num_processed_rooms += 1
yield sleep(self.INITIAL_SLEEP_MS / 1000.)
@ -166,7 +166,7 @@ class UserDirectoyHandler(object):
yield self.store.update_user_directory_stream_pos(new_pos)
@defer.inlineCallbacks
def _handle_intial_room(self, room_id):
def _handle_initial_room(self, room_id):
"""Called when we initially fill out user_directory one room at a time
"""
is_in_room = yield self.store.is_host_joined(room_id, self.server_name)

View file

@ -0,0 +1,55 @@
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.http.server import wrap_request_handler
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
class AdditionalResource(Resource):
"""Resource wrapper for additional_resources
If the user has configured additional_resources, we need to wrap the
handler class with a Resource so that we can map it into the resource tree.
This class is also where we wrap the request handler with logging, metrics,
and exception handling.
"""
def __init__(self, hs, handler):
"""Initialise AdditionalResource
The ``handler`` should return a deferred which completes when it has
done handling the request. It should write a response with
``request.write()``, and call ``request.finish()``.
Args:
hs (synapse.server.HomeServer): homeserver
handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred):
function to be called to handle the request.
"""
Resource.__init__(self)
self._handler = handler
# these are required by the request_handler wrapper
self.version_string = hs.version_string
self.clock = hs.get_clock()
def render(self, request):
self._async_render(request)
return NOT_DONE_YET
@wrap_request_handler
def _async_render(self, request):
return self._handler(request)

View file

@ -18,7 +18,7 @@ from OpenSSL.SSL import VERIFY_NONE
from synapse.api.errors import (
CodeMessageException, MatrixCodeMessageException, SynapseError, Codes,
)
from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.logcontext import make_deferred_yieldable
from synapse.util import logcontext
import synapse.metrics
from synapse.http.endpoint import SpiderEndpoint
@ -114,43 +114,73 @@ class SimpleHttpClient(object):
raise e
@defer.inlineCallbacks
def post_urlencoded_get_json(self, uri, args={}):
def post_urlencoded_get_json(self, uri, args={}, headers=None):
"""
Args:
uri (str):
args (dict[str, str|List[str]]): query params
headers (dict[str, List[str]]|None): If not None, a map from
header name to a list of values for that header
Returns:
Deferred[object]: parsed json
"""
# TODO: Do we ever want to log message contents?
logger.debug("post_urlencoded_get_json args: %s", args)
query_bytes = urllib.urlencode(encode_urlencode_args(args), True)
actual_headers = {
b"Content-Type": [b"application/x-www-form-urlencoded"],
b"User-Agent": [self.user_agent],
}
if headers:
actual_headers.update(headers)
response = yield self.request(
"POST",
uri.encode("ascii"),
headers=Headers({
b"Content-Type": [b"application/x-www-form-urlencoded"],
b"User-Agent": [self.user_agent],
}),
headers=Headers(actual_headers),
bodyProducer=FileBodyProducer(StringIO(query_bytes))
)
body = yield preserve_context_over_fn(readBody, response)
body = yield make_deferred_yieldable(readBody(response))
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
def post_json_get_json(self, uri, post_json):
def post_json_get_json(self, uri, post_json, headers=None):
"""
Args:
uri (str):
post_json (object):
headers (dict[str, List[str]]|None): If not None, a map from
header name to a list of values for that header
Returns:
Deferred[object]: parsed json
"""
json_str = encode_canonical_json(post_json)
logger.debug("HTTP POST %s -> %s", json_str, uri)
actual_headers = {
b"Content-Type": [b"application/json"],
b"User-Agent": [self.user_agent],
}
if headers:
actual_headers.update(headers)
response = yield self.request(
"POST",
uri.encode("ascii"),
headers=Headers({
b"Content-Type": [b"application/json"],
b"User-Agent": [self.user_agent],
}),
headers=Headers(actual_headers),
bodyProducer=FileBodyProducer(StringIO(json_str))
)
body = yield preserve_context_over_fn(readBody, response)
body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
defer.returnValue(json.loads(body))
@ -160,7 +190,7 @@ class SimpleHttpClient(object):
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
def get_json(self, uri, args={}):
def get_json(self, uri, args={}, headers=None):
""" Gets some json from the given URI.
Args:
@ -169,6 +199,8 @@ class SimpleHttpClient(object):
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
headers (dict[str, List[str]]|None): If not None, a map from
header name to a list of values for that header
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON.
@ -177,13 +209,13 @@ class SimpleHttpClient(object):
error message.
"""
try:
body = yield self.get_raw(uri, args)
body = yield self.get_raw(uri, args, headers=headers)
defer.returnValue(json.loads(body))
except CodeMessageException as e:
raise self._exceptionFromFailedRequest(e.code, e.msg)
@defer.inlineCallbacks
def put_json(self, uri, json_body, args={}):
def put_json(self, uri, json_body, args={}, headers=None):
""" Puts some json to the given URI.
Args:
@ -193,6 +225,8 @@ class SimpleHttpClient(object):
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
headers (dict[str, List[str]]|None): If not None, a map from
header name to a list of values for that header
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON.
@ -205,17 +239,21 @@ class SimpleHttpClient(object):
json_str = encode_canonical_json(json_body)
actual_headers = {
b"Content-Type": [b"application/json"],
b"User-Agent": [self.user_agent],
}
if headers:
actual_headers.update(headers)
response = yield self.request(
"PUT",
uri.encode("ascii"),
headers=Headers({
b"User-Agent": [self.user_agent],
"Content-Type": ["application/json"]
}),
headers=Headers(actual_headers),
bodyProducer=FileBodyProducer(StringIO(json_str))
)
body = yield preserve_context_over_fn(readBody, response)
body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
defer.returnValue(json.loads(body))
@ -226,7 +264,7 @@ class SimpleHttpClient(object):
raise CodeMessageException(response.code, body)
@defer.inlineCallbacks
def get_raw(self, uri, args={}):
def get_raw(self, uri, args={}, headers=None):
""" Gets raw text from the given URI.
Args:
@ -235,6 +273,8 @@ class SimpleHttpClient(object):
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
headers (dict[str, List[str]]|None): If not None, a map from
header name to a list of values for that header
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body at text.
@ -246,15 +286,19 @@ class SimpleHttpClient(object):
query_bytes = urllib.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes)
actual_headers = {
b"User-Agent": [self.user_agent],
}
if headers:
actual_headers.update(headers)
response = yield self.request(
"GET",
uri.encode("ascii"),
headers=Headers({
b"User-Agent": [self.user_agent],
})
headers=Headers(actual_headers),
)
body = yield preserve_context_over_fn(readBody, response)
body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
defer.returnValue(body)
@ -274,27 +318,33 @@ class SimpleHttpClient(object):
# The two should be factored out.
@defer.inlineCallbacks
def get_file(self, url, output_stream, max_size=None):
def get_file(self, url, output_stream, max_size=None, headers=None):
"""GETs a file from a given URL
Args:
url (str): The URL to GET
output_stream (file): File to write the response body to.
headers (dict[str, List[str]]|None): If not None, a map from
header name to a list of values for that header
Returns:
A (int,dict,string,int) tuple of the file length, dict of the response
headers, absolute URI of the response and HTTP response code.
"""
actual_headers = {
b"User-Agent": [self.user_agent],
}
if headers:
actual_headers.update(headers)
response = yield self.request(
"GET",
url.encode("ascii"),
headers=Headers({
b"User-Agent": [self.user_agent],
})
headers=Headers(actual_headers),
)
headers = dict(response.headers.getAllRawHeaders())
resp_headers = dict(response.headers.getAllRawHeaders())
if 'Content-Length' in headers and headers['Content-Length'] > max_size:
if 'Content-Length' in resp_headers and resp_headers['Content-Length'] > max_size:
logger.warn("Requested URL is too large > %r bytes" % (self.max_size,))
raise SynapseError(
502,
@ -315,10 +365,9 @@ class SimpleHttpClient(object):
# straight back in again
try:
length = yield preserve_context_over_fn(
_readBodyToFile,
response, output_stream, max_size
)
length = yield make_deferred_yieldable(_readBodyToFile(
response, output_stream, max_size,
))
except Exception as e:
logger.exception("Failed to download body")
raise SynapseError(
@ -327,7 +376,9 @@ class SimpleHttpClient(object):
Codes.UNKNOWN,
)
defer.returnValue((length, headers, response.request.absoluteURI, response.code))
defer.returnValue(
(length, resp_headers, response.request.absoluteURI, response.code),
)
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
@ -395,7 +446,7 @@ class CaptchaServerHttpClient(SimpleHttpClient):
)
try:
body = yield preserve_context_over_fn(readBody, response)
body = yield make_deferred_yieldable(readBody(response))
defer.returnValue(body)
except PartialDownloadError as e:
# twisted dislikes google's response, no content length.

View file

@ -167,7 +167,8 @@ def parse_json_value_from_request(request):
try:
content = simplejson.loads(content_bytes)
except simplejson.JSONDecodeError:
except Exception as e:
logger.warn("Unable to parse JSON: %s", e)
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
return content

View file

@ -0,0 +1,113 @@
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.types import UserID
class ModuleApi(object):
"""A proxy object that gets passed to password auth providers so they
can register new users etc if necessary.
"""
def __init__(self, hs, auth_handler):
self.hs = hs
self._store = hs.get_datastore()
self._auth = hs.get_auth()
self._auth_handler = auth_handler
def get_user_by_req(self, req, allow_guest=False):
"""Check the access_token provided for a request
Args:
req (twisted.web.server.Request): Incoming HTTP request
allow_guest (bool): True if guest users should be allowed. If this
is False, and the access token is for a guest user, an
AuthError will be thrown
Returns:
twisted.internet.defer.Deferred[synapse.types.Requester]:
the requester for this request
Raises:
synapse.api.errors.AuthError: if no user by that token exists,
or the token is invalid.
"""
return self._auth.get_user_by_req(req, allow_guest)
def get_qualified_user_id(self, username):
"""Qualify a user id, if necessary
Takes a user id provided by the user and adds the @ and :domain to
qualify it, if necessary
Args:
username (str): provided user id
Returns:
str: qualified @user:id
"""
if username.startswith('@'):
return username
return UserID(username, self.hs.hostname).to_string()
def check_user_exists(self, user_id):
"""Check if user exists.
Args:
user_id (str): Complete @user:id
Returns:
Deferred[str|None]: Canonical (case-corrected) user_id, or None
if the user is not registered.
"""
return self._auth_handler.check_user_exists(user_id)
def register(self, localpart):
"""Registers a new user with given localpart
Returns:
Deferred: a 2-tuple of (user_id, access_token)
"""
reg = self.hs.get_handlers().registration_handler
return reg.register(localpart=localpart)
def invalidate_access_token(self, access_token):
"""Invalidate an access token for a user
Args:
access_token(str): access token
Returns:
twisted.internet.defer.Deferred - resolves once the access token
has been removed.
Raises:
synapse.api.errors.AuthError: the access token is invalid
"""
return self._auth_handler.delete_access_token(access_token)
def run_db_interaction(self, desc, func, *args, **kwargs):
"""Run a function with a database connection
Args:
desc (str): description for the transaction, for metrics etc
func (func): function to be run. Passed a database cursor object
as well as *args and **kwargs
*args: positional args to be passed to func
**kwargs: named args to be passed to func
Returns:
Deferred[object]: result of func
"""
return self._store.runInteraction(desc, func, *args, **kwargs)

View file

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
class BaseSlavedStore(SQLBaseStore):
def __init__(self, db_conn, hs):
super(BaseSlavedStore, self).__init__(hs)
super(BaseSlavedStore, self).__init__(db_conn, hs)
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = SlavedIdTracker(
db_conn, "cache_invalidation_stream", "stream_id",

View file

@ -137,7 +137,7 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
self.store = hs.get_datastore()
self._auth_handler = hs.get_auth_handler()
super(DeactivateAccountRestServlet, self).__init__(hs)
@defer.inlineCallbacks
@ -149,12 +149,7 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
if not is_admin:
raise AuthError(403, "You are not a server admin")
# FIXME: Theoretically there is a race here wherein user resets password
# using threepid.
yield self.store.user_delete_access_tokens(target_user_id)
yield self.store.user_delete_threepids(target_user_id)
yield self.store.user_set_password_hash(target_user_id, None)
yield self._auth_handler.deactivate_account(target_user_id)
defer.returnValue((200, {}))

View file

@ -85,7 +85,6 @@ def login_id_thirdparty_from_phone(identifier):
class LoginRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login$")
PASS_TYPE = "m.login.password"
SAML2_TYPE = "m.login.saml2"
CAS_TYPE = "m.login.cas"
TOKEN_TYPE = "m.login.token"
@ -94,7 +93,6 @@ class LoginRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(LoginRestServlet, self).__init__(hs)
self.idp_redirect_url = hs.config.saml2_idp_redirect_url
self.password_enabled = hs.config.password_enabled
self.saml2_enabled = hs.config.saml2_enabled
self.jwt_enabled = hs.config.jwt_enabled
self.jwt_secret = hs.config.jwt_secret
@ -121,8 +119,10 @@ class LoginRestServlet(ClientV1RestServlet):
# fall back to the fallback API if they don't understand one of the
# login flow types returned.
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
if self.password_enabled:
flows.append({"type": LoginRestServlet.PASS_TYPE})
flows.extend((
{"type": t} for t in self.auth_handler.get_supported_login_types()
))
return (200, {"flows": flows})
@ -133,14 +133,8 @@ class LoginRestServlet(ClientV1RestServlet):
def on_POST(self, request):
login_submission = parse_json_object_from_request(request)
try:
if login_submission["type"] == LoginRestServlet.PASS_TYPE:
if not self.password_enabled:
raise SynapseError(400, "Password login has been disabled.")
result = yield self.do_password_login(login_submission)
defer.returnValue(result)
elif self.saml2_enabled and (login_submission["type"] ==
LoginRestServlet.SAML2_TYPE):
if self.saml2_enabled and (login_submission["type"] ==
LoginRestServlet.SAML2_TYPE):
relay_state = ""
if "relay_state" in login_submission:
relay_state = "&RelayState=" + urllib.quote(
@ -157,15 +151,31 @@ class LoginRestServlet(ClientV1RestServlet):
result = yield self.do_token_login(login_submission)
defer.returnValue(result)
else:
raise SynapseError(400, "Bad login type.")
result = yield self._do_other_login(login_submission)
defer.returnValue(result)
except KeyError:
raise SynapseError(400, "Missing JSON keys.")
@defer.inlineCallbacks
def do_password_login(self, login_submission):
if "password" not in login_submission:
raise SynapseError(400, "Missing parameter: password")
def _do_other_login(self, login_submission):
"""Handle non-token/saml/jwt logins
Args:
login_submission:
Returns:
(int, object): HTTP code/response
"""
# Log the request we got, but only certain fields to minimise the chance of
# logging someone's password (even if they accidentally put it in the wrong
# field)
logger.info(
"Got login request with identifier: %r, medium: %r, address: %r, user: %r",
login_submission.get('identifier'),
login_submission.get('medium'),
login_submission.get('address'),
login_submission.get('user'),
)
login_submission_legacy_convert(login_submission)
if "identifier" not in login_submission:
@ -208,30 +218,29 @@ class LoginRestServlet(ClientV1RestServlet):
if "user" not in identifier:
raise SynapseError(400, "User identifier is missing 'user' key")
user_id = identifier["user"]
if not user_id.startswith('@'):
user_id = UserID(
user_id, self.hs.hostname
).to_string()
auth_handler = self.auth_handler
user_id = yield auth_handler.validate_password_login(
user_id=user_id,
password=login_submission["password"],
canonical_user_id, callback = yield auth_handler.validate_login(
identifier["user"],
login_submission,
)
device_id = yield self._register_device(
canonical_user_id, login_submission,
)
device_id = yield self._register_device(user_id, login_submission)
access_token = yield auth_handler.get_access_token_for_user_id(
user_id, device_id,
login_submission.get("initial_device_display_name"),
canonical_user_id, device_id,
)
result = {
"user_id": user_id, # may have changed
"user_id": canonical_user_id,
"access_token": access_token,
"home_server": self.hs.hostname,
"device_id": device_id,
}
if callback is not None:
yield callback(result)
defer.returnValue((200, result))
@defer.inlineCallbacks
@ -244,7 +253,6 @@ class LoginRestServlet(ClientV1RestServlet):
device_id = yield self._register_device(user_id, login_submission)
access_token = yield auth_handler.get_access_token_for_user_id(
user_id, device_id,
login_submission.get("initial_device_display_name"),
)
result = {
"user_id": user_id, # may have changed
@ -287,7 +295,6 @@ class LoginRestServlet(ClientV1RestServlet):
)
access_token = yield auth_handler.get_access_token_for_user_id(
registered_user_id, device_id,
login_submission.get("initial_device_display_name"),
)
result = {

View file

@ -30,7 +30,7 @@ class LogoutRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(LogoutRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
self._auth_handler = hs.get_auth_handler()
def on_OPTIONS(self, request):
return (200, {})
@ -38,7 +38,7 @@ class LogoutRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
access_token = get_access_token_from_request(request)
yield self.store.delete_access_token(access_token)
yield self._auth_handler.delete_access_token(access_token)
defer.returnValue((200, {}))
@ -47,8 +47,8 @@ class LogoutAllRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(LogoutAllRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
def on_OPTIONS(self, request):
return (200, {})
@ -57,7 +57,7 @@ class LogoutAllRestServlet(ClientV1RestServlet):
def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
yield self.store.user_delete_access_tokens(user_id)
yield self._auth_handler.delete_access_tokens_for_user(user_id)
defer.returnValue((200, {}))

View file

@ -359,7 +359,7 @@ class RegisterRestServlet(ClientV1RestServlet):
if compare_digest(want_mac, got_mac):
handler = self.handlers.registration_handler
user_id, token = yield handler.register(
localpart=user,
localpart=user.lower(),
password=password,
admin=bool(admin),
)

View file

@ -13,22 +13,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.api.auth import has_access_token
from synapse.api.constants import LoginType
from synapse.api.errors import LoginError, SynapseError, Codes
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.http.servlet import (
RestServlet, parse_json_object_from_request, assert_params_in_request
RestServlet, assert_params_in_request,
parse_json_object_from_request,
)
from synapse.util.async import run_on_reactor
from synapse.util.msisdn import phone_number_to_msisdn
from ._base import client_v2_patterns
import logging
logger = logging.getLogger(__name__)
@ -163,7 +162,6 @@ class DeactivateAccountRestServlet(RestServlet):
def __init__(self, hs):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
super(DeactivateAccountRestServlet, self).__init__()
@ -172,6 +170,20 @@ class DeactivateAccountRestServlet(RestServlet):
def on_POST(self, request):
body = parse_json_object_from_request(request)
# if the caller provides an access token, it ought to be valid.
requester = None
if has_access_token(request):
requester = yield self.auth.get_user_by_req(
request,
) # type: synapse.types.Requester
# allow ASes to dectivate their own users
if requester and requester.app_service:
yield self.auth_handler.deactivate_account(
requester.user.to_string()
)
defer.returnValue((200, {}))
authed, result, params, _ = yield self.auth_handler.check_auth([
[LoginType.PASSWORD],
], body, self.hs.get_ip_from_request(request))
@ -179,25 +191,22 @@ class DeactivateAccountRestServlet(RestServlet):
if not authed:
defer.returnValue((401, result))
user_id = None
requester = None
if LoginType.PASSWORD in result:
user_id = result[LoginType.PASSWORD]
# if using password, they should also be logged in
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
if user_id != result[LoginType.PASSWORD]:
if requester is None:
raise SynapseError(
400,
"Deactivate account requires an access_token",
errcode=Codes.MISSING_TOKEN
)
if requester.user.to_string() != user_id:
raise LoginError(400, "", Codes.UNKNOWN)
else:
logger.error("Auth succeeded but no known type!", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN)
# FIXME: Theoretically there is a race here wherein user resets password
# using threepid.
yield self.store.user_delete_access_tokens(user_id)
yield self.store.user_delete_threepids(user_id)
yield self.store.user_set_password_hash(user_id, None)
yield self.auth_handler.deactivate_account(user_id)
defer.returnValue((200, {}))
@ -373,6 +382,20 @@ class ThreepidDeleteRestServlet(RestServlet):
defer.returnValue((200, {}))
class WhoamiRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/whoami$")
def __init__(self, hs):
super(WhoamiRestServlet, self).__init__()
self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request)
defer.returnValue((200, {'user_id': requester.user.to_string()}))
def register_servlets(hs, http_server):
EmailPasswordRequestTokenRestServlet(hs).register(http_server)
MsisdnPasswordRequestTokenRestServlet(hs).register(http_server)
@ -382,3 +405,4 @@ def register_servlets(hs, http_server):
MsisdnThreepidRequestTokenRestServlet(hs).register(http_server)
ThreepidRestServlet(hs).register(http_server)
ThreepidDeleteRestServlet(hs).register(http_server)
WhoamiRestServlet(hs).register(http_server)

View file

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
class DevicesRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False)
PATTERNS = client_v2_patterns("/devices$", v2_alpha=False)
def __init__(self, hs):
"""
@ -51,7 +51,7 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
API for bulk deletion of devices. Accepts a JSON object with a devices
key which lists the device_ids to delete. Requires user interactive auth.
"""
PATTERNS = client_v2_patterns("/delete_devices", releases=[], v2_alpha=False)
PATTERNS = client_v2_patterns("/delete_devices", v2_alpha=False)
def __init__(self, hs):
super(DeleteDevicesRestServlet, self).__init__()
@ -93,8 +93,7 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
class DeviceRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
releases=[], v2_alpha=False)
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", v2_alpha=False)
def __init__(self, hs):
"""
@ -118,6 +117,8 @@ class DeviceRestServlet(servlet.RestServlet):
@defer.inlineCallbacks
def on_DELETE(self, request, device_id):
requester = yield self.auth.get_user_by_req(request)
try:
body = servlet.parse_json_object_from_request(request)
@ -136,11 +137,12 @@ class DeviceRestServlet(servlet.RestServlet):
if not authed:
defer.returnValue((401, result))
requester = yield self.auth.get_user_by_req(request)
yield self.device_handler.delete_device(
requester.user.to_string(),
device_id,
)
# check that the UI auth matched the access token
user_id = result[constants.LoginType.PASSWORD]
if user_id != requester.user.to_string():
raise errors.AuthError(403, "Invalid auth")
yield self.device_handler.delete_device(user_id, device_id)
defer.returnValue((200, {}))
@defer.inlineCallbacks

View file

@ -39,20 +39,23 @@ class GroupServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
requester_user_id = requester.user.to_string()
group_description = yield self.groups_handler.get_group_profile(group_id, user_id)
group_description = yield self.groups_handler.get_group_profile(
group_id,
requester_user_id,
)
defer.returnValue((200, group_description))
@defer.inlineCallbacks
def on_POST(self, request, group_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
yield self.groups_handler.update_group_profile(
group_id, user_id, content,
group_id, requester_user_id, content,
)
defer.returnValue((200, {}))
@ -72,9 +75,12 @@ class GroupSummaryServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
requester_user_id = requester.user.to_string()
get_group_summary = yield self.groups_handler.get_group_summary(group_id, user_id)
get_group_summary = yield self.groups_handler.get_group_summary(
group_id,
requester_user_id,
)
defer.returnValue((200, get_group_summary))
@ -101,11 +107,11 @@ class GroupSummaryRoomsCatServlet(RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, group_id, category_id, room_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
resp = yield self.groups_handler.update_group_summary_room(
group_id, user_id,
group_id, requester_user_id,
room_id=room_id,
category_id=category_id,
content=content,
@ -116,10 +122,10 @@ class GroupSummaryRoomsCatServlet(RestServlet):
@defer.inlineCallbacks
def on_DELETE(self, request, group_id, category_id, room_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
requester_user_id = requester.user.to_string()
resp = yield self.groups_handler.delete_group_summary_room(
group_id, user_id,
group_id, requester_user_id,
room_id=room_id,
category_id=category_id,
)
@ -143,10 +149,10 @@ class GroupCategoryServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, group_id, category_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
requester_user_id = requester.user.to_string()
category = yield self.groups_handler.get_group_category(
group_id, user_id,
group_id, requester_user_id,
category_id=category_id,
)
@ -155,11 +161,11 @@ class GroupCategoryServlet(RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, group_id, category_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
resp = yield self.groups_handler.update_group_category(
group_id, user_id,
group_id, requester_user_id,
category_id=category_id,
content=content,
)
@ -169,10 +175,10 @@ class GroupCategoryServlet(RestServlet):
@defer.inlineCallbacks
def on_DELETE(self, request, group_id, category_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
requester_user_id = requester.user.to_string()
resp = yield self.groups_handler.delete_group_category(
group_id, user_id,
group_id, requester_user_id,
category_id=category_id,
)
@ -195,10 +201,10 @@ class GroupCategoriesServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
requester_user_id = requester.user.to_string()
category = yield self.groups_handler.get_group_categories(
group_id, user_id,
group_id, requester_user_id,
)
defer.returnValue((200, category))
@ -220,10 +226,10 @@ class GroupRoleServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, group_id, role_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
requester_user_id = requester.user.to_string()
category = yield self.groups_handler.get_group_role(
group_id, user_id,
group_id, requester_user_id,
role_id=role_id,
)
@ -232,11 +238,11 @@ class GroupRoleServlet(RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, group_id, role_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
resp = yield self.groups_handler.update_group_role(
group_id, user_id,
group_id, requester_user_id,
role_id=role_id,
content=content,
)
@ -246,10 +252,10 @@ class GroupRoleServlet(RestServlet):
@defer.inlineCallbacks
def on_DELETE(self, request, group_id, role_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
requester_user_id = requester.user.to_string()
resp = yield self.groups_handler.delete_group_role(
group_id, user_id,
group_id, requester_user_id,
role_id=role_id,
)
@ -272,10 +278,10 @@ class GroupRolesServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
requester_user_id = requester.user.to_string()
category = yield self.groups_handler.get_group_roles(
group_id, user_id,
group_id, requester_user_id,
)
defer.returnValue((200, category))
@ -343,9 +349,9 @@ class GroupRoomServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
requester_user_id = requester.user.to_string()
result = yield self.groups_handler.get_rooms_in_group(group_id, user_id)
result = yield self.groups_handler.get_rooms_in_group(group_id, requester_user_id)
defer.returnValue((200, result))
@ -364,9 +370,9 @@ class GroupUsersServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
requester_user_id = requester.user.to_string()
result = yield self.groups_handler.get_users_in_group(group_id, user_id)
result = yield self.groups_handler.get_users_in_group(group_id, requester_user_id)
defer.returnValue((200, result))
@ -385,9 +391,12 @@ class GroupInvitedUsersServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
requester_user_id = requester.user.to_string()
result = yield self.groups_handler.get_invited_users_in_group(group_id, user_id)
result = yield self.groups_handler.get_invited_users_in_group(
group_id,
requester_user_id,
)
defer.returnValue((200, result))
@ -407,14 +416,18 @@ class GroupCreateServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
requester_user_id = requester.user.to_string()
# TODO: Create group on remote server
content = parse_json_object_from_request(request)
localpart = content.pop("localpart")
group_id = GroupID(localpart, self.server_name).to_string()
result = yield self.groups_handler.create_group(group_id, user_id, content)
result = yield self.groups_handler.create_group(
group_id,
requester_user_id,
content,
)
defer.returnValue((200, result))
@ -435,11 +448,11 @@ class GroupAdminRoomsServlet(RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, group_id, room_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
result = yield self.groups_handler.add_room_to_group(
group_id, user_id, room_id, content,
group_id, requester_user_id, room_id, content,
)
defer.returnValue((200, result))
@ -447,10 +460,37 @@ class GroupAdminRoomsServlet(RestServlet):
@defer.inlineCallbacks
def on_DELETE(self, request, group_id, room_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
requester_user_id = requester.user.to_string()
result = yield self.groups_handler.remove_room_from_group(
group_id, user_id, room_id,
group_id, requester_user_id, room_id,
)
defer.returnValue((200, result))
class GroupAdminRoomsConfigServlet(RestServlet):
"""Update the config of a room in a group
"""
PATTERNS = client_v2_patterns(
"/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)"
"/config/(?P<config_key>[^/]*)$"
)
def __init__(self, hs):
super(GroupAdminRoomsConfigServlet, self).__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks
def on_PUT(self, request, group_id, room_id, config_key):
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
result = yield self.groups_handler.update_room_in_group(
group_id, requester_user_id, room_id, config_key, content,
)
defer.returnValue((200, result))
@ -685,9 +725,9 @@ class GroupsForUserServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
requester_user_id = requester.user.to_string()
result = yield self.groups_handler.get_joined_groups(user_id)
result = yield self.groups_handler.get_joined_groups(requester_user_id)
defer.returnValue((200, result))
@ -700,6 +740,7 @@ def register_servlets(hs, http_server):
GroupRoomServlet(hs).register(http_server)
GroupCreateServlet(hs).register(http_server)
GroupAdminRoomsServlet(hs).register(http_server)
GroupAdminRoomsConfigServlet(hs).register(http_server)
GroupAdminUsersInviteServlet(hs).register(http_server)
GroupAdminUsersKickServlet(hs).register(http_server)
GroupSelfLeaveServlet(hs).register(http_server)

View file

@ -53,8 +53,7 @@ class KeyUploadServlet(RestServlet):
},
}
"""
PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$",
releases=())
PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
def __init__(self, hs):
"""
@ -128,10 +127,7 @@ class KeyQueryServlet(RestServlet):
} } } } } }
"""
PATTERNS = client_v2_patterns(
"/keys/query$",
releases=()
)
PATTERNS = client_v2_patterns("/keys/query$")
def __init__(self, hs):
"""
@ -160,10 +156,7 @@ class KeyChangesServlet(RestServlet):
200 OK
{ "changed": ["@foo:example.com"] }
"""
PATTERNS = client_v2_patterns(
"/keys/changes$",
releases=()
)
PATTERNS = client_v2_patterns("/keys/changes$")
def __init__(self, hs):
"""
@ -213,10 +206,7 @@ class OneTimeKeyServlet(RestServlet):
} } } }
"""
PATTERNS = client_v2_patterns(
"/keys/claim$",
releases=()
)
PATTERNS = client_v2_patterns("/keys/claim$")
def __init__(self, hs):
super(OneTimeKeyServlet, self).__init__()

View file

@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
class NotificationsServlet(RestServlet):
PATTERNS = client_v2_patterns("/notifications$", releases=())
PATTERNS = client_v2_patterns("/notifications$")
def __init__(self, hs):
super(NotificationsServlet, self).__init__()

View file

@ -224,6 +224,12 @@ class RegisterRestServlet(RestServlet):
# 'user' key not 'username'). Since this is a new addition, we'll
# fallback to 'username' if they gave one.
desired_username = body.get("user", desired_username)
# XXX we should check that desired_username is valid. Currently
# we give appservices carte blanche for any insanity in mxids,
# because the IRC bridges rely on being able to register stupid
# IDs.
access_token = get_access_token_from_request(request)
if isinstance(desired_username, basestring):
@ -233,6 +239,15 @@ class RegisterRestServlet(RestServlet):
defer.returnValue((200, result)) # we throw for non 200 responses
return
# for either shared secret or regular registration, downcase the
# provided username before attempting to register it. This should mean
# that people who try to register with upper-case in their usernames
# don't get a nasty surprise. (Note that we treat username
# case-insenstively in login, so they are free to carry on imagining
# that their username is CrAzYh4cKeR if that keeps them happy)
if desired_username is not None:
desired_username = desired_username.lower()
# == Shared Secret Registration == (e.g. create new user scripts)
if 'mac' in body:
# FIXME: Should we really be determining if this is shared secret
@ -336,6 +351,9 @@ class RegisterRestServlet(RestServlet):
new_password = params.get("password", None)
guest_access_token = params.get("guest_access_token", None)
if desired_username is not None:
desired_username = desired_username.lower()
(registered_user_id, _) = yield self.registration_handler.register(
localpart=desired_username,
password=new_password,
@ -417,13 +435,22 @@ class RegisterRestServlet(RestServlet):
def _do_shared_secret_registration(self, username, password, body):
if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled")
if not username:
raise SynapseError(
400, "username must be specified", errcode=Codes.BAD_JSON,
)
user = username.encode("utf-8")
# use the username from the original request rather than the
# downcased one in `username` for the mac calculation
user = body["username"].encode("utf-8")
# str() because otherwise hmac complains that 'unicode' does not
# have the buffer interface
got_mac = str(body["mac"])
# FIXME this is different to the /v1/register endpoint, which
# includes the password and admin flag in the hashed text. Why are
# these different?
want_mac = hmac.new(
key=self.hs.config.registration_shared_secret,
msg=user,
@ -557,25 +584,28 @@ class RegisterRestServlet(RestServlet):
Args:
(str) user_id: full canonical @user:id
(object) params: registration parameters, from which we pull
device_id and initial_device_name
device_id, initial_device_name and inhibit_login
Returns:
defer.Deferred: (object) dictionary for response from /register
"""
device_id = yield self._register_device(user_id, params)
access_token = (
yield self.auth_handler.get_access_token_for_user_id(
user_id, device_id=device_id,
initial_display_name=params.get("initial_device_display_name")
)
)
defer.returnValue({
result = {
"user_id": user_id,
"access_token": access_token,
"home_server": self.hs.hostname,
"device_id": device_id,
})
}
if not params.get("inhibit_login", False):
device_id = yield self._register_device(user_id, params)
access_token = (
yield self.auth_handler.get_access_token_for_user_id(
user_id, device_id=device_id,
)
)
result.update({
"access_token": access_token,
"device_id": device_id,
})
defer.returnValue(result)
def _register_device(self, user_id, params):
"""Register a device for a user.

View file

@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class SendToDeviceRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns(
"/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$",
releases=[], v2_alpha=False
v2_alpha=False
)
def __init__(self, hs):

View file

@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
class ThirdPartyProtocolsServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/protocols", releases=())
PATTERNS = client_v2_patterns("/thirdparty/protocols")
def __init__(self, hs):
super(ThirdPartyProtocolsServlet, self).__init__()
@ -43,8 +43,7 @@ class ThirdPartyProtocolsServlet(RestServlet):
class ThirdPartyProtocolServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$",
releases=())
PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$")
def __init__(self, hs):
super(ThirdPartyProtocolServlet, self).__init__()
@ -66,8 +65,7 @@ class ThirdPartyProtocolServlet(RestServlet):
class ThirdPartyUserServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$",
releases=())
PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$")
def __init__(self, hs):
super(ThirdPartyUserServlet, self).__init__()
@ -90,8 +88,7 @@ class ThirdPartyUserServlet(RestServlet):
class ThirdPartyLocationServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$",
releases=())
PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$")
def __init__(self, hs):
super(ThirdPartyLocationServlet, self).__init__()

View file

@ -20,6 +20,7 @@ from twisted.web.resource import Resource
from synapse.api.errors import (
SynapseError, Codes,
)
from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
from synapse.util.stringutils import random_string
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.http.client import SpiderHttpClient
@ -63,16 +64,15 @@ class PreviewUrlResource(Resource):
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
# simple memory cache mapping urls to OG metadata
self.cache = ExpiringCache(
# memory cache mapping urls to an ObservableDeferred returning
# JSON-encoded OG metadata
self._cache = ExpiringCache(
cache_name="url_previews",
clock=self.clock,
# don't spider URLs more often than once an hour
expiry_ms=60 * 60 * 1000,
)
self.cache.start()
self.downloads = {}
self._cache.start()
self._cleaner_loop = self.clock.looping_call(
self._expire_url_cache_data, 10 * 1000
@ -94,6 +94,7 @@ class PreviewUrlResource(Resource):
else:
ts = self.clock.time_msec()
# XXX: we could move this into _do_preview if we wanted.
url_tuple = urlparse.urlsplit(url)
for entry in self.url_preview_url_blacklist:
match = True
@ -126,14 +127,42 @@ class PreviewUrlResource(Resource):
Codes.UNKNOWN
)
# first check the memory cache - good to handle all the clients on this
# HS thundering away to preview the same URL at the same time.
og = self.cache.get(url)
if og:
respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
return
# the in-memory cache:
# * ensures that only one request is active at a time
# * takes load off the DB for the thundering herds
# * also caches any failures (unlike the DB) so we don't keep
# requesting the same endpoint
# then check the URL cache in the DB (which will also provide us with
observable = self._cache.get(url)
if not observable:
download = preserve_fn(self._do_preview)(
url, requester.user, ts,
)
observable = ObservableDeferred(
download,
consumeErrors=True
)
self._cache[url] = observable
else:
logger.info("Returning cached response")
og = yield make_deferred_yieldable(observable.observe())
respond_with_json_bytes(request, 200, og, send_cors=True)
@defer.inlineCallbacks
def _do_preview(self, url, user, ts):
"""Check the db, and download the URL and build a preview
Args:
url (str):
user (str):
ts (int):
Returns:
Deferred[str]: json-encoded og data
"""
# check the URL cache in the DB (which will also provide us with
# historical previews, if we have any)
cache_result = yield self.store.get_url_cache(url, ts)
if (
@ -141,32 +170,10 @@ class PreviewUrlResource(Resource):
cache_result["expires_ts"] > ts and
cache_result["response_code"] / 100 == 2
):
respond_with_json_bytes(
request, 200, cache_result["og"].encode('utf-8'),
send_cors=True
)
defer.returnValue(cache_result["og"])
return
# Ensure only one download for a given URL is active at a time
download = self.downloads.get(url)
if download is None:
download = self._download_url(url, requester.user)
download = ObservableDeferred(
download,
consumeErrors=True
)
self.downloads[url] = download
@download.addBoth
def callback(media_info):
del self.downloads[url]
return media_info
media_info = yield download.observe()
# FIXME: we should probably update our cache now anyway, so that
# even if the OG calculation raises, we don't keep hammering on the
# remote server. For now, leave it uncached to aid debugging OG
# calculation problems
media_info = yield self._download_url(url, user)
logger.debug("got media_info of '%s'" % media_info)
@ -212,7 +219,7 @@ class PreviewUrlResource(Resource):
# just rely on the caching on the master request to speed things up.
if 'og:image' in og and og['og:image']:
image_info = yield self._download_url(
_rebase_url(og['og:image'], media_info['uri']), requester.user
_rebase_url(og['og:image'], media_info['uri']), user
)
if _is_media(image_info['media_type']):
@ -239,8 +246,7 @@ class PreviewUrlResource(Resource):
logger.debug("Calculated OG for %s as %s" % (url, og))
# store OG in ephemeral in-memory cache
self.cache[url] = og
jsonog = json.dumps(og)
# store OG in history-aware DB cache
yield self.store.store_url_cache(
@ -248,12 +254,12 @@ class PreviewUrlResource(Resource):
media_info["response_code"],
media_info["etag"],
media_info["expires"] + media_info["created_ts"],
json.dumps(og),
jsonog,
media_info["filesystem_id"],
media_info["created_ts"],
)
respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
defer.returnValue(jsonog)
@defer.inlineCallbacks
def _download_url(self, url, user):
@ -520,7 +526,14 @@ def _calc_og(tree, media_uri):
from lxml import etree
TAGS_TO_REMOVE = (
"header", "nav", "aside", "footer", "script", "style", etree.Comment
"header",
"nav",
"aside",
"footer",
"script",
"noscript",
"style",
etree.Comment
)
# Split all the text nodes into paragraphs (by splitting on new

View file

@ -268,7 +268,7 @@ class DataStore(RoomMemberStore, RoomStore,
self._stream_order_on_start = self.get_room_max_stream_ordering()
self._min_stream_order_on_start = self.get_room_min_stream_ordering()
super(DataStore, self).__init__(hs)
super(DataStore, self).__init__(db_conn, hs)
def take_presence_startup_info(self):
active_on_startup = self._presence_on_startup

View file

@ -162,7 +162,7 @@ class PerformanceCounters(object):
class SQLBaseStore(object):
_TXN_ID = 0
def __init__(self, hs):
def __init__(self, db_conn, hs):
self.hs = hs
self._clock = hs.get_clock()
self._db_pool = hs.get_db_pool()

View file

@ -63,7 +63,7 @@ class AccountDataStore(SQLBaseStore):
"get_account_data_for_user", get_account_data_for_user_txn
)
@cachedInlineCallbacks(num_args=2)
@cachedInlineCallbacks(num_args=2, max_entries=5000)
def get_global_account_data_by_type_for_user(self, data_type, user_id):
"""
Returns:

View file

@ -48,8 +48,8 @@ def _make_exclusive_regex(services_cache):
class ApplicationServiceStore(SQLBaseStore):
def __init__(self, hs):
super(ApplicationServiceStore, self).__init__(hs)
def __init__(self, db_conn, hs):
super(ApplicationServiceStore, self).__init__(db_conn, hs)
self.hostname = hs.hostname
self.services_cache = load_appservices(
hs.hostname,
@ -173,8 +173,8 @@ class ApplicationServiceStore(SQLBaseStore):
class ApplicationServiceTransactionStore(SQLBaseStore):
def __init__(self, hs):
super(ApplicationServiceTransactionStore, self).__init__(hs)
def __init__(self, db_conn, hs):
super(ApplicationServiceTransactionStore, self).__init__(db_conn, hs)
@defer.inlineCallbacks
def get_appservices_by_state(self, state):

View file

@ -80,8 +80,8 @@ class BackgroundUpdateStore(SQLBaseStore):
BACKGROUND_UPDATE_INTERVAL_MS = 1000
BACKGROUND_UPDATE_DURATION_MS = 100
def __init__(self, hs):
super(BackgroundUpdateStore, self).__init__(hs)
def __init__(self, db_conn, hs):
super(BackgroundUpdateStore, self).__init__(db_conn, hs)
self._background_update_performance = {}
self._background_update_queue = []
self._background_update_handlers = {}

View file

@ -32,14 +32,14 @@ LAST_SEEN_GRANULARITY = 120 * 1000
class ClientIpStore(background_updates.BackgroundUpdateStore):
def __init__(self, hs):
def __init__(self, db_conn, hs):
self.client_ip_last_seen = Cache(
name="client_ip_last_seen",
keylen=4,
max_entries=50000 * CACHE_SIZE_FACTOR,
)
super(ClientIpStore, self).__init__(hs)
super(ClientIpStore, self).__init__(db_conn, hs)
self.register_background_index_update(
"user_ips_device_index",

View file

@ -29,8 +29,8 @@ logger = logging.getLogger(__name__)
class DeviceInboxStore(BackgroundUpdateStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, hs):
super(DeviceInboxStore, self).__init__(hs)
def __init__(self, db_conn, hs):
super(DeviceInboxStore, self).__init__(db_conn, hs)
self.register_background_index_update(
"device_inbox_stream_index",

View file

@ -26,8 +26,8 @@ logger = logging.getLogger(__name__)
class DeviceStore(SQLBaseStore):
def __init__(self, hs):
super(DeviceStore, self).__init__(hs)
def __init__(self, db_conn, hs):
super(DeviceStore, self).__init__(db_conn, hs)
# Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists.

View file

@ -39,8 +39,8 @@ class EventFederationStore(SQLBaseStore):
EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
def __init__(self, hs):
super(EventFederationStore, self).__init__(hs)
def __init__(self, db_conn, hs):
super(EventFederationStore, self).__init__(db_conn, hs)
self.register_background_update_handler(
self.EVENT_AUTH_STATE_ONLY,

View file

@ -65,8 +65,8 @@ def _deserialize_action(actions, is_highlight):
class EventPushActionsStore(SQLBaseStore):
EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
def __init__(self, hs):
super(EventPushActionsStore, self).__init__(hs)
def __init__(self, db_conn, hs):
super(EventPushActionsStore, self).__init__(db_conn, hs)
self.register_background_index_update(
self.EPA_HIGHLIGHT_INDEX,

View file

@ -197,8 +197,8 @@ class EventsStore(SQLBaseStore):
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
def __init__(self, hs):
super(EventsStore, self).__init__(hs)
def __init__(self, db_conn, hs):
super(EventsStore, self).__init__(db_conn, hs)
self._clock = hs.get_clock()
self.register_background_update_handler(
self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts

View file

@ -35,7 +35,9 @@ class GroupServerStore(SQLBaseStore):
keyvalues={
"group_id": group_id,
},
retcols=("name", "short_description", "long_description", "avatar_url",),
retcols=(
"name", "short_description", "long_description", "avatar_url", "is_public"
),
allow_none=True,
desc="is_user_in_group",
)
@ -52,7 +54,7 @@ class GroupServerStore(SQLBaseStore):
return self._simple_select_list(
table="group_users",
keyvalues=keyvalues,
retcols=("user_id", "is_public",),
retcols=("user_id", "is_public", "is_admin",),
desc="get_users_in_group",
)
@ -855,6 +857,19 @@ class GroupServerStore(SQLBaseStore):
desc="add_room_to_group",
)
def update_room_in_group_visibility(self, group_id, room_id, is_public):
return self._simple_update(
table="group_rooms",
keyvalues={
"group_id": group_id,
"room_id": room_id,
},
updatevalues={
"is_public": is_public,
},
desc="update_room_in_group_visibility",
)
def remove_room_from_group(self, group_id, room_id):
def _remove_room_from_group_txn(txn):
self._simple_delete_txn(
@ -1026,6 +1041,7 @@ class GroupServerStore(SQLBaseStore):
"avatar_url": avatar_url,
"short_description": short_description,
"long_description": long_description,
"is_public": True,
},
desc="create_group",
)
@ -1086,6 +1102,24 @@ class GroupServerStore(SQLBaseStore):
desc="update_remote_attestion",
)
def remove_attestation_renewal(self, group_id, user_id):
"""Remove an attestation that we thought we should renew, but actually
shouldn't. Ideally this would never get called as we would never
incorrectly try and do attestations for local users on local groups.
Args:
group_id (str)
user_id (str)
"""
return self._simple_delete(
table="group_attestations_renewals",
keyvalues={
"group_id": group_id,
"user_id": user_id,
},
desc="remove_attestation_renewal",
)
@defer.inlineCallbacks
def get_remote_attestation(self, group_id, user_id):
"""Get the attestation that proves the remote agrees that the user is

View file

@ -254,6 +254,9 @@ class MediaRepositoryStore(SQLBaseStore):
return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn)
def delete_url_cache(self, media_ids):
if len(media_ids) == 0:
return
sql = (
"DELETE FROM local_media_repository_url_cache"
" WHERE media_id = ?"
@ -281,6 +284,9 @@ class MediaRepositoryStore(SQLBaseStore):
)
def delete_url_cache_media(self, media_ids):
if len(media_ids) == 0:
return
def _delete_url_cache_media_txn(txn):
sql = (
"DELETE FROM local_media_repository"

View file

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 45
SCHEMA_VERSION = 46
dir_path = os.path.abspath(os.path.dirname(__file__))
@ -44,6 +44,13 @@ def prepare_database(db_conn, database_engine, config):
If `config` is None then prepare_database will assert that no upgrade is
necessary, *or* will create a fresh database if the database is empty.
Args:
db_conn:
database_engine:
config (synapse.config.homeserver.HomeServerConfig|None):
application config, or None if we are connecting to an existing
database which we expect to be configured already
"""
try:
cur = db_conn.cursor()
@ -64,6 +71,10 @@ def prepare_database(db_conn, database_engine, config):
else:
_setup_new_database(cur, database_engine)
# check if any of our configured dynamic modules want a database
if config is not None:
_apply_module_schemas(cur, database_engine, config)
cur.close()
db_conn.commit()
except Exception:
@ -283,6 +294,65 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
)
def _apply_module_schemas(txn, database_engine, config):
"""Apply the module schemas for the dynamic modules, if any
Args:
cur: database cursor
database_engine: synapse database engine class
config (synapse.config.homeserver.HomeServerConfig):
application config
"""
for (mod, _config) in config.password_providers:
if not hasattr(mod, 'get_db_schema_files'):
continue
modname = ".".join((mod.__module__, mod.__name__))
_apply_module_schema_files(
txn, database_engine, modname, mod.get_db_schema_files(),
)
def _apply_module_schema_files(cur, database_engine, modname, names_and_streams):
"""Apply the module schemas for a single module
Args:
cur: database cursor
database_engine: synapse database engine class
modname (str): fully qualified name of the module
names_and_streams (Iterable[(str, file)]): the names and streams of
schemas to be applied
"""
cur.execute(
database_engine.convert_param_style(
"SELECT file FROM applied_module_schemas WHERE module_name = ?"
),
(modname,)
)
applied_deltas = set(d for d, in cur)
for (name, stream) in names_and_streams:
if name in applied_deltas:
continue
root_name, ext = os.path.splitext(name)
if ext != '.sql':
raise PrepareDatabaseException(
"only .sql files are currently supported for module schemas",
)
logger.info("applying schema %s for %s", name, modname)
for statement in get_statements(stream):
cur.execute(statement)
# Mark as done.
cur.execute(
database_engine.convert_param_style(
"INSERT INTO applied_module_schemas (module_name, file)"
" VALUES (?,?)",
),
(modname, name)
)
def get_statements(f):
statement_buffer = ""
in_comment = False # If we're in a /* ... */ style comment

View file

@ -27,8 +27,8 @@ logger = logging.getLogger(__name__)
class ReceiptsStore(SQLBaseStore):
def __init__(self, hs):
super(ReceiptsStore, self).__init__(hs)
def __init__(self, db_conn, hs):
super(ReceiptsStore, self).__init__(db_conn, hs)
self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()

View file

@ -24,8 +24,8 @@ from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
class RegistrationStore(background_updates.BackgroundUpdateStore):
def __init__(self, hs):
super(RegistrationStore, self).__init__(hs)
def __init__(self, db_conn, hs):
super(RegistrationStore, self).__init__(db_conn, hs)
self.clock = hs.get_clock()
@ -36,12 +36,15 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
columns=["user_id", "device_id"],
)
self.register_background_index_update(
"refresh_tokens_device_index",
index_name="refresh_tokens_device_id",
table="refresh_tokens",
columns=["user_id", "device_id"],
)
# we no longer use refresh tokens, but it's possible that some people
# might have a background update queued to build this index. Just
# clear the background update.
@defer.inlineCallbacks
def noop_update(progress, batch_size):
yield self._end_background_update("refresh_tokens_device_index")
defer.returnValue(1)
self.register_background_update_handler(
"refresh_tokens_device_index", noop_update)
@defer.inlineCallbacks
def add_access_token_to_user(self, user_id, token, device_id=None):
@ -177,9 +180,11 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
)
if create_profile_with_localpart:
# set a default displayname serverside to avoid ugly race
# between auto-joins and clients trying to set displaynames
txn.execute(
"INSERT INTO profiles(user_id) VALUES (?)",
(create_profile_with_localpart,)
"INSERT INTO profiles(user_id, displayname) VALUES (?,?)",
(create_profile_with_localpart, create_profile_with_localpart)
)
self._invalidate_cache_and_stream(
@ -236,12 +241,10 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
"user_set_password_hash", user_set_password_hash_txn
)
@defer.inlineCallbacks
def user_delete_access_tokens(self, user_id, except_token_id=None,
device_id=None,
delete_refresh_tokens=False):
device_id=None):
"""
Invalidate access/refresh tokens belonging to a user
Invalidate access tokens belonging to a user
Args:
user_id (str): ID of user the tokens belong to
@ -250,10 +253,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
device_id (str|None): ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will
be deleted
delete_refresh_tokens (bool): True to delete refresh tokens as
well as access tokens.
Returns:
defer.Deferred:
defer.Deferred[list[str, str|None]]: a list of the deleted tokens
and device IDs
"""
def f(txn):
keyvalues = {
@ -262,13 +264,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
if device_id is not None:
keyvalues["device_id"] = device_id
if delete_refresh_tokens:
self._simple_delete_txn(
txn,
table="refresh_tokens",
keyvalues=keyvalues,
)
items = keyvalues.items()
where_clause = " AND ".join(k + " = ?" for k, _ in items)
values = [v for _, v in items]
@ -277,14 +272,14 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
values.append(except_token_id)
txn.execute(
"SELECT token FROM access_tokens WHERE %s" % where_clause,
"SELECT token, device_id FROM access_tokens WHERE %s" % where_clause,
values
)
rows = self.cursor_to_dict(txn)
tokens_and_devices = [(r[0], r[1]) for r in txn]
for row in rows:
for token, _ in tokens_and_devices:
self._invalidate_cache_and_stream(
txn, self.get_user_by_access_token, (row["token"],)
txn, self.get_user_by_access_token, (token,)
)
txn.execute(
@ -292,7 +287,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
values
)
yield self.runInteraction(
return tokens_and_devices
return self.runInteraction(
"user_delete_access_tokens", f,
)

View file

@ -49,8 +49,8 @@ _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
class RoomMemberStore(SQLBaseStore):
def __init__(self, hs):
super(RoomMemberStore, self).__init__(hs)
def __init__(self, db_conn, hs):
super(RoomMemberStore, self).__init__(db_conn, hs)
self.register_background_update_handler(
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
)

View file

@ -1,17 +0,0 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
INSERT INTO background_updates (update_name, progress_json) VALUES
('refresh_tokens_device_index', '{}');

View file

@ -29,5 +29,5 @@ CREATE INDEX users_who_share_rooms_r_idx ON users_who_share_rooms(room_id);
CREATE INDEX users_who_share_rooms_o_idx ON users_who_share_rooms(other_user_id);
-- Make sure that we popualte the table initially
-- Make sure that we populate the table initially
UPDATE user_directory_stream_pos SET stream_id = NULL;

View file

@ -1,4 +1,4 @@
/* Copyright 2016 OpenMarket Ltd
/* Copyright 2017 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -13,4 +13,5 @@
* limitations under the License.
*/
ALTER TABLE refresh_tokens ADD COLUMN device_id TEXT;
/* we no longer use (or create) the refresh_tokens table */
DROP TABLE IF EXISTS refresh_tokens;

View file

@ -0,0 +1,32 @@
/* Copyright 2017 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE groups_new (
group_id TEXT NOT NULL,
name TEXT, -- the display name of the room
avatar_url TEXT,
short_description TEXT,
long_description TEXT,
is_public BOOL NOT NULL -- whether non-members can access group APIs
);
-- NB: awful hack to get the default to be true on postgres and 1 on sqlite
INSERT INTO groups_new
SELECT group_id, name, avatar_url, short_description, long_description, (1=1) FROM groups;
DROP TABLE groups;
ALTER TABLE groups_new RENAME TO groups;
CREATE UNIQUE INDEX groups_idx ON groups(group_id);

View file

@ -1,4 +1,4 @@
/* Copyright 2015, 2016 OpenMarket Ltd
/* Copyright 2017 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -13,9 +13,12 @@
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS refresh_tokens(
id INTEGER PRIMARY KEY,
token TEXT NOT NULL,
user_id TEXT NOT NULL,
UNIQUE (token)
);
-- this is just embarassing :|
ALTER TABLE users_in_pubic_room RENAME TO users_in_public_rooms;
-- this is only 300K rows on matrix.org and takes ~3s to generate the index,
-- so is hopefully not going to block anyone else for that long...
CREATE INDEX users_in_public_rooms_room_idx ON users_in_public_rooms(room_id);
CREATE UNIQUE INDEX users_in_public_rooms_user_idx ON users_in_public_rooms(user_id);
DROP INDEX users_in_pubic_room_room_idx;
DROP INDEX users_in_pubic_room_user_idx;

View file

@ -25,3 +25,10 @@ CREATE TABLE IF NOT EXISTS applied_schema_deltas(
file TEXT NOT NULL,
UNIQUE(version, file)
);
-- a list of schema files we have loaded on behalf of dynamic modules
CREATE TABLE IF NOT EXISTS applied_module_schemas(
module_name TEXT NOT NULL,
file TEXT NOT NULL,
UNIQUE(module_name, file)
);

View file

@ -33,8 +33,8 @@ class SearchStore(BackgroundUpdateStore):
EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
def __init__(self, hs):
super(SearchStore, self).__init__(hs)
def __init__(self, db_conn, hs):
super(SearchStore, self).__init__(db_conn, hs)
self.register_background_update_handler(
self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search
)

View file

@ -63,8 +63,8 @@ class StateStore(SQLBaseStore):
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
def __init__(self, hs):
super(StateStore, self).__init__(hs)
def __init__(self, db_conn, hs):
super(StateStore, self).__init__(db_conn, hs)
self.register_background_update_handler(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
self._background_deduplicate_state,

View file

@ -46,8 +46,8 @@ class TransactionStore(SQLBaseStore):
"""A collection of queries for handling PDUs.
"""
def __init__(self, hs):
super(TransactionStore, self).__init__(hs)
def __init__(self, db_conn, hs):
super(TransactionStore, self).__init__(db_conn, hs)
self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000)

View file

@ -63,7 +63,7 @@ class UserDirectoryStore(SQLBaseStore):
user_ids (list(str)): Users to add
"""
yield self._simple_insert_many(
table="users_in_pubic_room",
table="users_in_public_rooms",
values=[
{
"user_id": user_id,
@ -219,7 +219,7 @@ class UserDirectoryStore(SQLBaseStore):
@defer.inlineCallbacks
def update_user_in_public_user_list(self, user_id, room_id):
yield self._simple_update_one(
table="users_in_pubic_room",
table="users_in_public_rooms",
keyvalues={"user_id": user_id},
updatevalues={"room_id": room_id},
desc="update_user_in_public_user_list",
@ -240,7 +240,7 @@ class UserDirectoryStore(SQLBaseStore):
)
self._simple_delete_txn(
txn,
table="users_in_pubic_room",
table="users_in_public_rooms",
keyvalues={"user_id": user_id},
)
txn.call_after(
@ -256,7 +256,7 @@ class UserDirectoryStore(SQLBaseStore):
@defer.inlineCallbacks
def remove_from_user_in_public_room(self, user_id):
yield self._simple_delete(
table="users_in_pubic_room",
table="users_in_public_rooms",
keyvalues={"user_id": user_id},
desc="remove_from_user_in_public_room",
)
@ -267,7 +267,7 @@ class UserDirectoryStore(SQLBaseStore):
in the given room_id
"""
return self._simple_select_onecol(
table="users_in_pubic_room",
table="users_in_public_rooms",
keyvalues={"room_id": room_id},
retcol="user_id",
desc="get_users_in_public_due_to_room",
@ -286,7 +286,7 @@ class UserDirectoryStore(SQLBaseStore):
)
user_ids_pub = yield self._simple_select_onecol(
table="users_in_pubic_room",
table="users_in_public_rooms",
keyvalues={"room_id": room_id},
retcol="user_id",
desc="get_users_in_dir_due_to_room",
@ -514,7 +514,7 @@ class UserDirectoryStore(SQLBaseStore):
def _delete_all_from_user_dir_txn(txn):
txn.execute("DELETE FROM user_directory")
txn.execute("DELETE FROM user_directory_search")
txn.execute("DELETE FROM users_in_pubic_room")
txn.execute("DELETE FROM users_in_public_rooms")
txn.execute("DELETE FROM users_who_share_rooms")
txn.call_after(self.get_user_in_directory.invalidate_all)
txn.call_after(self.get_user_in_public_room.invalidate_all)
@ -537,7 +537,7 @@ class UserDirectoryStore(SQLBaseStore):
@cached()
def get_user_in_public_room(self, user_id):
return self._simple_select_one(
table="users_in_pubic_room",
table="users_in_public_rooms",
keyvalues={"user_id": user_id},
retcols=("room_id",),
allow_none=True,
@ -641,7 +641,7 @@ class UserDirectoryStore(SQLBaseStore):
SELECT d.user_id, display_name, avatar_url
FROM user_directory_search
INNER JOIN user_directory AS d USING (user_id)
LEFT JOIN users_in_pubic_room AS p USING (user_id)
LEFT JOIN users_in_public_rooms AS p USING (user_id)
LEFT JOIN (
SELECT other_user_id AS user_id FROM users_who_share_rooms
WHERE user_id = ? AND share_private
@ -680,7 +680,7 @@ class UserDirectoryStore(SQLBaseStore):
SELECT d.user_id, display_name, avatar_url
FROM user_directory_search
INNER JOIN user_directory AS d USING (user_id)
LEFT JOIN users_in_pubic_room AS p USING (user_id)
LEFT JOIN users_in_public_rooms AS p USING (user_id)
LEFT JOIN (
SELECT other_user_id AS user_id FROM users_who_share_rooms
WHERE user_id = ? AND share_private

View file

@ -278,8 +278,13 @@ class Limiter(object):
if entry[0] >= self.max_count:
new_defer = defer.Deferred()
entry[1].append(new_defer)
logger.info("Waiting to acquire limiter lock for key %r", key)
with PreserveLoggingContext():
yield new_defer
logger.info("Acquired limiter lock for key %r", key)
else:
logger.info("Acquired uncontended limiter lock for key %r", key)
entry[0] += 1
@ -288,16 +293,21 @@ class Limiter(object):
try:
yield
finally:
logger.info("Releasing limiter lock for key %r", key)
# We've finished executing so check if there are any things
# blocked waiting to execute and start one of them
entry[0] -= 1
try:
entry[1].pop(0).callback(None)
except IndexError:
# If nothing else is executing for this key then remove it
# from the map
if entry[0] == 0:
self.key_to_defer.pop(key, None)
if entry[1]:
next_def = entry[1].pop(0)
with PreserveLoggingContext():
next_def.callback(None)
elif entry[0] == 0:
# We were the last thing for this key: remove it from the
# map.
del self.key_to_defer[key]
defer.returnValue(_ctx_manager())

View file

@ -53,7 +53,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
type="m.room.message",
room_id="!foo:bar"
)
self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
self.mock_store.get_new_events_for_appservice.side_effect = [
(0, [event]),
(0, [])
]
self.mock_as_api.push = Mock()
yield self.handler.notify_interested_services(0)
self.mock_scheduler.submit_event_for_as.assert_called_once_with(
@ -75,7 +78,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
)
self.mock_as_api.push = Mock()
self.mock_as_api.query_user = Mock()
self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
self.mock_store.get_new_events_for_appservice.side_effect = [
(0, [event]),
(0, [])
]
yield self.handler.notify_interested_services(0)
self.mock_as_api.query_user.assert_called_once_with(
services[0], user_id
@ -98,7 +104,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
)
self.mock_as_api.push = Mock()
self.mock_as_api.query_user = Mock()
self.mock_store.get_new_events_for_appservice.return_value = (0, [event])
self.mock_store.get_new_events_for_appservice.side_effect = [
(0, [event]),
(0, [])
]
yield self.handler.notify_interested_services(0)
self.assertFalse(
self.mock_as_api.query_user.called,

View file

@ -58,7 +58,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
# must be done after inserts
self.store = ApplicationServiceStore(hs)
self.store = ApplicationServiceStore(None, hs)
def tearDown(self):
# TODO: suboptimal that we need to create files for tests!
@ -150,7 +150,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
self.as_yaml_files = []
self.store = TestTransactionStore(hs)
self.store = TestTransactionStore(None, hs)
def _add_service(self, url, as_token, id):
as_yaml = dict(url=url, as_token=as_token, hs_token="something",
@ -420,8 +420,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
class TestTransactionStore(ApplicationServiceTransactionStore,
ApplicationServiceStore):
def __init__(self, hs):
super(TestTransactionStore, self).__init__(hs)
def __init__(self, db_conn, hs):
super(TestTransactionStore, self).__init__(db_conn, hs)
class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
@ -458,7 +458,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
replication_layer=Mock(),
)
ApplicationServiceStore(hs)
ApplicationServiceStore(None, hs)
@defer.inlineCallbacks
def test_duplicate_ids(self):
@ -477,7 +477,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
)
with self.assertRaises(ConfigError) as cm:
ApplicationServiceStore(hs)
ApplicationServiceStore(None, hs)
e = cm.exception
self.assertIn(f1, e.message)
@ -501,7 +501,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
)
with self.assertRaises(ConfigError) as cm:
ApplicationServiceStore(hs)
ApplicationServiceStore(None, hs)
e = cm.exception
self.assertIn(f1, e.message)

View file

@ -56,7 +56,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
database_engine=create_engine(config.database_config),
)
self.datastore = SQLBaseStore(hs)
self.datastore = SQLBaseStore(None, hs)
@defer.inlineCallbacks
def test_insert_1col(self):

View file

@ -29,7 +29,7 @@ class DirectoryStoreTestCase(unittest.TestCase):
def setUp(self):
hs = yield setup_test_homeserver()
self.store = DirectoryStore(hs)
self.store = DirectoryStore(None, hs)
self.room = RoomID.from_string("!abcde:test")
self.alias = RoomAlias.from_string("#my-room:test")

View file

@ -29,7 +29,7 @@ class PresenceStoreTestCase(unittest.TestCase):
def setUp(self):
hs = yield setup_test_homeserver(clock=MockClock())
self.store = PresenceStore(hs)
self.store = PresenceStore(None, hs)
self.u_apple = UserID.from_string("@apple:test")
self.u_banana = UserID.from_string("@banana:test")

View file

@ -29,7 +29,7 @@ class ProfileStoreTestCase(unittest.TestCase):
def setUp(self):
hs = yield setup_test_homeserver()
self.store = ProfileStore(hs)
self.store = ProfileStore(None, hs)
self.u_frank = UserID.from_string("@frank:test")

View file

@ -86,7 +86,8 @@ class RegistrationStoreTestCase(unittest.TestCase):
# now delete some
yield self.store.user_delete_access_tokens(
self.user_id, device_id=self.device_id, delete_refresh_tokens=True)
self.user_id, device_id=self.device_id,
)
# check they were deleted
user = yield self.store.get_user_by_access_token(self.tokens[1])
@ -97,8 +98,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
self.assertEqual(self.user_id, user["name"])
# now delete the rest
yield self.store.user_delete_access_tokens(
self.user_id, delete_refresh_tokens=True)
yield self.store.user_delete_access_tokens(self.user_id)
user = yield self.store.get_user_by_access_token(self.tokens[0])
self.assertIsNone(user,

View file

@ -310,6 +310,7 @@ class SQLiteMemoryDbPool(ConnectionPool, object):
)
self.config = Mock()
self.config.password_providers = []
self.config.database_config = {"name": "sqlite3"}
def prepare(self):