forked from MirrorHub/synapse
Merge branch 'develop' into matthew/sync_deleted_devices
This commit is contained in:
commit
9b34f3ea3a
103 changed files with 5421 additions and 4713 deletions
|
@ -23,6 +23,9 @@ matrix:
|
|||
- python: 3.6
|
||||
env: TOX_ENV=py36
|
||||
|
||||
- python: 3.6
|
||||
env: TOX_ENV=check_isort
|
||||
|
||||
- python: 3.6
|
||||
env: TOX_ENV=check-newsfragment
|
||||
|
||||
|
|
2470
CHANGES.md
Normal file
2470
CHANGES.md
Normal file
File diff suppressed because it is too large
Load diff
2839
CHANGES.rst
2839
CHANGES.rst
File diff suppressed because it is too large
Load diff
|
@ -2,6 +2,7 @@ include synctl
|
|||
include LICENSE
|
||||
include VERSION
|
||||
include *.rst
|
||||
include *.md
|
||||
include demo/README
|
||||
include demo/demo.tls.dh
|
||||
include demo/*.py
|
||||
|
|
16
README.rst
16
README.rst
|
@ -71,7 +71,7 @@ We'd like to invite you to join #matrix:matrix.org (via
|
|||
https://matrix.org/docs/projects/try-matrix-now.html), run a homeserver, take a look
|
||||
at the `Matrix spec <https://matrix.org/docs/spec>`_, and experiment with the
|
||||
`APIs <https://matrix.org/docs/api>`_ and `Client SDKs
|
||||
<http://matrix.org/docs/projects/try-matrix-now.html#client-sdks>`_.
|
||||
<https://matrix.org/docs/projects/try-matrix-now.html#client-sdks>`_.
|
||||
|
||||
Thanks for using Matrix!
|
||||
|
||||
|
@ -283,7 +283,7 @@ Connecting to Synapse from a client
|
|||
|
||||
The easiest way to try out your new Synapse installation is by connecting to it
|
||||
from a web client. The easiest option is probably the one at
|
||||
http://riot.im/app. You will need to specify a "Custom server" when you log on
|
||||
https://riot.im/app. You will need to specify a "Custom server" when you log on
|
||||
or register: set this to ``https://domain.tld`` if you setup a reverse proxy
|
||||
following the recommended setup, or ``https://localhost:8448`` - remember to specify the
|
||||
port (``:8448``) if not ``:443`` unless you changed the configuration. (Leave the identity
|
||||
|
@ -329,7 +329,7 @@ Security Note
|
|||
=============
|
||||
|
||||
Matrix serves raw user generated data in some APIs - specifically the `content
|
||||
repository endpoints <http://matrix.org/docs/spec/client_server/latest.html#get-matrix-media-r0-download-servername-mediaid>`_.
|
||||
repository endpoints <https://matrix.org/docs/spec/client_server/latest.html#get-matrix-media-r0-download-servername-mediaid>`_.
|
||||
|
||||
Whilst we have tried to mitigate against possible XSS attacks (e.g.
|
||||
https://github.com/matrix-org/synapse/pull/1021) we recommend running
|
||||
|
@ -348,7 +348,7 @@ Platform-Specific Instructions
|
|||
Debian
|
||||
------
|
||||
|
||||
Matrix provides official Debian packages via apt from http://matrix.org/packages/debian/.
|
||||
Matrix provides official Debian packages via apt from https://matrix.org/packages/debian/.
|
||||
Note that these packages do not include a client - choose one from
|
||||
https://matrix.org/docs/projects/try-matrix-now.html (or build your own with one of our SDKs :)
|
||||
|
||||
|
@ -524,7 +524,7 @@ Troubleshooting Running
|
|||
-----------------------
|
||||
|
||||
If synapse fails with ``missing "sodium.h"`` crypto errors, you may need
|
||||
to manually upgrade PyNaCL, as synapse uses NaCl (http://nacl.cr.yp.to/) for
|
||||
to manually upgrade PyNaCL, as synapse uses NaCl (https://nacl.cr.yp.to/) for
|
||||
encryption and digital signatures.
|
||||
Unfortunately PyNACL currently has a few issues
|
||||
(https://github.com/pyca/pynacl/issues/53) and
|
||||
|
@ -672,8 +672,8 @@ useful just for development purposes. See `<demo/README>`_.
|
|||
Using PostgreSQL
|
||||
================
|
||||
|
||||
As of Synapse 0.9, `PostgreSQL <http://www.postgresql.org>`_ is supported as an
|
||||
alternative to the `SQLite <http://sqlite.org/>`_ database that Synapse has
|
||||
As of Synapse 0.9, `PostgreSQL <https://www.postgresql.org>`_ is supported as an
|
||||
alternative to the `SQLite <https://sqlite.org/>`_ database that Synapse has
|
||||
traditionally used for convenience and simplicity.
|
||||
|
||||
The advantages of Postgres include:
|
||||
|
@ -697,7 +697,7 @@ Using a reverse proxy with Synapse
|
|||
It is recommended to put a reverse proxy such as
|
||||
`nginx <https://nginx.org/en/docs/http/ngx_http_proxy_module.html>`_,
|
||||
`Apache <https://httpd.apache.org/docs/current/mod/mod_proxy_http.html>`_ or
|
||||
`HAProxy <http://www.haproxy.org/>`_ in front of Synapse. One advantage of
|
||||
`HAProxy <https://www.haproxy.org/>`_ in front of Synapse. One advantage of
|
||||
doing so is that it means that you can expose the default https port (443) to
|
||||
Matrix clients without needing to run Synapse with root privileges.
|
||||
|
||||
|
|
1
changelog.d/3367.misc
Normal file
1
changelog.d/3367.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Remove unnecessary event re-signing hacks
|
|
@ -1 +0,0 @@
|
|||
Include CPU time from database threads in request/block metrics.
|
|
@ -1 +0,0 @@
|
|||
Add CPU metrics for _fetch_event_list
|
1
changelog.d/3514.bugfix
Normal file
1
changelog.d/3514.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Don't generate TURN credentials if no TURN config options are set
|
1
changelog.d/3548.bugfix
Normal file
1
changelog.d/3548.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Catch failures saving metrics captured by Measure, and instead log the faulty metrics information for further analysis.
|
1
changelog.d/3552.misc
Normal file
1
changelog.d/3552.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Release notes are now in the Markdown format.
|
1
changelog.d/3553.feature
Normal file
1
changelog.d/3553.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Add metrics to track resource usage by background processes
|
1
changelog.d/3554.feature
Normal file
1
changelog.d/3554.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Add `code` label to `synapse_http_server_response_time_seconds` prometheus metric
|
1
changelog.d/3556.feature
Normal file
1
changelog.d/3556.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Add metrics to track resource usage by background processes
|
1
changelog.d/3559.misc
Normal file
1
changelog.d/3559.misc
Normal file
|
@ -0,0 +1 @@
|
|||
add config for pep8
|
1
changelog.d/3570.bugfix
Normal file
1
changelog.d/3570.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Fix potential stack overflow and deadlock under heavy load
|
1
changelog.d/3571.misc
Normal file
1
changelog.d/3571.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Merge Linearizer and Limiter
|
1
changelog.d/3572.misc
Normal file
1
changelog.d/3572.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Merge Linearizer and Limiter
|
63
docs/admin_api/register_api.rst
Normal file
63
docs/admin_api/register_api.rst
Normal file
|
@ -0,0 +1,63 @@
|
|||
Shared-Secret Registration
|
||||
==========================
|
||||
|
||||
This API allows for the creation of users in an administrative and
|
||||
non-interactive way. This is generally used for bootstrapping a Synapse
|
||||
instance with administrator accounts.
|
||||
|
||||
To authenticate yourself to the server, you will need both the shared secret
|
||||
(``registration_shared_secret`` in the homeserver configuration), and a
|
||||
one-time nonce. If the registration shared secret is not configured, this API
|
||||
is not enabled.
|
||||
|
||||
To fetch the nonce, you need to request one from the API::
|
||||
|
||||
> GET /_matrix/client/r0/admin/register
|
||||
|
||||
< {"nonce": "thisisanonce"}
|
||||
|
||||
Once you have the nonce, you can make a ``POST`` to the same URL with a JSON
|
||||
body containing the nonce, username, password, whether they are an admin
|
||||
(optional, False by default), and a HMAC digest of the content.
|
||||
|
||||
As an example::
|
||||
|
||||
> POST /_matrix/client/r0/admin/register
|
||||
> {
|
||||
"nonce": "thisisanonce",
|
||||
"username": "pepper_roni",
|
||||
"password": "pizza",
|
||||
"admin": true,
|
||||
"mac": "mac_digest_here"
|
||||
}
|
||||
|
||||
< {
|
||||
"access_token": "token_here",
|
||||
"user_id": "@pepper_roni@test",
|
||||
"home_server": "test",
|
||||
"device_id": "device_id_here"
|
||||
}
|
||||
|
||||
The MAC is the hex digest output of the HMAC-SHA1 algorithm, with the key being
|
||||
the shared secret and the content being the nonce, user, password, and either
|
||||
the string "admin" or "notadmin", each separated by NULs. For an example of
|
||||
generation in Python::
|
||||
|
||||
import hmac, hashlib
|
||||
|
||||
def generate_mac(nonce, user, password, admin=False):
|
||||
|
||||
mac = hmac.new(
|
||||
key=shared_secret,
|
||||
digestmod=hashlib.sha1,
|
||||
)
|
||||
|
||||
mac.update(nonce.encode('utf8'))
|
||||
mac.update(b"\x00")
|
||||
mac.update(user.encode('utf8'))
|
||||
mac.update(b"\x00")
|
||||
mac.update(password.encode('utf8'))
|
||||
mac.update(b"\x00")
|
||||
mac.update(b"admin" if admin else b"notadmin")
|
||||
|
||||
return mac.hexdigest()
|
|
@ -1,5 +1,30 @@
|
|||
[tool.towncrier]
|
||||
package = "synapse"
|
||||
filename = "CHANGES.rst"
|
||||
filename = "CHANGES.md"
|
||||
directory = "changelog.d"
|
||||
issue_format = "`#{issue} <https://github.com/matrix-org/synapse/issues/{issue}>`_"
|
||||
issue_format = "[\\#{issue}](https://github.com/matrix-org/synapse/issues/{issue}>)"
|
||||
|
||||
[[tool.towncrier.type]]
|
||||
directory = "feature"
|
||||
name = "Features"
|
||||
showcontent = true
|
||||
|
||||
[[tool.towncrier.type]]
|
||||
directory = "bugfix"
|
||||
name = "Bugfixes"
|
||||
showcontent = true
|
||||
|
||||
[[tool.towncrier.type]]
|
||||
directory = "doc"
|
||||
name = "Improved Documentation"
|
||||
showcontent = true
|
||||
|
||||
[[tool.towncrier.type]]
|
||||
directory = "removal"
|
||||
name = "Deprecations and Removals"
|
||||
showcontent = true
|
||||
|
||||
[[tool.towncrier.type]]
|
||||
directory = "misc"
|
||||
name = "Internal Changes"
|
||||
showcontent = true
|
||||
|
|
|
@ -26,11 +26,37 @@ import yaml
|
|||
|
||||
|
||||
def request_registration(user, password, server_location, shared_secret, admin=False):
|
||||
req = urllib2.Request(
|
||||
"%s/_matrix/client/r0/admin/register" % (server_location,),
|
||||
headers={'Content-Type': 'application/json'}
|
||||
)
|
||||
|
||||
try:
|
||||
if sys.version_info[:3] >= (2, 7, 9):
|
||||
# As of version 2.7.9, urllib2 now checks SSL certs
|
||||
import ssl
|
||||
f = urllib2.urlopen(req, context=ssl.SSLContext(ssl.PROTOCOL_SSLv23))
|
||||
else:
|
||||
f = urllib2.urlopen(req)
|
||||
body = f.read()
|
||||
f.close()
|
||||
nonce = json.loads(body)["nonce"]
|
||||
except urllib2.HTTPError as e:
|
||||
print "ERROR! Received %d %s" % (e.code, e.reason,)
|
||||
if 400 <= e.code < 500:
|
||||
if e.info().type == "application/json":
|
||||
resp = json.load(e)
|
||||
if "error" in resp:
|
||||
print resp["error"]
|
||||
sys.exit(1)
|
||||
|
||||
mac = hmac.new(
|
||||
key=shared_secret,
|
||||
digestmod=hashlib.sha1,
|
||||
)
|
||||
|
||||
mac.update(nonce)
|
||||
mac.update("\x00")
|
||||
mac.update(user)
|
||||
mac.update("\x00")
|
||||
mac.update(password)
|
||||
|
@ -40,10 +66,10 @@ def request_registration(user, password, server_location, shared_secret, admin=F
|
|||
mac = mac.hexdigest()
|
||||
|
||||
data = {
|
||||
"user": user,
|
||||
"nonce": nonce,
|
||||
"username": user,
|
||||
"password": password,
|
||||
"mac": mac,
|
||||
"type": "org.matrix.login.shared_secret",
|
||||
"admin": admin,
|
||||
}
|
||||
|
||||
|
@ -52,7 +78,7 @@ def request_registration(user, password, server_location, shared_secret, admin=F
|
|||
print "Sending registration request..."
|
||||
|
||||
req = urllib2.Request(
|
||||
"%s/_matrix/client/api/v1/register" % (server_location,),
|
||||
"%s/_matrix/client/r0/admin/register" % (server_location,),
|
||||
data=json.dumps(data),
|
||||
headers={'Content-Type': 'application/json'}
|
||||
)
|
||||
|
|
12
setup.cfg
12
setup.cfg
|
@ -14,12 +14,17 @@ ignore =
|
|||
pylint.cfg
|
||||
tox.ini
|
||||
|
||||
[flake8]
|
||||
[pep8]
|
||||
max-line-length = 90
|
||||
# W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.
|
||||
# E203 is contrary to PEP8.
|
||||
# W503 requires that binary operators be at the end, not start, of lines. Erik
|
||||
# doesn't like it. E203 is contrary to PEP8.
|
||||
ignore = W503,E203
|
||||
|
||||
[flake8]
|
||||
# note that flake8 inherits the "ignore" settings from "pep8" (because it uses
|
||||
# pep8 to do those checks), but not the "max-line-length" setting
|
||||
max-line-length = 90
|
||||
|
||||
[isort]
|
||||
line_length = 89
|
||||
not_skip = __init__.py
|
||||
|
@ -31,3 +36,4 @@ known_compat = mock,six
|
|||
known_twisted=twisted,OpenSSL
|
||||
multi_line_output=3
|
||||
include_trailing_comma=true
|
||||
combine_as_imports=true
|
||||
|
|
|
@ -17,4 +17,4 @@
|
|||
""" This is a reference implementation of a Matrix home server.
|
||||
"""
|
||||
|
||||
__version__ = "0.32.2"
|
||||
__version__ = "0.33.0"
|
||||
|
|
|
@ -193,7 +193,7 @@ class Auth(object):
|
|||
synapse.types.create_requester(user_id, app_service=app_service)
|
||||
)
|
||||
|
||||
access_token = get_access_token_from_request(
|
||||
access_token = self.get_access_token_from_request(
|
||||
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
|
||||
)
|
||||
|
||||
|
@ -239,7 +239,7 @@ class Auth(object):
|
|||
@defer.inlineCallbacks
|
||||
def _get_appservice_user_id(self, request):
|
||||
app_service = self.store.get_app_service_by_token(
|
||||
get_access_token_from_request(
|
||||
self.get_access_token_from_request(
|
||||
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
|
||||
)
|
||||
)
|
||||
|
@ -513,7 +513,7 @@ class Auth(object):
|
|||
|
||||
def get_appservice_by_req(self, request):
|
||||
try:
|
||||
token = get_access_token_from_request(
|
||||
token = self.get_access_token_from_request(
|
||||
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
|
||||
)
|
||||
service = self.store.get_app_service_by_token(token)
|
||||
|
@ -673,67 +673,67 @@ class Auth(object):
|
|||
" edit its room list entry"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def has_access_token(request):
|
||||
"""Checks if the request has an access_token.
|
||||
|
||||
def has_access_token(request):
|
||||
"""Checks if the request has an access_token.
|
||||
Returns:
|
||||
bool: False if no access_token was given, True otherwise.
|
||||
"""
|
||||
query_params = request.args.get("access_token")
|
||||
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
|
||||
return bool(query_params) or bool(auth_headers)
|
||||
|
||||
Returns:
|
||||
bool: False if no access_token was given, True otherwise.
|
||||
"""
|
||||
query_params = request.args.get("access_token")
|
||||
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
|
||||
return bool(query_params) or bool(auth_headers)
|
||||
@staticmethod
|
||||
def get_access_token_from_request(request, token_not_found_http_status=401):
|
||||
"""Extracts the access_token from the request.
|
||||
|
||||
Args:
|
||||
request: The http request.
|
||||
token_not_found_http_status(int): The HTTP status code to set in the
|
||||
AuthError if the token isn't found. This is used in some of the
|
||||
legacy APIs to change the status code to 403 from the default of
|
||||
401 since some of the old clients depended on auth errors returning
|
||||
403.
|
||||
Returns:
|
||||
str: The access_token
|
||||
Raises:
|
||||
AuthError: If there isn't an access_token in the request.
|
||||
"""
|
||||
|
||||
def get_access_token_from_request(request, token_not_found_http_status=401):
|
||||
"""Extracts the access_token from the request.
|
||||
|
||||
Args:
|
||||
request: The http request.
|
||||
token_not_found_http_status(int): The HTTP status code to set in the
|
||||
AuthError if the token isn't found. This is used in some of the
|
||||
legacy APIs to change the status code to 403 from the default of
|
||||
401 since some of the old clients depended on auth errors returning
|
||||
403.
|
||||
Returns:
|
||||
str: The access_token
|
||||
Raises:
|
||||
AuthError: If there isn't an access_token in the request.
|
||||
"""
|
||||
|
||||
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
|
||||
query_params = request.args.get(b"access_token")
|
||||
if auth_headers:
|
||||
# Try the get the access_token from a "Authorization: Bearer"
|
||||
# header
|
||||
if query_params is not None:
|
||||
raise AuthError(
|
||||
token_not_found_http_status,
|
||||
"Mixing Authorization headers and access_token query parameters.",
|
||||
errcode=Codes.MISSING_TOKEN,
|
||||
)
|
||||
if len(auth_headers) > 1:
|
||||
raise AuthError(
|
||||
token_not_found_http_status,
|
||||
"Too many Authorization headers.",
|
||||
errcode=Codes.MISSING_TOKEN,
|
||||
)
|
||||
parts = auth_headers[0].split(" ")
|
||||
if parts[0] == "Bearer" and len(parts) == 2:
|
||||
return parts[1]
|
||||
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
|
||||
query_params = request.args.get(b"access_token")
|
||||
if auth_headers:
|
||||
# Try the get the access_token from a "Authorization: Bearer"
|
||||
# header
|
||||
if query_params is not None:
|
||||
raise AuthError(
|
||||
token_not_found_http_status,
|
||||
"Mixing Authorization headers and access_token query parameters.",
|
||||
errcode=Codes.MISSING_TOKEN,
|
||||
)
|
||||
if len(auth_headers) > 1:
|
||||
raise AuthError(
|
||||
token_not_found_http_status,
|
||||
"Too many Authorization headers.",
|
||||
errcode=Codes.MISSING_TOKEN,
|
||||
)
|
||||
parts = auth_headers[0].split(" ")
|
||||
if parts[0] == "Bearer" and len(parts) == 2:
|
||||
return parts[1]
|
||||
else:
|
||||
raise AuthError(
|
||||
token_not_found_http_status,
|
||||
"Invalid Authorization header.",
|
||||
errcode=Codes.MISSING_TOKEN,
|
||||
)
|
||||
else:
|
||||
raise AuthError(
|
||||
token_not_found_http_status,
|
||||
"Invalid Authorization header.",
|
||||
errcode=Codes.MISSING_TOKEN,
|
||||
)
|
||||
else:
|
||||
# Try to get the access_token from the query params.
|
||||
if not query_params:
|
||||
raise AuthError(
|
||||
token_not_found_http_status,
|
||||
"Missing access token.",
|
||||
errcode=Codes.MISSING_TOKEN
|
||||
)
|
||||
# Try to get the access_token from the query params.
|
||||
if not query_params:
|
||||
raise AuthError(
|
||||
token_not_found_http_status,
|
||||
"Missing access token.",
|
||||
errcode=Codes.MISSING_TOKEN
|
||||
)
|
||||
|
||||
return query_params[0]
|
||||
return query_params[0]
|
||||
|
|
|
@ -18,6 +18,8 @@ import logging
|
|||
import os
|
||||
import sys
|
||||
|
||||
from six import iteritems
|
||||
|
||||
from twisted.application import service
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.web.resource import EncodingResourceWrapper, NoResource
|
||||
|
@ -442,7 +444,7 @@ def run(hs):
|
|||
stats["total_nonbridged_users"] = total_nonbridged_users
|
||||
|
||||
daily_user_type_results = yield hs.get_datastore().count_daily_user_type()
|
||||
for name, count in daily_user_type_results.iteritems():
|
||||
for name, count in iteritems(daily_user_type_results):
|
||||
stats["daily_user_type_" + name] = count
|
||||
|
||||
room_count = yield hs.get_datastore().get_room_count()
|
||||
|
@ -453,7 +455,7 @@ def run(hs):
|
|||
stats["daily_messages"] = yield hs.get_datastore().count_daily_messages()
|
||||
|
||||
r30_results = yield hs.get_datastore().count_r30_users()
|
||||
for name, count in r30_results.iteritems():
|
||||
for name, count in iteritems(r30_results):
|
||||
stats["r30_users_" + name] = count
|
||||
|
||||
daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages()
|
||||
|
|
|
@ -25,6 +25,8 @@ import subprocess
|
|||
import sys
|
||||
import time
|
||||
|
||||
from six import iteritems
|
||||
|
||||
import yaml
|
||||
|
||||
SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"]
|
||||
|
@ -173,7 +175,7 @@ def main():
|
|||
os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor)
|
||||
|
||||
cache_factors = config.get("synctl_cache_factors", {})
|
||||
for cache_name, factor in cache_factors.iteritems():
|
||||
for cache_name, factor in iteritems(cache_factors):
|
||||
os.environ["SYNAPSE_CACHE_FACTOR_" + cache_name.upper()] = str(factor)
|
||||
|
||||
worker_configfiles = []
|
||||
|
|
|
@ -30,10 +30,10 @@ class VoipConfig(Config):
|
|||
## Turn ##
|
||||
|
||||
# The public URIs of the TURN server to give to clients
|
||||
turn_uris: []
|
||||
#turn_uris: []
|
||||
|
||||
# The shared secret used to compute passwords for the TURN server
|
||||
turn_shared_secret: "YOUR_SHARED_SECRET"
|
||||
#turn_shared_secret: "YOUR_SHARED_SECRET"
|
||||
|
||||
# The Username and password if the TURN server needs them and
|
||||
# does not use a token
|
||||
|
|
|
@ -13,6 +13,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from six import iteritems
|
||||
|
||||
from frozendict import frozendict
|
||||
|
||||
from twisted.internet import defer
|
||||
|
@ -159,7 +161,7 @@ def _encode_state_dict(state_dict):
|
|||
|
||||
return [
|
||||
(etype, state_key, v)
|
||||
for (etype, state_key), v in state_dict.iteritems()
|
||||
for (etype, state_key), v in iteritems(state_dict)
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ from synapse.api.errors import Codes, SynapseError
|
|||
from synapse.crypto.event_signing import check_event_content_hash
|
||||
from synapse.events import FrozenEvent
|
||||
from synapse.events.utils import prune_event
|
||||
from synapse.http.servlet import assert_params_in_request
|
||||
from synapse.http.servlet import assert_params_in_dict
|
||||
from synapse.util import logcontext, unwrapFirstError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -199,7 +199,7 @@ def event_from_pdu_json(pdu_json, outlier=False):
|
|||
"""
|
||||
# we could probably enforce a bunch of other fields here (room_id, sender,
|
||||
# origin, etc etc)
|
||||
assert_params_in_request(pdu_json, ('event_id', 'type', 'depth'))
|
||||
assert_params_in_dict(pdu_json, ('event_id', 'type', 'depth'))
|
||||
|
||||
depth = pdu_json['depth']
|
||||
if not isinstance(depth, six.integer_types):
|
||||
|
|
|
@ -30,7 +30,8 @@ from synapse.metrics import (
|
|||
sent_edus_counter,
|
||||
sent_transactions_counter,
|
||||
)
|
||||
from synapse.util import PreserveLoggingContext, logcontext
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.util import logcontext
|
||||
from synapse.util.metrics import measure_func
|
||||
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
|
||||
|
||||
|
@ -165,10 +166,11 @@ class TransactionQueue(object):
|
|||
if self._is_processing:
|
||||
return
|
||||
|
||||
# fire off a processing loop in the background. It's likely it will
|
||||
# outlast the current request, so run it in the sentinel logcontext.
|
||||
with PreserveLoggingContext():
|
||||
self._process_event_queue_loop()
|
||||
# fire off a processing loop in the background
|
||||
run_as_background_process(
|
||||
"process_event_queue_for_federation",
|
||||
self._process_event_queue_loop,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _process_event_queue_loop(self):
|
||||
|
@ -432,14 +434,11 @@ class TransactionQueue(object):
|
|||
|
||||
logger.debug("TX [%s] Starting transaction loop", destination)
|
||||
|
||||
# Drop the logcontext before starting the transaction. It doesn't
|
||||
# really make sense to log all the outbound transactions against
|
||||
# whatever path led us to this point: that's pretty arbitrary really.
|
||||
#
|
||||
# (this also means we can fire off _perform_transaction without
|
||||
# yielding)
|
||||
with logcontext.PreserveLoggingContext():
|
||||
self._transaction_transmission_loop(destination)
|
||||
run_as_background_process(
|
||||
"federation_transaction_transmission_loop",
|
||||
self._transaction_transmission_loop,
|
||||
destination,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _transaction_transmission_loop(self, destination):
|
||||
|
|
|
@ -21,8 +21,8 @@ import logging
|
|||
import sys
|
||||
|
||||
import six
|
||||
from six import iteritems
|
||||
from six.moves import http_client
|
||||
from six import iteritems, itervalues
|
||||
from six.moves import http_client, zip
|
||||
|
||||
from signedjson.key import decode_verify_key_bytes
|
||||
from signedjson.sign import verify_signed_json
|
||||
|
@ -43,7 +43,6 @@ from synapse.crypto.event_signing import (
|
|||
add_hashes_and_signatures,
|
||||
compute_event_signature,
|
||||
)
|
||||
from synapse.events.utils import prune_event
|
||||
from synapse.events.validator import EventValidator
|
||||
from synapse.state import resolve_events_with_factory
|
||||
from synapse.types import UserID, get_domain_from_id
|
||||
|
@ -52,8 +51,8 @@ from synapse.util.async import Linearizer
|
|||
from synapse.util.distributor import user_joined_room
|
||||
from synapse.util.frozenutils import unfreeze
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.metrics import measure_func
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
from synapse.visibility import filter_events_for_server
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
|
@ -501,137 +500,6 @@ class FederationHandler(BaseHandler):
|
|||
user = UserID.from_string(event.state_key)
|
||||
yield user_joined_room(self.distributor, user, event.room_id)
|
||||
|
||||
@measure_func("_filter_events_for_server")
|
||||
@defer.inlineCallbacks
|
||||
def _filter_events_for_server(self, server_name, room_id, events):
|
||||
"""Filter the given events for the given server, redacting those the
|
||||
server can't see.
|
||||
|
||||
Assumes the server is currently in the room.
|
||||
|
||||
Returns
|
||||
list[FrozenEvent]
|
||||
"""
|
||||
# First lets check to see if all the events have a history visibility
|
||||
# of "shared" or "world_readable". If thats the case then we don't
|
||||
# need to check membership (as we know the server is in the room).
|
||||
event_to_state_ids = yield self.store.get_state_ids_for_events(
|
||||
frozenset(e.event_id for e in events),
|
||||
types=(
|
||||
(EventTypes.RoomHistoryVisibility, ""),
|
||||
)
|
||||
)
|
||||
|
||||
visibility_ids = set()
|
||||
for sids in event_to_state_ids.itervalues():
|
||||
hist = sids.get((EventTypes.RoomHistoryVisibility, ""))
|
||||
if hist:
|
||||
visibility_ids.add(hist)
|
||||
|
||||
# If we failed to find any history visibility events then the default
|
||||
# is "shared" visiblity.
|
||||
if not visibility_ids:
|
||||
defer.returnValue(events)
|
||||
|
||||
event_map = yield self.store.get_events(visibility_ids)
|
||||
all_open = all(
|
||||
e.content.get("history_visibility") in (None, "shared", "world_readable")
|
||||
for e in event_map.itervalues()
|
||||
)
|
||||
|
||||
if all_open:
|
||||
defer.returnValue(events)
|
||||
|
||||
# Ok, so we're dealing with events that have non-trivial visibility
|
||||
# rules, so we need to also get the memberships of the room.
|
||||
|
||||
event_to_state_ids = yield self.store.get_state_ids_for_events(
|
||||
frozenset(e.event_id for e in events),
|
||||
types=(
|
||||
(EventTypes.RoomHistoryVisibility, ""),
|
||||
(EventTypes.Member, None),
|
||||
)
|
||||
)
|
||||
|
||||
# We only want to pull out member events that correspond to the
|
||||
# server's domain.
|
||||
|
||||
def check_match(id):
|
||||
try:
|
||||
return server_name == get_domain_from_id(id)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
# Parses mapping `event_id -> (type, state_key) -> state event_id`
|
||||
# to get all state ids that we're interested in.
|
||||
event_map = yield self.store.get_events([
|
||||
e_id
|
||||
for key_to_eid in list(event_to_state_ids.values())
|
||||
for key, e_id in key_to_eid.items()
|
||||
if key[0] != EventTypes.Member or check_match(key[1])
|
||||
])
|
||||
|
||||
event_to_state = {
|
||||
e_id: {
|
||||
key: event_map[inner_e_id]
|
||||
for key, inner_e_id in key_to_eid.iteritems()
|
||||
if inner_e_id in event_map
|
||||
}
|
||||
for e_id, key_to_eid in event_to_state_ids.iteritems()
|
||||
}
|
||||
|
||||
erased_senders = yield self.store.are_users_erased(
|
||||
e.sender for e in events,
|
||||
)
|
||||
|
||||
def redact_disallowed(event, state):
|
||||
# if the sender has been gdpr17ed, always return a redacted
|
||||
# copy of the event.
|
||||
if erased_senders[event.sender]:
|
||||
logger.info(
|
||||
"Sender of %s has been erased, redacting",
|
||||
event.event_id,
|
||||
)
|
||||
return prune_event(event)
|
||||
|
||||
if not state:
|
||||
return event
|
||||
|
||||
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
|
||||
if history:
|
||||
visibility = history.content.get("history_visibility", "shared")
|
||||
if visibility in ["invited", "joined"]:
|
||||
# We now loop through all state events looking for
|
||||
# membership states for the requesting server to determine
|
||||
# if the server is either in the room or has been invited
|
||||
# into the room.
|
||||
for ev in state.itervalues():
|
||||
if ev.type != EventTypes.Member:
|
||||
continue
|
||||
try:
|
||||
domain = get_domain_from_id(ev.state_key)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if domain != server_name:
|
||||
continue
|
||||
|
||||
memtype = ev.membership
|
||||
if memtype == Membership.JOIN:
|
||||
return event
|
||||
elif memtype == Membership.INVITE:
|
||||
if visibility == "invited":
|
||||
return event
|
||||
else:
|
||||
return prune_event(event)
|
||||
|
||||
return event
|
||||
|
||||
defer.returnValue([
|
||||
redact_disallowed(e, event_to_state[e.event_id])
|
||||
for e in events
|
||||
])
|
||||
|
||||
@log_function
|
||||
@defer.inlineCallbacks
|
||||
def backfill(self, dest, room_id, limit, extremities):
|
||||
|
@ -863,7 +731,7 @@ class FederationHandler(BaseHandler):
|
|||
"""
|
||||
joined_users = [
|
||||
(state_key, int(event.depth))
|
||||
for (e_type, state_key), event in state.iteritems()
|
||||
for (e_type, state_key), event in iteritems(state)
|
||||
if e_type == EventTypes.Member
|
||||
and event.membership == Membership.JOIN
|
||||
]
|
||||
|
@ -880,7 +748,7 @@ class FederationHandler(BaseHandler):
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
return sorted(joined_domains.iteritems(), key=lambda d: d[1])
|
||||
return sorted(joined_domains.items(), key=lambda d: d[1])
|
||||
|
||||
curr_domains = get_domains_from_state(curr_state)
|
||||
|
||||
|
@ -943,7 +811,7 @@ class FederationHandler(BaseHandler):
|
|||
tried_domains = set(likely_domains)
|
||||
tried_domains.add(self.server_name)
|
||||
|
||||
event_ids = list(extremities.iterkeys())
|
||||
event_ids = list(extremities.keys())
|
||||
|
||||
logger.debug("calling resolve_state_groups in _maybe_backfill")
|
||||
resolve = logcontext.preserve_fn(
|
||||
|
@ -959,15 +827,15 @@ class FederationHandler(BaseHandler):
|
|||
states = dict(zip(event_ids, [s.state for s in states]))
|
||||
|
||||
state_map = yield self.store.get_events(
|
||||
[e_id for ids in states.itervalues() for e_id in ids.itervalues()],
|
||||
[e_id for ids in itervalues(states) for e_id in itervalues(ids)],
|
||||
get_prev_content=False
|
||||
)
|
||||
states = {
|
||||
key: {
|
||||
k: state_map[e_id]
|
||||
for k, e_id in state_dict.iteritems()
|
||||
for k, e_id in iteritems(state_dict)
|
||||
if e_id in state_map
|
||||
} for key, state_dict in states.iteritems()
|
||||
} for key, state_dict in iteritems(states)
|
||||
}
|
||||
|
||||
for e_id, _ in sorted_extremeties_tuple:
|
||||
|
@ -1038,16 +906,6 @@ class FederationHandler(BaseHandler):
|
|||
[auth_id for auth_id, _ in event.auth_events],
|
||||
include_given=True
|
||||
)
|
||||
|
||||
for event in auth:
|
||||
event.signatures.update(
|
||||
compute_event_signature(
|
||||
event,
|
||||
self.hs.hostname,
|
||||
self.hs.config.signing_key[0]
|
||||
)
|
||||
)
|
||||
|
||||
defer.returnValue([e for e in auth])
|
||||
|
||||
@log_function
|
||||
|
@ -1503,18 +1361,6 @@ class FederationHandler(BaseHandler):
|
|||
del results[(event.type, event.state_key)]
|
||||
|
||||
res = list(results.values())
|
||||
for event in res:
|
||||
# We sign these again because there was a bug where we
|
||||
# incorrectly signed things the first time round
|
||||
if self.is_mine_id(event.event_id):
|
||||
event.signatures.update(
|
||||
compute_event_signature(
|
||||
event,
|
||||
self.hs.hostname,
|
||||
self.hs.config.signing_key[0]
|
||||
)
|
||||
)
|
||||
|
||||
defer.returnValue(res)
|
||||
else:
|
||||
defer.returnValue([])
|
||||
|
@ -1558,7 +1404,7 @@ class FederationHandler(BaseHandler):
|
|||
limit
|
||||
)
|
||||
|
||||
events = yield self._filter_events_for_server(origin, room_id, events)
|
||||
events = yield filter_events_for_server(self.store, origin, events)
|
||||
|
||||
defer.returnValue(events)
|
||||
|
||||
|
@ -1586,18 +1432,6 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
|
||||
if event:
|
||||
if self.is_mine_id(event.event_id):
|
||||
# FIXME: This is a temporary work around where we occasionally
|
||||
# return events slightly differently than when they were
|
||||
# originally signed
|
||||
event.signatures.update(
|
||||
compute_event_signature(
|
||||
event,
|
||||
self.hs.hostname,
|
||||
self.hs.config.signing_key[0]
|
||||
)
|
||||
)
|
||||
|
||||
in_room = yield self.auth.check_host_in_room(
|
||||
event.room_id,
|
||||
origin
|
||||
|
@ -1605,8 +1439,8 @@ class FederationHandler(BaseHandler):
|
|||
if not in_room:
|
||||
raise AuthError(403, "Host not in room.")
|
||||
|
||||
events = yield self._filter_events_for_server(
|
||||
origin, event.room_id, [event]
|
||||
events = yield filter_events_for_server(
|
||||
self.store, origin, [event],
|
||||
)
|
||||
event = events[0]
|
||||
defer.returnValue(event)
|
||||
|
@ -1681,7 +1515,7 @@ class FederationHandler(BaseHandler):
|
|||
yield self.store.persist_events(
|
||||
[
|
||||
(ev_info["event"], context)
|
||||
for ev_info, context in itertools.izip(event_infos, contexts)
|
||||
for ev_info, context in zip(event_infos, contexts)
|
||||
],
|
||||
backfilled=backfilled,
|
||||
)
|
||||
|
@ -1862,15 +1696,6 @@ class FederationHandler(BaseHandler):
|
|||
local_auth_chain, remote_auth_chain
|
||||
)
|
||||
|
||||
for event in ret["auth_chain"]:
|
||||
event.signatures.update(
|
||||
compute_event_signature(
|
||||
event,
|
||||
self.hs.hostname,
|
||||
self.hs.config.signing_key[0]
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug("on_query_auth returning: %s", ret)
|
||||
|
||||
defer.returnValue(ret)
|
||||
|
@ -1896,8 +1721,8 @@ class FederationHandler(BaseHandler):
|
|||
min_depth=min_depth,
|
||||
)
|
||||
|
||||
missing_events = yield self._filter_events_for_server(
|
||||
origin, room_id, missing_events,
|
||||
missing_events = yield filter_events_for_server(
|
||||
self.store, origin, missing_events,
|
||||
)
|
||||
|
||||
defer.returnValue(missing_events)
|
||||
|
|
|
@ -33,7 +33,7 @@ from synapse.events.utils import serialize_event
|
|||
from synapse.events.validator import EventValidator
|
||||
from synapse.replication.http.send_event import send_event_to_master
|
||||
from synapse.types import RoomAlias, RoomStreamToken, UserID
|
||||
from synapse.util.async import Limiter, ReadWriteLock
|
||||
from synapse.util.async import Linearizer, ReadWriteLock
|
||||
from synapse.util.frozenutils import frozendict_json_encoder
|
||||
from synapse.util.logcontext import run_in_background
|
||||
from synapse.util.metrics import measure_func
|
||||
|
@ -427,7 +427,7 @@ class EventCreationHandler(object):
|
|||
|
||||
# We arbitrarily limit concurrent event creation for a room to 5.
|
||||
# This is to stop us from diverging history *too* much.
|
||||
self.limiter = Limiter(max_count=5)
|
||||
self.limiter = Linearizer(max_count=5, name="room_event_creation_limit")
|
||||
|
||||
self.action_generator = hs.get_action_generator()
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
|
|||
|
||||
|
||||
# This is used to indicate we should only return rooms published to the main list.
|
||||
EMTPY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
|
||||
EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
|
||||
|
||||
|
||||
class RoomListHandler(BaseHandler):
|
||||
|
@ -50,7 +50,7 @@ class RoomListHandler(BaseHandler):
|
|||
|
||||
def get_local_public_room_list(self, limit=None, since_token=None,
|
||||
search_filter=None,
|
||||
network_tuple=EMTPY_THIRD_PARTY_ID,):
|
||||
network_tuple=EMPTY_THIRD_PARTY_ID,):
|
||||
"""Generate a local public room list.
|
||||
|
||||
There are multiple different lists: the main one plus one per third
|
||||
|
@ -87,7 +87,7 @@ class RoomListHandler(BaseHandler):
|
|||
@defer.inlineCallbacks
|
||||
def _get_public_room_list(self, limit=None, since_token=None,
|
||||
search_filter=None,
|
||||
network_tuple=EMTPY_THIRD_PARTY_ID,):
|
||||
network_tuple=EMPTY_THIRD_PARTY_ID,):
|
||||
if since_token and since_token != "END":
|
||||
since_token = RoomListNextBatch.from_token(since_token)
|
||||
else:
|
||||
|
|
|
@ -26,9 +26,11 @@ from OpenSSL.SSL import VERIFY_NONE
|
|||
from twisted.internet import defer, protocol, reactor, ssl, task
|
||||
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
|
||||
from twisted.web._newclient import ResponseDone
|
||||
from twisted.web.client import Agent, BrowserLikeRedirectAgent, ContentDecoderAgent
|
||||
from twisted.web.client import FileBodyProducer as TwistedFileBodyProducer
|
||||
from twisted.web.client import (
|
||||
Agent,
|
||||
BrowserLikeRedirectAgent,
|
||||
ContentDecoderAgent,
|
||||
FileBodyProducer as TwistedFileBodyProducer,
|
||||
GzipDecoder,
|
||||
HTTPConnectionPool,
|
||||
PartialDownloadError,
|
||||
|
|
|
@ -38,7 +38,8 @@ outgoing_responses_counter = Counter(
|
|||
)
|
||||
|
||||
response_timer = Histogram(
|
||||
"synapse_http_server_response_time_seconds", "sec", ["method", "servlet", "tag"]
|
||||
"synapse_http_server_response_time_seconds", "sec",
|
||||
["method", "servlet", "tag", "code"],
|
||||
)
|
||||
|
||||
response_ru_utime = Counter(
|
||||
|
@ -171,11 +172,13 @@ class RequestMetrics(object):
|
|||
)
|
||||
return
|
||||
|
||||
outgoing_responses_counter.labels(request.method, str(request.code)).inc()
|
||||
response_code = str(request.code)
|
||||
|
||||
outgoing_responses_counter.labels(request.method, response_code).inc()
|
||||
|
||||
response_count.labels(request.method, self.name, tag).inc()
|
||||
|
||||
response_timer.labels(request.method, self.name, tag).observe(
|
||||
response_timer.labels(request.method, self.name, tag, response_code).observe(
|
||||
time_sec - self.start
|
||||
)
|
||||
|
||||
|
|
|
@ -206,7 +206,7 @@ def parse_json_object_from_request(request, allow_empty_body=False):
|
|||
return content
|
||||
|
||||
|
||||
def assert_params_in_request(body, required):
|
||||
def assert_params_in_dict(body, required):
|
||||
absent = []
|
||||
for k in required:
|
||||
if k not in body:
|
||||
|
|
|
@ -20,7 +20,7 @@ from twisted.web.server import Request, Site
|
|||
|
||||
from synapse.http import redact_uri
|
||||
from synapse.http.request_metrics import RequestMetrics
|
||||
from synapse.util.logcontext import LoggingContext, ContextResourceUsage
|
||||
from synapse.util.logcontext import ContextResourceUsage, LoggingContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -42,9 +42,10 @@ class SynapseRequest(Request):
|
|||
which is handling the request, and returns a context manager.
|
||||
|
||||
"""
|
||||
def __init__(self, site, *args, **kw):
|
||||
Request.__init__(self, *args, **kw)
|
||||
def __init__(self, site, channel, *args, **kw):
|
||||
Request.__init__(self, channel, *args, **kw)
|
||||
self.site = site
|
||||
self._channel = channel
|
||||
self.authenticated_entity = None
|
||||
self.start_time = 0
|
||||
|
||||
|
|
179
synapse/metrics/background_process_metrics.py
Normal file
179
synapse/metrics/background_process_metrics.py
Normal file
|
@ -0,0 +1,179 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import six
|
||||
|
||||
from prometheus_client.core import REGISTRY, Counter, GaugeMetricFamily
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
||||
|
||||
_background_process_start_count = Counter(
|
||||
"synapse_background_process_start_count",
|
||||
"Number of background processes started",
|
||||
["name"],
|
||||
)
|
||||
|
||||
# we set registry=None in all of these to stop them getting registered with
|
||||
# the default registry. Instead we collect them all via the CustomCollector,
|
||||
# which ensures that we can update them before they are collected.
|
||||
#
|
||||
_background_process_ru_utime = Counter(
|
||||
"synapse_background_process_ru_utime_seconds",
|
||||
"User CPU time used by background processes, in seconds",
|
||||
["name"],
|
||||
registry=None,
|
||||
)
|
||||
|
||||
_background_process_ru_stime = Counter(
|
||||
"synapse_background_process_ru_stime_seconds",
|
||||
"System CPU time used by background processes, in seconds",
|
||||
["name"],
|
||||
registry=None,
|
||||
)
|
||||
|
||||
_background_process_db_txn_count = Counter(
|
||||
"synapse_background_process_db_txn_count",
|
||||
"Number of database transactions done by background processes",
|
||||
["name"],
|
||||
registry=None,
|
||||
)
|
||||
|
||||
_background_process_db_txn_duration = Counter(
|
||||
"synapse_background_process_db_txn_duration_seconds",
|
||||
("Seconds spent by background processes waiting for database "
|
||||
"transactions, excluding scheduling time"),
|
||||
["name"],
|
||||
registry=None,
|
||||
)
|
||||
|
||||
_background_process_db_sched_duration = Counter(
|
||||
"synapse_background_process_db_sched_duration_seconds",
|
||||
"Seconds spent by background processes waiting for database connections",
|
||||
["name"],
|
||||
registry=None,
|
||||
)
|
||||
|
||||
# map from description to a counter, so that we can name our logcontexts
|
||||
# incrementally. (It actually duplicates _background_process_start_count, but
|
||||
# it's much simpler to do so than to try to combine them.)
|
||||
_background_process_counts = dict() # type: dict[str, int]
|
||||
|
||||
# map from description to the currently running background processes.
|
||||
#
|
||||
# it's kept as a dict of sets rather than a big set so that we can keep track
|
||||
# of process descriptions that no longer have any active processes.
|
||||
_background_processes = dict() # type: dict[str, set[_BackgroundProcess]]
|
||||
|
||||
|
||||
class _Collector(object):
|
||||
"""A custom metrics collector for the background process metrics.
|
||||
|
||||
Ensures that all of the metrics are up-to-date with any in-flight processes
|
||||
before they are returned.
|
||||
"""
|
||||
def collect(self):
|
||||
background_process_in_flight_count = GaugeMetricFamily(
|
||||
"synapse_background_process_in_flight_count",
|
||||
"Number of background processes in flight",
|
||||
labels=["name"],
|
||||
)
|
||||
|
||||
for desc, processes in six.iteritems(_background_processes):
|
||||
background_process_in_flight_count.add_metric(
|
||||
(desc,), len(processes),
|
||||
)
|
||||
for process in processes:
|
||||
process.update_metrics()
|
||||
|
||||
yield background_process_in_flight_count
|
||||
|
||||
# now we need to run collect() over each of the static Counters, and
|
||||
# yield each metric they return.
|
||||
for m in (
|
||||
_background_process_ru_utime,
|
||||
_background_process_ru_stime,
|
||||
_background_process_db_txn_count,
|
||||
_background_process_db_txn_duration,
|
||||
_background_process_db_sched_duration,
|
||||
):
|
||||
for r in m.collect():
|
||||
yield r
|
||||
|
||||
|
||||
REGISTRY.register(_Collector())
|
||||
|
||||
|
||||
class _BackgroundProcess(object):
|
||||
def __init__(self, desc, ctx):
|
||||
self.desc = desc
|
||||
self._context = ctx
|
||||
self._reported_stats = None
|
||||
|
||||
def update_metrics(self):
|
||||
"""Updates the metrics with values from this process."""
|
||||
new_stats = self._context.get_resource_usage()
|
||||
if self._reported_stats is None:
|
||||
diff = new_stats
|
||||
else:
|
||||
diff = new_stats - self._reported_stats
|
||||
self._reported_stats = new_stats
|
||||
|
||||
_background_process_ru_utime.labels(self.desc).inc(diff.ru_utime)
|
||||
_background_process_ru_stime.labels(self.desc).inc(diff.ru_stime)
|
||||
_background_process_db_txn_count.labels(self.desc).inc(
|
||||
diff.db_txn_count,
|
||||
)
|
||||
_background_process_db_txn_duration.labels(self.desc).inc(
|
||||
diff.db_txn_duration_sec,
|
||||
)
|
||||
_background_process_db_sched_duration.labels(self.desc).inc(
|
||||
diff.db_sched_duration_sec,
|
||||
)
|
||||
|
||||
|
||||
def run_as_background_process(desc, func, *args, **kwargs):
|
||||
"""Run the given function in its own logcontext, with resource metrics
|
||||
|
||||
This should be used to wrap processes which are fired off to run in the
|
||||
background, instead of being associated with a particular request.
|
||||
|
||||
Args:
|
||||
desc (str): a description for this background process type
|
||||
func: a function, which may return a Deferred
|
||||
args: positional args for func
|
||||
kwargs: keyword args for func
|
||||
|
||||
Returns: None
|
||||
"""
|
||||
@defer.inlineCallbacks
|
||||
def run():
|
||||
count = _background_process_counts.get(desc, 0)
|
||||
_background_process_counts[desc] = count + 1
|
||||
_background_process_start_count.labels(desc).inc()
|
||||
|
||||
with LoggingContext(desc) as context:
|
||||
context.request = "%s-%i" % (desc, count)
|
||||
proc = _BackgroundProcess(desc, context)
|
||||
_background_processes.setdefault(desc, set()).add(proc)
|
||||
try:
|
||||
yield func(*args, **kwargs)
|
||||
finally:
|
||||
proc.update_metrics()
|
||||
_background_processes[desc].remove(proc)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
run()
|
|
@ -49,7 +49,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
|
|||
|
||||
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
|
||||
self.get_receipts_for_user.invalidate((user_id, receipt_type))
|
||||
self.get_linearized_receipts_for_room.invalidate_many((room_id,))
|
||||
self._get_linearized_receipts_for_room.invalidate_many((room_id,))
|
||||
self.get_last_receipt_event_id_for_user.invalidate(
|
||||
(user_id, room_id, receipt_type)
|
||||
)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -13,13 +14,24 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from six import PY3
|
||||
|
||||
from synapse.http.server import JsonResource
|
||||
from synapse.rest.client import versions
|
||||
from synapse.rest.client.v1 import admin, directory, events, initial_sync
|
||||
from synapse.rest.client.v1 import login as v1_login
|
||||
from synapse.rest.client.v1 import logout, presence, profile, push_rule, pusher
|
||||
from synapse.rest.client.v1 import register as v1_register
|
||||
from synapse.rest.client.v1 import room, voip
|
||||
from synapse.rest.client.v1 import (
|
||||
admin,
|
||||
directory,
|
||||
events,
|
||||
initial_sync,
|
||||
login as v1_login,
|
||||
logout,
|
||||
presence,
|
||||
profile,
|
||||
push_rule,
|
||||
pusher,
|
||||
room,
|
||||
voip,
|
||||
)
|
||||
from synapse.rest.client.v2_alpha import (
|
||||
account,
|
||||
account_data,
|
||||
|
@ -42,6 +54,11 @@ from synapse.rest.client.v2_alpha import (
|
|||
user_directory,
|
||||
)
|
||||
|
||||
if not PY3:
|
||||
from synapse.rest.client.v1_only import (
|
||||
register as v1_register,
|
||||
)
|
||||
|
||||
|
||||
class ClientRestResource(JsonResource):
|
||||
"""A resource for version 1 of the matrix client API."""
|
||||
|
@ -54,14 +71,22 @@ class ClientRestResource(JsonResource):
|
|||
def register_servlets(client_resource, hs):
|
||||
versions.register_servlets(client_resource)
|
||||
|
||||
# "v1"
|
||||
room.register_servlets(hs, client_resource)
|
||||
if not PY3:
|
||||
# "v1" (Python 2 only)
|
||||
v1_register.register_servlets(hs, client_resource)
|
||||
|
||||
# Deprecated in r0
|
||||
initial_sync.register_servlets(hs, client_resource)
|
||||
room.register_deprecated_servlets(hs, client_resource)
|
||||
|
||||
# Partially deprecated in r0
|
||||
events.register_servlets(hs, client_resource)
|
||||
v1_register.register_servlets(hs, client_resource)
|
||||
|
||||
# "v1" + "r0"
|
||||
room.register_servlets(hs, client_resource)
|
||||
v1_login.register_servlets(hs, client_resource)
|
||||
profile.register_servlets(hs, client_resource)
|
||||
presence.register_servlets(hs, client_resource)
|
||||
initial_sync.register_servlets(hs, client_resource)
|
||||
directory.register_servlets(hs, client_resource)
|
||||
voip.register_servlets(hs, client_resource)
|
||||
admin.register_servlets(hs, client_resource)
|
||||
|
|
|
@ -17,38 +17,20 @@
|
|||
to ensure idempotency when performing PUTs using the REST API."""
|
||||
import logging
|
||||
|
||||
from synapse.api.auth import get_access_token_from_request
|
||||
from synapse.util.async import ObservableDeferred
|
||||
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_transaction_key(request):
|
||||
"""A helper function which returns a transaction key that can be used
|
||||
with TransactionCache for idempotent requests.
|
||||
|
||||
Idempotency is based on the returned key being the same for separate
|
||||
requests to the same endpoint. The key is formed from the HTTP request
|
||||
path and the access_token for the requesting user.
|
||||
|
||||
Args:
|
||||
request (twisted.web.http.Request): The incoming request. Must
|
||||
contain an access_token.
|
||||
Returns:
|
||||
str: A transaction key
|
||||
"""
|
||||
token = get_access_token_from_request(request)
|
||||
return request.path + "/" + token
|
||||
|
||||
|
||||
CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins
|
||||
|
||||
|
||||
class HttpTransactionCache(object):
|
||||
|
||||
def __init__(self, clock):
|
||||
self.clock = clock
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self.auth = self.hs.get_auth()
|
||||
self.clock = self.hs.get_clock()
|
||||
self.transactions = {
|
||||
# $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp)
|
||||
}
|
||||
|
@ -56,6 +38,23 @@ class HttpTransactionCache(object):
|
|||
# for at *LEAST* 30 mins, and at *MOST* 60 mins.
|
||||
self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS)
|
||||
|
||||
def _get_transaction_key(self, request):
|
||||
"""A helper function which returns a transaction key that can be used
|
||||
with TransactionCache for idempotent requests.
|
||||
|
||||
Idempotency is based on the returned key being the same for separate
|
||||
requests to the same endpoint. The key is formed from the HTTP request
|
||||
path and the access_token for the requesting user.
|
||||
|
||||
Args:
|
||||
request (twisted.web.http.Request): The incoming request. Must
|
||||
contain an access_token.
|
||||
Returns:
|
||||
str: A transaction key
|
||||
"""
|
||||
token = self.auth.get_access_token_from_request(request)
|
||||
return request.path + "/" + token
|
||||
|
||||
def fetch_or_execute_request(self, request, fn, *args, **kwargs):
|
||||
"""A helper function for fetch_or_execute which extracts
|
||||
a transaction key from the given request.
|
||||
|
@ -64,7 +63,7 @@ class HttpTransactionCache(object):
|
|||
fetch_or_execute
|
||||
"""
|
||||
return self.fetch_or_execute(
|
||||
get_transaction_key(request), fn, *args, **kwargs
|
||||
self._get_transaction_key(request), fn, *args, **kwargs
|
||||
)
|
||||
|
||||
def fetch_or_execute(self, txn_key, fn, *args, **kwargs):
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
|
||||
from six.moves import http_client
|
||||
|
@ -22,7 +24,12 @@ from twisted.internet import defer
|
|||
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
|
||||
from synapse.http.servlet import parse_json_object_from_request
|
||||
from synapse.http.servlet import (
|
||||
assert_params_in_dict,
|
||||
parse_integer,
|
||||
parse_json_object_from_request,
|
||||
parse_string,
|
||||
)
|
||||
from synapse.types import UserID, create_requester
|
||||
|
||||
from .base import ClientV1RestServlet, client_path_patterns
|
||||
|
@ -58,6 +65,125 @@ class UsersRestServlet(ClientV1RestServlet):
|
|||
defer.returnValue((200, ret))
|
||||
|
||||
|
||||
class UserRegisterServlet(ClientV1RestServlet):
|
||||
"""
|
||||
Attributes:
|
||||
NONCE_TIMEOUT (int): Seconds until a generated nonce won't be accepted
|
||||
nonces (dict[str, int]): The nonces that we will accept. A dict of
|
||||
nonce to the time it was generated, in int seconds.
|
||||
"""
|
||||
PATTERNS = client_path_patterns("/admin/register")
|
||||
NONCE_TIMEOUT = 60
|
||||
|
||||
def __init__(self, hs):
|
||||
super(UserRegisterServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
self.reactor = hs.get_reactor()
|
||||
self.nonces = {}
|
||||
self.hs = hs
|
||||
|
||||
def _clear_old_nonces(self):
|
||||
"""
|
||||
Clear out old nonces that are older than NONCE_TIMEOUT.
|
||||
"""
|
||||
now = int(self.reactor.seconds())
|
||||
|
||||
for k, v in list(self.nonces.items()):
|
||||
if now - v > self.NONCE_TIMEOUT:
|
||||
del self.nonces[k]
|
||||
|
||||
def on_GET(self, request):
|
||||
"""
|
||||
Generate a new nonce.
|
||||
"""
|
||||
self._clear_old_nonces()
|
||||
|
||||
nonce = self.hs.get_secrets().token_hex(64)
|
||||
self.nonces[nonce] = int(self.reactor.seconds())
|
||||
return (200, {"nonce": nonce.encode('ascii')})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
self._clear_old_nonces()
|
||||
|
||||
if not self.hs.config.registration_shared_secret:
|
||||
raise SynapseError(400, "Shared secret registration is not enabled")
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
if "nonce" not in body:
|
||||
raise SynapseError(
|
||||
400, "nonce must be specified", errcode=Codes.BAD_JSON,
|
||||
)
|
||||
|
||||
nonce = body["nonce"]
|
||||
|
||||
if nonce not in self.nonces:
|
||||
raise SynapseError(
|
||||
400, "unrecognised nonce",
|
||||
)
|
||||
|
||||
# Delete the nonce, so it can't be reused, even if it's invalid
|
||||
del self.nonces[nonce]
|
||||
|
||||
if "username" not in body:
|
||||
raise SynapseError(
|
||||
400, "username must be specified", errcode=Codes.BAD_JSON,
|
||||
)
|
||||
else:
|
||||
if (not isinstance(body['username'], str) or len(body['username']) > 512):
|
||||
raise SynapseError(400, "Invalid username")
|
||||
|
||||
username = body["username"].encode("utf-8")
|
||||
if b"\x00" in username:
|
||||
raise SynapseError(400, "Invalid username")
|
||||
|
||||
if "password" not in body:
|
||||
raise SynapseError(
|
||||
400, "password must be specified", errcode=Codes.BAD_JSON,
|
||||
)
|
||||
else:
|
||||
if (not isinstance(body['password'], str) or len(body['password']) > 512):
|
||||
raise SynapseError(400, "Invalid password")
|
||||
|
||||
password = body["password"].encode("utf-8")
|
||||
if b"\x00" in password:
|
||||
raise SynapseError(400, "Invalid password")
|
||||
|
||||
admin = body.get("admin", None)
|
||||
got_mac = body["mac"]
|
||||
|
||||
want_mac = hmac.new(
|
||||
key=self.hs.config.registration_shared_secret.encode(),
|
||||
digestmod=hashlib.sha1,
|
||||
)
|
||||
want_mac.update(nonce)
|
||||
want_mac.update(b"\x00")
|
||||
want_mac.update(username)
|
||||
want_mac.update(b"\x00")
|
||||
want_mac.update(password)
|
||||
want_mac.update(b"\x00")
|
||||
want_mac.update(b"admin" if admin else b"notadmin")
|
||||
want_mac = want_mac.hexdigest()
|
||||
|
||||
if not hmac.compare_digest(want_mac, got_mac):
|
||||
raise SynapseError(
|
||||
403, "HMAC incorrect",
|
||||
)
|
||||
|
||||
# Reuse the parts of RegisterRestServlet to reduce code duplication
|
||||
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
|
||||
register = RegisterRestServlet(self.hs)
|
||||
|
||||
(user_id, _) = yield register.registration_handler.register(
|
||||
localpart=username.lower(), password=password, admin=bool(admin),
|
||||
generate_token=False,
|
||||
)
|
||||
|
||||
result = yield register._create_registration_details(user_id, body)
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
||||
class WhoisRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)")
|
||||
|
||||
|
@ -98,16 +224,8 @@ class PurgeMediaCacheRestServlet(ClientV1RestServlet):
|
|||
if not is_admin:
|
||||
raise AuthError(403, "You are not a server admin")
|
||||
|
||||
before_ts = request.args.get("before_ts", None)
|
||||
if not before_ts:
|
||||
raise SynapseError(400, "Missing 'before_ts' arg")
|
||||
|
||||
logger.info("before_ts: %r", before_ts[0])
|
||||
|
||||
try:
|
||||
before_ts = int(before_ts[0])
|
||||
except Exception:
|
||||
raise SynapseError(400, "Invalid 'before_ts' arg")
|
||||
before_ts = parse_integer(request, "before_ts", required=True)
|
||||
logger.info("before_ts: %r", before_ts)
|
||||
|
||||
ret = yield self.media_repository.delete_old_remote_media(before_ts)
|
||||
|
||||
|
@ -300,10 +418,8 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
|
|||
raise AuthError(403, "You are not a server admin")
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
new_room_user_id = content.get("new_room_user_id")
|
||||
if not new_room_user_id:
|
||||
raise SynapseError(400, "Please provide field `new_room_user_id`")
|
||||
assert_params_in_dict(content, ["new_room_user_id"])
|
||||
new_room_user_id = content["new_room_user_id"]
|
||||
|
||||
room_creator_requester = create_requester(new_room_user_id)
|
||||
|
||||
|
@ -464,9 +580,8 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
|
|||
raise AuthError(403, "You are not a server admin")
|
||||
|
||||
params = parse_json_object_from_request(request)
|
||||
assert_params_in_dict(params, ["new_password"])
|
||||
new_password = params['new_password']
|
||||
if not new_password:
|
||||
raise SynapseError(400, "Missing 'new_password' arg")
|
||||
|
||||
logger.info("new_password: %r", new_password)
|
||||
|
||||
|
@ -514,12 +629,9 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet):
|
|||
raise SynapseError(400, "Can only users a local user")
|
||||
|
||||
order = "name" # order by name in user table
|
||||
start = request.args.get("start")[0]
|
||||
limit = request.args.get("limit")[0]
|
||||
if not limit:
|
||||
raise SynapseError(400, "Missing 'limit' arg")
|
||||
if not start:
|
||||
raise SynapseError(400, "Missing 'start' arg")
|
||||
start = parse_integer(request, "start", required=True)
|
||||
limit = parse_integer(request, "limit", required=True)
|
||||
|
||||
logger.info("limit: %s, start: %s", limit, start)
|
||||
|
||||
ret = yield self.handlers.admin_handler.get_users_paginate(
|
||||
|
@ -551,12 +663,9 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet):
|
|||
|
||||
order = "name" # order by name in user table
|
||||
params = parse_json_object_from_request(request)
|
||||
assert_params_in_dict(params, ["limit", "start"])
|
||||
limit = params['limit']
|
||||
start = params['start']
|
||||
if not limit:
|
||||
raise SynapseError(400, "Missing 'limit' arg")
|
||||
if not start:
|
||||
raise SynapseError(400, "Missing 'start' arg")
|
||||
logger.info("limit: %s, start: %s", limit, start)
|
||||
|
||||
ret = yield self.handlers.admin_handler.get_users_paginate(
|
||||
|
@ -604,10 +713,7 @@ class SearchUsersRestServlet(ClientV1RestServlet):
|
|||
if not self.hs.is_mine(target_user):
|
||||
raise SynapseError(400, "Can only users a local user")
|
||||
|
||||
term = request.args.get("term")[0]
|
||||
if not term:
|
||||
raise SynapseError(400, "Missing 'term' arg")
|
||||
|
||||
term = parse_string(request, "term", required=True)
|
||||
logger.info("term: %s ", term)
|
||||
|
||||
ret = yield self.handlers.admin_handler.search_users(
|
||||
|
@ -629,3 +735,4 @@ def register_servlets(hs, http_server):
|
|||
ShutdownRoomRestServlet(hs).register(http_server)
|
||||
QuarantineMediaInRoom(hs).register(http_server)
|
||||
ListMediaInRoom(hs).register(http_server)
|
||||
UserRegisterServlet(hs).register(http_server)
|
||||
|
|
|
@ -62,4 +62,4 @@ class ClientV1RestServlet(RestServlet):
|
|||
self.hs = hs
|
||||
self.builder_factory = hs.get_event_builder_factory()
|
||||
self.auth = hs.get_auth()
|
||||
self.txns = HttpTransactionCache(hs.get_clock())
|
||||
self.txns = HttpTransactionCache(hs)
|
||||
|
|
|
@ -52,15 +52,14 @@ class ClientDirectoryServer(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, room_alias):
|
||||
room_alias = RoomAlias.from_string(room_alias)
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
if "room_id" not in content:
|
||||
raise SynapseError(400, "Missing room_id key",
|
||||
raise SynapseError(400, 'Missing params: ["room_id"]',
|
||||
errcode=Codes.BAD_JSON)
|
||||
|
||||
logger.debug("Got content: %s", content)
|
||||
|
||||
room_alias = RoomAlias.from_string(room_alias)
|
||||
|
||||
logger.debug("Got room name: %s", room_alias.to_string())
|
||||
|
||||
room_id = content["room_id"]
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.http.servlet import parse_boolean
|
||||
from synapse.streams.config import PaginationConfig
|
||||
|
||||
from .base import ClientV1RestServlet, client_path_patterns
|
||||
|
@ -33,7 +34,7 @@ class InitialSyncRestServlet(ClientV1RestServlet):
|
|||
requester = yield self.auth.get_user_by_req(request)
|
||||
as_client_event = "raw" not in request.args
|
||||
pagination_config = PaginationConfig.from_request(request)
|
||||
include_archived = request.args.get("archived", None) == ["true"]
|
||||
include_archived = parse_boolean(request, "archived", default=False)
|
||||
content = yield self.initial_sync_handler.snapshot_all_rooms(
|
||||
user_id=requester.user.to_string(),
|
||||
pagin_config=pagination_config,
|
||||
|
|
|
@ -17,7 +17,6 @@ import logging
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.auth import get_access_token_from_request
|
||||
from synapse.api.errors import AuthError
|
||||
|
||||
from .base import ClientV1RestServlet, client_path_patterns
|
||||
|
@ -51,7 +50,7 @@ class LogoutRestServlet(ClientV1RestServlet):
|
|||
if requester.device_id is None:
|
||||
# the acccess token wasn't associated with a device.
|
||||
# Just delete the access token
|
||||
access_token = get_access_token_from_request(request)
|
||||
access_token = self._auth.get_access_token_from_request(request)
|
||||
yield self._auth_handler.delete_access_token(access_token)
|
||||
else:
|
||||
yield self._device_handler.delete_device(
|
||||
|
|
|
@ -21,7 +21,7 @@ from synapse.api.errors import (
|
|||
SynapseError,
|
||||
UnrecognizedRequestError,
|
||||
)
|
||||
from synapse.http.servlet import parse_json_value_from_request
|
||||
from synapse.http.servlet import parse_json_value_from_request, parse_string
|
||||
from synapse.push.baserules import BASE_RULE_IDS
|
||||
from synapse.push.clientformat import format_push_rules_for_user
|
||||
from synapse.push.rulekinds import PRIORITY_CLASS_MAP
|
||||
|
@ -75,13 +75,13 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
|||
except InvalidRuleException as e:
|
||||
raise SynapseError(400, e.message)
|
||||
|
||||
before = request.args.get("before", None)
|
||||
before = parse_string(request, "before")
|
||||
if before:
|
||||
before = _namespaced_rule_id(spec, before[0])
|
||||
before = _namespaced_rule_id(spec, before)
|
||||
|
||||
after = request.args.get("after", None)
|
||||
after = parse_string(request, "after")
|
||||
if after:
|
||||
after = _namespaced_rule_id(spec, after[0])
|
||||
after = _namespaced_rule_id(spec, after)
|
||||
|
||||
try:
|
||||
yield self.store.add_push_rule(
|
||||
|
|
|
@ -21,6 +21,7 @@ from synapse.api.errors import Codes, StoreError, SynapseError
|
|||
from synapse.http.server import finish_request
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
assert_params_in_dict,
|
||||
parse_json_object_from_request,
|
||||
parse_string,
|
||||
)
|
||||
|
@ -91,15 +92,11 @@ class PushersSetRestServlet(ClientV1RestServlet):
|
|||
)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
reqd = ['kind', 'app_id', 'app_display_name',
|
||||
'device_display_name', 'pushkey', 'lang', 'data']
|
||||
missing = []
|
||||
for i in reqd:
|
||||
if i not in content:
|
||||
missing.append(i)
|
||||
if len(missing):
|
||||
raise SynapseError(400, "Missing parameters: " + ','.join(missing),
|
||||
errcode=Codes.MISSING_PARAM)
|
||||
assert_params_in_dict(
|
||||
content,
|
||||
['kind', 'app_id', 'app_display_name',
|
||||
'device_display_name', 'pushkey', 'lang', 'data']
|
||||
)
|
||||
|
||||
logger.debug("set pushkey %s to kind %s", content['pushkey'], content['kind'])
|
||||
logger.debug("Got pushers request with body: %r", content)
|
||||
|
@ -148,7 +145,7 @@ class PushersRemoveRestServlet(RestServlet):
|
|||
SUCCESS_HTML = "<html><body>You have been unsubscribed</body><html>"
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RestServlet, self).__init__()
|
||||
super(PushersRemoveRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.notifier = hs.get_notifier()
|
||||
self.auth = hs.get_auth()
|
||||
|
|
|
@ -28,6 +28,7 @@ from synapse.api.errors import AuthError, Codes, SynapseError
|
|||
from synapse.api.filtering import Filter
|
||||
from synapse.events.utils import format_event_for_client_v2, serialize_event
|
||||
from synapse.http.servlet import (
|
||||
assert_params_in_dict,
|
||||
parse_integer,
|
||||
parse_json_object_from_request,
|
||||
parse_string,
|
||||
|
@ -435,9 +436,9 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
|
|||
request, default_limit=10,
|
||||
)
|
||||
as_client_event = "raw" not in request.args
|
||||
filter_bytes = request.args.get("filter", None)
|
||||
filter_bytes = parse_string(request, "filter")
|
||||
if filter_bytes:
|
||||
filter_json = urlparse.unquote(filter_bytes[-1]).decode("UTF-8")
|
||||
filter_json = urlparse.unquote(filter_bytes).decode("UTF-8")
|
||||
event_filter = Filter(json.loads(filter_json))
|
||||
else:
|
||||
event_filter = None
|
||||
|
@ -530,7 +531,7 @@ class RoomEventContextServlet(ClientV1RestServlet):
|
|||
def on_GET(self, request, room_id, event_id):
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
limit = int(request.args.get("limit", [10])[0])
|
||||
limit = parse_integer(request, "limit", default=10)
|
||||
|
||||
results = yield self.handlers.room_context_handler.get_event_context(
|
||||
requester.user,
|
||||
|
@ -636,8 +637,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
|
|||
|
||||
target = requester.user
|
||||
if membership_action in ["invite", "ban", "unban", "kick"]:
|
||||
if "user_id" not in content:
|
||||
raise SynapseError(400, "Missing user_id key.")
|
||||
assert_params_in_dict(content, ["user_id"])
|
||||
target = UserID.from_string(content["user_id"])
|
||||
|
||||
event_content = None
|
||||
|
@ -764,7 +764,7 @@ class SearchRestServlet(ClientV1RestServlet):
|
|||
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
batch = request.args.get("next_batch", [None])[0]
|
||||
batch = parse_string(request, "next_batch")
|
||||
results = yield self.handlers.search_handler.search(
|
||||
requester.user,
|
||||
content,
|
||||
|
@ -832,10 +832,13 @@ def register_servlets(hs, http_server):
|
|||
RoomSendEventRestServlet(hs).register(http_server)
|
||||
PublicRoomListRestServlet(hs).register(http_server)
|
||||
RoomStateRestServlet(hs).register(http_server)
|
||||
RoomInitialSyncRestServlet(hs).register(http_server)
|
||||
RoomRedactEventRestServlet(hs).register(http_server)
|
||||
RoomTypingRestServlet(hs).register(http_server)
|
||||
SearchRestServlet(hs).register(http_server)
|
||||
JoinedRoomsRestServlet(hs).register(http_server)
|
||||
RoomEventServlet(hs).register(http_server)
|
||||
RoomEventContextServlet(hs).register(http_server)
|
||||
|
||||
|
||||
def register_deprecated_servlets(hs, http_server):
|
||||
RoomInitialSyncRestServlet(hs).register(http_server)
|
||||
|
|
3
synapse/rest/client/v1_only/__init__.py
Normal file
3
synapse/rest/client/v1_only/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
|||
"""
|
||||
REST APIs that are only used in v1 (the legacy API).
|
||||
"""
|
39
synapse/rest/client/v1_only/base.py
Normal file
39
synapse/rest/client/v1_only/base.py
Normal file
|
@ -0,0 +1,39 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""This module contains base REST classes for constructing client v1 servlets.
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
from synapse.api.urls import CLIENT_PREFIX
|
||||
|
||||
|
||||
def v1_only_client_path_patterns(path_regex, include_in_unstable=True):
|
||||
"""Creates a regex compiled client path with the correct client path
|
||||
prefix.
|
||||
|
||||
Args:
|
||||
path_regex (str): The regex string to match. This should NOT have a ^
|
||||
as this will be prefixed.
|
||||
Returns:
|
||||
list of SRE_Pattern
|
||||
"""
|
||||
patterns = [re.compile("^" + CLIENT_PREFIX + path_regex)]
|
||||
if include_in_unstable:
|
||||
unstable_prefix = CLIENT_PREFIX.replace("/api/v1", "/unstable")
|
||||
patterns.append(re.compile("^" + unstable_prefix + path_regex))
|
||||
return patterns
|
|
@ -18,18 +18,16 @@ import hmac
|
|||
import logging
|
||||
from hashlib import sha1
|
||||
|
||||
from six import string_types
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse.util.stringutils as stringutils
|
||||
from synapse.api.auth import get_access_token_from_request
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http.servlet import parse_json_object_from_request
|
||||
from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
|
||||
from synapse.rest.client.v1.base import ClientV1RestServlet
|
||||
from synapse.types import create_requester
|
||||
|
||||
from .base import ClientV1RestServlet, client_path_patterns
|
||||
from .base import v1_only_client_path_patterns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -52,7 +50,7 @@ class RegisterRestServlet(ClientV1RestServlet):
|
|||
handler doesn't have a concept of multi-stages or sessions.
|
||||
"""
|
||||
|
||||
PATTERNS = client_path_patterns("/register$", releases=(), include_in_unstable=False)
|
||||
PATTERNS = v1_only_client_path_patterns("/register$", include_in_unstable=False)
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
|
@ -67,6 +65,7 @@ class RegisterRestServlet(ClientV1RestServlet):
|
|||
# TODO: persistent storage
|
||||
self.sessions = {}
|
||||
self.enable_registration = hs.config.enable_registration
|
||||
self.auth = hs.get_auth()
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
|
@ -124,8 +123,7 @@ class RegisterRestServlet(ClientV1RestServlet):
|
|||
session = (register_json["session"]
|
||||
if "session" in register_json else None)
|
||||
login_type = None
|
||||
if "type" not in register_json:
|
||||
raise SynapseError(400, "Missing 'type' key.")
|
||||
assert_params_in_dict(register_json, ["type"])
|
||||
|
||||
try:
|
||||
login_type = register_json["type"]
|
||||
|
@ -310,11 +308,9 @@ class RegisterRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def _do_app_service(self, request, register_json, session):
|
||||
as_token = get_access_token_from_request(request)
|
||||
|
||||
if "user" not in register_json:
|
||||
raise SynapseError(400, "Expected 'user' key.")
|
||||
as_token = self.auth.get_access_token_from_request(request)
|
||||
|
||||
assert_params_in_dict(register_json, ["user"])
|
||||
user_localpart = register_json["user"].encode("utf-8")
|
||||
|
||||
handler = self.handlers.registration_handler
|
||||
|
@ -331,12 +327,7 @@ class RegisterRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def _do_shared_secret(self, request, register_json, session):
|
||||
if not isinstance(register_json.get("mac", None), string_types):
|
||||
raise SynapseError(400, "Expected mac.")
|
||||
if not isinstance(register_json.get("user", None), string_types):
|
||||
raise SynapseError(400, "Expected 'user' key.")
|
||||
if not isinstance(register_json.get("password", None), string_types):
|
||||
raise SynapseError(400, "Expected 'password' key.")
|
||||
assert_params_in_dict(register_json, ["mac", "user", "password"])
|
||||
|
||||
if not self.hs.config.registration_shared_secret:
|
||||
raise SynapseError(400, "Shared secret registration is not enabled")
|
||||
|
@ -389,7 +380,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
|
|||
"""Handles user creation via a server-to-server interface
|
||||
"""
|
||||
|
||||
PATTERNS = client_path_patterns("/createUser$", releases=())
|
||||
PATTERNS = v1_only_client_path_patterns("/createUser$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(CreateUserRestServlet, self).__init__(hs)
|
||||
|
@ -400,7 +391,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
|
|||
def on_POST(self, request):
|
||||
user_json = parse_json_object_from_request(request)
|
||||
|
||||
access_token = get_access_token_from_request(request)
|
||||
access_token = self.auth.get_access_token_from_request(request)
|
||||
app_service = self.store.get_app_service_by_token(
|
||||
access_token
|
||||
)
|
||||
|
@ -419,11 +410,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def _do_create(self, requester, user_json):
|
||||
if "localpart" not in user_json:
|
||||
raise SynapseError(400, "Expected 'localpart' key.")
|
||||
|
||||
if "displayname" not in user_json:
|
||||
raise SynapseError(400, "Expected 'displayname' key.")
|
||||
assert_params_in_dict(user_json, ["localpart", "displayname"])
|
||||
|
||||
localpart = user_json["localpart"].encode("utf-8")
|
||||
displayname = user_json["displayname"].encode("utf-8")
|
|
@ -20,12 +20,11 @@ from six.moves import http_client
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.auth import has_access_token
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
assert_params_in_request,
|
||||
assert_params_in_dict,
|
||||
parse_json_object_from_request,
|
||||
)
|
||||
from synapse.util.msisdn import phone_number_to_msisdn
|
||||
|
@ -48,7 +47,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
|
|||
def on_POST(self, request):
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
assert_params_in_request(body, [
|
||||
assert_params_in_dict(body, [
|
||||
'id_server', 'client_secret', 'email', 'send_attempt'
|
||||
])
|
||||
|
||||
|
@ -81,7 +80,7 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
|
|||
def on_POST(self, request):
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
assert_params_in_request(body, [
|
||||
assert_params_in_dict(body, [
|
||||
'id_server', 'client_secret',
|
||||
'country', 'phone_number', 'send_attempt',
|
||||
])
|
||||
|
@ -130,7 +129,7 @@ class PasswordRestServlet(RestServlet):
|
|||
#
|
||||
# In the second case, we require a password to confirm their identity.
|
||||
|
||||
if has_access_token(request):
|
||||
if self.auth.has_access_token(request):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
params = yield self.auth_handler.validate_user_via_ui_auth(
|
||||
requester, body, self.hs.get_ip_from_request(request),
|
||||
|
@ -160,11 +159,10 @@ class PasswordRestServlet(RestServlet):
|
|||
raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
|
||||
user_id = threepid_user_id
|
||||
else:
|
||||
logger.error("Auth succeeded but no known type!", result.keys())
|
||||
logger.error("Auth succeeded but no known type! %r", result.keys())
|
||||
raise SynapseError(500, "", Codes.UNKNOWN)
|
||||
|
||||
if 'new_password' not in params:
|
||||
raise SynapseError(400, "", Codes.MISSING_PARAM)
|
||||
assert_params_in_dict(params, ["new_password"])
|
||||
new_password = params['new_password']
|
||||
|
||||
yield self._set_password_handler.set_password(
|
||||
|
@ -229,15 +227,10 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
required = ['id_server', 'client_secret', 'email', 'send_attempt']
|
||||
absent = []
|
||||
for k in required:
|
||||
if k not in body:
|
||||
absent.append(k)
|
||||
|
||||
if absent:
|
||||
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
|
||||
assert_params_in_dict(
|
||||
body,
|
||||
['id_server', 'client_secret', 'email', 'send_attempt'],
|
||||
)
|
||||
|
||||
if not check_3pid_allowed(self.hs, "email", body['email']):
|
||||
raise SynapseError(
|
||||
|
@ -267,18 +260,10 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
required = [
|
||||
assert_params_in_dict(body, [
|
||||
'id_server', 'client_secret',
|
||||
'country', 'phone_number', 'send_attempt',
|
||||
]
|
||||
absent = []
|
||||
for k in required:
|
||||
if k not in body:
|
||||
absent.append(k)
|
||||
|
||||
if absent:
|
||||
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
|
||||
])
|
||||
|
||||
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
|
||||
|
||||
|
@ -373,15 +358,7 @@ class ThreepidDeleteRestServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
required = ['medium', 'address']
|
||||
absent = []
|
||||
for k in required:
|
||||
if k not in body:
|
||||
absent.append(k)
|
||||
|
||||
if absent:
|
||||
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
|
||||
assert_params_in_dict(body, ['medium', 'address'])
|
||||
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
|
|
@ -18,14 +18,18 @@ import logging
|
|||
from twisted.internet import defer
|
||||
|
||||
from synapse.api import errors
|
||||
from synapse.http import servlet
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
assert_params_in_dict,
|
||||
parse_json_object_from_request,
|
||||
)
|
||||
|
||||
from ._base import client_v2_patterns, interactive_auth_handler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DevicesRestServlet(servlet.RestServlet):
|
||||
class DevicesRestServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/devices$", v2_alpha=False)
|
||||
|
||||
def __init__(self, hs):
|
||||
|
@ -47,7 +51,7 @@ class DevicesRestServlet(servlet.RestServlet):
|
|||
defer.returnValue((200, {"devices": devices}))
|
||||
|
||||
|
||||
class DeleteDevicesRestServlet(servlet.RestServlet):
|
||||
class DeleteDevicesRestServlet(RestServlet):
|
||||
"""
|
||||
API for bulk deletion of devices. Accepts a JSON object with a devices
|
||||
key which lists the device_ids to delete. Requires user interactive auth.
|
||||
|
@ -67,19 +71,17 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
|
|||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
try:
|
||||
body = servlet.parse_json_object_from_request(request)
|
||||
body = parse_json_object_from_request(request)
|
||||
except errors.SynapseError as e:
|
||||
if e.errcode == errors.Codes.NOT_JSON:
|
||||
# deal with older clients which didn't pass a J*DELETESON dict
|
||||
# DELETE
|
||||
# deal with older clients which didn't pass a JSON dict
|
||||
# the same as those that pass an empty dict
|
||||
body = {}
|
||||
else:
|
||||
raise e
|
||||
|
||||
if 'devices' not in body:
|
||||
raise errors.SynapseError(
|
||||
400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM
|
||||
)
|
||||
assert_params_in_dict(body, ["devices"])
|
||||
|
||||
yield self.auth_handler.validate_user_via_ui_auth(
|
||||
requester, body, self.hs.get_ip_from_request(request),
|
||||
|
@ -92,7 +94,7 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
|
|||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
class DeviceRestServlet(servlet.RestServlet):
|
||||
class DeviceRestServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", v2_alpha=False)
|
||||
|
||||
def __init__(self, hs):
|
||||
|
@ -121,7 +123,7 @@ class DeviceRestServlet(servlet.RestServlet):
|
|||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
try:
|
||||
body = servlet.parse_json_object_from_request(request)
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
except errors.SynapseError as e:
|
||||
if e.errcode == errors.Codes.NOT_JSON:
|
||||
|
@ -144,7 +146,7 @@ class DeviceRestServlet(servlet.RestServlet):
|
|||
def on_PUT(self, request, device_id):
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
body = servlet.parse_json_object_from_request(request)
|
||||
body = parse_json_object_from_request(request)
|
||||
yield self.device_handler.update_device(
|
||||
requester.user.to_string(),
|
||||
device_id,
|
||||
|
|
|
@ -24,12 +24,11 @@ from twisted.internet import defer
|
|||
|
||||
import synapse
|
||||
import synapse.types
|
||||
from synapse.api.auth import get_access_token_from_request, has_access_token
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
assert_params_in_request,
|
||||
assert_params_in_dict,
|
||||
parse_json_object_from_request,
|
||||
parse_string,
|
||||
)
|
||||
|
@ -69,7 +68,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
|
|||
def on_POST(self, request):
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
assert_params_in_request(body, [
|
||||
assert_params_in_dict(body, [
|
||||
'id_server', 'client_secret', 'email', 'send_attempt'
|
||||
])
|
||||
|
||||
|
@ -105,7 +104,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
|
|||
def on_POST(self, request):
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
assert_params_in_request(body, [
|
||||
assert_params_in_dict(body, [
|
||||
'id_server', 'client_secret',
|
||||
'country', 'phone_number',
|
||||
'send_attempt',
|
||||
|
@ -224,7 +223,7 @@ class RegisterRestServlet(RestServlet):
|
|||
desired_username = body['username']
|
||||
|
||||
appservice = None
|
||||
if has_access_token(request):
|
||||
if self.auth.has_access_token(request):
|
||||
appservice = yield self.auth.get_appservice_by_req(request)
|
||||
|
||||
# fork off as soon as possible for ASes and shared secret auth which
|
||||
|
@ -242,7 +241,7 @@ class RegisterRestServlet(RestServlet):
|
|||
# because the IRC bridges rely on being able to register stupid
|
||||
# IDs.
|
||||
|
||||
access_token = get_access_token_from_request(request)
|
||||
access_token = self.auth.get_access_token_from_request(request)
|
||||
|
||||
if isinstance(desired_username, string_types):
|
||||
result = yield self._do_appservice_registration(
|
||||
|
@ -387,9 +386,7 @@ class RegisterRestServlet(RestServlet):
|
|||
add_msisdn = False
|
||||
else:
|
||||
# NB: This may be from the auth handler and NOT from the POST
|
||||
if 'password' not in params:
|
||||
raise SynapseError(400, "Missing password.",
|
||||
Codes.MISSING_PARAM)
|
||||
assert_params_in_dict(params, ["password"])
|
||||
|
||||
desired_username = params.get("username", None)
|
||||
new_password = params.get("password", None)
|
||||
|
@ -566,11 +563,14 @@ class RegisterRestServlet(RestServlet):
|
|||
Returns:
|
||||
defer.Deferred:
|
||||
"""
|
||||
reqd = ('medium', 'address', 'validated_at')
|
||||
if any(x not in threepid for x in reqd):
|
||||
# This will only happen if the ID server returns a malformed response
|
||||
logger.info("Can't add incomplete 3pid")
|
||||
defer.returnValue()
|
||||
try:
|
||||
assert_params_in_dict(threepid, ['medium', 'address', 'validated_at'])
|
||||
except SynapseError as ex:
|
||||
if ex.errcode == Codes.MISSING_PARAM:
|
||||
# This will only happen if the ID server returns a malformed response
|
||||
logger.info("Can't add incomplete 3pid")
|
||||
defer.returnValue(None)
|
||||
raise
|
||||
|
||||
yield self.auth_handler.add_threepid(
|
||||
user_id,
|
||||
|
@ -643,7 +643,7 @@ class RegisterRestServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def _do_guest_registration(self, params):
|
||||
if not self.hs.config.allow_guest_access:
|
||||
defer.returnValue((403, "Guest access is disabled"))
|
||||
raise SynapseError(403, "Guest access is disabled")
|
||||
user_id, _ = yield self.registration_handler.register(
|
||||
generate_token=False,
|
||||
make_guest=True
|
||||
|
|
|
@ -15,9 +15,17 @@
|
|||
|
||||
import logging
|
||||
|
||||
from six import string_types
|
||||
from six.moves import http_client
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
assert_params_in_dict,
|
||||
parse_json_object_from_request,
|
||||
)
|
||||
|
||||
from ._base import client_v2_patterns
|
||||
|
||||
|
@ -42,12 +50,26 @@ class ReportEventRestServlet(RestServlet):
|
|||
user_id = requester.user.to_string()
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
assert_params_in_dict(body, ("reason", "score"))
|
||||
|
||||
if not isinstance(body["reason"], string_types):
|
||||
raise SynapseError(
|
||||
http_client.BAD_REQUEST,
|
||||
"Param 'reason' must be a string",
|
||||
Codes.BAD_JSON,
|
||||
)
|
||||
if not isinstance(body["score"], int):
|
||||
raise SynapseError(
|
||||
http_client.BAD_REQUEST,
|
||||
"Param 'score' must be an integer",
|
||||
Codes.BAD_JSON,
|
||||
)
|
||||
|
||||
yield self.store.add_event_report(
|
||||
room_id=room_id,
|
||||
event_id=event_id,
|
||||
user_id=user_id,
|
||||
reason=body.get("reason"),
|
||||
reason=body["reason"],
|
||||
content=body,
|
||||
received_ts=self.clock.time_msec(),
|
||||
)
|
||||
|
|
|
@ -40,7 +40,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
|
|||
super(SendToDeviceRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.txns = HttpTransactionCache(hs.get_clock())
|
||||
self.txns = HttpTransactionCache(hs)
|
||||
self.device_message_handler = hs.get_device_message_handler()
|
||||
|
||||
def on_PUT(self, request, message_type, txn_id):
|
||||
|
|
|
@ -16,6 +16,8 @@ from pydenticon import Generator
|
|||
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from synapse.http.servlet import parse_integer
|
||||
|
||||
FOREGROUND = [
|
||||
"rgb(45,79,255)",
|
||||
"rgb(254,180,44)",
|
||||
|
@ -56,8 +58,8 @@ class IdenticonResource(Resource):
|
|||
|
||||
def render_GET(self, request):
|
||||
name = "/".join(request.postpath)
|
||||
width = int(request.args.get("width", [96])[0])
|
||||
height = int(request.args.get("height", [96])[0])
|
||||
width = parse_integer(request, "width", default=96)
|
||||
height = parse_integer(request, "height", default=96)
|
||||
identicon_bytes = self.generate_identicon(name, width, height)
|
||||
request.setHeader(b"Content-Type", b"image/png")
|
||||
request.setHeader(
|
||||
|
|
|
@ -40,6 +40,7 @@ from synapse.http.server import (
|
|||
respond_with_json_bytes,
|
||||
wrap_json_request_handler,
|
||||
)
|
||||
from synapse.http.servlet import parse_integer, parse_string
|
||||
from synapse.util.async import ObservableDeferred
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
|
||||
|
@ -96,9 +97,9 @@ class PreviewUrlResource(Resource):
|
|||
|
||||
# XXX: if get_user_by_req fails, what should we do in an async render?
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
url = request.args.get("url")[0]
|
||||
url = parse_string(request, "url")
|
||||
if "ts" in request.args:
|
||||
ts = int(request.args.get("ts")[0])
|
||||
ts = parse_integer(request, "ts")
|
||||
else:
|
||||
ts = self.clock.time_msec()
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ from twisted.web.server import NOT_DONE_YET
|
|||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.http.server import respond_with_json, wrap_json_request_handler
|
||||
from synapse.http.servlet import parse_string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -65,10 +66,10 @@ class UploadResource(Resource):
|
|||
code=413,
|
||||
)
|
||||
|
||||
upload_name = request.args.get("filename", None)
|
||||
upload_name = parse_string(request, "filename")
|
||||
if upload_name:
|
||||
try:
|
||||
upload_name = upload_name[0].decode('UTF-8')
|
||||
upload_name = upload_name.decode('UTF-8')
|
||||
except UnicodeDecodeError:
|
||||
raise SynapseError(
|
||||
msg="Invalid UTF-8 filename parameter: %r" % (upload_name),
|
||||
|
|
42
synapse/secrets.py
Normal file
42
synapse/secrets.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Injectable secrets module for Synapse.
|
||||
|
||||
See https://docs.python.org/3/library/secrets.html#module-secrets for the API
|
||||
used in Python 3.6, and the API emulated in Python 2.7.
|
||||
"""
|
||||
|
||||
import six
|
||||
|
||||
if six.PY3:
|
||||
import secrets
|
||||
|
||||
def Secrets():
|
||||
return secrets
|
||||
|
||||
|
||||
else:
|
||||
|
||||
import os
|
||||
import binascii
|
||||
|
||||
class Secrets(object):
|
||||
def token_bytes(self, nbytes=32):
|
||||
return os.urandom(nbytes)
|
||||
|
||||
def token_hex(self, nbytes=32):
|
||||
return binascii.hexlify(self.token_bytes(nbytes))
|
|
@ -74,6 +74,7 @@ from synapse.rest.media.v1.media_repository import (
|
|||
MediaRepository,
|
||||
MediaRepositoryResource,
|
||||
)
|
||||
from synapse.secrets import Secrets
|
||||
from synapse.server_notices.server_notices_manager import ServerNoticesManager
|
||||
from synapse.server_notices.server_notices_sender import ServerNoticesSender
|
||||
from synapse.server_notices.worker_server_notices_sender import WorkerServerNoticesSender
|
||||
|
@ -158,6 +159,7 @@ class HomeServer(object):
|
|||
'groups_server_handler',
|
||||
'groups_attestation_signing',
|
||||
'groups_attestation_renewer',
|
||||
'secrets',
|
||||
'spam_checker',
|
||||
'room_member_handler',
|
||||
'federation_registry',
|
||||
|
@ -405,6 +407,9 @@ class HomeServer(object):
|
|||
def build_groups_attestation_renewer(self):
|
||||
return GroupAttestionRenewer(self)
|
||||
|
||||
def build_secrets(self):
|
||||
return Secrets()
|
||||
|
||||
def build_spam_checker(self):
|
||||
return SpamChecker(self)
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ import hashlib
|
|||
import logging
|
||||
from collections import namedtuple
|
||||
|
||||
from six import iteritems, itervalues
|
||||
from six import iteritems, iterkeys, itervalues
|
||||
|
||||
from frozendict import frozendict
|
||||
|
||||
|
@ -647,7 +647,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
|
|||
for event_id in event_ids
|
||||
)
|
||||
if event_map is not None:
|
||||
needed_events -= set(event_map.iterkeys())
|
||||
needed_events -= set(iterkeys(event_map))
|
||||
|
||||
logger.info("Asking for %d conflicted events", len(needed_events))
|
||||
|
||||
|
@ -668,7 +668,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
|
|||
new_needed_events = set(itervalues(auth_events))
|
||||
new_needed_events -= needed_events
|
||||
if event_map is not None:
|
||||
new_needed_events -= set(event_map.iterkeys())
|
||||
new_needed_events -= set(iterkeys(event_map))
|
||||
|
||||
logger.info("Asking for %d auth events", len(new_needed_events))
|
||||
|
||||
|
|
|
@ -344,7 +344,7 @@ class SQLBaseStore(object):
|
|||
parent_context = LoggingContext.current_context()
|
||||
if parent_context == LoggingContext.sentinel:
|
||||
logger.warn(
|
||||
"Running db txn from sentinel context: metrics will be lost",
|
||||
"Starting db connection from sentinel context: metrics will be lost",
|
||||
)
|
||||
parent_context = None
|
||||
|
||||
|
|
|
@ -19,6 +19,8 @@ from canonicaljson import json
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
|
||||
from . import engines
|
||||
from ._base import SQLBaseStore
|
||||
|
||||
|
@ -87,10 +89,14 @@ class BackgroundUpdateStore(SQLBaseStore):
|
|||
self._background_update_handlers = {}
|
||||
self._all_done = False
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def start_doing_background_updates(self):
|
||||
logger.info("Starting background schema updates")
|
||||
run_as_background_process(
|
||||
"background_updates", self._run_background_updates,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _run_background_updates(self):
|
||||
logger.info("Starting background schema updates")
|
||||
while True:
|
||||
yield self.hs.get_clock().sleep(
|
||||
self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.)
|
||||
|
|
|
@ -19,6 +19,7 @@ from six import iteritems
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.util.caches import CACHE_SIZE_FACTOR
|
||||
|
||||
from . import background_updates
|
||||
|
@ -93,10 +94,16 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
|||
self._batch_row_update[key] = (user_agent, device_id, now)
|
||||
|
||||
def _update_client_ips_batch(self):
|
||||
to_update = self._batch_row_update
|
||||
self._batch_row_update = {}
|
||||
return self.runInteraction(
|
||||
"_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
|
||||
def update():
|
||||
to_update = self._batch_row_update
|
||||
self._batch_row_update = {}
|
||||
return self.runInteraction(
|
||||
"_update_client_ips_batch", self._update_client_ips_batch_txn,
|
||||
to_update,
|
||||
)
|
||||
|
||||
run_as_background_process(
|
||||
"update_client_ips", update,
|
||||
)
|
||||
|
||||
def _update_client_ips_batch_txn(self, txn, to_update):
|
||||
|
|
|
@ -33,12 +33,13 @@ from synapse.api.errors import SynapseError
|
|||
# these are only included to make the type annotations work
|
||||
from synapse.events import EventBase # noqa: F401
|
||||
from synapse.events.snapshot import EventContext # noqa: F401
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage.events_worker import EventsWorkerStore
|
||||
from synapse.types import RoomStreamToken, get_domain_from_id
|
||||
from synapse.util.async import ObservableDeferred
|
||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
||||
from synapse.util.frozenutils import frozendict_json_encoder
|
||||
from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable
|
||||
from synapse.util.logcontext import make_deferred_yieldable
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
|
@ -155,11 +156,8 @@ class _EventPeristenceQueue(object):
|
|||
self._event_persist_queues[room_id] = queue
|
||||
self._currently_persisting_rooms.discard(room_id)
|
||||
|
||||
# set handle_queue_loop off on the background. We don't want to
|
||||
# attribute work done in it to the current request, so we drop the
|
||||
# logcontext altogether.
|
||||
with PreserveLoggingContext():
|
||||
handle_queue_loop()
|
||||
# set handle_queue_loop off in the background
|
||||
run_as_background_process("persist_events", handle_queue_loop)
|
||||
|
||||
def _get_drainining_queue(self, room_id):
|
||||
queue = self._event_persist_queues.setdefault(room_id, deque())
|
||||
|
|
|
@ -25,6 +25,7 @@ from synapse.events import EventBase # noqa: F401
|
|||
from synapse.events import FrozenEvent
|
||||
from synapse.events.snapshot import EventContext # noqa: F401
|
||||
from synapse.events.utils import prune_event
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.util.logcontext import (
|
||||
LoggingContext,
|
||||
PreserveLoggingContext,
|
||||
|
@ -322,10 +323,11 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
should_start = False
|
||||
|
||||
if should_start:
|
||||
with PreserveLoggingContext():
|
||||
self.runWithConnection(
|
||||
self._do_fetch
|
||||
)
|
||||
run_as_background_process(
|
||||
"fetch_events",
|
||||
self.runWithConnection,
|
||||
self._do_fetch,
|
||||
)
|
||||
|
||||
logger.debug("Loading %d events", len(events))
|
||||
with PreserveLoggingContext():
|
||||
|
|
|
@ -140,7 +140,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||
"""
|
||||
room_ids = set(room_ids)
|
||||
|
||||
if from_key:
|
||||
if from_key is not None:
|
||||
# Only ask the database about rooms where there have been new
|
||||
# receipts added since `from_key`
|
||||
room_ids = yield self._receipts_stream_cache.get_entities_changed(
|
||||
room_ids, from_key
|
||||
)
|
||||
|
@ -151,7 +153,6 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||
|
||||
defer.returnValue([ev for res in results.values() for ev in res])
|
||||
|
||||
@cachedInlineCallbacks(num_args=3, tree=True)
|
||||
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
|
||||
"""Get receipts for a single room for sending to clients.
|
||||
|
||||
|
@ -162,7 +163,19 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||
from the start.
|
||||
|
||||
Returns:
|
||||
list: A list of receipts.
|
||||
Deferred[list]: A list of receipts.
|
||||
"""
|
||||
if from_key is not None:
|
||||
# Check the cache first to see if any new receipts have been added
|
||||
# since`from_key`. If not we can no-op.
|
||||
if not self._receipts_stream_cache.has_entity_changed(room_id, from_key):
|
||||
defer.succeed([])
|
||||
|
||||
return self._get_linearized_receipts_for_room(room_id, to_key, from_key)
|
||||
|
||||
@cachedInlineCallbacks(num_args=3, tree=True)
|
||||
def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
|
||||
"""See get_linearized_receipts_for_room
|
||||
"""
|
||||
def f(txn):
|
||||
if from_key:
|
||||
|
@ -211,7 +224,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||
"content": content,
|
||||
}])
|
||||
|
||||
@cachedList(cached_method_name="get_linearized_receipts_for_room",
|
||||
@cachedList(cached_method_name="_get_linearized_receipts_for_room",
|
||||
list_name="room_ids", num_args=3, inlineCallbacks=True)
|
||||
def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
|
||||
if not room_ids:
|
||||
|
@ -373,7 +386,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
|
|||
self.get_receipts_for_user.invalidate, (user_id, receipt_type)
|
||||
)
|
||||
# FIXME: This shouldn't invalidate the whole cache
|
||||
txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,))
|
||||
txn.call_after(self._get_linearized_receipts_for_room.invalidate_many, (room_id,))
|
||||
|
||||
txn.call_after(
|
||||
self._receipts_stream_cache.entity_has_changed,
|
||||
|
@ -493,7 +506,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
|
|||
self.get_receipts_for_user.invalidate, (user_id, receipt_type)
|
||||
)
|
||||
# FIXME: This shouldn't invalidate the whole cache
|
||||
txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,))
|
||||
txn.call_after(self._get_linearized_receipts_for_room.invalidate_many, (room_id,))
|
||||
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
import logging
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.http.servlet import parse_integer, parse_string
|
||||
from synapse.types import StreamToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -56,23 +57,10 @@ class PaginationConfig(object):
|
|||
@classmethod
|
||||
def from_request(cls, request, raise_invalid_params=True,
|
||||
default_limit=None):
|
||||
def get_param(name, default=None):
|
||||
lst = request.args.get(name, [])
|
||||
if len(lst) > 1:
|
||||
raise SynapseError(
|
||||
400, "%s must be specified only once" % (name,)
|
||||
)
|
||||
elif len(lst) == 1:
|
||||
return lst[0]
|
||||
else:
|
||||
return default
|
||||
direction = parse_string(request, "dir", default='f', allowed_values=['f', 'b'])
|
||||
|
||||
direction = get_param("dir", 'f')
|
||||
if direction not in ['f', 'b']:
|
||||
raise SynapseError(400, "'dir' parameter is invalid.")
|
||||
|
||||
from_tok = get_param("from")
|
||||
to_tok = get_param("to")
|
||||
from_tok = parse_string(request, "from")
|
||||
to_tok = parse_string(request, "to")
|
||||
|
||||
try:
|
||||
if from_tok == "END":
|
||||
|
@ -88,12 +76,10 @@ class PaginationConfig(object):
|
|||
except Exception:
|
||||
raise SynapseError(400, "'to' paramater is invalid")
|
||||
|
||||
limit = get_param("limit", None)
|
||||
if limit is not None and not limit.isdigit():
|
||||
raise SynapseError(400, "'limit' parameter must be an integer.")
|
||||
limit = parse_integer(request, "limit", default=default_limit)
|
||||
|
||||
if limit is None:
|
||||
limit = default_limit
|
||||
if limit and limit < 0:
|
||||
raise SynapseError(400, "Limit must be 0 or above")
|
||||
|
||||
try:
|
||||
return PaginationConfig(from_tok, to_tok, direction, limit)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2018 New Vector Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -12,7 +13,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import collections
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
@ -156,54 +157,72 @@ def concurrently_execute(func, args, limit):
|
|||
|
||||
|
||||
class Linearizer(object):
|
||||
"""Linearizes access to resources based on a key. Useful to ensure only one
|
||||
thing is happening at a time on a given resource.
|
||||
"""Limits concurrent access to resources based on a key. Useful to ensure
|
||||
only a few things happen at a time on a given resource.
|
||||
|
||||
Example:
|
||||
|
||||
with (yield linearizer.queue("test_key")):
|
||||
with (yield limiter.queue("test_key")):
|
||||
# do some work.
|
||||
|
||||
"""
|
||||
def __init__(self, name=None, clock=None):
|
||||
def __init__(self, name=None, max_count=1, clock=None):
|
||||
"""
|
||||
Args:
|
||||
max_count(int): The maximum number of concurrent accesses
|
||||
"""
|
||||
if name is None:
|
||||
self.name = id(self)
|
||||
else:
|
||||
self.name = name
|
||||
self.key_to_defer = {}
|
||||
|
||||
if not clock:
|
||||
from twisted.internet import reactor
|
||||
clock = Clock(reactor)
|
||||
self._clock = clock
|
||||
self.max_count = max_count
|
||||
|
||||
# key_to_defer is a map from the key to a 2 element list where
|
||||
# the first element is the number of things executing, and
|
||||
# the second element is an OrderedDict, where the keys are deferreds for the
|
||||
# things blocked from executing.
|
||||
self.key_to_defer = {}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def queue(self, key):
|
||||
# If there is already a deferred in the queue, we pull it out so that
|
||||
# we can wait on it later.
|
||||
# Then we replace it with a deferred that we resolve *after* the
|
||||
# context manager has exited.
|
||||
# We only return the context manager after the previous deferred has
|
||||
# resolved.
|
||||
# This all has the net effect of creating a chain of deferreds that
|
||||
# wait for the previous deferred before starting their work.
|
||||
current_defer = self.key_to_defer.get(key)
|
||||
entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()])
|
||||
|
||||
new_defer = defer.Deferred()
|
||||
self.key_to_defer[key] = new_defer
|
||||
# If the number of things executing is greater than the maximum
|
||||
# then add a deferred to the list of blocked items
|
||||
# When on of the things currently executing finishes it will callback
|
||||
# this item so that it can continue executing.
|
||||
if entry[0] >= self.max_count:
|
||||
new_defer = defer.Deferred()
|
||||
entry[1][new_defer] = 1
|
||||
|
||||
if current_defer:
|
||||
logger.info(
|
||||
"Waiting to acquire linearizer lock %r for key %r", self.name, key
|
||||
"Waiting to acquire linearizer lock %r for key %r", self.name, key,
|
||||
)
|
||||
try:
|
||||
with PreserveLoggingContext():
|
||||
yield current_defer
|
||||
except Exception:
|
||||
logger.exception("Unexpected exception in Linearizer")
|
||||
yield make_deferred_yieldable(new_defer)
|
||||
except Exception as e:
|
||||
if isinstance(e, CancelledError):
|
||||
logger.info(
|
||||
"Cancelling wait for linearizer lock %r for key %r",
|
||||
self.name, key,
|
||||
)
|
||||
else:
|
||||
logger.warn(
|
||||
"Unexpected exception waiting for linearizer lock %r for key %r",
|
||||
self.name, key,
|
||||
)
|
||||
|
||||
logger.info("Acquired linearizer lock %r for key %r", self.name,
|
||||
key)
|
||||
# we just have to take ourselves back out of the queue.
|
||||
del entry[1][new_defer]
|
||||
raise
|
||||
|
||||
logger.info("Acquired linearizer lock %r for key %r", self.name, key)
|
||||
entry[0] += 1
|
||||
|
||||
# if the code holding the lock completes synchronously, then it
|
||||
# will recursively run the next claimant on the list. That can
|
||||
|
@ -213,15 +232,15 @@ class Linearizer(object):
|
|||
# In order to break the cycle, we add a cheeky sleep(0) here to
|
||||
# ensure that we fall back to the reactor between each iteration.
|
||||
#
|
||||
# (There's no particular need for it to happen before we return
|
||||
# the context manager, but it needs to happen while we hold the
|
||||
# lock, and the context manager's exit code must be synchronous,
|
||||
# so actually this is the only sensible place.
|
||||
# (This needs to happen while we hold the lock, and the context manager's exit
|
||||
# code must be synchronous, so this is the only sensible place.)
|
||||
yield self._clock.sleep(0)
|
||||
|
||||
else:
|
||||
logger.info("Acquired uncontended linearizer lock %r for key %r",
|
||||
self.name, key)
|
||||
logger.info(
|
||||
"Acquired uncontended linearizer lock %r for key %r", self.name, key,
|
||||
)
|
||||
entry[0] += 1
|
||||
|
||||
@contextmanager
|
||||
def _ctx_manager():
|
||||
|
@ -229,73 +248,15 @@ class Linearizer(object):
|
|||
yield
|
||||
finally:
|
||||
logger.info("Releasing linearizer lock %r for key %r", self.name, key)
|
||||
with PreserveLoggingContext():
|
||||
new_defer.callback(None)
|
||||
current_d = self.key_to_defer.get(key)
|
||||
if current_d is new_defer:
|
||||
self.key_to_defer.pop(key, None)
|
||||
|
||||
defer.returnValue(_ctx_manager())
|
||||
|
||||
|
||||
class Limiter(object):
|
||||
"""Limits concurrent access to resources based on a key. Useful to ensure
|
||||
only a few thing happen at a time on a given resource.
|
||||
|
||||
Example:
|
||||
|
||||
with (yield limiter.queue("test_key")):
|
||||
# do some work.
|
||||
|
||||
"""
|
||||
def __init__(self, max_count):
|
||||
"""
|
||||
Args:
|
||||
max_count(int): The maximum number of concurrent access
|
||||
"""
|
||||
self.max_count = max_count
|
||||
|
||||
# key_to_defer is a map from the key to a 2 element list where
|
||||
# the first element is the number of things executing
|
||||
# the second element is a list of deferreds for the things blocked from
|
||||
# executing.
|
||||
self.key_to_defer = {}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def queue(self, key):
|
||||
entry = self.key_to_defer.setdefault(key, [0, []])
|
||||
|
||||
# If the number of things executing is greater than the maximum
|
||||
# then add a deferred to the list of blocked items
|
||||
# When on of the things currently executing finishes it will callback
|
||||
# this item so that it can continue executing.
|
||||
if entry[0] >= self.max_count:
|
||||
new_defer = defer.Deferred()
|
||||
entry[1].append(new_defer)
|
||||
|
||||
logger.info("Waiting to acquire limiter lock for key %r", key)
|
||||
with PreserveLoggingContext():
|
||||
yield new_defer
|
||||
logger.info("Acquired limiter lock for key %r", key)
|
||||
else:
|
||||
logger.info("Acquired uncontended limiter lock for key %r", key)
|
||||
|
||||
entry[0] += 1
|
||||
|
||||
@contextmanager
|
||||
def _ctx_manager():
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
logger.info("Releasing limiter lock for key %r", key)
|
||||
|
||||
# We've finished executing so check if there are any things
|
||||
# blocked waiting to execute and start one of them
|
||||
entry[0] -= 1
|
||||
|
||||
if entry[1]:
|
||||
next_def = entry[1].pop(0)
|
||||
(next_def, _) = entry[1].popitem(last=False)
|
||||
|
||||
# we need to run the next thing in the sentinel context.
|
||||
with PreserveLoggingContext():
|
||||
next_def.callback(None)
|
||||
elif entry[0] == 0:
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
import logging
|
||||
from collections import OrderedDict
|
||||
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.util.caches import register_cache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -63,7 +64,10 @@ class ExpiringCache(object):
|
|||
return
|
||||
|
||||
def f():
|
||||
self._prune_cache()
|
||||
run_as_background_process(
|
||||
"prune_cache_%s" % self._cache_name,
|
||||
self._prune_cache,
|
||||
)
|
||||
|
||||
self._clock.looping_call(f, self._expiry_ms / 2)
|
||||
|
||||
|
|
|
@ -74,14 +74,13 @@ class StreamChangeCache(object):
|
|||
assert type(stream_pos) is int
|
||||
|
||||
if stream_pos >= self._earliest_known_stream_pos:
|
||||
not_known_entities = set(entities) - set(self._entity_to_key)
|
||||
changed_entities = {
|
||||
self._cache[k] for k in self._cache.islice(
|
||||
start=self._cache.bisect_right(stream_pos),
|
||||
)
|
||||
}
|
||||
|
||||
result = (
|
||||
{self._cache[k] for k in self._cache.islice(
|
||||
start=self._cache.bisect_right(stream_pos))}
|
||||
.intersection(entities)
|
||||
.union(not_known_entities)
|
||||
)
|
||||
result = changed_entities.intersection(entities)
|
||||
|
||||
self.metrics.inc_hits()
|
||||
else:
|
||||
|
|
|
@ -17,20 +17,18 @@ import logging
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def user_left_room(distributor, user, room_id):
|
||||
with PreserveLoggingContext():
|
||||
distributor.fire("user_left_room", user=user, room_id=room_id)
|
||||
distributor.fire("user_left_room", user=user, room_id=room_id)
|
||||
|
||||
|
||||
def user_joined_room(distributor, user, room_id):
|
||||
with PreserveLoggingContext():
|
||||
distributor.fire("user_joined_room", user=user, room_id=room_id)
|
||||
distributor.fire("user_joined_room", user=user, room_id=room_id)
|
||||
|
||||
|
||||
class Distributor(object):
|
||||
|
@ -44,9 +42,7 @@ class Distributor(object):
|
|||
model will do for today.
|
||||
"""
|
||||
|
||||
def __init__(self, suppress_failures=True):
|
||||
self.suppress_failures = suppress_failures
|
||||
|
||||
def __init__(self):
|
||||
self.signals = {}
|
||||
self.pre_registration = {}
|
||||
|
||||
|
@ -56,7 +52,6 @@ class Distributor(object):
|
|||
|
||||
self.signals[name] = Signal(
|
||||
name,
|
||||
suppress_failures=self.suppress_failures,
|
||||
)
|
||||
|
||||
if name in self.pre_registration:
|
||||
|
@ -75,10 +70,18 @@ class Distributor(object):
|
|||
self.pre_registration[name].append(observer)
|
||||
|
||||
def fire(self, name, *args, **kwargs):
|
||||
"""Dispatches the given signal to the registered observers.
|
||||
|
||||
Runs the observers as a background process. Does not return a deferred.
|
||||
"""
|
||||
if name not in self.signals:
|
||||
raise KeyError("%r does not have a signal named %s" % (self, name))
|
||||
|
||||
return self.signals[name].fire(*args, **kwargs)
|
||||
run_as_background_process(
|
||||
name,
|
||||
self.signals[name].fire,
|
||||
*args, **kwargs
|
||||
)
|
||||
|
||||
|
||||
class Signal(object):
|
||||
|
@ -91,9 +94,8 @@ class Signal(object):
|
|||
method into all of the observers.
|
||||
"""
|
||||
|
||||
def __init__(self, name, suppress_failures):
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.suppress_failures = suppress_failures
|
||||
self.observers = []
|
||||
|
||||
def observe(self, observer):
|
||||
|
@ -103,7 +105,6 @@ class Signal(object):
|
|||
Each observer callable may return a Deferred."""
|
||||
self.observers.append(observer)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def fire(self, *args, **kwargs):
|
||||
"""Invokes every callable in the observer list, passing in the args and
|
||||
kwargs. Exceptions thrown by observers are logged but ignored. It is
|
||||
|
@ -121,22 +122,17 @@ class Signal(object):
|
|||
failure.type,
|
||||
failure.value,
|
||||
failure.getTracebackObject()))
|
||||
if not self.suppress_failures:
|
||||
return failure
|
||||
|
||||
return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
deferreds = [
|
||||
do(observer)
|
||||
for observer in self.observers
|
||||
]
|
||||
deferreds = [
|
||||
run_in_background(do, o)
|
||||
for o in self.observers
|
||||
]
|
||||
|
||||
res = yield defer.gatherResults(
|
||||
deferreds, consumeErrors=True
|
||||
).addErrback(unwrapFirstError)
|
||||
|
||||
defer.returnValue(res)
|
||||
return make_deferred_yieldable(defer.gatherResults(
|
||||
deferreds, consumeErrors=True,
|
||||
))
|
||||
|
||||
def __repr__(self):
|
||||
return "<Signal name=%r>" % (self.name,)
|
||||
|
|
|
@ -99,6 +99,17 @@ class ContextResourceUsage(object):
|
|||
self.db_sched_duration_sec = 0
|
||||
self.evt_db_fetch_count = 0
|
||||
|
||||
def __repr__(self):
|
||||
return ("<ContextResourceUsage ru_stime='%r', ru_utime='%r', "
|
||||
"db_txn_count='%r', db_txn_duration_sec='%r', "
|
||||
"db_sched_duration_sec='%r', evt_db_fetch_count='%r'>") % (
|
||||
self.ru_stime,
|
||||
self.ru_utime,
|
||||
self.db_txn_count,
|
||||
self.db_txn_duration_sec,
|
||||
self.db_sched_duration_sec,
|
||||
self.evt_db_fetch_count,)
|
||||
|
||||
def __iadd__(self, other):
|
||||
"""Add another ContextResourceUsage's stats to this one's.
|
||||
|
||||
|
|
|
@ -104,12 +104,19 @@ class Measure(object):
|
|||
logger.warn("Expected context. (%r)", self.name)
|
||||
return
|
||||
|
||||
usage = context.get_resource_usage() - self.start_usage
|
||||
block_ru_utime.labels(self.name).inc(usage.ru_utime)
|
||||
block_ru_stime.labels(self.name).inc(usage.ru_stime)
|
||||
block_db_txn_count.labels(self.name).inc(usage.db_txn_count)
|
||||
block_db_txn_duration.labels(self.name).inc(usage.db_txn_duration_sec)
|
||||
block_db_sched_duration.labels(self.name).inc(usage.db_sched_duration_sec)
|
||||
current = context.get_resource_usage()
|
||||
usage = current - self.start_usage
|
||||
try:
|
||||
block_ru_utime.labels(self.name).inc(usage.ru_utime)
|
||||
block_ru_stime.labels(self.name).inc(usage.ru_stime)
|
||||
block_db_txn_count.labels(self.name).inc(usage.db_txn_count)
|
||||
block_db_txn_duration.labels(self.name).inc(usage.db_txn_duration_sec)
|
||||
block_db_sched_duration.labels(self.name).inc(usage.db_sched_duration_sec)
|
||||
except ValueError:
|
||||
logger.warn(
|
||||
"Failed to save metrics! OLD: %r, NEW: %r",
|
||||
self.start_usage, current
|
||||
)
|
||||
|
||||
if self.created_context:
|
||||
self.start_context.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
|
|
@ -92,13 +92,22 @@ class _PerHostRatelimiter(object):
|
|||
|
||||
self.window_size = window_size
|
||||
self.sleep_limit = sleep_limit
|
||||
self.sleep_msec = sleep_msec
|
||||
self.sleep_sec = sleep_msec / 1000.0
|
||||
self.reject_limit = reject_limit
|
||||
self.concurrent_requests = concurrent_requests
|
||||
|
||||
# request_id objects for requests which have been slept
|
||||
self.sleeping_requests = set()
|
||||
|
||||
# map from request_id object to Deferred for requests which are ready
|
||||
# for processing but have been queued
|
||||
self.ready_request_queue = collections.OrderedDict()
|
||||
|
||||
# request id objects for requests which are in progress
|
||||
self.current_processing = set()
|
||||
|
||||
# times at which we have recently (within the last window_size ms)
|
||||
# received requests.
|
||||
self.request_times = []
|
||||
|
||||
@contextlib.contextmanager
|
||||
|
@ -117,11 +126,15 @@ class _PerHostRatelimiter(object):
|
|||
|
||||
def _on_enter(self, request_id):
|
||||
time_now = self.clock.time_msec()
|
||||
|
||||
# remove any entries from request_times which aren't within the window
|
||||
self.request_times[:] = [
|
||||
r for r in self.request_times
|
||||
if time_now - r < self.window_size
|
||||
]
|
||||
|
||||
# reject the request if we already have too many queued up (either
|
||||
# sleeping or in the ready queue).
|
||||
queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
|
||||
if queue_size > self.reject_limit:
|
||||
raise LimitExceededError(
|
||||
|
@ -134,9 +147,13 @@ class _PerHostRatelimiter(object):
|
|||
|
||||
def queue_request():
|
||||
if len(self.current_processing) > self.concurrent_requests:
|
||||
logger.debug("Ratelimit [%s]: Queue req", id(request_id))
|
||||
queue_defer = defer.Deferred()
|
||||
self.ready_request_queue[request_id] = queue_defer
|
||||
logger.info(
|
||||
"Ratelimiter: queueing request (queue now %i items)",
|
||||
len(self.ready_request_queue),
|
||||
)
|
||||
|
||||
return queue_defer
|
||||
else:
|
||||
return defer.succeed(None)
|
||||
|
@ -148,10 +165,9 @@ class _PerHostRatelimiter(object):
|
|||
|
||||
if len(self.request_times) > self.sleep_limit:
|
||||
logger.debug(
|
||||
"Ratelimit [%s]: sleeping req",
|
||||
id(request_id),
|
||||
"Ratelimiter: sleeping request for %f sec", self.sleep_sec,
|
||||
)
|
||||
ret_defer = run_in_background(self.clock.sleep, self.sleep_msec / 1000.0)
|
||||
ret_defer = run_in_background(self.clock.sleep, self.sleep_sec)
|
||||
|
||||
self.sleeping_requests.add(request_id)
|
||||
|
||||
|
@ -200,11 +216,8 @@ class _PerHostRatelimiter(object):
|
|||
)
|
||||
self.current_processing.discard(request_id)
|
||||
try:
|
||||
request_id, deferred = self.ready_request_queue.popitem()
|
||||
|
||||
# XXX: why do we do the following? the on_start callback above will
|
||||
# do it for us.
|
||||
self.current_processing.add(request_id)
|
||||
# start processing the next item on the queue.
|
||||
_, deferred = self.ready_request_queue.popitem(last=False)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
deferred.callback(None)
|
||||
|
|
|
@ -12,14 +12,18 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import itertools
|
||||
|
||||
import logging
|
||||
import operator
|
||||
|
||||
from six import iteritems, itervalues
|
||||
from six.moves import map
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.events.utils import prune_event
|
||||
from synapse.types import get_domain_from_id
|
||||
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -218,10 +222,161 @@ def filter_events_for_client(store, user_id, events, is_peeking=False,
|
|||
return event
|
||||
|
||||
# check each event: gives an iterable[None|EventBase]
|
||||
filtered_events = itertools.imap(allowed, events)
|
||||
filtered_events = map(allowed, events)
|
||||
|
||||
# remove the None entries
|
||||
filtered_events = filter(operator.truth, filtered_events)
|
||||
|
||||
# we turn it into a list before returning it.
|
||||
defer.returnValue(list(filtered_events))
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def filter_events_for_server(store, server_name, events):
|
||||
# Whatever else we do, we need to check for senders which have requested
|
||||
# erasure of their data.
|
||||
erased_senders = yield store.are_users_erased(
|
||||
e.sender for e in events,
|
||||
)
|
||||
|
||||
def redact_disallowed(event, state):
|
||||
# if the sender has been gdpr17ed, always return a redacted
|
||||
# copy of the event.
|
||||
if erased_senders[event.sender]:
|
||||
logger.info(
|
||||
"Sender of %s has been erased, redacting",
|
||||
event.event_id,
|
||||
)
|
||||
return prune_event(event)
|
||||
|
||||
# state will be None if we decided we didn't need to filter by
|
||||
# room membership.
|
||||
if not state:
|
||||
return event
|
||||
|
||||
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
|
||||
if history:
|
||||
visibility = history.content.get("history_visibility", "shared")
|
||||
if visibility in ["invited", "joined"]:
|
||||
# We now loop through all state events looking for
|
||||
# membership states for the requesting server to determine
|
||||
# if the server is either in the room or has been invited
|
||||
# into the room.
|
||||
for ev in itervalues(state):
|
||||
if ev.type != EventTypes.Member:
|
||||
continue
|
||||
try:
|
||||
domain = get_domain_from_id(ev.state_key)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if domain != server_name:
|
||||
continue
|
||||
|
||||
memtype = ev.membership
|
||||
if memtype == Membership.JOIN:
|
||||
return event
|
||||
elif memtype == Membership.INVITE:
|
||||
if visibility == "invited":
|
||||
return event
|
||||
else:
|
||||
# server has no users in the room: redact
|
||||
return prune_event(event)
|
||||
|
||||
return event
|
||||
|
||||
# Next lets check to see if all the events have a history visibility
|
||||
# of "shared" or "world_readable". If thats the case then we don't
|
||||
# need to check membership (as we know the server is in the room).
|
||||
event_to_state_ids = yield store.get_state_ids_for_events(
|
||||
frozenset(e.event_id for e in events),
|
||||
types=(
|
||||
(EventTypes.RoomHistoryVisibility, ""),
|
||||
)
|
||||
)
|
||||
|
||||
visibility_ids = set()
|
||||
for sids in itervalues(event_to_state_ids):
|
||||
hist = sids.get((EventTypes.RoomHistoryVisibility, ""))
|
||||
if hist:
|
||||
visibility_ids.add(hist)
|
||||
|
||||
# If we failed to find any history visibility events then the default
|
||||
# is "shared" visiblity.
|
||||
if not visibility_ids:
|
||||
all_open = True
|
||||
else:
|
||||
event_map = yield store.get_events(visibility_ids)
|
||||
all_open = all(
|
||||
e.content.get("history_visibility") in (None, "shared", "world_readable")
|
||||
for e in itervalues(event_map)
|
||||
)
|
||||
|
||||
if all_open:
|
||||
# all the history_visibility state affecting these events is open, so
|
||||
# we don't need to filter by membership state. We *do* need to check
|
||||
# for user erasure, though.
|
||||
if erased_senders:
|
||||
events = [
|
||||
redact_disallowed(e, None)
|
||||
for e in events
|
||||
]
|
||||
|
||||
defer.returnValue(events)
|
||||
|
||||
# Ok, so we're dealing with events that have non-trivial visibility
|
||||
# rules, so we need to also get the memberships of the room.
|
||||
|
||||
# first, for each event we're wanting to return, get the event_ids
|
||||
# of the history vis and membership state at those events.
|
||||
event_to_state_ids = yield store.get_state_ids_for_events(
|
||||
frozenset(e.event_id for e in events),
|
||||
types=(
|
||||
(EventTypes.RoomHistoryVisibility, ""),
|
||||
(EventTypes.Member, None),
|
||||
)
|
||||
)
|
||||
|
||||
# We only want to pull out member events that correspond to the
|
||||
# server's domain.
|
||||
#
|
||||
# event_to_state_ids contains lots of duplicates, so it turns out to be
|
||||
# cheaper to build a complete set of unique
|
||||
# ((type, state_key), event_id) tuples, and then filter out the ones we
|
||||
# don't want.
|
||||
#
|
||||
state_key_to_event_id_set = {
|
||||
e
|
||||
for key_to_eid in itervalues(event_to_state_ids)
|
||||
for e in key_to_eid.items()
|
||||
}
|
||||
|
||||
def include(typ, state_key):
|
||||
if typ != EventTypes.Member:
|
||||
return True
|
||||
|
||||
# we avoid using get_domain_from_id here for efficiency.
|
||||
idx = state_key.find(":")
|
||||
if idx == -1:
|
||||
return False
|
||||
return state_key[idx + 1:] == server_name
|
||||
|
||||
event_map = yield store.get_events([
|
||||
e_id
|
||||
for key, e_id in state_key_to_event_id_set
|
||||
if include(key[0], key[1])
|
||||
])
|
||||
|
||||
event_to_state = {
|
||||
e_id: {
|
||||
key: event_map[inner_e_id]
|
||||
for key, inner_e_id in iteritems(key_to_eid)
|
||||
if inner_e_id in event_map
|
||||
}
|
||||
for e_id, key_to_eid in iteritems(event_to_state_ids)
|
||||
}
|
||||
|
||||
defer.returnValue([
|
||||
redact_disallowed(e, event_to_state[e.event_id])
|
||||
for e in events
|
||||
])
|
||||
|
|
|
@ -14,7 +14,10 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
|
|||
|
||||
def setUp(self):
|
||||
self.clock = MockClock()
|
||||
self.cache = HttpTransactionCache(self.clock)
|
||||
self.hs = Mock()
|
||||
self.hs.get_clock = Mock(return_value=self.clock)
|
||||
self.hs.get_auth = Mock()
|
||||
self.cache = HttpTransactionCache(self.hs)
|
||||
|
||||
self.mock_http_response = (200, "GOOD JOB!")
|
||||
self.mock_key = "foo"
|
||||
|
|
305
tests/rest/client/v1/test_admin.py
Normal file
305
tests/rest/client/v1/test_admin.py
Normal file
|
@ -0,0 +1,305 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
|
||||
from mock import Mock
|
||||
|
||||
from synapse.http.server import JsonResource
|
||||
from synapse.rest.client.v1.admin import register_servlets
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.server import (
|
||||
ThreadedMemoryReactorClock,
|
||||
make_request,
|
||||
render,
|
||||
setup_test_homeserver,
|
||||
)
|
||||
|
||||
|
||||
class UserRegisterTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
||||
self.clock = ThreadedMemoryReactorClock()
|
||||
self.hs_clock = Clock(self.clock)
|
||||
self.url = "/_matrix/client/r0/admin/register"
|
||||
|
||||
self.registration_handler = Mock()
|
||||
self.identity_handler = Mock()
|
||||
self.login_handler = Mock()
|
||||
self.device_handler = Mock()
|
||||
self.device_handler.check_device_registered = Mock(return_value="FAKE")
|
||||
|
||||
self.datastore = Mock(return_value=Mock())
|
||||
self.datastore.get_current_state_deltas = Mock(return_value=[])
|
||||
|
||||
self.secrets = Mock()
|
||||
|
||||
self.hs = setup_test_homeserver(
|
||||
http_client=None, clock=self.hs_clock, reactor=self.clock
|
||||
)
|
||||
|
||||
self.hs.config.registration_shared_secret = u"shared"
|
||||
|
||||
self.hs.get_media_repository = Mock()
|
||||
self.hs.get_deactivate_account_handler = Mock()
|
||||
|
||||
self.resource = JsonResource(self.hs)
|
||||
register_servlets(self.hs, self.resource)
|
||||
|
||||
def test_disabled(self):
|
||||
"""
|
||||
If there is no shared secret, registration through this method will be
|
||||
prevented.
|
||||
"""
|
||||
self.hs.config.registration_shared_secret = None
|
||||
|
||||
request, channel = make_request("POST", self.url, b'{}')
|
||||
render(request, self.resource, self.clock)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(
|
||||
'Shared secret registration is not enabled', channel.json_body["error"]
|
||||
)
|
||||
|
||||
def test_get_nonce(self):
|
||||
"""
|
||||
Calling GET on the endpoint will return a randomised nonce, using the
|
||||
homeserver's secrets provider.
|
||||
"""
|
||||
secrets = Mock()
|
||||
secrets.token_hex = Mock(return_value="abcd")
|
||||
|
||||
self.hs.get_secrets = Mock(return_value=secrets)
|
||||
|
||||
request, channel = make_request("GET", self.url)
|
||||
render(request, self.resource, self.clock)
|
||||
|
||||
self.assertEqual(channel.json_body, {"nonce": "abcd"})
|
||||
|
||||
def test_expired_nonce(self):
|
||||
"""
|
||||
Calling GET on the endpoint will return a randomised nonce, which will
|
||||
only last for SALT_TIMEOUT (60s).
|
||||
"""
|
||||
request, channel = make_request("GET", self.url)
|
||||
render(request, self.resource, self.clock)
|
||||
nonce = channel.json_body["nonce"]
|
||||
|
||||
# 59 seconds
|
||||
self.clock.advance(59)
|
||||
|
||||
body = json.dumps({"nonce": nonce})
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual('username must be specified', channel.json_body["error"])
|
||||
|
||||
# 61 seconds
|
||||
self.clock.advance(2)
|
||||
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual('unrecognised nonce', channel.json_body["error"])
|
||||
|
||||
def test_register_incorrect_nonce(self):
|
||||
"""
|
||||
Only the provided nonce can be used, as it's checked in the MAC.
|
||||
"""
|
||||
request, channel = make_request("GET", self.url)
|
||||
render(request, self.resource, self.clock)
|
||||
nonce = channel.json_body["nonce"]
|
||||
|
||||
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
|
||||
want_mac.update(b"notthenonce\x00bob\x00abc123\x00admin")
|
||||
want_mac = want_mac.hexdigest()
|
||||
|
||||
body = json.dumps(
|
||||
{
|
||||
"nonce": nonce,
|
||||
"username": "bob",
|
||||
"password": "abc123",
|
||||
"admin": True,
|
||||
"mac": want_mac,
|
||||
}
|
||||
).encode('utf8')
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
|
||||
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual("HMAC incorrect", channel.json_body["error"])
|
||||
|
||||
def test_register_correct_nonce(self):
|
||||
"""
|
||||
When the correct nonce is provided, and the right key is provided, the
|
||||
user is registered.
|
||||
"""
|
||||
request, channel = make_request("GET", self.url)
|
||||
render(request, self.resource, self.clock)
|
||||
nonce = channel.json_body["nonce"]
|
||||
|
||||
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
|
||||
want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin")
|
||||
want_mac = want_mac.hexdigest()
|
||||
|
||||
body = json.dumps(
|
||||
{
|
||||
"nonce": nonce,
|
||||
"username": "bob",
|
||||
"password": "abc123",
|
||||
"admin": True,
|
||||
"mac": want_mac,
|
||||
}
|
||||
).encode('utf8')
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual("@bob:test", channel.json_body["user_id"])
|
||||
|
||||
def test_nonce_reuse(self):
|
||||
"""
|
||||
A valid unrecognised nonce.
|
||||
"""
|
||||
request, channel = make_request("GET", self.url)
|
||||
render(request, self.resource, self.clock)
|
||||
nonce = channel.json_body["nonce"]
|
||||
|
||||
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
|
||||
want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin")
|
||||
want_mac = want_mac.hexdigest()
|
||||
|
||||
body = json.dumps(
|
||||
{
|
||||
"nonce": nonce,
|
||||
"username": "bob",
|
||||
"password": "abc123",
|
||||
"admin": True,
|
||||
"mac": want_mac,
|
||||
}
|
||||
).encode('utf8')
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual("@bob:test", channel.json_body["user_id"])
|
||||
|
||||
# Now, try and reuse it
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual('unrecognised nonce', channel.json_body["error"])
|
||||
|
||||
def test_missing_parts(self):
|
||||
"""
|
||||
Synapse will complain if you don't give nonce, username, password, and
|
||||
mac. Admin is optional. Additional checks are done for length and
|
||||
type.
|
||||
"""
|
||||
def nonce():
|
||||
request, channel = make_request("GET", self.url)
|
||||
render(request, self.resource, self.clock)
|
||||
return channel.json_body["nonce"]
|
||||
|
||||
#
|
||||
# Nonce check
|
||||
#
|
||||
|
||||
# Must be present
|
||||
body = json.dumps({})
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual('nonce must be specified', channel.json_body["error"])
|
||||
|
||||
#
|
||||
# Username checks
|
||||
#
|
||||
|
||||
# Must be present
|
||||
body = json.dumps({"nonce": nonce()})
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual('username must be specified', channel.json_body["error"])
|
||||
|
||||
# Must be a string
|
||||
body = json.dumps({"nonce": nonce(), "username": 1234})
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual('Invalid username', channel.json_body["error"])
|
||||
|
||||
# Must not have null bytes
|
||||
body = json.dumps({"nonce": nonce(), "username": b"abcd\x00"})
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual('Invalid username', channel.json_body["error"])
|
||||
|
||||
# Must not have null bytes
|
||||
body = json.dumps({"nonce": nonce(), "username": "a" * 1000})
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual('Invalid username', channel.json_body["error"])
|
||||
|
||||
#
|
||||
# Username checks
|
||||
#
|
||||
|
||||
# Must be present
|
||||
body = json.dumps({"nonce": nonce(), "username": "a"})
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual('password must be specified', channel.json_body["error"])
|
||||
|
||||
# Must be a string
|
||||
body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234})
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual('Invalid password', channel.json_body["error"])
|
||||
|
||||
# Must not have null bytes
|
||||
body = json.dumps({"nonce": nonce(), "username": "a", "password": b"abcd\x00"})
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual('Invalid password', channel.json_body["error"])
|
||||
|
||||
# Super long
|
||||
body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000})
|
||||
request, channel = make_request("POST", self.url, body.encode('utf8'))
|
||||
render(request, self.resource, self.clock)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual('Invalid password', channel.json_body["error"])
|
|
@ -14,100 +14,30 @@
|
|||
# limitations under the License.
|
||||
|
||||
""" Tests REST events for /events paths."""
|
||||
|
||||
from mock import Mock, NonCallableMock
|
||||
from six import PY3
|
||||
|
||||
# twisted imports
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse.rest.client.v1.events
|
||||
import synapse.rest.client.v1.register
|
||||
import synapse.rest.client.v1.room
|
||||
|
||||
from tests import unittest
|
||||
|
||||
from ....utils import MockHttpResource, setup_test_homeserver
|
||||
from .utils import RestTestCase
|
||||
|
||||
PATH_PREFIX = "/_matrix/client/api/v1"
|
||||
|
||||
|
||||
class EventStreamPaginationApiTestCase(unittest.TestCase):
|
||||
""" Tests event streaming query parameters and start/end keys used in the
|
||||
Pagination stream API. """
|
||||
user_id = "sid1"
|
||||
|
||||
def setUp(self):
|
||||
# configure stream and inject items
|
||||
pass
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
def TODO_test_long_poll(self):
|
||||
# stream from 'end' key, send (self+other) message, expect message.
|
||||
|
||||
# stream from 'END', send (self+other) message, expect message.
|
||||
|
||||
# stream from 'end' key, send (self+other) topic, expect topic.
|
||||
|
||||
# stream from 'END', send (self+other) topic, expect topic.
|
||||
|
||||
# stream from 'end' key, send (self+other) invite, expect invite.
|
||||
|
||||
# stream from 'END', send (self+other) invite, expect invite.
|
||||
|
||||
pass
|
||||
|
||||
def TODO_test_stream_forward(self):
|
||||
# stream from START, expect injected items
|
||||
|
||||
# stream from 'start' key, expect same content
|
||||
|
||||
# stream from 'end' key, expect nothing
|
||||
|
||||
# stream from 'END', expect nothing
|
||||
|
||||
# The following is needed for cases where content is removed e.g. you
|
||||
# left a room, so the token you're streaming from is > the one that
|
||||
# would be returned naturally from START>END.
|
||||
# stream from very new token (higher than end key), expect same token
|
||||
# returned as end key
|
||||
pass
|
||||
|
||||
def TODO_test_limits(self):
|
||||
# stream from a key, expect limit_num items
|
||||
|
||||
# stream from START, expect limit_num items
|
||||
|
||||
pass
|
||||
|
||||
def TODO_test_range(self):
|
||||
# stream from key to key, expect X items
|
||||
|
||||
# stream from key to END, expect X items
|
||||
|
||||
# stream from START to key, expect X items
|
||||
|
||||
# stream from START to END, expect all items
|
||||
pass
|
||||
|
||||
def TODO_test_direction(self):
|
||||
# stream from END to START and fwds, expect newest first
|
||||
|
||||
# stream from END to START and bwds, expect oldest first
|
||||
|
||||
# stream from START to END and fwds, expect oldest first
|
||||
|
||||
# stream from START to END and bwds, expect newest first
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class EventStreamPermissionsTestCase(RestTestCase):
|
||||
""" Tests event streaming (GET /events). """
|
||||
|
||||
if PY3:
|
||||
skip = "Skip on Py3 until ported to use not V1 only register."
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
import synapse.rest.client.v1.events
|
||||
import synapse.rest.client.v1_only.register
|
||||
import synapse.rest.client.v1.room
|
||||
|
||||
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
|
||||
|
||||
hs = yield setup_test_homeserver(
|
||||
|
@ -125,7 +55,7 @@ class EventStreamPermissionsTestCase(RestTestCase):
|
|||
|
||||
hs.get_handlers().federation_handler = Mock()
|
||||
|
||||
synapse.rest.client.v1.register.register_servlets(hs, self.mock_resource)
|
||||
synapse.rest.client.v1_only.register.register_servlets(hs, self.mock_resource)
|
||||
synapse.rest.client.v1.events.register_servlets(hs, self.mock_resource)
|
||||
synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)
|
||||
|
||||
|
|
|
@ -16,27 +16,26 @@
|
|||
import json
|
||||
|
||||
from mock import Mock
|
||||
from six import PY3
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.test.proto_helpers import MemoryReactorClock
|
||||
|
||||
from synapse.rest.client.v1.register import CreateUserRestServlet
|
||||
from synapse.http.server import JsonResource
|
||||
from synapse.rest.client.v1_only.register import register_servlets
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.utils import mock_getRawHeaders
|
||||
from tests.server import make_request, setup_test_homeserver
|
||||
|
||||
|
||||
class CreateUserServletTestCase(unittest.TestCase):
|
||||
"""
|
||||
Tests for CreateUserRestServlet.
|
||||
"""
|
||||
if PY3:
|
||||
skip = "Not ported to Python 3."
|
||||
|
||||
def setUp(self):
|
||||
# do the dance to hook up request data to self.request_data
|
||||
self.request_data = ""
|
||||
self.request = Mock(
|
||||
content=Mock(read=Mock(side_effect=lambda: self.request_data)),
|
||||
path='/_matrix/client/api/v1/createUser'
|
||||
)
|
||||
self.request.args = {}
|
||||
self.request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
|
||||
self.registration_handler = Mock()
|
||||
|
||||
self.appservice = Mock(sender="@as:test")
|
||||
|
@ -44,39 +43,49 @@ class CreateUserServletTestCase(unittest.TestCase):
|
|||
get_app_service_by_token=Mock(return_value=self.appservice)
|
||||
)
|
||||
|
||||
# do the dance to hook things up to the hs global
|
||||
handlers = Mock(
|
||||
registration_handler=self.registration_handler,
|
||||
handlers = Mock(registration_handler=self.registration_handler)
|
||||
self.clock = MemoryReactorClock()
|
||||
self.hs_clock = Clock(self.clock)
|
||||
|
||||
self.hs = self.hs = setup_test_homeserver(
|
||||
http_client=None, clock=self.hs_clock, reactor=self.clock
|
||||
)
|
||||
self.hs = Mock()
|
||||
self.hs.hostname = "superbig~testing~thing.com"
|
||||
self.hs.get_datastore = Mock(return_value=self.datastore)
|
||||
self.hs.get_handlers = Mock(return_value=handlers)
|
||||
self.servlet = CreateUserRestServlet(self.hs)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_POST_createuser_with_valid_user(self):
|
||||
|
||||
res = JsonResource(self.hs)
|
||||
register_servlets(self.hs, res)
|
||||
|
||||
request_data = json.dumps(
|
||||
{
|
||||
"localpart": "someone",
|
||||
"displayname": "someone interesting",
|
||||
"duration_seconds": 200,
|
||||
}
|
||||
)
|
||||
|
||||
url = b'/_matrix/client/api/v1/createUser?access_token=i_am_an_app_service'
|
||||
|
||||
user_id = "@someone:interesting"
|
||||
token = "my token"
|
||||
self.request.args = {
|
||||
"access_token": "i_am_an_app_service"
|
||||
}
|
||||
self.request_data = json.dumps({
|
||||
"localpart": "someone",
|
||||
"displayname": "someone interesting",
|
||||
"duration_seconds": 200
|
||||
})
|
||||
|
||||
self.registration_handler.get_or_create_user = Mock(
|
||||
return_value=(user_id, token)
|
||||
)
|
||||
|
||||
(code, result) = yield self.servlet.on_POST(self.request)
|
||||
self.assertEquals(code, 200)
|
||||
request, channel = make_request(b"POST", url, request_data)
|
||||
request.render(res)
|
||||
|
||||
# Advance the clock because it waits
|
||||
self.clock.advance(1)
|
||||
|
||||
self.assertEquals(channel.result["code"], b"200")
|
||||
|
||||
det_data = {
|
||||
"user_id": user_id,
|
||||
"access_token": token,
|
||||
"home_server": self.hs.hostname
|
||||
"home_server": self.hs.hostname,
|
||||
}
|
||||
self.assertDictContainsSubset(det_data, result)
|
||||
self.assertDictContainsSubset(det_data, json.loads(channel.result["body"]))
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -16,13 +16,14 @@
|
|||
import json
|
||||
import time
|
||||
|
||||
# twisted imports
|
||||
import attr
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import Membership
|
||||
|
||||
# trial imports
|
||||
from tests import unittest
|
||||
from tests.server import make_request, wait_until_result
|
||||
|
||||
|
||||
class RestTestCase(unittest.TestCase):
|
||||
|
@ -133,3 +134,113 @@ class RestTestCase(unittest.TestCase):
|
|||
for key in required:
|
||||
self.assertEquals(required[key], actual[key],
|
||||
msg="%s mismatch. %s" % (key, actual))
|
||||
|
||||
|
||||
@attr.s
|
||||
class RestHelper(object):
|
||||
"""Contains extra helper functions to quickly and clearly perform a given
|
||||
REST action, which isn't the focus of the test.
|
||||
"""
|
||||
|
||||
hs = attr.ib()
|
||||
resource = attr.ib()
|
||||
auth_user_id = attr.ib()
|
||||
|
||||
def create_room_as(self, room_creator, is_public=True, tok=None):
|
||||
temp_id = self.auth_user_id
|
||||
self.auth_user_id = room_creator
|
||||
path = b"/_matrix/client/r0/createRoom"
|
||||
content = {}
|
||||
if not is_public:
|
||||
content["visibility"] = "private"
|
||||
if tok:
|
||||
path = path + b"?access_token=%s" % tok.encode('ascii')
|
||||
|
||||
request, channel = make_request(b"POST", path, json.dumps(content).encode('utf8'))
|
||||
request.render(self.resource)
|
||||
wait_until_result(self.hs.get_reactor(), channel)
|
||||
|
||||
assert channel.result["code"] == b"200", channel.result
|
||||
self.auth_user_id = temp_id
|
||||
return channel.json_body["room_id"]
|
||||
|
||||
def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None):
|
||||
self.change_membership(
|
||||
room=room,
|
||||
src=src,
|
||||
targ=targ,
|
||||
tok=tok,
|
||||
membership=Membership.INVITE,
|
||||
expect_code=expect_code,
|
||||
)
|
||||
|
||||
def join(self, room=None, user=None, expect_code=200, tok=None):
|
||||
self.change_membership(
|
||||
room=room,
|
||||
src=user,
|
||||
targ=user,
|
||||
tok=tok,
|
||||
membership=Membership.JOIN,
|
||||
expect_code=expect_code,
|
||||
)
|
||||
|
||||
def leave(self, room=None, user=None, expect_code=200, tok=None):
|
||||
self.change_membership(
|
||||
room=room,
|
||||
src=user,
|
||||
targ=user,
|
||||
tok=tok,
|
||||
membership=Membership.LEAVE,
|
||||
expect_code=expect_code,
|
||||
)
|
||||
|
||||
def change_membership(self, room, src, targ, membership, tok=None, expect_code=200):
|
||||
temp_id = self.auth_user_id
|
||||
self.auth_user_id = src
|
||||
|
||||
path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (room, targ)
|
||||
if tok:
|
||||
path = path + "?access_token=%s" % tok
|
||||
|
||||
data = {"membership": membership}
|
||||
|
||||
request, channel = make_request(
|
||||
b"PUT", path.encode('ascii'), json.dumps(data).encode('utf8')
|
||||
)
|
||||
|
||||
request.render(self.resource)
|
||||
wait_until_result(self.hs.get_reactor(), channel)
|
||||
|
||||
assert int(channel.result["code"]) == expect_code, (
|
||||
"Expected: %d, got: %d, resp: %r"
|
||||
% (expect_code, int(channel.result["code"]), channel.result["body"])
|
||||
)
|
||||
|
||||
self.auth_user_id = temp_id
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def register(self, user_id):
|
||||
(code, response) = yield self.mock_resource.trigger(
|
||||
"POST",
|
||||
"/_matrix/client/r0/register",
|
||||
json.dumps(
|
||||
{"user": user_id, "password": "test", "type": "m.login.password"}
|
||||
),
|
||||
)
|
||||
self.assertEquals(200, code)
|
||||
defer.returnValue(response)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
|
||||
if txn_id is None:
|
||||
txn_id = "m%s" % (str(time.time()))
|
||||
if body is None:
|
||||
body = "body_text_here"
|
||||
|
||||
path = "/_matrix/client/r0/rooms/%s/send/m.room.message/%s" % (room_id, txn_id)
|
||||
content = '{"msgtype":"m.text","body":"%s"}' % body
|
||||
if tok:
|
||||
path = path + "?access_token=%s" % tok
|
||||
|
||||
(code, response) = yield self.mock_resource.trigger("PUT", path, content)
|
||||
self.assertEquals(expect_code, code, msg=str(response))
|
||||
|
|
|
@ -1,61 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from mock import Mock
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.types import UserID
|
||||
|
||||
from tests import unittest
|
||||
|
||||
from ....utils import MockHttpResource, setup_test_homeserver
|
||||
|
||||
PATH_PREFIX = "/_matrix/client/v2_alpha"
|
||||
|
||||
|
||||
class V2AlphaRestTestCase(unittest.TestCase):
|
||||
# Consumer must define
|
||||
# USER_ID = <some string>
|
||||
# TO_REGISTER = [<list of REST servlets to register>]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
|
||||
|
||||
hs = yield setup_test_homeserver(
|
||||
datastore=self.make_datastore_mock(),
|
||||
http_client=None,
|
||||
resource_for_client=self.mock_resource,
|
||||
resource_for_federation=self.mock_resource,
|
||||
)
|
||||
|
||||
def get_user_by_access_token(token=None, allow_guest=False):
|
||||
return {
|
||||
"user": UserID.from_string(self.USER_ID),
|
||||
"token_id": 1,
|
||||
"is_guest": False,
|
||||
}
|
||||
hs.get_auth().get_user_by_access_token = get_user_by_access_token
|
||||
|
||||
for r in self.TO_REGISTER:
|
||||
r.register_servlets(hs, self.mock_resource)
|
||||
|
||||
def make_datastore_mock(self):
|
||||
store = Mock(spec=[
|
||||
"insert_client_ip",
|
||||
])
|
||||
store.get_app_service_by_token = Mock(return_value=None)
|
||||
return store
|
|
@ -13,35 +13,37 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse.types
|
||||
from synapse.api.errors import Codes
|
||||
from synapse.http.server import JsonResource
|
||||
from synapse.rest.client.v2_alpha import filter
|
||||
from synapse.types import UserID
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
|
||||
from ....utils import MockHttpResource, setup_test_homeserver
|
||||
from tests.server import (
|
||||
ThreadedMemoryReactorClock as MemoryReactorClock,
|
||||
make_request,
|
||||
setup_test_homeserver,
|
||||
wait_until_result,
|
||||
)
|
||||
|
||||
PATH_PREFIX = "/_matrix/client/v2_alpha"
|
||||
|
||||
|
||||
class FilterTestCase(unittest.TestCase):
|
||||
|
||||
USER_ID = "@apple:test"
|
||||
USER_ID = b"@apple:test"
|
||||
EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}}
|
||||
EXAMPLE_FILTER_JSON = '{"room": {"timeline": {"types": ["m.room.message"]}}}'
|
||||
EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}'
|
||||
TO_REGISTER = [filter]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
|
||||
self.clock = MemoryReactorClock()
|
||||
self.hs_clock = Clock(self.clock)
|
||||
|
||||
self.hs = yield setup_test_homeserver(
|
||||
http_client=None,
|
||||
resource_for_client=self.mock_resource,
|
||||
resource_for_federation=self.mock_resource,
|
||||
self.hs = setup_test_homeserver(
|
||||
http_client=None, clock=self.hs_clock, reactor=self.clock
|
||||
)
|
||||
|
||||
self.auth = self.hs.get_auth()
|
||||
|
@ -55,82 +57,103 @@ class FilterTestCase(unittest.TestCase):
|
|||
|
||||
def get_user_by_req(request, allow_guest=False, rights="access"):
|
||||
return synapse.types.create_requester(
|
||||
UserID.from_string(self.USER_ID), 1, False, None)
|
||||
UserID.from_string(self.USER_ID), 1, False, None
|
||||
)
|
||||
|
||||
self.auth.get_user_by_access_token = get_user_by_access_token
|
||||
self.auth.get_user_by_req = get_user_by_req
|
||||
|
||||
self.store = self.hs.get_datastore()
|
||||
self.filtering = self.hs.get_filtering()
|
||||
self.resource = JsonResource(self.hs)
|
||||
|
||||
for r in self.TO_REGISTER:
|
||||
r.register_servlets(self.hs, self.mock_resource)
|
||||
r.register_servlets(self.hs, self.resource)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_add_filter(self):
|
||||
(code, response) = yield self.mock_resource.trigger(
|
||||
"POST", "/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON
|
||||
request, channel = make_request(
|
||||
b"POST",
|
||||
b"/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
|
||||
self.EXAMPLE_FILTER_JSON,
|
||||
)
|
||||
self.assertEquals(200, code)
|
||||
self.assertEquals({"filter_id": "0"}, response)
|
||||
filter = yield self.store.get_user_filter(
|
||||
user_localpart='apple',
|
||||
filter_id=0,
|
||||
)
|
||||
self.assertEquals(filter, self.EXAMPLE_FILTER)
|
||||
request.render(self.resource)
|
||||
wait_until_result(self.clock, channel)
|
||||
|
||||
self.assertEqual(channel.result["code"], b"200")
|
||||
self.assertEqual(channel.json_body, {"filter_id": "0"})
|
||||
filter = self.store.get_user_filter(user_localpart="apple", filter_id=0)
|
||||
self.clock.advance(0)
|
||||
self.assertEquals(filter.result, self.EXAMPLE_FILTER)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_add_filter_for_other_user(self):
|
||||
(code, response) = yield self.mock_resource.trigger(
|
||||
"POST", "/user/%s/filter" % ('@watermelon:test'), self.EXAMPLE_FILTER_JSON
|
||||
request, channel = make_request(
|
||||
b"POST",
|
||||
b"/_matrix/client/r0/user/%s/filter" % (b"@watermelon:test"),
|
||||
self.EXAMPLE_FILTER_JSON,
|
||||
)
|
||||
self.assertEquals(403, code)
|
||||
self.assertEquals(response['errcode'], Codes.FORBIDDEN)
|
||||
request.render(self.resource)
|
||||
wait_until_result(self.clock, channel)
|
||||
|
||||
self.assertEqual(channel.result["code"], b"403")
|
||||
self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_add_filter_non_local_user(self):
|
||||
_is_mine = self.hs.is_mine
|
||||
self.hs.is_mine = lambda target_user: False
|
||||
(code, response) = yield self.mock_resource.trigger(
|
||||
"POST", "/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON
|
||||
request, channel = make_request(
|
||||
b"POST",
|
||||
b"/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
|
||||
self.EXAMPLE_FILTER_JSON,
|
||||
)
|
||||
request.render(self.resource)
|
||||
wait_until_result(self.clock, channel)
|
||||
|
||||
self.hs.is_mine = _is_mine
|
||||
self.assertEquals(403, code)
|
||||
self.assertEquals(response['errcode'], Codes.FORBIDDEN)
|
||||
self.assertEqual(channel.result["code"], b"403")
|
||||
self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_filter(self):
|
||||
filter_id = yield self.filtering.add_user_filter(
|
||||
user_localpart='apple',
|
||||
user_filter=self.EXAMPLE_FILTER
|
||||
filter_id = self.filtering.add_user_filter(
|
||||
user_localpart="apple", user_filter=self.EXAMPLE_FILTER
|
||||
)
|
||||
(code, response) = yield self.mock_resource.trigger_get(
|
||||
"/user/%s/filter/%s" % (self.USER_ID, filter_id)
|
||||
self.clock.advance(1)
|
||||
filter_id = filter_id.result
|
||||
request, channel = make_request(
|
||||
b"GET", b"/_matrix/client/r0/user/%s/filter/%s" % (self.USER_ID, filter_id)
|
||||
)
|
||||
self.assertEquals(200, code)
|
||||
self.assertEquals(self.EXAMPLE_FILTER, response)
|
||||
request.render(self.resource)
|
||||
wait_until_result(self.clock, channel)
|
||||
|
||||
self.assertEqual(channel.result["code"], b"200")
|
||||
self.assertEquals(channel.json_body, self.EXAMPLE_FILTER)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_filter_non_existant(self):
|
||||
(code, response) = yield self.mock_resource.trigger_get(
|
||||
"/user/%s/filter/12382148321" % (self.USER_ID)
|
||||
request, channel = make_request(
|
||||
b"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.USER_ID)
|
||||
)
|
||||
self.assertEquals(400, code)
|
||||
self.assertEquals(response['errcode'], Codes.NOT_FOUND)
|
||||
request.render(self.resource)
|
||||
wait_until_result(self.clock, channel)
|
||||
|
||||
self.assertEqual(channel.result["code"], b"400")
|
||||
self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND)
|
||||
|
||||
# Currently invalid params do not have an appropriate errcode
|
||||
# in errors.py
|
||||
@defer.inlineCallbacks
|
||||
def test_get_filter_invalid_id(self):
|
||||
(code, response) = yield self.mock_resource.trigger_get(
|
||||
"/user/%s/filter/foobar" % (self.USER_ID)
|
||||
request, channel = make_request(
|
||||
b"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.USER_ID)
|
||||
)
|
||||
self.assertEquals(400, code)
|
||||
request.render(self.resource)
|
||||
wait_until_result(self.clock, channel)
|
||||
|
||||
self.assertEqual(channel.result["code"], b"400")
|
||||
|
||||
# No ID also returns an invalid_id error
|
||||
@defer.inlineCallbacks
|
||||
def test_get_filter_no_id(self):
|
||||
(code, response) = yield self.mock_resource.trigger_get(
|
||||
"/user/%s/filter/" % (self.USER_ID)
|
||||
request, channel = make_request(
|
||||
b"GET", "/_matrix/client/r0/user/%s/filter/" % (self.USER_ID)
|
||||
)
|
||||
self.assertEquals(400, code)
|
||||
request.render(self.resource)
|
||||
wait_until_result(self.clock, channel)
|
||||
|
||||
self.assertEqual(channel.result["code"], b"400")
|
||||
|
|
|
@ -2,165 +2,192 @@ import json
|
|||
|
||||
from mock import Mock
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.python import failure
|
||||
from twisted.test.proto_helpers import MemoryReactorClock
|
||||
|
||||
from synapse.api.errors import InteractiveAuthIncompleteError, SynapseError
|
||||
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
|
||||
from synapse.api.errors import InteractiveAuthIncompleteError
|
||||
from synapse.http.server import JsonResource
|
||||
from synapse.rest.client.v2_alpha.register import register_servlets
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.utils import mock_getRawHeaders
|
||||
from tests.server import make_request, setup_test_homeserver, wait_until_result
|
||||
|
||||
|
||||
class RegisterRestServletTestCase(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# do the dance to hook up request data to self.request_data
|
||||
self.request_data = ""
|
||||
self.request = Mock(
|
||||
content=Mock(read=Mock(side_effect=lambda: self.request_data)),
|
||||
path='/_matrix/api/v2_alpha/register'
|
||||
)
|
||||
self.request.args = {}
|
||||
self.request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
|
||||
self.clock = MemoryReactorClock()
|
||||
self.hs_clock = Clock(self.clock)
|
||||
self.url = b"/_matrix/client/r0/register"
|
||||
|
||||
self.appservice = None
|
||||
self.auth = Mock(get_appservice_by_req=Mock(
|
||||
side_effect=lambda x: self.appservice)
|
||||
self.auth = Mock(
|
||||
get_appservice_by_req=Mock(side_effect=lambda x: self.appservice)
|
||||
)
|
||||
|
||||
self.auth_result = failure.Failure(InteractiveAuthIncompleteError(None))
|
||||
self.auth_handler = Mock(
|
||||
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
|
||||
get_session_data=Mock(return_value=None)
|
||||
get_session_data=Mock(return_value=None),
|
||||
)
|
||||
self.registration_handler = Mock()
|
||||
self.identity_handler = Mock()
|
||||
self.login_handler = Mock()
|
||||
self.device_handler = Mock()
|
||||
self.device_handler.check_device_registered = Mock(return_value="FAKE")
|
||||
|
||||
self.datastore = Mock(return_value=Mock())
|
||||
self.datastore.get_current_state_deltas = Mock(return_value=[])
|
||||
|
||||
# do the dance to hook it up to the hs global
|
||||
self.handlers = Mock(
|
||||
registration_handler=self.registration_handler,
|
||||
identity_handler=self.identity_handler,
|
||||
login_handler=self.login_handler
|
||||
login_handler=self.login_handler,
|
||||
)
|
||||
self.hs = setup_test_homeserver(
|
||||
http_client=None, clock=self.hs_clock, reactor=self.clock
|
||||
)
|
||||
self.hs = Mock()
|
||||
self.hs.hostname = "superbig~testing~thing.com"
|
||||
self.hs.get_auth = Mock(return_value=self.auth)
|
||||
self.hs.get_handlers = Mock(return_value=self.handlers)
|
||||
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
|
||||
self.hs.get_device_handler = Mock(return_value=self.device_handler)
|
||||
self.hs.get_datastore = Mock(return_value=self.datastore)
|
||||
self.hs.config.enable_registration = True
|
||||
self.hs.config.registrations_require_3pid = []
|
||||
self.hs.config.auto_join_rooms = []
|
||||
|
||||
# init the thing we're testing
|
||||
self.servlet = RegisterRestServlet(self.hs)
|
||||
self.resource = JsonResource(self.hs)
|
||||
register_servlets(self.hs, self.resource)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_POST_appservice_registration_valid(self):
|
||||
user_id = "@kermit:muppet"
|
||||
token = "kermits_access_token"
|
||||
self.request.args = {
|
||||
"access_token": "i_am_an_app_service"
|
||||
}
|
||||
self.request_data = json.dumps({
|
||||
"username": "kermit"
|
||||
})
|
||||
self.appservice = {
|
||||
"id": "1234"
|
||||
}
|
||||
self.registration_handler.appservice_register = Mock(
|
||||
return_value=user_id
|
||||
)
|
||||
self.auth_handler.get_access_token_for_user_id = Mock(
|
||||
return_value=token
|
||||
)
|
||||
self.appservice = {"id": "1234"}
|
||||
self.registration_handler.appservice_register = Mock(return_value=user_id)
|
||||
self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
|
||||
request_data = json.dumps({"username": "kermit"})
|
||||
|
||||
(code, result) = yield self.servlet.on_POST(self.request)
|
||||
self.assertEquals(code, 200)
|
||||
request, channel = make_request(
|
||||
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
|
||||
)
|
||||
request.render(self.resource)
|
||||
wait_until_result(self.clock, channel)
|
||||
|
||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||
det_data = {
|
||||
"user_id": user_id,
|
||||
"access_token": token,
|
||||
"home_server": self.hs.hostname
|
||||
"home_server": self.hs.hostname,
|
||||
}
|
||||
self.assertDictContainsSubset(det_data, result)
|
||||
self.assertDictContainsSubset(det_data, json.loads(channel.result["body"]))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_POST_appservice_registration_invalid(self):
|
||||
self.request.args = {
|
||||
"access_token": "i_am_an_app_service"
|
||||
}
|
||||
|
||||
self.request_data = json.dumps({
|
||||
"username": "kermit"
|
||||
})
|
||||
self.appservice = None # no application service exists
|
||||
result = yield self.servlet.on_POST(self.request)
|
||||
self.assertEquals(result, (401, None))
|
||||
request_data = json.dumps({"username": "kermit"})
|
||||
request, channel = make_request(
|
||||
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
|
||||
)
|
||||
request.render(self.resource)
|
||||
wait_until_result(self.clock, channel)
|
||||
|
||||
self.assertEquals(channel.result["code"], b"401", channel.result)
|
||||
|
||||
def test_POST_bad_password(self):
|
||||
self.request_data = json.dumps({
|
||||
"username": "kermit",
|
||||
"password": 666
|
||||
})
|
||||
d = self.servlet.on_POST(self.request)
|
||||
return self.assertFailure(d, SynapseError)
|
||||
request_data = json.dumps({"username": "kermit", "password": 666})
|
||||
request, channel = make_request(b"POST", self.url, request_data)
|
||||
request.render(self.resource)
|
||||
wait_until_result(self.clock, channel)
|
||||
|
||||
self.assertEquals(channel.result["code"], b"400", channel.result)
|
||||
self.assertEquals(
|
||||
json.loads(channel.result["body"])["error"], "Invalid password"
|
||||
)
|
||||
|
||||
def test_POST_bad_username(self):
|
||||
self.request_data = json.dumps({
|
||||
"username": 777,
|
||||
"password": "monkey"
|
||||
})
|
||||
d = self.servlet.on_POST(self.request)
|
||||
return self.assertFailure(d, SynapseError)
|
||||
request_data = json.dumps({"username": 777, "password": "monkey"})
|
||||
request, channel = make_request(b"POST", self.url, request_data)
|
||||
request.render(self.resource)
|
||||
wait_until_result(self.clock, channel)
|
||||
|
||||
self.assertEquals(channel.result["code"], b"400", channel.result)
|
||||
self.assertEquals(
|
||||
json.loads(channel.result["body"])["error"], "Invalid username"
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_POST_user_valid(self):
|
||||
user_id = "@kermit:muppet"
|
||||
token = "kermits_access_token"
|
||||
device_id = "frogfone"
|
||||
self.request_data = json.dumps({
|
||||
"username": "kermit",
|
||||
"password": "monkey",
|
||||
"device_id": device_id,
|
||||
})
|
||||
self.registration_handler.check_username = Mock(return_value=True)
|
||||
self.auth_result = (None, {
|
||||
"username": "kermit",
|
||||
"password": "monkey"
|
||||
}, None)
|
||||
self.registration_handler.register = Mock(return_value=(user_id, None))
|
||||
self.auth_handler.get_access_token_for_user_id = Mock(
|
||||
return_value=token
|
||||
request_data = json.dumps(
|
||||
{"username": "kermit", "password": "monkey", "device_id": device_id}
|
||||
)
|
||||
self.device_handler.check_device_registered = \
|
||||
Mock(return_value=device_id)
|
||||
self.registration_handler.check_username = Mock(return_value=True)
|
||||
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
|
||||
self.registration_handler.register = Mock(return_value=(user_id, None))
|
||||
self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
|
||||
self.device_handler.check_device_registered = Mock(return_value=device_id)
|
||||
|
||||
request, channel = make_request(b"POST", self.url, request_data)
|
||||
request.render(self.resource)
|
||||
wait_until_result(self.clock, channel)
|
||||
|
||||
(code, result) = yield self.servlet.on_POST(self.request)
|
||||
self.assertEquals(code, 200)
|
||||
det_data = {
|
||||
"user_id": user_id,
|
||||
"access_token": token,
|
||||
"home_server": self.hs.hostname,
|
||||
"device_id": device_id,
|
||||
}
|
||||
self.assertDictContainsSubset(det_data, result)
|
||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||
self.assertDictContainsSubset(det_data, json.loads(channel.result["body"]))
|
||||
self.auth_handler.get_login_tuple_for_user_id(
|
||||
user_id, device_id=device_id, initial_device_display_name=None)
|
||||
user_id, device_id=device_id, initial_device_display_name=None
|
||||
)
|
||||
|
||||
def test_POST_disabled_registration(self):
|
||||
self.hs.config.enable_registration = False
|
||||
self.request_data = json.dumps({
|
||||
"username": "kermit",
|
||||
"password": "monkey"
|
||||
})
|
||||
request_data = json.dumps({"username": "kermit", "password": "monkey"})
|
||||
self.registration_handler.check_username = Mock(return_value=True)
|
||||
self.auth_result = (None, {
|
||||
"username": "kermit",
|
||||
"password": "monkey"
|
||||
}, None)
|
||||
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
|
||||
self.registration_handler.register = Mock(return_value=("@user:id", "t"))
|
||||
d = self.servlet.on_POST(self.request)
|
||||
return self.assertFailure(d, SynapseError)
|
||||
|
||||
request, channel = make_request(b"POST", self.url, request_data)
|
||||
request.render(self.resource)
|
||||
wait_until_result(self.clock, channel)
|
||||
|
||||
self.assertEquals(channel.result["code"], b"403", channel.result)
|
||||
self.assertEquals(
|
||||
json.loads(channel.result["body"])["error"],
|
||||
"Registration has been disabled",
|
||||
)
|
||||
|
||||
def test_POST_guest_registration(self):
|
||||
user_id = "a@b"
|
||||
self.hs.config.macaroon_secret_key = "test"
|
||||
self.hs.config.allow_guest_access = True
|
||||
self.registration_handler.register = Mock(return_value=(user_id, None))
|
||||
|
||||
request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
|
||||
request.render(self.resource)
|
||||
wait_until_result(self.clock, channel)
|
||||
|
||||
det_data = {
|
||||
"user_id": user_id,
|
||||
"home_server": self.hs.hostname,
|
||||
"device_id": "guest_device",
|
||||
}
|
||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||
self.assertDictContainsSubset(det_data, json.loads(channel.result["body"]))
|
||||
|
||||
def test_POST_disabled_guest_registration(self):
|
||||
self.hs.config.allow_guest_access = False
|
||||
|
||||
request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
|
||||
request.render(self.resource)
|
||||
wait_until_result(self.clock, channel)
|
||||
|
||||
self.assertEquals(channel.result["code"], b"403", channel.result)
|
||||
self.assertEquals(
|
||||
json.loads(channel.result["body"])["error"], "Guest access is disabled"
|
||||
)
|
||||
|
|
87
tests/rest/client/v2_alpha/test_sync.py
Normal file
87
tests/rest/client/v2_alpha/test_sync.py
Normal file
|
@ -0,0 +1,87 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vector
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import synapse.types
|
||||
from synapse.http.server import JsonResource
|
||||
from synapse.rest.client.v2_alpha import sync
|
||||
from synapse.types import UserID
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.server import (
|
||||
ThreadedMemoryReactorClock as MemoryReactorClock,
|
||||
make_request,
|
||||
setup_test_homeserver,
|
||||
wait_until_result,
|
||||
)
|
||||
|
||||
PATH_PREFIX = "/_matrix/client/v2_alpha"
|
||||
|
||||
|
||||
class FilterTestCase(unittest.TestCase):
|
||||
|
||||
USER_ID = b"@apple:test"
|
||||
TO_REGISTER = [sync]
|
||||
|
||||
def setUp(self):
|
||||
self.clock = MemoryReactorClock()
|
||||
self.hs_clock = Clock(self.clock)
|
||||
|
||||
self.hs = setup_test_homeserver(
|
||||
http_client=None, clock=self.hs_clock, reactor=self.clock
|
||||
)
|
||||
|
||||
self.auth = self.hs.get_auth()
|
||||
|
||||
def get_user_by_access_token(token=None, allow_guest=False):
|
||||
return {
|
||||
"user": UserID.from_string(self.USER_ID),
|
||||
"token_id": 1,
|
||||
"is_guest": False,
|
||||
}
|
||||
|
||||
def get_user_by_req(request, allow_guest=False, rights="access"):
|
||||
return synapse.types.create_requester(
|
||||
UserID.from_string(self.USER_ID), 1, False, None
|
||||
)
|
||||
|
||||
self.auth.get_user_by_access_token = get_user_by_access_token
|
||||
self.auth.get_user_by_req = get_user_by_req
|
||||
|
||||
self.store = self.hs.get_datastore()
|
||||
self.filtering = self.hs.get_filtering()
|
||||
self.resource = JsonResource(self.hs)
|
||||
|
||||
for r in self.TO_REGISTER:
|
||||
r.register_servlets(self.hs, self.resource)
|
||||
|
||||
def test_sync_argless(self):
|
||||
request, channel = make_request(b"GET", b"/_matrix/client/r0/sync")
|
||||
request.render(self.resource)
|
||||
wait_until_result(self.clock, channel)
|
||||
|
||||
self.assertEqual(channel.result["code"], b"200")
|
||||
self.assertTrue(
|
||||
set(
|
||||
[
|
||||
"next_batch",
|
||||
"rooms",
|
||||
"presence",
|
||||
"account_data",
|
||||
"to_device",
|
||||
"device_lists",
|
||||
]
|
||||
).issubset(set(channel.json_body.keys()))
|
||||
)
|
|
@ -80,6 +80,11 @@ def make_request(method, path, content=b""):
|
|||
content, and return the Request and the Channel underneath.
|
||||
"""
|
||||
|
||||
# Decorate it to be the full path
|
||||
if not path.startswith(b"/_matrix"):
|
||||
path = b"/_matrix/client/r0/" + path
|
||||
path = path.replace("//", "/")
|
||||
|
||||
if isinstance(content, text_type):
|
||||
content = content.encode('utf8')
|
||||
|
||||
|
@ -110,6 +115,11 @@ def wait_until_result(clock, channel, timeout=100):
|
|||
clock.advance(0.1)
|
||||
|
||||
|
||||
def render(request, resource, clock):
|
||||
request.render(resource)
|
||||
wait_until_result(clock, request._channel)
|
||||
|
||||
|
||||
class ThreadedMemoryReactorClock(MemoryReactorClock):
|
||||
"""
|
||||
A MemoryReactorClock that supports callFromThread.
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -15,8 +16,6 @@
|
|||
|
||||
from mock import Mock, patch
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.util.distributor import Distributor
|
||||
|
||||
from . import unittest
|
||||
|
@ -27,38 +26,15 @@ class DistributorTestCase(unittest.TestCase):
|
|||
def setUp(self):
|
||||
self.dist = Distributor()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_signal_dispatch(self):
|
||||
self.dist.declare("alert")
|
||||
|
||||
observer = Mock()
|
||||
self.dist.observe("alert", observer)
|
||||
|
||||
d = self.dist.fire("alert", 1, 2, 3)
|
||||
yield d
|
||||
self.assertTrue(d.called)
|
||||
self.dist.fire("alert", 1, 2, 3)
|
||||
observer.assert_called_with(1, 2, 3)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_signal_dispatch_deferred(self):
|
||||
self.dist.declare("whine")
|
||||
|
||||
d_inner = defer.Deferred()
|
||||
|
||||
def observer():
|
||||
return d_inner
|
||||
|
||||
self.dist.observe("whine", observer)
|
||||
|
||||
d_outer = self.dist.fire("whine")
|
||||
|
||||
self.assertFalse(d_outer.called)
|
||||
|
||||
d_inner.callback(None)
|
||||
yield d_outer
|
||||
self.assertTrue(d_outer.called)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_signal_catch(self):
|
||||
self.dist.declare("alarm")
|
||||
|
||||
|
@ -71,9 +47,7 @@ class DistributorTestCase(unittest.TestCase):
|
|||
with patch(
|
||||
"synapse.util.distributor.logger", spec=["warning"]
|
||||
) as mock_logger:
|
||||
d = self.dist.fire("alarm", "Go")
|
||||
yield d
|
||||
self.assertTrue(d.called)
|
||||
self.dist.fire("alarm", "Go")
|
||||
|
||||
observers[0].assert_called_once_with("Go")
|
||||
observers[1].assert_called_once_with("Go")
|
||||
|
@ -83,34 +57,12 @@ class DistributorTestCase(unittest.TestCase):
|
|||
mock_logger.warning.call_args[0][0], str
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_signal_catch_no_suppress(self):
|
||||
# Gut-wrenching
|
||||
self.dist.suppress_failures = False
|
||||
|
||||
self.dist.declare("whail")
|
||||
|
||||
class MyException(Exception):
|
||||
pass
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def observer():
|
||||
raise MyException("Oopsie")
|
||||
|
||||
self.dist.observe("whail", observer)
|
||||
|
||||
d = self.dist.fire("whail")
|
||||
|
||||
yield self.assertFailure(d, MyException)
|
||||
self.dist.suppress_failures = True
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_signal_prereg(self):
|
||||
observer = Mock()
|
||||
self.dist.observe("flare", observer)
|
||||
|
||||
self.dist.declare("flare")
|
||||
yield self.dist.fire("flare", 4, 5)
|
||||
self.dist.fire("flare", 4, 5)
|
||||
|
||||
observer.assert_called_with(4, 5)
|
||||
|
||||
|
|
|
@ -137,7 +137,6 @@ class MessageAcceptTests(unittest.TestCase):
|
|||
)
|
||||
self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
|
||||
|
||||
@unittest.DEBUG
|
||||
def test_cant_hide_past_history(self):
|
||||
"""
|
||||
If you send a message, you must be able to provide the direct
|
||||
|
@ -178,7 +177,7 @@ class MessageAcceptTests(unittest.TestCase):
|
|||
for x, y in d.items()
|
||||
if x == ("m.room.member", "@us:test")
|
||||
],
|
||||
"auth_chain_ids": d.values(),
|
||||
"auth_chain_ids": list(d.values()),
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
@ -33,9 +33,11 @@ class JsonResourceTests(unittest.TestCase):
|
|||
return (200, kwargs)
|
||||
|
||||
res = JsonResource(self.homeserver)
|
||||
res.register_paths("GET", [re.compile("^/foo/(?P<room_id>[^/]*)$")], _callback)
|
||||
res.register_paths(
|
||||
"GET", [re.compile("^/_matrix/foo/(?P<room_id>[^/]*)$")], _callback
|
||||
)
|
||||
|
||||
request, channel = make_request(b"GET", b"/foo/%E2%98%83?a=%E2%98%83")
|
||||
request, channel = make_request(b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83")
|
||||
request.render(res)
|
||||
|
||||
self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]})
|
||||
|
@ -51,9 +53,9 @@ class JsonResourceTests(unittest.TestCase):
|
|||
raise Exception("boo")
|
||||
|
||||
res = JsonResource(self.homeserver)
|
||||
res.register_paths("GET", [re.compile("^/foo$")], _callback)
|
||||
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
|
||||
|
||||
request, channel = make_request(b"GET", b"/foo")
|
||||
request, channel = make_request(b"GET", b"/_matrix/foo")
|
||||
request.render(res)
|
||||
|
||||
self.assertEqual(channel.result["code"], b'500')
|
||||
|
@ -74,9 +76,9 @@ class JsonResourceTests(unittest.TestCase):
|
|||
return d
|
||||
|
||||
res = JsonResource(self.homeserver)
|
||||
res.register_paths("GET", [re.compile("^/foo$")], _callback)
|
||||
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
|
||||
|
||||
request, channel = make_request(b"GET", b"/foo")
|
||||
request, channel = make_request(b"GET", b"/_matrix/foo")
|
||||
request.render(res)
|
||||
|
||||
# No error has been raised yet
|
||||
|
@ -96,9 +98,9 @@ class JsonResourceTests(unittest.TestCase):
|
|||
raise SynapseError(403, "Forbidden!!one!", Codes.FORBIDDEN)
|
||||
|
||||
res = JsonResource(self.homeserver)
|
||||
res.register_paths("GET", [re.compile("^/foo$")], _callback)
|
||||
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
|
||||
|
||||
request, channel = make_request(b"GET", b"/foo")
|
||||
request, channel = make_request(b"GET", b"/_matrix/foo")
|
||||
request.render(res)
|
||||
|
||||
self.assertEqual(channel.result["code"], b'403')
|
||||
|
@ -118,9 +120,9 @@ class JsonResourceTests(unittest.TestCase):
|
|||
self.fail("shouldn't ever get here")
|
||||
|
||||
res = JsonResource(self.homeserver)
|
||||
res.register_paths("GET", [re.compile("^/foo$")], _callback)
|
||||
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
|
||||
|
||||
request, channel = make_request(b"GET", b"/foobar")
|
||||
request, channel = make_request(b"GET", b"/_matrix/foobar")
|
||||
request.render(res)
|
||||
|
||||
self.assertEqual(channel.result["code"], b'400')
|
||||
|
|
324
tests/test_visibility.py
Normal file
324
tests/test_visibility.py
Normal file
|
@ -0,0 +1,324 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.defer import succeed
|
||||
|
||||
from synapse.events import FrozenEvent
|
||||
from synapse.visibility import filter_events_for_server
|
||||
|
||||
import tests.unittest
|
||||
from tests.utils import setup_test_homeserver
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TEST_ROOM_ID = "!TEST:ROOM"
|
||||
|
||||
|
||||
class FilterEventsForServerTestCase(tests.unittest.TestCase):
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
self.hs = yield setup_test_homeserver()
|
||||
self.event_creation_handler = self.hs.get_event_creation_handler()
|
||||
self.event_builder_factory = self.hs.get_event_builder_factory()
|
||||
self.store = self.hs.get_datastore()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_filtering(self):
|
||||
#
|
||||
# The events to be filtered consist of 10 membership events (it doesn't
|
||||
# really matter if they are joins or leaves, so let's make them joins).
|
||||
# One of those membership events is going to be for a user on the
|
||||
# server we are filtering for (so we can check the filtering is doing
|
||||
# the right thing).
|
||||
#
|
||||
|
||||
# before we do that, we persist some other events to act as state.
|
||||
self.inject_visibility("@admin:hs", "joined")
|
||||
for i in range(0, 10):
|
||||
yield self.inject_room_member("@resident%i:hs" % i)
|
||||
|
||||
events_to_filter = []
|
||||
|
||||
for i in range(0, 10):
|
||||
user = "@user%i:%s" % (
|
||||
i, "test_server" if i == 5 else "other_server"
|
||||
)
|
||||
evt = yield self.inject_room_member(user, extra_content={"a": "b"})
|
||||
events_to_filter.append(evt)
|
||||
|
||||
filtered = yield filter_events_for_server(
|
||||
self.store, "test_server", events_to_filter,
|
||||
)
|
||||
|
||||
# the result should be 5 redacted events, and 5 unredacted events.
|
||||
for i in range(0, 5):
|
||||
self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
|
||||
self.assertNotIn("a", filtered[i].content)
|
||||
|
||||
for i in range(5, 10):
|
||||
self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
|
||||
self.assertEqual(filtered[i].content["a"], "b")
|
||||
|
||||
@tests.unittest.DEBUG
|
||||
@defer.inlineCallbacks
|
||||
def test_erased_user(self):
|
||||
# 4 message events, from erased and unerased users, with a membership
|
||||
# change in the middle of them.
|
||||
events_to_filter = []
|
||||
|
||||
evt = yield self.inject_message("@unerased:local_hs")
|
||||
events_to_filter.append(evt)
|
||||
|
||||
evt = yield self.inject_message("@erased:local_hs")
|
||||
events_to_filter.append(evt)
|
||||
|
||||
evt = yield self.inject_room_member("@joiner:remote_hs")
|
||||
events_to_filter.append(evt)
|
||||
|
||||
evt = yield self.inject_message("@unerased:local_hs")
|
||||
events_to_filter.append(evt)
|
||||
|
||||
evt = yield self.inject_message("@erased:local_hs")
|
||||
events_to_filter.append(evt)
|
||||
|
||||
# the erasey user gets erased
|
||||
self.hs.get_datastore().mark_user_erased("@erased:local_hs")
|
||||
|
||||
# ... and the filtering happens.
|
||||
filtered = yield filter_events_for_server(
|
||||
self.store, "test_server", events_to_filter,
|
||||
)
|
||||
|
||||
for i in range(0, len(events_to_filter)):
|
||||
self.assertEqual(
|
||||
events_to_filter[i].event_id, filtered[i].event_id,
|
||||
"Unexpected event at result position %i" % (i, )
|
||||
)
|
||||
|
||||
for i in (0, 3):
|
||||
self.assertEqual(
|
||||
events_to_filter[i].content["body"], filtered[i].content["body"],
|
||||
"Unexpected event content at result position %i" % (i,)
|
||||
)
|
||||
|
||||
for i in (1, 4):
|
||||
self.assertNotIn("body", filtered[i].content)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def inject_visibility(self, user_id, visibility):
|
||||
content = {"history_visibility": visibility}
|
||||
builder = self.event_builder_factory.new({
|
||||
"type": "m.room.history_visibility",
|
||||
"sender": user_id,
|
||||
"state_key": "",
|
||||
"room_id": TEST_ROOM_ID,
|
||||
"content": content,
|
||||
})
|
||||
|
||||
event, context = yield self.event_creation_handler.create_new_client_event(
|
||||
builder
|
||||
)
|
||||
yield self.hs.get_datastore().persist_event(event, context)
|
||||
defer.returnValue(event)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def inject_room_member(self, user_id, membership="join", extra_content={}):
|
||||
content = {"membership": membership}
|
||||
content.update(extra_content)
|
||||
builder = self.event_builder_factory.new({
|
||||
"type": "m.room.member",
|
||||
"sender": user_id,
|
||||
"state_key": user_id,
|
||||
"room_id": TEST_ROOM_ID,
|
||||
"content": content,
|
||||
})
|
||||
|
||||
event, context = yield self.event_creation_handler.create_new_client_event(
|
||||
builder
|
||||
)
|
||||
|
||||
yield self.hs.get_datastore().persist_event(event, context)
|
||||
defer.returnValue(event)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def inject_message(self, user_id, content=None):
|
||||
if content is None:
|
||||
content = {"body": "testytest"}
|
||||
builder = self.event_builder_factory.new({
|
||||
"type": "m.room.message",
|
||||
"sender": user_id,
|
||||
"room_id": TEST_ROOM_ID,
|
||||
"content": content,
|
||||
})
|
||||
|
||||
event, context = yield self.event_creation_handler.create_new_client_event(
|
||||
builder
|
||||
)
|
||||
|
||||
yield self.hs.get_datastore().persist_event(event, context)
|
||||
defer.returnValue(event)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_large_room(self):
|
||||
# see what happens when we have a large room with hundreds of thousands
|
||||
# of membership events
|
||||
|
||||
# As above, the events to be filtered consist of 10 membership events,
|
||||
# where one of them is for a user on the server we are filtering for.
|
||||
|
||||
import cProfile
|
||||
import pstats
|
||||
import time
|
||||
|
||||
# we stub out the store, because building up all that state the normal
|
||||
# way is very slow.
|
||||
test_store = _TestStore()
|
||||
|
||||
# our initial state is 100000 membership events and one
|
||||
# history_visibility event.
|
||||
room_state = []
|
||||
|
||||
history_visibility_evt = FrozenEvent({
|
||||
"event_id": "$history_vis",
|
||||
"type": "m.room.history_visibility",
|
||||
"sender": "@resident_user_0:test.com",
|
||||
"state_key": "",
|
||||
"room_id": TEST_ROOM_ID,
|
||||
"content": {"history_visibility": "joined"},
|
||||
})
|
||||
room_state.append(history_visibility_evt)
|
||||
test_store.add_event(history_visibility_evt)
|
||||
|
||||
for i in range(0, 100000):
|
||||
user = "@resident_user_%i:test.com" % (i, )
|
||||
evt = FrozenEvent({
|
||||
"event_id": "$res_event_%i" % (i, ),
|
||||
"type": "m.room.member",
|
||||
"state_key": user,
|
||||
"sender": user,
|
||||
"room_id": TEST_ROOM_ID,
|
||||
"content": {
|
||||
"membership": "join",
|
||||
"extra": "zzz,"
|
||||
},
|
||||
})
|
||||
room_state.append(evt)
|
||||
test_store.add_event(evt)
|
||||
|
||||
events_to_filter = []
|
||||
for i in range(0, 10):
|
||||
user = "@user%i:%s" % (
|
||||
i, "test_server" if i == 5 else "other_server"
|
||||
)
|
||||
evt = FrozenEvent({
|
||||
"event_id": "$evt%i" % (i, ),
|
||||
"type": "m.room.member",
|
||||
"state_key": user,
|
||||
"sender": user,
|
||||
"room_id": TEST_ROOM_ID,
|
||||
"content": {
|
||||
"membership": "join",
|
||||
"extra": "zzz",
|
||||
},
|
||||
})
|
||||
events_to_filter.append(evt)
|
||||
room_state.append(evt)
|
||||
|
||||
test_store.add_event(evt)
|
||||
test_store.set_state_ids_for_event(evt, {
|
||||
(e.type, e.state_key): e.event_id for e in room_state
|
||||
})
|
||||
|
||||
pr = cProfile.Profile()
|
||||
pr.enable()
|
||||
|
||||
logger.info("Starting filtering")
|
||||
start = time.time()
|
||||
filtered = yield filter_events_for_server(
|
||||
test_store, "test_server", events_to_filter,
|
||||
)
|
||||
logger.info("Filtering took %f seconds", time.time() - start)
|
||||
|
||||
pr.disable()
|
||||
with open("filter_events_for_server.profile", "w+") as f:
|
||||
ps = pstats.Stats(pr, stream=f).sort_stats('cumulative')
|
||||
ps.print_stats()
|
||||
|
||||
# the result should be 5 redacted events, and 5 unredacted events.
|
||||
for i in range(0, 5):
|
||||
self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
|
||||
self.assertNotIn("extra", filtered[i].content)
|
||||
|
||||
for i in range(5, 10):
|
||||
self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
|
||||
self.assertEqual(filtered[i].content["extra"], "zzz")
|
||||
|
||||
test_large_room.skip = "Disabled by default because it's slow"
|
||||
|
||||
|
||||
class _TestStore(object):
|
||||
"""Implements a few methods of the DataStore, so that we can test
|
||||
filter_events_for_server
|
||||
|
||||
"""
|
||||
def __init__(self):
|
||||
# data for get_events: a map from event_id to event
|
||||
self.events = {}
|
||||
|
||||
# data for get_state_ids_for_events mock: a map from event_id to
|
||||
# a map from (type_state_key) -> event_id for the state at that
|
||||
# event
|
||||
self.state_ids_for_events = {}
|
||||
|
||||
def add_event(self, event):
|
||||
self.events[event.event_id] = event
|
||||
|
||||
def set_state_ids_for_event(self, event, state):
|
||||
self.state_ids_for_events[event.event_id] = state
|
||||
|
||||
def get_state_ids_for_events(self, events, types):
|
||||
res = {}
|
||||
include_memberships = False
|
||||
for (type, state_key) in types:
|
||||
if type == "m.room.history_visibility":
|
||||
continue
|
||||
if type != "m.room.member" or state_key is not None:
|
||||
raise RuntimeError(
|
||||
"Unimplemented: get_state_ids with type (%s, %s)" %
|
||||
(type, state_key),
|
||||
)
|
||||
include_memberships = True
|
||||
|
||||
if include_memberships:
|
||||
for event_id in events:
|
||||
res[event_id] = self.state_ids_for_events[event_id]
|
||||
|
||||
else:
|
||||
k = ("m.room.history_visibility", "")
|
||||
for event_id in events:
|
||||
hve = self.state_ids_for_events[event_id][k]
|
||||
res[event_id] = {k: hve}
|
||||
|
||||
return succeed(res)
|
||||
|
||||
def get_events(self, events):
|
||||
return succeed({
|
||||
event_id: self.events[event_id] for event_id in events
|
||||
})
|
||||
|
||||
def are_users_erased(self, users):
|
||||
return succeed({u: False for u in users})
|
|
@ -109,6 +109,17 @@ class TestCase(unittest.TestCase):
|
|||
except AssertionError as e:
|
||||
raise (type(e))(e.message + " for '.%s'" % key)
|
||||
|
||||
def assert_dict(self, required, actual):
|
||||
"""Does a partial assert of a dict.
|
||||
|
||||
Args:
|
||||
required (dict): The keys and value which MUST be in 'actual'.
|
||||
actual (dict): The test result. Extra keys will not be checked.
|
||||
"""
|
||||
for key in required:
|
||||
self.assertEquals(required[key], actual[key],
|
||||
msg="%s mismatch. %s" % (key, actual))
|
||||
|
||||
|
||||
def DEBUG(target):
|
||||
"""A decorator to set the .loglevel attribute to logging.DEBUG.
|
||||
|
|
|
@ -1,70 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.util.async import Limiter
|
||||
|
||||
from tests import unittest
|
||||
|
||||
|
||||
class LimiterTestCase(unittest.TestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_limiter(self):
|
||||
limiter = Limiter(3)
|
||||
|
||||
key = object()
|
||||
|
||||
d1 = limiter.queue(key)
|
||||
cm1 = yield d1
|
||||
|
||||
d2 = limiter.queue(key)
|
||||
cm2 = yield d2
|
||||
|
||||
d3 = limiter.queue(key)
|
||||
cm3 = yield d3
|
||||
|
||||
d4 = limiter.queue(key)
|
||||
self.assertFalse(d4.called)
|
||||
|
||||
d5 = limiter.queue(key)
|
||||
self.assertFalse(d5.called)
|
||||
|
||||
with cm1:
|
||||
self.assertFalse(d4.called)
|
||||
self.assertFalse(d5.called)
|
||||
|
||||
self.assertTrue(d4.called)
|
||||
self.assertFalse(d5.called)
|
||||
|
||||
with cm3:
|
||||
self.assertFalse(d5.called)
|
||||
|
||||
self.assertTrue(d5.called)
|
||||
|
||||
with cm2:
|
||||
pass
|
||||
|
||||
with (yield d4):
|
||||
pass
|
||||
|
||||
with (yield d5):
|
||||
pass
|
||||
|
||||
d6 = limiter.queue(key)
|
||||
with (yield d6):
|
||||
pass
|
|
@ -1,5 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
# Copyright 2018 New Vector Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -16,6 +17,7 @@
|
|||
from six.moves import range
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet.defer import CancelledError
|
||||
|
||||
from synapse.util import Clock, logcontext
|
||||
from synapse.util.async import Linearizer
|
||||
|
@ -65,3 +67,79 @@ class LinearizerTestCase(unittest.TestCase):
|
|||
func(i)
|
||||
|
||||
return func(1000)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_multiple_entries(self):
|
||||
limiter = Linearizer(max_count=3)
|
||||
|
||||
key = object()
|
||||
|
||||
d1 = limiter.queue(key)
|
||||
cm1 = yield d1
|
||||
|
||||
d2 = limiter.queue(key)
|
||||
cm2 = yield d2
|
||||
|
||||
d3 = limiter.queue(key)
|
||||
cm3 = yield d3
|
||||
|
||||
d4 = limiter.queue(key)
|
||||
self.assertFalse(d4.called)
|
||||
|
||||
d5 = limiter.queue(key)
|
||||
self.assertFalse(d5.called)
|
||||
|
||||
with cm1:
|
||||
self.assertFalse(d4.called)
|
||||
self.assertFalse(d5.called)
|
||||
|
||||
cm4 = yield d4
|
||||
self.assertFalse(d5.called)
|
||||
|
||||
with cm3:
|
||||
self.assertFalse(d5.called)
|
||||
|
||||
cm5 = yield d5
|
||||
|
||||
with cm2:
|
||||
pass
|
||||
|
||||
with cm4:
|
||||
pass
|
||||
|
||||
with cm5:
|
||||
pass
|
||||
|
||||
d6 = limiter.queue(key)
|
||||
with (yield d6):
|
||||
pass
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_cancellation(self):
|
||||
linearizer = Linearizer()
|
||||
|
||||
key = object()
|
||||
|
||||
d1 = linearizer.queue(key)
|
||||
cm1 = yield d1
|
||||
|
||||
d2 = linearizer.queue(key)
|
||||
self.assertFalse(d2.called)
|
||||
|
||||
d3 = linearizer.queue(key)
|
||||
self.assertFalse(d3.called)
|
||||
|
||||
d2.cancel()
|
||||
|
||||
with cm1:
|
||||
pass
|
||||
|
||||
self.assertTrue(d2.called)
|
||||
try:
|
||||
yield d2
|
||||
self.fail("Expected d2 to raise CancelledError")
|
||||
except CancelledError:
|
||||
pass
|
||||
|
||||
with (yield d3):
|
||||
pass
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue