mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-15 04:03:51 +01:00
Merge branch 'release-v0.20.0' of github.com:matrix-org/synapse
This commit is contained in:
commit
4902db1fc9
100 changed files with 3724 additions and 1355 deletions
53
CHANGES.rst
53
CHANGES.rst
|
@ -1,3 +1,56 @@
|
|||
Changes in synapse v0.20.0 (2017-04-11)
|
||||
=======================================
|
||||
|
||||
Bug fixes:
|
||||
|
||||
* Fix joining rooms over federation where not all servers in the room saw the
|
||||
new server had joined (PR #2094)
|
||||
|
||||
|
||||
Changes in synapse v0.20.0-rc1 (2017-03-30)
|
||||
===========================================
|
||||
|
||||
Features:
|
||||
|
||||
* Add delete_devices API (PR #1993)
|
||||
* Add phone number registration/login support (PR #1994, #2055)
|
||||
|
||||
|
||||
Changes:
|
||||
|
||||
* Use JSONSchema for validation of filters. Thanks @pik! (PR #1783)
|
||||
* Reread log config on SIGHUP (PR #1982)
|
||||
* Speed up public room list (PR #1989)
|
||||
* Add helpful texts to logger config options (PR #1990)
|
||||
* Minor ``/sync`` performance improvements. (PR #2002, #2013, #2022)
|
||||
* Add some debug to help diagnose weird federation issue (PR #2035)
|
||||
* Correctly limit retries for all federation requests (PR #2050, #2061)
|
||||
* Don't lock table when persisting new one time keys (PR #2053)
|
||||
* Reduce some CPU work on DB threads (PR #2054)
|
||||
* Cache hosts in room (PR #2060)
|
||||
* Batch sending of device list pokes (PR #2063)
|
||||
* Speed up persist event path in certain edge cases (PR #2070)
|
||||
|
||||
|
||||
Bug fixes:
|
||||
|
||||
* Fix bug where current_state_events renamed to current_state_ids (PR #1849)
|
||||
* Fix routing loop when fetching remote media (PR #1992)
|
||||
* Fix current_state_events table to not lie (PR #1996)
|
||||
* Fix CAS login to handle PartialDownloadError (PR #1997)
|
||||
* Fix assertion to stop transaction queue getting wedged (PR #2010)
|
||||
* Fix presence to fallback to last_active_ts if it beats the last sync time.
|
||||
Thanks @Half-Shot! (PR #2014)
|
||||
* Fix bug when federation received a PDU while a room join is in progress (PR
|
||||
#2016)
|
||||
* Fix resetting state on rejected events (PR #2025)
|
||||
* Fix installation issues in readme. Thanks @ricco386 (PR #2037)
|
||||
* Fix caching of remote servers' signature keys (PR #2042)
|
||||
* Fix some leaking log context (PR #2048, #2049, #2057, #2058)
|
||||
* Fix rejection of invites not reaching sync (PR #2056)
|
||||
|
||||
|
||||
|
||||
Changes in synapse v0.19.3 (2017-03-20)
|
||||
=======================================
|
||||
|
||||
|
|
|
@ -146,6 +146,7 @@ To install the synapse homeserver run::
|
|||
|
||||
virtualenv -p python2.7 ~/.synapse
|
||||
source ~/.synapse/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade setuptools
|
||||
pip install https://github.com/matrix-org/synapse/tarball/master
|
||||
|
||||
|
@ -228,6 +229,7 @@ To get started, it is easiest to use the command line to register new users::
|
|||
New user localpart: erikj
|
||||
Password:
|
||||
Confirm password:
|
||||
Make admin [no]:
|
||||
Success!
|
||||
|
||||
This process uses a setting ``registration_shared_secret`` in
|
||||
|
@ -808,7 +810,7 @@ directory of your choice::
|
|||
Synapse has a number of external dependencies, that are easiest
|
||||
to install using pip and a virtualenv::
|
||||
|
||||
virtualenv env
|
||||
virtualenv -p python2.7 env
|
||||
source env/bin/activate
|
||||
python synapse/python_dependencies.py | xargs pip install
|
||||
pip install lxml mock
|
||||
|
|
|
@ -39,9 +39,11 @@ loggers:
|
|||
synapse:
|
||||
level: INFO
|
||||
|
||||
synapse.storage:
|
||||
synapse.storage.SQL:
|
||||
# beware: increasing this to DEBUG will make synapse log sensitive
|
||||
# information such as access tokens.
|
||||
level: INFO
|
||||
|
||||
|
||||
# example of enabling debugging for a component:
|
||||
#
|
||||
# synapse.federation.transport.server:
|
||||
|
|
|
@ -1,10 +1,446 @@
|
|||
What do I do about "Unexpected logging context" debug log-lines everywhere?
|
||||
Log contexts
|
||||
============
|
||||
|
||||
<Mjark> The logging context lives in thread local storage
|
||||
<Mjark> Sometimes it gets out of sync with what it should actually be, usually because something scheduled something to run on the reactor without preserving the logging context.
|
||||
<Matthew> what is the impact of it getting out of sync? and how and when should we preserve log context?
|
||||
<Mjark> The impact is that some of the CPU and database metrics will be under-reported, and some log lines will be mis-attributed.
|
||||
<Mjark> It should happen auto-magically in all the APIs that do IO or otherwise defer to the reactor.
|
||||
<Erik> Mjark: the other place is if we branch, e.g. using defer.gatherResults
|
||||
.. contents::
|
||||
|
||||
Unanswered: how and when should we preserve log context?
|
||||
To help track the processing of individual requests, synapse uses a
|
||||
'log context' to track which request it is handling at any given moment. This
|
||||
is done via a thread-local variable; a ``logging.Filter`` is then used to fish
|
||||
the information back out of the thread-local variable and add it to each log
|
||||
record.
|
||||
|
||||
Logcontexts are also used for CPU and database accounting, so that we can track
|
||||
which requests were responsible for high CPU use or database activity.
|
||||
|
||||
The ``synapse.util.logcontext`` module provides a facilities for managing the
|
||||
current log context (as well as providing the ``LoggingContextFilter`` class).
|
||||
|
||||
Deferreds make the whole thing complicated, so this document describes how it
|
||||
all works, and how to write code which follows the rules.
|
||||
|
||||
Logcontexts without Deferreds
|
||||
-----------------------------
|
||||
|
||||
In the absence of any Deferred voodoo, things are simple enough. As with any
|
||||
code of this nature, the rule is that our function should leave things as it
|
||||
found them:
|
||||
|
||||
.. code:: python
|
||||
|
||||
from synapse.util import logcontext # omitted from future snippets
|
||||
|
||||
def handle_request(request_id):
|
||||
request_context = logcontext.LoggingContext()
|
||||
|
||||
calling_context = logcontext.LoggingContext.current_context()
|
||||
logcontext.LoggingContext.set_current_context(request_context)
|
||||
try:
|
||||
request_context.request = request_id
|
||||
do_request_handling()
|
||||
logger.debug("finished")
|
||||
finally:
|
||||
logcontext.LoggingContext.set_current_context(calling_context)
|
||||
|
||||
def do_request_handling():
|
||||
logger.debug("phew") # this will be logged against request_id
|
||||
|
||||
|
||||
LoggingContext implements the context management methods, so the above can be
|
||||
written much more succinctly as:
|
||||
|
||||
.. code:: python
|
||||
|
||||
def handle_request(request_id):
|
||||
with logcontext.LoggingContext() as request_context:
|
||||
request_context.request = request_id
|
||||
do_request_handling()
|
||||
logger.debug("finished")
|
||||
|
||||
def do_request_handling():
|
||||
logger.debug("phew")
|
||||
|
||||
|
||||
Using logcontexts with Deferreds
|
||||
--------------------------------
|
||||
|
||||
Deferreds — and in particular, ``defer.inlineCallbacks`` — break
|
||||
the linear flow of code so that there is no longer a single entry point where
|
||||
we should set the logcontext and a single exit point where we should remove it.
|
||||
|
||||
Consider the example above, where ``do_request_handling`` needs to do some
|
||||
blocking operation, and returns a deferred:
|
||||
|
||||
.. code:: python
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_request(request_id):
|
||||
with logcontext.LoggingContext() as request_context:
|
||||
request_context.request = request_id
|
||||
yield do_request_handling()
|
||||
logger.debug("finished")
|
||||
|
||||
|
||||
In the above flow:
|
||||
|
||||
* The logcontext is set
|
||||
* ``do_request_handling`` is called, and returns a deferred
|
||||
* ``handle_request`` yields the deferred
|
||||
* The ``inlineCallbacks`` wrapper of ``handle_request`` returns a deferred
|
||||
|
||||
So we have stopped processing the request (and will probably go on to start
|
||||
processing the next), without clearing the logcontext.
|
||||
|
||||
To circumvent this problem, synapse code assumes that, wherever you have a
|
||||
deferred, you will want to yield on it. To that end, whereever functions return
|
||||
a deferred, we adopt the following conventions:
|
||||
|
||||
**Rules for functions returning deferreds:**
|
||||
|
||||
* If the deferred is already complete, the function returns with the same
|
||||
logcontext it started with.
|
||||
* If the deferred is incomplete, the function clears the logcontext before
|
||||
returning; when the deferred completes, it restores the logcontext before
|
||||
running any callbacks.
|
||||
|
||||
That sounds complicated, but actually it means a lot of code (including the
|
||||
example above) "just works". There are two cases:
|
||||
|
||||
* If ``do_request_handling`` returns a completed deferred, then the logcontext
|
||||
will still be in place. In this case, execution will continue immediately
|
||||
after the ``yield``; the "finished" line will be logged against the right
|
||||
context, and the ``with`` block restores the original context before we
|
||||
return to the caller.
|
||||
|
||||
* If the returned deferred is incomplete, ``do_request_handling`` clears the
|
||||
logcontext before returning. The logcontext is therefore clear when
|
||||
``handle_request`` yields the deferred. At that point, the ``inlineCallbacks``
|
||||
wrapper adds a callback to the deferred, and returns another (incomplete)
|
||||
deferred to the caller, and it is safe to begin processing the next request.
|
||||
|
||||
Once ``do_request_handling``'s deferred completes, it will reinstate the
|
||||
logcontext, before running the callback added by the ``inlineCallbacks``
|
||||
wrapper. That callback runs the second half of ``handle_request``, so again
|
||||
the "finished" line will be logged against the right
|
||||
context, and the ``with`` block restores the original context.
|
||||
|
||||
As an aside, it's worth noting that ``handle_request`` follows our rules -
|
||||
though that only matters if the caller has its own logcontext which it cares
|
||||
about.
|
||||
|
||||
The following sections describe pitfalls and helpful patterns when implementing
|
||||
these rules.
|
||||
|
||||
Always yield your deferreds
|
||||
---------------------------
|
||||
|
||||
Whenever you get a deferred back from a function, you should ``yield`` on it
|
||||
as soon as possible. (Returning it directly to your caller is ok too, if you're
|
||||
not doing ``inlineCallbacks``.) Do not pass go; do not do any logging; do not
|
||||
call any other functions.
|
||||
|
||||
.. code:: python
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def fun():
|
||||
logger.debug("starting")
|
||||
yield do_some_stuff() # just like this
|
||||
|
||||
d = more_stuff()
|
||||
result = yield d # also fine, of course
|
||||
|
||||
defer.returnValue(result)
|
||||
|
||||
def nonInlineCallbacksFun():
|
||||
logger.debug("just a wrapper really")
|
||||
return do_some_stuff() # this is ok too - the caller will yield on
|
||||
# it anyway.
|
||||
|
||||
Provided this pattern is followed all the way back up to the callchain to where
|
||||
the logcontext was set, this will make things work out ok: provided
|
||||
``do_some_stuff`` and ``more_stuff`` follow the rules above, then so will
|
||||
``fun`` (as wrapped by ``inlineCallbacks``) and ``nonInlineCallbacksFun``.
|
||||
|
||||
It's all too easy to forget to ``yield``: for instance if we forgot that
|
||||
``do_some_stuff`` returned a deferred, we might plough on regardless. This
|
||||
leads to a mess; it will probably work itself out eventually, but not before
|
||||
a load of stuff has been logged against the wrong content. (Normally, other
|
||||
things will break, more obviously, if you forget to ``yield``, so this tends
|
||||
not to be a major problem in practice.)
|
||||
|
||||
Of course sometimes you need to do something a bit fancier with your Deferreds
|
||||
- not all code follows the linear A-then-B-then-C pattern. Notes on
|
||||
implementing more complex patterns are in later sections.
|
||||
|
||||
Where you create a new Deferred, make it follow the rules
|
||||
---------------------------------------------------------
|
||||
|
||||
Most of the time, a Deferred comes from another synapse function. Sometimes,
|
||||
though, we need to make up a new Deferred, or we get a Deferred back from
|
||||
external code. We need to make it follow our rules.
|
||||
|
||||
The easy way to do it is with a combination of ``defer.inlineCallbacks``, and
|
||||
``logcontext.PreserveLoggingContext``. Suppose we want to implement ``sleep``,
|
||||
which returns a deferred which will run its callbacks after a given number of
|
||||
seconds. That might look like:
|
||||
|
||||
.. code:: python
|
||||
|
||||
# not a logcontext-rules-compliant function
|
||||
def get_sleep_deferred(seconds):
|
||||
d = defer.Deferred()
|
||||
reactor.callLater(seconds, d.callback, None)
|
||||
return d
|
||||
|
||||
That doesn't follow the rules, but we can fix it by wrapping it with
|
||||
``PreserveLoggingContext`` and ``yield`` ing on it:
|
||||
|
||||
.. code:: python
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def sleep(seconds):
|
||||
with PreserveLoggingContext():
|
||||
yield get_sleep_deferred(seconds)
|
||||
|
||||
This technique works equally for external functions which return deferreds,
|
||||
or deferreds we have made ourselves.
|
||||
|
||||
You can also use ``logcontext.make_deferred_yieldable``, which just does the
|
||||
boilerplate for you, so the above could be written:
|
||||
|
||||
.. code:: python
|
||||
|
||||
def sleep(seconds):
|
||||
return logcontext.make_deferred_yieldable(get_sleep_deferred(seconds))
|
||||
|
||||
|
||||
Fire-and-forget
|
||||
---------------
|
||||
|
||||
Sometimes you want to fire off a chain of execution, but not wait for its
|
||||
result. That might look a bit like this:
|
||||
|
||||
.. code:: python
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def do_request_handling():
|
||||
yield foreground_operation()
|
||||
|
||||
# *don't* do this
|
||||
background_operation()
|
||||
|
||||
logger.debug("Request handling complete")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def background_operation():
|
||||
yield first_background_step()
|
||||
logger.debug("Completed first step")
|
||||
yield second_background_step()
|
||||
logger.debug("Completed second step")
|
||||
|
||||
The above code does a couple of steps in the background after
|
||||
``do_request_handling`` has finished. The log lines are still logged against
|
||||
the ``request_context`` logcontext, which may or may not be desirable. There
|
||||
are two big problems with the above, however. The first problem is that, if
|
||||
``background_operation`` returns an incomplete Deferred, it will expect its
|
||||
caller to ``yield`` immediately, so will have cleared the logcontext. In this
|
||||
example, that means that 'Request handling complete' will be logged without any
|
||||
context.
|
||||
|
||||
The second problem, which is potentially even worse, is that when the Deferred
|
||||
returned by ``background_operation`` completes, it will restore the original
|
||||
logcontext. There is nothing waiting on that Deferred, so the logcontext will
|
||||
leak into the reactor and possibly get attached to some arbitrary future
|
||||
operation.
|
||||
|
||||
There are two potential solutions to this.
|
||||
|
||||
One option is to surround the call to ``background_operation`` with a
|
||||
``PreserveLoggingContext`` call. That will reset the logcontext before
|
||||
starting ``background_operation`` (so the context restored when the deferred
|
||||
completes will be the empty logcontext), and will restore the current
|
||||
logcontext before continuing the foreground process:
|
||||
|
||||
.. code:: python
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def do_request_handling():
|
||||
yield foreground_operation()
|
||||
|
||||
# start background_operation off in the empty logcontext, to
|
||||
# avoid leaking the current context into the reactor.
|
||||
with PreserveLoggingContext():
|
||||
background_operation()
|
||||
|
||||
# this will now be logged against the request context
|
||||
logger.debug("Request handling complete")
|
||||
|
||||
Obviously that option means that the operations done in
|
||||
``background_operation`` would be not be logged against a logcontext (though
|
||||
that might be fixed by setting a different logcontext via a ``with
|
||||
LoggingContext(...)`` in ``background_operation``).
|
||||
|
||||
The second option is to use ``logcontext.preserve_fn``, which wraps a function
|
||||
so that it doesn't reset the logcontext even when it returns an incomplete
|
||||
deferred, and adds a callback to the returned deferred to reset the
|
||||
logcontext. In other words, it turns a function that follows the Synapse rules
|
||||
about logcontexts and Deferreds into one which behaves more like an external
|
||||
function — the opposite operation to that described in the previous section.
|
||||
It can be used like this:
|
||||
|
||||
.. code:: python
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def do_request_handling():
|
||||
yield foreground_operation()
|
||||
|
||||
logcontext.preserve_fn(background_operation)()
|
||||
|
||||
# this will now be logged against the request context
|
||||
logger.debug("Request handling complete")
|
||||
|
||||
XXX: I think ``preserve_context_over_fn`` is supposed to do the first option,
|
||||
but the fact that it does ``preserve_context_over_deferred`` on its results
|
||||
means that its use is fraught with difficulty.
|
||||
|
||||
Passing synapse deferreds into third-party functions
|
||||
----------------------------------------------------
|
||||
|
||||
A typical example of this is where we want to collect together two or more
|
||||
deferred via ``defer.gatherResults``:
|
||||
|
||||
.. code:: python
|
||||
|
||||
d1 = operation1()
|
||||
d2 = operation2()
|
||||
d3 = defer.gatherResults([d1, d2])
|
||||
|
||||
This is really a variation of the fire-and-forget problem above, in that we are
|
||||
firing off ``d1`` and ``d2`` without yielding on them. The difference
|
||||
is that we now have third-party code attached to their callbacks. Anyway either
|
||||
technique given in the `Fire-and-forget`_ section will work.
|
||||
|
||||
Of course, the new Deferred returned by ``gatherResults`` needs to be wrapped
|
||||
in order to make it follow the logcontext rules before we can yield it, as
|
||||
described in `Where you create a new Deferred, make it follow the rules`_.
|
||||
|
||||
So, option one: reset the logcontext before starting the operations to be
|
||||
gathered:
|
||||
|
||||
.. code:: python
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def do_request_handling():
|
||||
with PreserveLoggingContext():
|
||||
d1 = operation1()
|
||||
d2 = operation2()
|
||||
result = yield defer.gatherResults([d1, d2])
|
||||
|
||||
In this case particularly, though, option two, of using
|
||||
``logcontext.preserve_fn`` almost certainly makes more sense, so that
|
||||
``operation1`` and ``operation2`` are both logged against the original
|
||||
logcontext. This looks like:
|
||||
|
||||
.. code:: python
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def do_request_handling():
|
||||
d1 = logcontext.preserve_fn(operation1)()
|
||||
d2 = logcontext.preserve_fn(operation2)()
|
||||
|
||||
with PreserveLoggingContext():
|
||||
result = yield defer.gatherResults([d1, d2])
|
||||
|
||||
|
||||
Was all this really necessary?
|
||||
------------------------------
|
||||
|
||||
The conventions used work fine for a linear flow where everything happens in
|
||||
series via ``defer.inlineCallbacks`` and ``yield``, but are certainly tricky to
|
||||
follow for any more exotic flows. It's hard not to wonder if we could have done
|
||||
something else.
|
||||
|
||||
We're not going to rewrite Synapse now, so the following is entirely of
|
||||
academic interest, but I'd like to record some thoughts on an alternative
|
||||
approach.
|
||||
|
||||
I briefly prototyped some code following an alternative set of rules. I think
|
||||
it would work, but I certainly didn't get as far as thinking how it would
|
||||
interact with concepts as complicated as the cache descriptors.
|
||||
|
||||
My alternative rules were:
|
||||
|
||||
* functions always preserve the logcontext of their caller, whether or not they
|
||||
are returning a Deferred.
|
||||
|
||||
* Deferreds returned by synapse functions run their callbacks in the same
|
||||
context as the function was orignally called in.
|
||||
|
||||
The main point of this scheme is that everywhere that sets the logcontext is
|
||||
responsible for clearing it before returning control to the reactor.
|
||||
|
||||
So, for example, if you were the function which started a ``with
|
||||
LoggingContext`` block, you wouldn't ``yield`` within it — instead you'd start
|
||||
off the background process, and then leave the ``with`` block to wait for it:
|
||||
|
||||
.. code:: python
|
||||
|
||||
def handle_request(request_id):
|
||||
with logcontext.LoggingContext() as request_context:
|
||||
request_context.request = request_id
|
||||
d = do_request_handling()
|
||||
|
||||
def cb(r):
|
||||
logger.debug("finished")
|
||||
|
||||
d.addCallback(cb)
|
||||
return d
|
||||
|
||||
(in general, mixing ``with LoggingContext`` blocks and
|
||||
``defer.inlineCallbacks`` in the same function leads to slighly
|
||||
counter-intuitive code, under this scheme).
|
||||
|
||||
Because we leave the original ``with`` block as soon as the Deferred is
|
||||
returned (as opposed to waiting for it to be resolved, as we do today), the
|
||||
logcontext is cleared before control passes back to the reactor; so if there is
|
||||
some code within ``do_request_handling`` which needs to wait for a Deferred to
|
||||
complete, there is no need for it to worry about clearing the logcontext before
|
||||
doing so:
|
||||
|
||||
.. code:: python
|
||||
|
||||
def handle_request():
|
||||
r = do_some_stuff()
|
||||
r.addCallback(do_some_more_stuff)
|
||||
return r
|
||||
|
||||
— and provided ``do_some_stuff`` follows the rules of returning a Deferred which
|
||||
runs its callbacks in the original logcontext, all is happy.
|
||||
|
||||
The business of a Deferred which runs its callbacks in the original logcontext
|
||||
isn't hard to achieve — we have it today, in the shape of
|
||||
``logcontext._PreservingContextDeferred``:
|
||||
|
||||
.. code:: python
|
||||
|
||||
def do_some_stuff():
|
||||
deferred = do_some_io()
|
||||
pcd = _PreservingContextDeferred(LoggingContext.current_context())
|
||||
deferred.chainDeferred(pcd)
|
||||
return pcd
|
||||
|
||||
It turns out that, thanks to the way that Deferreds chain together, we
|
||||
automatically get the property of a context-preserving deferred with
|
||||
``defer.inlineCallbacks``, provided the final Defered the function ``yields``
|
||||
on has that property. So we can just write:
|
||||
|
||||
.. code:: python
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_request():
|
||||
yield do_some_stuff()
|
||||
yield do_some_more_stuff()
|
||||
|
||||
To conclude: I think this scheme would have worked equally well, with less
|
||||
danger of messing it up, and probably made some more esoteric code easier to
|
||||
write. But again — changing the conventions of the entire Synapse codebase is
|
||||
not a sensible option for the marginal improvement offered.
|
||||
|
|
|
@ -16,4 +16,4 @@
|
|||
""" This is a reference implementation of a Matrix home server.
|
||||
"""
|
||||
|
||||
__version__ = "0.19.3"
|
||||
__version__ = "0.20.0"
|
||||
|
|
|
@ -23,7 +23,7 @@ from synapse import event_auth
|
|||
from synapse.api.constants import EventTypes, Membership, JoinRules
|
||||
from synapse.api.errors import AuthError, Codes
|
||||
from synapse.types import UserID
|
||||
from synapse.util.logcontext import preserve_context_over_fn
|
||||
from synapse.util import logcontext
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -209,8 +209,7 @@ class Auth(object):
|
|||
default=[""]
|
||||
)[0]
|
||||
if user and access_token and ip_addr:
|
||||
preserve_context_over_fn(
|
||||
self.store.insert_client_ip,
|
||||
logcontext.preserve_fn(self.store.insert_client_ip)(
|
||||
user=user,
|
||||
access_token=access_token,
|
||||
ip=ip_addr,
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2017 Vector Creations Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -44,6 +45,7 @@ class JoinRules(object):
|
|||
class LoginType(object):
|
||||
PASSWORD = u"m.login.password"
|
||||
EMAIL_IDENTITY = u"m.login.email.identity"
|
||||
MSISDN = u"m.login.msisdn"
|
||||
RECAPTCHA = u"m.login.recaptcha"
|
||||
DUMMY = u"m.login.dummy"
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
"""Contains exceptions and error codes."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -50,27 +51,35 @@ class Codes(object):
|
|||
|
||||
|
||||
class CodeMessageException(RuntimeError):
|
||||
"""An exception with integer code and message string attributes."""
|
||||
"""An exception with integer code and message string attributes.
|
||||
|
||||
Attributes:
|
||||
code (int): HTTP error code
|
||||
msg (str): string describing the error
|
||||
"""
|
||||
def __init__(self, code, msg):
|
||||
super(CodeMessageException, self).__init__("%d: %s" % (code, msg))
|
||||
self.code = code
|
||||
self.msg = msg
|
||||
self.response_code_message = None
|
||||
|
||||
def error_dict(self):
|
||||
return cs_error(self.msg)
|
||||
|
||||
|
||||
class SynapseError(CodeMessageException):
|
||||
"""A base error which can be caught for all synapse events."""
|
||||
"""A base exception type for matrix errors which have an errcode and error
|
||||
message (as well as an HTTP status code).
|
||||
|
||||
Attributes:
|
||||
errcode (str): Matrix error code e.g 'M_FORBIDDEN'
|
||||
"""
|
||||
def __init__(self, code, msg, errcode=Codes.UNKNOWN):
|
||||
"""Constructs a synapse error.
|
||||
|
||||
Args:
|
||||
code (int): The integer error code (an HTTP response code)
|
||||
msg (str): The human-readable error message.
|
||||
err (str): The error code e.g 'M_FORBIDDEN'
|
||||
errcode (str): The matrix error code e.g 'M_FORBIDDEN'
|
||||
"""
|
||||
super(SynapseError, self).__init__(code, msg)
|
||||
self.errcode = errcode
|
||||
|
@ -81,6 +90,39 @@ class SynapseError(CodeMessageException):
|
|||
self.errcode,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_http_response_exception(cls, err):
|
||||
"""Make a SynapseError based on an HTTPResponseException
|
||||
|
||||
This is useful when a proxied request has failed, and we need to
|
||||
decide how to map the failure onto a matrix error to send back to the
|
||||
client.
|
||||
|
||||
An attempt is made to parse the body of the http response as a matrix
|
||||
error. If that succeeds, the errcode and error message from the body
|
||||
are used as the errcode and error message in the new synapse error.
|
||||
|
||||
Otherwise, the errcode is set to M_UNKNOWN, and the error message is
|
||||
set to the reason code from the HTTP response.
|
||||
|
||||
Args:
|
||||
err (HttpResponseException):
|
||||
|
||||
Returns:
|
||||
SynapseError:
|
||||
"""
|
||||
# try to parse the body as json, to get better errcode/msg, but
|
||||
# default to M_UNKNOWN with the HTTP status as the error text
|
||||
try:
|
||||
j = json.loads(err.response)
|
||||
except ValueError:
|
||||
j = {}
|
||||
errcode = j.get('errcode', Codes.UNKNOWN)
|
||||
errmsg = j.get('error', err.msg)
|
||||
|
||||
res = SynapseError(err.code, errmsg, errcode)
|
||||
return res
|
||||
|
||||
|
||||
class RegistrationError(SynapseError):
|
||||
"""An error raised when a registration event fails."""
|
||||
|
@ -106,13 +148,11 @@ class UnrecognizedRequestError(SynapseError):
|
|||
|
||||
class NotFoundError(SynapseError):
|
||||
"""An error indicating we can't find the thing you asked for"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
if "errcode" not in kwargs:
|
||||
kwargs["errcode"] = Codes.NOT_FOUND
|
||||
def __init__(self, msg="Not found", errcode=Codes.NOT_FOUND):
|
||||
super(NotFoundError, self).__init__(
|
||||
404,
|
||||
"Not found",
|
||||
**kwargs
|
||||
msg,
|
||||
errcode=errcode
|
||||
)
|
||||
|
||||
|
||||
|
@ -173,7 +213,6 @@ class LimitExceededError(SynapseError):
|
|||
errcode=Codes.LIMIT_EXCEEDED):
|
||||
super(LimitExceededError, self).__init__(code, msg, errcode)
|
||||
self.retry_after_ms = retry_after_ms
|
||||
self.response_code_message = "Too Many Requests"
|
||||
|
||||
def error_dict(self):
|
||||
return cs_error(
|
||||
|
@ -243,6 +282,19 @@ class FederationError(RuntimeError):
|
|||
|
||||
|
||||
class HttpResponseException(CodeMessageException):
|
||||
"""
|
||||
Represents an HTTP-level failure of an outbound request
|
||||
|
||||
Attributes:
|
||||
response (str): body of response
|
||||
"""
|
||||
def __init__(self, code, msg, response):
|
||||
self.response = response
|
||||
"""
|
||||
|
||||
Args:
|
||||
code (int): HTTP status code
|
||||
msg (str): reason phrase from HTTP response status line
|
||||
response (str): body of response
|
||||
"""
|
||||
super(HttpResponseException, self).__init__(code, msg)
|
||||
self.response = response
|
||||
|
|
|
@ -13,11 +13,174 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.storage.presence import UserPresenceState
|
||||
from synapse.types import UserID, RoomID
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
import ujson as json
|
||||
import jsonschema
|
||||
from jsonschema import FormatChecker
|
||||
|
||||
FILTER_SCHEMA = {
|
||||
"additionalProperties": False,
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"limit": {
|
||||
"type": "number"
|
||||
},
|
||||
"senders": {
|
||||
"$ref": "#/definitions/user_id_array"
|
||||
},
|
||||
"not_senders": {
|
||||
"$ref": "#/definitions/user_id_array"
|
||||
},
|
||||
# TODO: We don't limit event type values but we probably should...
|
||||
# check types are valid event types
|
||||
"types": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"not_types": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ROOM_FILTER_SCHEMA = {
|
||||
"additionalProperties": False,
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"not_rooms": {
|
||||
"$ref": "#/definitions/room_id_array"
|
||||
},
|
||||
"rooms": {
|
||||
"$ref": "#/definitions/room_id_array"
|
||||
},
|
||||
"ephemeral": {
|
||||
"$ref": "#/definitions/room_event_filter"
|
||||
},
|
||||
"include_leave": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"state": {
|
||||
"$ref": "#/definitions/room_event_filter"
|
||||
},
|
||||
"timeline": {
|
||||
"$ref": "#/definitions/room_event_filter"
|
||||
},
|
||||
"account_data": {
|
||||
"$ref": "#/definitions/room_event_filter"
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
ROOM_EVENT_FILTER_SCHEMA = {
|
||||
"additionalProperties": False,
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"limit": {
|
||||
"type": "number"
|
||||
},
|
||||
"senders": {
|
||||
"$ref": "#/definitions/user_id_array"
|
||||
},
|
||||
"not_senders": {
|
||||
"$ref": "#/definitions/user_id_array"
|
||||
},
|
||||
"types": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"not_types": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"rooms": {
|
||||
"$ref": "#/definitions/room_id_array"
|
||||
},
|
||||
"not_rooms": {
|
||||
"$ref": "#/definitions/room_id_array"
|
||||
},
|
||||
"contains_url": {
|
||||
"type": "boolean"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
USER_ID_ARRAY_SCHEMA = {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"format": "matrix_user_id"
|
||||
}
|
||||
}
|
||||
|
||||
ROOM_ID_ARRAY_SCHEMA = {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"format": "matrix_room_id"
|
||||
}
|
||||
}
|
||||
|
||||
USER_FILTER_SCHEMA = {
|
||||
"$schema": "http://json-schema.org/draft-04/schema#",
|
||||
"description": "schema for a Sync filter",
|
||||
"type": "object",
|
||||
"definitions": {
|
||||
"room_id_array": ROOM_ID_ARRAY_SCHEMA,
|
||||
"user_id_array": USER_ID_ARRAY_SCHEMA,
|
||||
"filter": FILTER_SCHEMA,
|
||||
"room_filter": ROOM_FILTER_SCHEMA,
|
||||
"room_event_filter": ROOM_EVENT_FILTER_SCHEMA
|
||||
},
|
||||
"properties": {
|
||||
"presence": {
|
||||
"$ref": "#/definitions/filter"
|
||||
},
|
||||
"account_data": {
|
||||
"$ref": "#/definitions/filter"
|
||||
},
|
||||
"room": {
|
||||
"$ref": "#/definitions/room_filter"
|
||||
},
|
||||
"event_format": {
|
||||
"type": "string",
|
||||
"enum": ["client", "federation"]
|
||||
},
|
||||
"event_fields": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
# Don't allow '\\' in event field filters. This makes matching
|
||||
# events a lot easier as we can then use a negative lookbehind
|
||||
# assertion to split '\.' If we allowed \\ then it would
|
||||
# incorrectly split '\\.' See synapse.events.utils.serialize_event
|
||||
"pattern": "^((?!\\\).)*$"
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": False
|
||||
}
|
||||
|
||||
|
||||
@FormatChecker.cls_checks('matrix_room_id')
|
||||
def matrix_room_id_validator(room_id_str):
|
||||
return RoomID.from_string(room_id_str)
|
||||
|
||||
|
||||
@FormatChecker.cls_checks('matrix_user_id')
|
||||
def matrix_user_id_validator(user_id_str):
|
||||
return UserID.from_string(user_id_str)
|
||||
|
||||
|
||||
class Filtering(object):
|
||||
|
@ -52,98 +215,11 @@ class Filtering(object):
|
|||
# NB: Filters are the complete json blobs. "Definitions" are an
|
||||
# individual top-level key e.g. public_user_data. Filters are made of
|
||||
# many definitions.
|
||||
|
||||
top_level_definitions = [
|
||||
"presence", "account_data"
|
||||
]
|
||||
|
||||
room_level_definitions = [
|
||||
"state", "timeline", "ephemeral", "account_data"
|
||||
]
|
||||
|
||||
for key in top_level_definitions:
|
||||
if key in user_filter_json:
|
||||
self._check_definition(user_filter_json[key])
|
||||
|
||||
if "room" in user_filter_json:
|
||||
self._check_definition_room_lists(user_filter_json["room"])
|
||||
for key in room_level_definitions:
|
||||
if key in user_filter_json["room"]:
|
||||
self._check_definition(user_filter_json["room"][key])
|
||||
|
||||
if "event_fields" in user_filter_json:
|
||||
if type(user_filter_json["event_fields"]) != list:
|
||||
raise SynapseError(400, "event_fields must be a list of strings")
|
||||
for field in user_filter_json["event_fields"]:
|
||||
if not isinstance(field, basestring):
|
||||
raise SynapseError(400, "Event field must be a string")
|
||||
# Don't allow '\\' in event field filters. This makes matching
|
||||
# events a lot easier as we can then use a negative lookbehind
|
||||
# assertion to split '\.' If we allowed \\ then it would
|
||||
# incorrectly split '\\.' See synapse.events.utils.serialize_event
|
||||
if r'\\' in field:
|
||||
raise SynapseError(
|
||||
400, r'The escape character \ cannot itself be escaped'
|
||||
)
|
||||
|
||||
def _check_definition_room_lists(self, definition):
|
||||
"""Check that "rooms" and "not_rooms" are lists of room ids if they
|
||||
are present
|
||||
|
||||
Args:
|
||||
definition(dict): The filter definition
|
||||
Raises:
|
||||
SynapseError: If there was a problem with this definition.
|
||||
"""
|
||||
# check rooms are valid room IDs
|
||||
room_id_keys = ["rooms", "not_rooms"]
|
||||
for key in room_id_keys:
|
||||
if key in definition:
|
||||
if type(definition[key]) != list:
|
||||
raise SynapseError(400, "Expected %s to be a list." % key)
|
||||
for room_id in definition[key]:
|
||||
RoomID.from_string(room_id)
|
||||
|
||||
def _check_definition(self, definition):
|
||||
"""Check if the provided definition is valid.
|
||||
|
||||
This inspects not only the types but also the values to make sure they
|
||||
make sense.
|
||||
|
||||
Args:
|
||||
definition(dict): The filter definition
|
||||
Raises:
|
||||
SynapseError: If there was a problem with this definition.
|
||||
"""
|
||||
# NB: Filters are the complete json blobs. "Definitions" are an
|
||||
# individual top-level key e.g. public_user_data. Filters are made of
|
||||
# many definitions.
|
||||
if type(definition) != dict:
|
||||
raise SynapseError(
|
||||
400, "Expected JSON object, not %s" % (definition,)
|
||||
)
|
||||
|
||||
self._check_definition_room_lists(definition)
|
||||
|
||||
# check senders are valid user IDs
|
||||
user_id_keys = ["senders", "not_senders"]
|
||||
for key in user_id_keys:
|
||||
if key in definition:
|
||||
if type(definition[key]) != list:
|
||||
raise SynapseError(400, "Expected %s to be a list." % key)
|
||||
for user_id in definition[key]:
|
||||
UserID.from_string(user_id)
|
||||
|
||||
# TODO: We don't limit event type values but we probably should...
|
||||
# check types are valid event types
|
||||
event_keys = ["types", "not_types"]
|
||||
for key in event_keys:
|
||||
if key in definition:
|
||||
if type(definition[key]) != list:
|
||||
raise SynapseError(400, "Expected %s to be a list." % key)
|
||||
for event_type in definition[key]:
|
||||
if not isinstance(event_type, basestring):
|
||||
raise SynapseError(400, "Event type should be a string")
|
||||
try:
|
||||
jsonschema.validate(user_filter_json, USER_FILTER_SCHEMA,
|
||||
format_checker=FormatChecker())
|
||||
except jsonschema.ValidationError as e:
|
||||
raise SynapseError(400, e.message)
|
||||
|
||||
|
||||
class FilterCollection(object):
|
||||
|
@ -253,19 +329,35 @@ class Filter(object):
|
|||
Returns:
|
||||
bool: True if the event matches
|
||||
"""
|
||||
sender = event.get("sender", None)
|
||||
if not sender:
|
||||
# Presence events have their 'sender' in content.user_id
|
||||
content = event.get("content")
|
||||
# account_data has been allowed to have non-dict content, so check type first
|
||||
if isinstance(content, dict):
|
||||
sender = content.get("user_id")
|
||||
# We usually get the full "events" as dictionaries coming through,
|
||||
# except for presence which actually gets passed around as its own
|
||||
# namedtuple type.
|
||||
if isinstance(event, UserPresenceState):
|
||||
sender = event.user_id
|
||||
room_id = None
|
||||
ev_type = "m.presence"
|
||||
is_url = False
|
||||
else:
|
||||
sender = event.get("sender", None)
|
||||
if not sender:
|
||||
# Presence events had their 'sender' in content.user_id, but are
|
||||
# now handled above. We don't know if anything else uses this
|
||||
# form. TODO: Check this and probably remove it.
|
||||
content = event.get("content")
|
||||
# account_data has been allowed to have non-dict content, so
|
||||
# check type first
|
||||
if isinstance(content, dict):
|
||||
sender = content.get("user_id")
|
||||
|
||||
room_id = event.get("room_id", None)
|
||||
ev_type = event.get("type", None)
|
||||
is_url = "url" in event.get("content", {})
|
||||
|
||||
return self.check_fields(
|
||||
event.get("room_id", None),
|
||||
room_id,
|
||||
sender,
|
||||
event.get("type", None),
|
||||
"url" in event.get("content", {})
|
||||
ev_type,
|
||||
is_url,
|
||||
)
|
||||
|
||||
def check_fields(self, room_id, sender, event_type, contains_url):
|
||||
|
|
|
@ -29,7 +29,7 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
|
|||
from synapse.storage.engines import create_engine
|
||||
from synapse.util.async import sleep
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
||||
from synapse.util.manhole import manhole
|
||||
from synapse.util.rlimit import change_resource_limit
|
||||
from synapse.util.versionstring import get_version_string
|
||||
|
@ -157,7 +157,7 @@ def start(config_options):
|
|||
|
||||
assert config.worker_app == "synapse.app.appservice"
|
||||
|
||||
setup_logging(config.worker_log_config, config.worker_log_file)
|
||||
setup_logging(config, use_worker_options=True)
|
||||
|
||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
|
||||
|
@ -187,7 +187,11 @@ def start(config_options):
|
|||
ps.start_listening(config.worker_listeners)
|
||||
|
||||
def run():
|
||||
with LoggingContext("run"):
|
||||
# make sure that we run the reactor with the sentinel log context,
|
||||
# otherwise other PreserveLoggingContext instances will get confused
|
||||
# and complain when they see the logcontext arbitrarily swapping
|
||||
# between the sentinel and `run` logcontexts.
|
||||
with PreserveLoggingContext():
|
||||
logger.info("Running")
|
||||
change_resource_limit(config.soft_file_limit)
|
||||
if config.gc_thresholds:
|
||||
|
|
|
@ -29,13 +29,14 @@ from synapse.replication.slave.storage.keys import SlavedKeyStore
|
|||
from synapse.replication.slave.storage.room import RoomStore
|
||||
from synapse.replication.slave.storage.directory import DirectoryStore
|
||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||
from synapse.replication.slave.storage.transactions import TransactionStore
|
||||
from synapse.rest.client.v1.room import PublicRoomListRestServlet
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.client_ips import ClientIpStore
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.util.async import sleep
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
||||
from synapse.util.manhole import manhole
|
||||
from synapse.util.rlimit import change_resource_limit
|
||||
from synapse.util.versionstring import get_version_string
|
||||
|
@ -63,6 +64,7 @@ class ClientReaderSlavedStore(
|
|||
DirectoryStore,
|
||||
SlavedApplicationServiceStore,
|
||||
SlavedRegistrationStore,
|
||||
TransactionStore,
|
||||
BaseSlavedStore,
|
||||
ClientIpStore, # After BaseSlavedStore because the constructor is different
|
||||
):
|
||||
|
@ -171,7 +173,7 @@ def start(config_options):
|
|||
|
||||
assert config.worker_app == "synapse.app.client_reader"
|
||||
|
||||
setup_logging(config.worker_log_config, config.worker_log_file)
|
||||
setup_logging(config, use_worker_options=True)
|
||||
|
||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
|
||||
|
@ -193,7 +195,11 @@ def start(config_options):
|
|||
ss.start_listening(config.worker_listeners)
|
||||
|
||||
def run():
|
||||
with LoggingContext("run"):
|
||||
# make sure that we run the reactor with the sentinel log context,
|
||||
# otherwise other PreserveLoggingContext instances will get confused
|
||||
# and complain when they see the logcontext arbitrarily swapping
|
||||
# between the sentinel and `run` logcontexts.
|
||||
with PreserveLoggingContext():
|
||||
logger.info("Running")
|
||||
change_resource_limit(config.soft_file_limit)
|
||||
if config.gc_thresholds:
|
||||
|
|
|
@ -31,7 +31,7 @@ from synapse.server import HomeServer
|
|||
from synapse.storage.engines import create_engine
|
||||
from synapse.util.async import sleep
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
||||
from synapse.util.manhole import manhole
|
||||
from synapse.util.rlimit import change_resource_limit
|
||||
from synapse.util.versionstring import get_version_string
|
||||
|
@ -162,7 +162,7 @@ def start(config_options):
|
|||
|
||||
assert config.worker_app == "synapse.app.federation_reader"
|
||||
|
||||
setup_logging(config.worker_log_config, config.worker_log_file)
|
||||
setup_logging(config, use_worker_options=True)
|
||||
|
||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
|
||||
|
@ -184,7 +184,11 @@ def start(config_options):
|
|||
ss.start_listening(config.worker_listeners)
|
||||
|
||||
def run():
|
||||
with LoggingContext("run"):
|
||||
# make sure that we run the reactor with the sentinel log context,
|
||||
# otherwise other PreserveLoggingContext instances will get confused
|
||||
# and complain when they see the logcontext arbitrarily swapping
|
||||
# between the sentinel and `run` logcontexts.
|
||||
with PreserveLoggingContext():
|
||||
logger.info("Running")
|
||||
change_resource_limit(config.soft_file_limit)
|
||||
if config.gc_thresholds:
|
||||
|
|
|
@ -35,7 +35,7 @@ from synapse.storage.engines import create_engine
|
|||
from synapse.storage.presence import UserPresenceState
|
||||
from synapse.util.async import sleep
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
||||
from synapse.util.manhole import manhole
|
||||
from synapse.util.rlimit import change_resource_limit
|
||||
from synapse.util.versionstring import get_version_string
|
||||
|
@ -160,7 +160,7 @@ def start(config_options):
|
|||
|
||||
assert config.worker_app == "synapse.app.federation_sender"
|
||||
|
||||
setup_logging(config.worker_log_config, config.worker_log_file)
|
||||
setup_logging(config, use_worker_options=True)
|
||||
|
||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
|
||||
|
@ -193,7 +193,11 @@ def start(config_options):
|
|||
ps.start_listening(config.worker_listeners)
|
||||
|
||||
def run():
|
||||
with LoggingContext("run"):
|
||||
# make sure that we run the reactor with the sentinel log context,
|
||||
# otherwise other PreserveLoggingContext instances will get confused
|
||||
# and complain when they see the logcontext arbitrarily swapping
|
||||
# between the sentinel and `run` logcontexts.
|
||||
with PreserveLoggingContext():
|
||||
logger.info("Running")
|
||||
change_resource_limit(config.soft_file_limit)
|
||||
if config.gc_thresholds:
|
||||
|
|
|
@ -20,6 +20,8 @@ import gc
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import synapse.config.logger
|
||||
from synapse.config._base import ConfigError
|
||||
|
||||
from synapse.python_dependencies import (
|
||||
|
@ -50,7 +52,7 @@ from synapse.api.urls import (
|
|||
)
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.crypto import context_factory
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
||||
from synapse.metrics import register_memory_metrics, get_metrics_for
|
||||
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
||||
from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX
|
||||
|
@ -286,7 +288,7 @@ def setup(config_options):
|
|||
# generating config files and shouldn't try to continue.
|
||||
sys.exit(0)
|
||||
|
||||
config.setup_logging()
|
||||
synapse.config.logger.setup_logging(config, use_worker_options=False)
|
||||
|
||||
# check any extra requirements we have now we have a config
|
||||
check_requirements(config)
|
||||
|
@ -454,7 +456,12 @@ def run(hs):
|
|||
def in_thread():
|
||||
# Uncomment to enable tracing of log context changes.
|
||||
# sys.settrace(logcontext_tracer)
|
||||
with LoggingContext("run"):
|
||||
|
||||
# make sure that we run the reactor with the sentinel log context,
|
||||
# otherwise other PreserveLoggingContext instances will get confused
|
||||
# and complain when they see the logcontext arbitrarily swapping
|
||||
# between the sentinel and `run` logcontexts.
|
||||
with PreserveLoggingContext():
|
||||
change_resource_limit(hs.config.soft_file_limit)
|
||||
if hs.config.gc_thresholds:
|
||||
gc.set_threshold(*hs.config.gc_thresholds)
|
||||
|
|
|
@ -24,6 +24,7 @@ from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
|||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
|
||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||
from synapse.replication.slave.storage.transactions import TransactionStore
|
||||
from synapse.rest.media.v0.content_repository import ContentRepoResource
|
||||
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
|
||||
from synapse.server import HomeServer
|
||||
|
@ -32,7 +33,7 @@ from synapse.storage.engines import create_engine
|
|||
from synapse.storage.media_repository import MediaRepositoryStore
|
||||
from synapse.util.async import sleep
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
||||
from synapse.util.manhole import manhole
|
||||
from synapse.util.rlimit import change_resource_limit
|
||||
from synapse.util.versionstring import get_version_string
|
||||
|
@ -59,6 +60,7 @@ logger = logging.getLogger("synapse.app.media_repository")
|
|||
class MediaRepositorySlavedStore(
|
||||
SlavedApplicationServiceStore,
|
||||
SlavedRegistrationStore,
|
||||
TransactionStore,
|
||||
BaseSlavedStore,
|
||||
MediaRepositoryStore,
|
||||
ClientIpStore,
|
||||
|
@ -168,7 +170,7 @@ def start(config_options):
|
|||
|
||||
assert config.worker_app == "synapse.app.media_repository"
|
||||
|
||||
setup_logging(config.worker_log_config, config.worker_log_file)
|
||||
setup_logging(config, use_worker_options=True)
|
||||
|
||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
|
||||
|
@ -190,7 +192,11 @@ def start(config_options):
|
|||
ss.start_listening(config.worker_listeners)
|
||||
|
||||
def run():
|
||||
with LoggingContext("run"):
|
||||
# make sure that we run the reactor with the sentinel log context,
|
||||
# otherwise other PreserveLoggingContext instances will get confused
|
||||
# and complain when they see the logcontext arbitrarily swapping
|
||||
# between the sentinel and `run` logcontexts.
|
||||
with PreserveLoggingContext():
|
||||
logger.info("Running")
|
||||
change_resource_limit(config.soft_file_limit)
|
||||
if config.gc_thresholds:
|
||||
|
|
|
@ -31,7 +31,8 @@ from synapse.storage.engines import create_engine
|
|||
from synapse.storage import DataStore
|
||||
from synapse.util.async import sleep
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.logcontext import LoggingContext, preserve_fn
|
||||
from synapse.util.logcontext import LoggingContext, preserve_fn, \
|
||||
PreserveLoggingContext
|
||||
from synapse.util.manhole import manhole
|
||||
from synapse.util.rlimit import change_resource_limit
|
||||
from synapse.util.versionstring import get_version_string
|
||||
|
@ -245,7 +246,7 @@ def start(config_options):
|
|||
|
||||
assert config.worker_app == "synapse.app.pusher"
|
||||
|
||||
setup_logging(config.worker_log_config, config.worker_log_file)
|
||||
setup_logging(config, use_worker_options=True)
|
||||
|
||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
|
||||
|
@ -275,7 +276,11 @@ def start(config_options):
|
|||
ps.start_listening(config.worker_listeners)
|
||||
|
||||
def run():
|
||||
with LoggingContext("run"):
|
||||
# make sure that we run the reactor with the sentinel log context,
|
||||
# otherwise other PreserveLoggingContext instances will get confused
|
||||
# and complain when they see the logcontext arbitrarily swapping
|
||||
# between the sentinel and `run` logcontexts.
|
||||
with PreserveLoggingContext():
|
||||
logger.info("Running")
|
||||
change_resource_limit(config.soft_file_limit)
|
||||
if config.gc_thresholds:
|
||||
|
|
|
@ -20,7 +20,6 @@ from synapse.api.constants import EventTypes, PresenceState
|
|||
from synapse.config._base import ConfigError
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.config.logger import setup_logging
|
||||
from synapse.events import FrozenEvent
|
||||
from synapse.handlers.presence import PresenceHandler
|
||||
from synapse.http.site import SynapseSite
|
||||
from synapse.http.server import JsonResource
|
||||
|
@ -48,7 +47,8 @@ from synapse.storage.presence import PresenceStore, UserPresenceState
|
|||
from synapse.storage.roommember import RoomMemberStore
|
||||
from synapse.util.async import sleep
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.logcontext import LoggingContext, preserve_fn
|
||||
from synapse.util.logcontext import LoggingContext, preserve_fn, \
|
||||
PreserveLoggingContext
|
||||
from synapse.util.manhole import manhole
|
||||
from synapse.util.rlimit import change_resource_limit
|
||||
from synapse.util.stringutils import random_string
|
||||
|
@ -399,8 +399,7 @@ class SynchrotronServer(HomeServer):
|
|||
position = row[position_index]
|
||||
user_id = row[user_index]
|
||||
|
||||
rooms = yield store.get_rooms_for_user(user_id)
|
||||
room_ids = [r.room_id for r in rooms]
|
||||
room_ids = yield store.get_rooms_for_user(user_id)
|
||||
|
||||
notifier.on_new_event(
|
||||
"device_list_key", position, rooms=room_ids,
|
||||
|
@ -411,11 +410,16 @@ class SynchrotronServer(HomeServer):
|
|||
stream = result.get("events")
|
||||
if stream:
|
||||
max_position = stream["position"]
|
||||
|
||||
event_map = yield store.get_events([row[1] for row in stream["rows"]])
|
||||
|
||||
for row in stream["rows"]:
|
||||
position = row[0]
|
||||
internal = json.loads(row[1])
|
||||
event_json = json.loads(row[2])
|
||||
event = FrozenEvent(event_json, internal_metadata_dict=internal)
|
||||
event_id = row[1]
|
||||
event = event_map.get(event_id, None)
|
||||
if not event:
|
||||
continue
|
||||
|
||||
extra_users = ()
|
||||
if event.type == EventTypes.Member:
|
||||
extra_users = (event.state_key,)
|
||||
|
@ -478,7 +482,7 @@ def start(config_options):
|
|||
|
||||
assert config.worker_app == "synapse.app.synchrotron"
|
||||
|
||||
setup_logging(config.worker_log_config, config.worker_log_file)
|
||||
setup_logging(config, use_worker_options=True)
|
||||
|
||||
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
|
||||
|
@ -497,7 +501,11 @@ def start(config_options):
|
|||
ss.start_listening(config.worker_listeners)
|
||||
|
||||
def run():
|
||||
with LoggingContext("run"):
|
||||
# make sure that we run the reactor with the sentinel log context,
|
||||
# otherwise other PreserveLoggingContext instances will get confused
|
||||
# and complain when they see the logcontext arbitrarily swapping
|
||||
# between the sentinel and `run` logcontexts.
|
||||
with PreserveLoggingContext():
|
||||
logger.info("Running")
|
||||
change_resource_limit(config.soft_file_limit)
|
||||
if config.gc_thresholds:
|
||||
|
|
|
@ -23,14 +23,27 @@ import signal
|
|||
import subprocess
|
||||
import sys
|
||||
import yaml
|
||||
import errno
|
||||
import time
|
||||
|
||||
SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"]
|
||||
|
||||
GREEN = "\x1b[1;32m"
|
||||
YELLOW = "\x1b[1;33m"
|
||||
RED = "\x1b[1;31m"
|
||||
NORMAL = "\x1b[m"
|
||||
|
||||
|
||||
def pid_running(pid):
|
||||
try:
|
||||
os.kill(pid, 0)
|
||||
return True
|
||||
except OSError, err:
|
||||
if err.errno == errno.EPERM:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def write(message, colour=NORMAL, stream=sys.stdout):
|
||||
if colour == NORMAL:
|
||||
stream.write(message + "\n")
|
||||
|
@ -38,6 +51,11 @@ def write(message, colour=NORMAL, stream=sys.stdout):
|
|||
stream.write(colour + message + NORMAL + "\n")
|
||||
|
||||
|
||||
def abort(message, colour=RED, stream=sys.stderr):
|
||||
write(message, colour, stream)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def start(configfile):
|
||||
write("Starting ...")
|
||||
args = SYNAPSE
|
||||
|
@ -45,7 +63,8 @@ def start(configfile):
|
|||
|
||||
try:
|
||||
subprocess.check_call(args)
|
||||
write("started synapse.app.homeserver(%r)" % (configfile,), colour=GREEN)
|
||||
write("started synapse.app.homeserver(%r)" %
|
||||
(configfile,), colour=GREEN)
|
||||
except subprocess.CalledProcessError as e:
|
||||
write(
|
||||
"error starting (exit code: %d); see above for logs" % e.returncode,
|
||||
|
@ -76,8 +95,16 @@ def start_worker(app, configfile, worker_configfile):
|
|||
def stop(pidfile, app):
|
||||
if os.path.exists(pidfile):
|
||||
pid = int(open(pidfile).read())
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
write("stopped %s" % (app,), colour=GREEN)
|
||||
try:
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
write("stopped %s" % (app,), colour=GREEN)
|
||||
except OSError, err:
|
||||
if err.errno == errno.ESRCH:
|
||||
write("%s not running" % (app,), colour=YELLOW)
|
||||
elif err.errno == errno.EPERM:
|
||||
abort("Cannot stop %s: Operation not permitted" % (app,))
|
||||
else:
|
||||
abort("Cannot stop %s: Unknown error" % (app,))
|
||||
|
||||
|
||||
Worker = collections.namedtuple("Worker", [
|
||||
|
@ -190,7 +217,19 @@ def main():
|
|||
if start_stop_synapse:
|
||||
stop(pidfile, "synapse.app.homeserver")
|
||||
|
||||
# TODO: Wait for synapse to actually shutdown before starting it again
|
||||
# Wait for synapse to actually shutdown before starting it again
|
||||
if action == "restart":
|
||||
running_pids = []
|
||||
if start_stop_synapse and os.path.exists(pidfile):
|
||||
running_pids.append(int(open(pidfile).read()))
|
||||
for worker in workers:
|
||||
if os.path.exists(worker.pidfile):
|
||||
running_pids.append(int(open(worker.pidfile).read()))
|
||||
if len(running_pids) > 0:
|
||||
write("Waiting for process to exit before restarting...")
|
||||
for running_pid in running_pids:
|
||||
while pid_running(running_pid):
|
||||
time.sleep(0.2)
|
||||
|
||||
if action == "start" or action == "restart":
|
||||
if start_stop_synapse:
|
||||
|
|
|
@ -45,7 +45,6 @@ handlers:
|
|||
maxBytes: 104857600
|
||||
backupCount: 10
|
||||
filters: [context]
|
||||
level: INFO
|
||||
console:
|
||||
class: logging.StreamHandler
|
||||
formatter: precise
|
||||
|
@ -56,6 +55,8 @@ loggers:
|
|||
level: INFO
|
||||
|
||||
synapse.storage.SQL:
|
||||
# beware: increasing this to DEBUG will make synapse log sensitive
|
||||
# information such as access tokens.
|
||||
level: INFO
|
||||
|
||||
root:
|
||||
|
@ -68,6 +69,7 @@ class LoggingConfig(Config):
|
|||
|
||||
def read_config(self, config):
|
||||
self.verbosity = config.get("verbose", 0)
|
||||
self.no_redirect_stdio = config.get("no_redirect_stdio", False)
|
||||
self.log_config = self.abspath(config.get("log_config"))
|
||||
self.log_file = self.abspath(config.get("log_file"))
|
||||
|
||||
|
@ -77,10 +79,10 @@ class LoggingConfig(Config):
|
|||
os.path.join(config_dir_path, server_name + ".log.config")
|
||||
)
|
||||
return """
|
||||
# Logging verbosity level.
|
||||
# Logging verbosity level. Ignored if log_config is specified.
|
||||
verbose: 0
|
||||
|
||||
# File to write logging to
|
||||
# File to write logging to. Ignored if log_config is specified.
|
||||
log_file: "%(log_file)s"
|
||||
|
||||
# A yaml python logging config file
|
||||
|
@ -90,6 +92,8 @@ class LoggingConfig(Config):
|
|||
def read_arguments(self, args):
|
||||
if args.verbose is not None:
|
||||
self.verbosity = args.verbose
|
||||
if args.no_redirect_stdio is not None:
|
||||
self.no_redirect_stdio = args.no_redirect_stdio
|
||||
if args.log_config is not None:
|
||||
self.log_config = args.log_config
|
||||
if args.log_file is not None:
|
||||
|
@ -99,16 +103,22 @@ class LoggingConfig(Config):
|
|||
logging_group = parser.add_argument_group("logging")
|
||||
logging_group.add_argument(
|
||||
'-v', '--verbose', dest="verbose", action='count',
|
||||
help="The verbosity level."
|
||||
help="The verbosity level. Specify multiple times to increase "
|
||||
"verbosity. (Ignored if --log-config is specified.)"
|
||||
)
|
||||
logging_group.add_argument(
|
||||
'-f', '--log-file', dest="log_file",
|
||||
help="File to log to."
|
||||
help="File to log to. (Ignored if --log-config is specified.)"
|
||||
)
|
||||
logging_group.add_argument(
|
||||
'--log-config', dest="log_config", default=None,
|
||||
help="Python logging config file"
|
||||
)
|
||||
logging_group.add_argument(
|
||||
'-n', '--no-redirect-stdio',
|
||||
action='store_true', default=None,
|
||||
help="Do not redirect stdout/stderr to the log"
|
||||
)
|
||||
|
||||
def generate_files(self, config):
|
||||
log_config = config.get("log_config")
|
||||
|
@ -118,11 +128,22 @@ class LoggingConfig(Config):
|
|||
DEFAULT_LOG_CONFIG.substitute(log_file=config["log_file"])
|
||||
)
|
||||
|
||||
def setup_logging(self):
|
||||
setup_logging(self.log_config, self.log_file, self.verbosity)
|
||||
|
||||
def setup_logging(config, use_worker_options=False):
|
||||
""" Set up python logging
|
||||
|
||||
Args:
|
||||
config (LoggingConfig | synapse.config.workers.WorkerConfig):
|
||||
configuration data
|
||||
|
||||
use_worker_options (bool): True to use 'worker_log_config' and
|
||||
'worker_log_file' options instead of 'log_config' and 'log_file'.
|
||||
"""
|
||||
log_config = (config.worker_log_config if use_worker_options
|
||||
else config.log_config)
|
||||
log_file = (config.worker_log_file if use_worker_options
|
||||
else config.log_file)
|
||||
|
||||
def setup_logging(log_config=None, log_file=None, verbosity=None):
|
||||
log_format = (
|
||||
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
|
||||
" - %(message)s"
|
||||
|
@ -131,9 +152,9 @@ def setup_logging(log_config=None, log_file=None, verbosity=None):
|
|||
|
||||
level = logging.INFO
|
||||
level_for_storage = logging.INFO
|
||||
if verbosity:
|
||||
if config.verbosity:
|
||||
level = logging.DEBUG
|
||||
if verbosity > 1:
|
||||
if config.verbosity > 1:
|
||||
level_for_storage = logging.DEBUG
|
||||
|
||||
# FIXME: we need a logging.WARN for a -q quiet option
|
||||
|
@ -153,14 +174,6 @@ def setup_logging(log_config=None, log_file=None, verbosity=None):
|
|||
logger.info("Closing log file due to SIGHUP")
|
||||
handler.doRollover()
|
||||
logger.info("Opened new log file due to SIGHUP")
|
||||
|
||||
# TODO(paul): obviously this is a terrible mechanism for
|
||||
# stealing SIGHUP, because it means no other part of synapse
|
||||
# can use it instead. If we want to catch SIGHUP anywhere
|
||||
# else as well, I'd suggest we find a nicer way to broadcast
|
||||
# it around.
|
||||
if getattr(signal, "SIGHUP"):
|
||||
signal.signal(signal.SIGHUP, sighup)
|
||||
else:
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(formatter)
|
||||
|
@ -169,8 +182,25 @@ def setup_logging(log_config=None, log_file=None, verbosity=None):
|
|||
|
||||
logger.addHandler(handler)
|
||||
else:
|
||||
with open(log_config, 'r') as f:
|
||||
logging.config.dictConfig(yaml.load(f))
|
||||
def load_log_config():
|
||||
with open(log_config, 'r') as f:
|
||||
logging.config.dictConfig(yaml.load(f))
|
||||
|
||||
def sighup(signum, stack):
|
||||
# it might be better to use a file watcher or something for this.
|
||||
logging.info("Reloading log config from %s due to SIGHUP",
|
||||
log_config)
|
||||
load_log_config()
|
||||
|
||||
load_log_config()
|
||||
|
||||
# TODO(paul): obviously this is a terrible mechanism for
|
||||
# stealing SIGHUP, because it means no other part of synapse
|
||||
# can use it instead. If we want to catch SIGHUP anywhere
|
||||
# else as well, I'd suggest we find a nicer way to broadcast
|
||||
# it around.
|
||||
if getattr(signal, "SIGHUP"):
|
||||
signal.signal(signal.SIGHUP, sighup)
|
||||
|
||||
# It's critical to point twisted's internal logging somewhere, otherwise it
|
||||
# stacks up and leaks kup to 64K object;
|
||||
|
@ -183,4 +213,7 @@ def setup_logging(log_config=None, log_file=None, verbosity=None):
|
|||
#
|
||||
# However this may not be too much of a problem if we are just writing to a file.
|
||||
observer = STDLibLogObserver()
|
||||
globalLogBeginner.beginLoggingTo([observer])
|
||||
globalLogBeginner.beginLoggingTo(
|
||||
[observer],
|
||||
redirectStandardIO=not config.no_redirect_stdio,
|
||||
)
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
|
||||
from synapse.crypto.keyclient import fetch_server_key
|
||||
from synapse.api.errors import SynapseError, Codes
|
||||
from synapse.util.retryutils import get_retry_limiter
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.async import ObservableDeferred
|
||||
from synapse.util.logcontext import (
|
||||
|
@ -96,10 +95,11 @@ class Keyring(object):
|
|||
verify_requests = []
|
||||
|
||||
for server_name, json_object in server_and_json:
|
||||
logger.debug("Verifying for %s", server_name)
|
||||
|
||||
key_ids = signature_ids(json_object, server_name)
|
||||
if not key_ids:
|
||||
logger.warn("Request from %s: no supported signature keys",
|
||||
server_name)
|
||||
deferred = defer.fail(SynapseError(
|
||||
400,
|
||||
"Not signed with a supported algorithm",
|
||||
|
@ -108,6 +108,9 @@ class Keyring(object):
|
|||
else:
|
||||
deferred = defer.Deferred()
|
||||
|
||||
logger.debug("Verifying for %s with key_ids %s",
|
||||
server_name, key_ids)
|
||||
|
||||
verify_request = VerifyKeyRequest(
|
||||
server_name, key_ids, json_object, deferred
|
||||
)
|
||||
|
@ -142,6 +145,9 @@ class Keyring(object):
|
|||
|
||||
json_object = verify_request.json_object
|
||||
|
||||
logger.debug("Got key %s %s:%s for server %s, verifying" % (
|
||||
key_id, verify_key.alg, verify_key.version, server_name,
|
||||
))
|
||||
try:
|
||||
verify_signed_json(json_object, server_name, verify_key)
|
||||
except:
|
||||
|
@ -231,8 +237,14 @@ class Keyring(object):
|
|||
d.addBoth(rm, server_name)
|
||||
|
||||
def get_server_verify_keys(self, verify_requests):
|
||||
"""Takes a dict of KeyGroups and tries to find at least one key for
|
||||
each group.
|
||||
"""Tries to find at least one key for each verify request
|
||||
|
||||
For each verify_request, verify_request.deferred is called back with
|
||||
params (server_name, key_id, VerifyKey) if a key is found, or errbacked
|
||||
with a SynapseError if none of the keys are found.
|
||||
|
||||
Args:
|
||||
verify_requests (list[VerifyKeyRequest]): list of verify requests
|
||||
"""
|
||||
|
||||
# These are functions that produce keys given a list of key ids
|
||||
|
@ -245,8 +257,11 @@ class Keyring(object):
|
|||
@defer.inlineCallbacks
|
||||
def do_iterations():
|
||||
with Measure(self.clock, "get_server_verify_keys"):
|
||||
# dict[str, dict[str, VerifyKey]]: results so far.
|
||||
# map server_name -> key_id -> VerifyKey
|
||||
merged_results = {}
|
||||
|
||||
# dict[str, set(str)]: keys to fetch for each server
|
||||
missing_keys = {}
|
||||
for verify_request in verify_requests:
|
||||
missing_keys.setdefault(verify_request.server_name, set()).update(
|
||||
|
@ -308,6 +323,16 @@ class Keyring(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def get_keys_from_store(self, server_name_and_key_ids):
|
||||
"""
|
||||
|
||||
Args:
|
||||
server_name_and_key_ids (list[(str, iterable[str])]):
|
||||
list of (server_name, iterable[key_id]) tuples to fetch keys for
|
||||
|
||||
Returns:
|
||||
Deferred: resolves to dict[str, dict[str, VerifyKey]]: map from
|
||||
server_name -> key_id -> VerifyKey
|
||||
"""
|
||||
res = yield preserve_context_over_deferred(defer.gatherResults(
|
||||
[
|
||||
preserve_fn(self.store.get_server_verify_keys)(
|
||||
|
@ -356,30 +381,24 @@ class Keyring(object):
|
|||
def get_keys_from_server(self, server_name_and_key_ids):
|
||||
@defer.inlineCallbacks
|
||||
def get_key(server_name, key_ids):
|
||||
limiter = yield get_retry_limiter(
|
||||
server_name,
|
||||
self.clock,
|
||||
self.store,
|
||||
)
|
||||
with limiter:
|
||||
keys = None
|
||||
try:
|
||||
keys = yield self.get_server_verify_key_v2_direct(
|
||||
server_name, key_ids
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"Unable to get key %r for %r directly: %s %s",
|
||||
key_ids, server_name,
|
||||
type(e).__name__, str(e.message),
|
||||
)
|
||||
keys = None
|
||||
try:
|
||||
keys = yield self.get_server_verify_key_v2_direct(
|
||||
server_name, key_ids
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"Unable to get key %r for %r directly: %s %s",
|
||||
key_ids, server_name,
|
||||
type(e).__name__, str(e.message),
|
||||
)
|
||||
|
||||
if not keys:
|
||||
keys = yield self.get_server_verify_key_v1_direct(
|
||||
server_name, key_ids
|
||||
)
|
||||
if not keys:
|
||||
keys = yield self.get_server_verify_key_v1_direct(
|
||||
server_name, key_ids
|
||||
)
|
||||
|
||||
keys = {server_name: keys}
|
||||
keys = {server_name: keys}
|
||||
|
||||
defer.returnValue(keys)
|
||||
|
||||
|
|
|
@ -15,6 +15,32 @@
|
|||
|
||||
|
||||
class EventContext(object):
|
||||
"""
|
||||
Attributes:
|
||||
current_state_ids (dict[(str, str), str]):
|
||||
The current state map including the current event.
|
||||
(type, state_key) -> event_id
|
||||
|
||||
prev_state_ids (dict[(str, str), str]):
|
||||
The current state map excluding the current event.
|
||||
(type, state_key) -> event_id
|
||||
|
||||
state_group (int): state group id
|
||||
rejected (bool|str): A rejection reason if the event was rejected, else
|
||||
False
|
||||
|
||||
push_actions (list[(str, list[object])]): list of (user_id, actions)
|
||||
tuples
|
||||
|
||||
prev_group (int): Previously persisted state group. ``None`` for an
|
||||
outlier.
|
||||
delta_ids (dict[(str, str), str]): Delta from ``prev_group``.
|
||||
(type, state_key) -> event_id. ``None`` for an outlier.
|
||||
|
||||
prev_state_events (?): XXX: is this ever set to anything other than
|
||||
the empty list?
|
||||
"""
|
||||
|
||||
__slots__ = [
|
||||
"current_state_ids",
|
||||
"prev_state_ids",
|
||||
|
|
|
@ -29,7 +29,7 @@ from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
|||
from synapse.events import FrozenEvent, builder
|
||||
import synapse.metrics
|
||||
|
||||
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
|
@ -88,7 +88,7 @@ class FederationClient(FederationBase):
|
|||
|
||||
@log_function
|
||||
def make_query(self, destination, query_type, args,
|
||||
retry_on_dns_fail=False):
|
||||
retry_on_dns_fail=False, ignore_backoff=False):
|
||||
"""Sends a federation Query to a remote homeserver of the given type
|
||||
and arguments.
|
||||
|
||||
|
@ -98,6 +98,8 @@ class FederationClient(FederationBase):
|
|||
handler name used in register_query_handler().
|
||||
args (dict): Mapping of strings to strings containing the details
|
||||
of the query request.
|
||||
ignore_backoff (bool): true to ignore the historical backoff data
|
||||
and try the request anyway.
|
||||
|
||||
Returns:
|
||||
a Deferred which will eventually yield a JSON object from the
|
||||
|
@ -106,7 +108,8 @@ class FederationClient(FederationBase):
|
|||
sent_queries_counter.inc(query_type)
|
||||
|
||||
return self.transport_layer.make_query(
|
||||
destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
|
||||
destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail,
|
||||
ignore_backoff=ignore_backoff,
|
||||
)
|
||||
|
||||
@log_function
|
||||
|
@ -234,31 +237,24 @@ class FederationClient(FederationBase):
|
|||
continue
|
||||
|
||||
try:
|
||||
limiter = yield get_retry_limiter(
|
||||
destination,
|
||||
self._clock,
|
||||
self.store,
|
||||
transaction_data = yield self.transport_layer.get_event(
|
||||
destination, event_id, timeout=timeout,
|
||||
)
|
||||
|
||||
with limiter:
|
||||
transaction_data = yield self.transport_layer.get_event(
|
||||
destination, event_id, timeout=timeout,
|
||||
)
|
||||
logger.debug("transaction_data %r", transaction_data)
|
||||
|
||||
logger.debug("transaction_data %r", transaction_data)
|
||||
pdu_list = [
|
||||
self.event_from_pdu_json(p, outlier=outlier)
|
||||
for p in transaction_data["pdus"]
|
||||
]
|
||||
|
||||
pdu_list = [
|
||||
self.event_from_pdu_json(p, outlier=outlier)
|
||||
for p in transaction_data["pdus"]
|
||||
]
|
||||
if pdu_list and pdu_list[0]:
|
||||
pdu = pdu_list[0]
|
||||
|
||||
if pdu_list and pdu_list[0]:
|
||||
pdu = pdu_list[0]
|
||||
# Check signatures are correct.
|
||||
signed_pdu = yield self._check_sigs_and_hashes([pdu])[0]
|
||||
|
||||
# Check signatures are correct.
|
||||
signed_pdu = yield self._check_sigs_and_hashes([pdu])[0]
|
||||
|
||||
break
|
||||
break
|
||||
|
||||
pdu_attempts[destination] = now
|
||||
|
||||
|
|
|
@ -52,7 +52,6 @@ class FederationServer(FederationBase):
|
|||
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
self._room_pdu_linearizer = Linearizer("fed_room_pdu")
|
||||
self._server_linearizer = Linearizer("fed_server")
|
||||
|
||||
# We cache responses to state queries, as they take a while and often
|
||||
|
@ -147,11 +146,15 @@ class FederationServer(FederationBase):
|
|||
# check that it's actually being sent from a valid destination to
|
||||
# workaround bug #1753 in 0.18.5 and 0.18.6
|
||||
if transaction.origin != get_domain_from_id(pdu.event_id):
|
||||
# We continue to accept join events from any server; this is
|
||||
# necessary for the federation join dance to work correctly.
|
||||
# (When we join over federation, the "helper" server is
|
||||
# responsible for sending out the join event, rather than the
|
||||
# origin. See bug #1893).
|
||||
if not (
|
||||
pdu.type == 'm.room.member' and
|
||||
pdu.content and
|
||||
pdu.content.get("membership", None) == 'join' and
|
||||
self.hs.is_mine_id(pdu.state_key)
|
||||
pdu.content.get("membership", None) == 'join'
|
||||
):
|
||||
logger.info(
|
||||
"Discarding PDU %s from invalid origin %s",
|
||||
|
@ -165,7 +168,7 @@ class FederationServer(FederationBase):
|
|||
)
|
||||
|
||||
try:
|
||||
yield self._handle_new_pdu(transaction.origin, pdu)
|
||||
yield self._handle_received_pdu(transaction.origin, pdu)
|
||||
results.append({})
|
||||
except FederationError as e:
|
||||
self.send_failure(e, transaction.origin)
|
||||
|
@ -497,27 +500,16 @@ class FederationServer(FederationBase):
|
|||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def _handle_new_pdu(self, origin, pdu, get_missing=True):
|
||||
def _handle_received_pdu(self, origin, pdu):
|
||||
""" Process a PDU received in a federation /send/ transaction.
|
||||
|
||||
# We reprocess pdus when we have seen them only as outliers
|
||||
existing = yield self._get_persisted_pdu(
|
||||
origin, pdu.event_id, do_auth=False
|
||||
)
|
||||
|
||||
# FIXME: Currently we fetch an event again when we already have it
|
||||
# if it has been marked as an outlier.
|
||||
|
||||
already_seen = (
|
||||
existing and (
|
||||
not existing.internal_metadata.is_outlier()
|
||||
or pdu.internal_metadata.is_outlier()
|
||||
)
|
||||
)
|
||||
if already_seen:
|
||||
logger.debug("Already seen pdu %s", pdu.event_id)
|
||||
return
|
||||
Args:
|
||||
origin (str): server which sent the pdu
|
||||
pdu (FrozenEvent): received pdu
|
||||
|
||||
Returns (Deferred): completes with None
|
||||
Raises: FederationError if the signatures / hash do not match
|
||||
"""
|
||||
# Check signature.
|
||||
try:
|
||||
pdu = yield self._check_sigs_and_hash(pdu)
|
||||
|
@ -529,143 +521,7 @@ class FederationServer(FederationBase):
|
|||
affected=pdu.event_id,
|
||||
)
|
||||
|
||||
state = None
|
||||
|
||||
auth_chain = []
|
||||
|
||||
have_seen = yield self.store.have_events(
|
||||
[ev for ev, _ in pdu.prev_events]
|
||||
)
|
||||
|
||||
fetch_state = False
|
||||
|
||||
# Get missing pdus if necessary.
|
||||
if not pdu.internal_metadata.is_outlier():
|
||||
# We only backfill backwards to the min depth.
|
||||
min_depth = yield self.handler.get_min_depth_for_context(
|
||||
pdu.room_id
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"_handle_new_pdu min_depth for %s: %d",
|
||||
pdu.room_id, min_depth
|
||||
)
|
||||
|
||||
prevs = {e_id for e_id, _ in pdu.prev_events}
|
||||
seen = set(have_seen.keys())
|
||||
|
||||
if min_depth and pdu.depth < min_depth:
|
||||
# This is so that we don't notify the user about this
|
||||
# message, to work around the fact that some events will
|
||||
# reference really really old events we really don't want to
|
||||
# send to the clients.
|
||||
pdu.internal_metadata.outlier = True
|
||||
elif min_depth and pdu.depth > min_depth:
|
||||
if get_missing and prevs - seen:
|
||||
# If we're missing stuff, ensure we only fetch stuff one
|
||||
# at a time.
|
||||
logger.info(
|
||||
"Acquiring lock for room %r to fetch %d missing events: %r...",
|
||||
pdu.room_id, len(prevs - seen), list(prevs - seen)[:5],
|
||||
)
|
||||
with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
|
||||
logger.info(
|
||||
"Acquired lock for room %r to fetch %d missing events",
|
||||
pdu.room_id, len(prevs - seen),
|
||||
)
|
||||
|
||||
# We recalculate seen, since it may have changed.
|
||||
have_seen = yield self.store.have_events(prevs)
|
||||
seen = set(have_seen.keys())
|
||||
|
||||
if prevs - seen:
|
||||
latest = yield self.store.get_latest_event_ids_in_room(
|
||||
pdu.room_id
|
||||
)
|
||||
|
||||
# We add the prev events that we have seen to the latest
|
||||
# list to ensure the remote server doesn't give them to us
|
||||
latest = set(latest)
|
||||
latest |= seen
|
||||
|
||||
logger.info(
|
||||
"Missing %d events for room %r: %r...",
|
||||
len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
|
||||
)
|
||||
|
||||
# XXX: we set timeout to 10s to help workaround
|
||||
# https://github.com/matrix-org/synapse/issues/1733.
|
||||
# The reason is to avoid holding the linearizer lock
|
||||
# whilst processing inbound /send transactions, causing
|
||||
# FDs to stack up and block other inbound transactions
|
||||
# which empirically can currently take up to 30 minutes.
|
||||
#
|
||||
# N.B. this explicitly disables retry attempts.
|
||||
#
|
||||
# N.B. this also increases our chances of falling back to
|
||||
# fetching fresh state for the room if the missing event
|
||||
# can't be found, which slightly reduces our security.
|
||||
# it may also increase our DAG extremity count for the room,
|
||||
# causing additional state resolution? See #1760.
|
||||
# However, fetching state doesn't hold the linearizer lock
|
||||
# apparently.
|
||||
#
|
||||
# see https://github.com/matrix-org/synapse/pull/1744
|
||||
|
||||
missing_events = yield self.get_missing_events(
|
||||
origin,
|
||||
pdu.room_id,
|
||||
earliest_events_ids=list(latest),
|
||||
latest_events=[pdu],
|
||||
limit=10,
|
||||
min_depth=min_depth,
|
||||
timeout=10000,
|
||||
)
|
||||
|
||||
# We want to sort these by depth so we process them and
|
||||
# tell clients about them in order.
|
||||
missing_events.sort(key=lambda x: x.depth)
|
||||
|
||||
for e in missing_events:
|
||||
yield self._handle_new_pdu(
|
||||
origin,
|
||||
e,
|
||||
get_missing=False
|
||||
)
|
||||
|
||||
have_seen = yield self.store.have_events(
|
||||
[ev for ev, _ in pdu.prev_events]
|
||||
)
|
||||
|
||||
prevs = {e_id for e_id, _ in pdu.prev_events}
|
||||
seen = set(have_seen.keys())
|
||||
if prevs - seen:
|
||||
logger.info(
|
||||
"Still missing %d events for room %r: %r...",
|
||||
len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
|
||||
)
|
||||
fetch_state = True
|
||||
|
||||
if fetch_state:
|
||||
# We need to get the state at this event, since we haven't
|
||||
# processed all the prev events.
|
||||
logger.debug(
|
||||
"_handle_new_pdu getting state for %s",
|
||||
pdu.room_id
|
||||
)
|
||||
try:
|
||||
state, auth_chain = yield self.get_state_for_room(
|
||||
origin, pdu.room_id, pdu.event_id,
|
||||
)
|
||||
except:
|
||||
logger.exception("Failed to get state for event: %s", pdu.event_id)
|
||||
|
||||
yield self.handler.on_receive_pdu(
|
||||
origin,
|
||||
pdu,
|
||||
state=state,
|
||||
auth_chain=auth_chain,
|
||||
)
|
||||
yield self.handler.on_receive_pdu(origin, pdu, get_missing=True)
|
||||
|
||||
def __str__(self):
|
||||
return "<ReplicationLayer(%s)>" % self.server_name
|
||||
|
|
|
@ -54,6 +54,7 @@ class FederationRemoteSendQueue(object):
|
|||
def __init__(self, hs):
|
||||
self.server_name = hs.hostname
|
||||
self.clock = hs.get_clock()
|
||||
self.notifier = hs.get_notifier()
|
||||
|
||||
self.presence_map = {}
|
||||
self.presence_changed = sorteddict()
|
||||
|
@ -186,6 +187,8 @@ class FederationRemoteSendQueue(object):
|
|||
else:
|
||||
self.edus[pos] = edu
|
||||
|
||||
self.notifier.on_new_replication_data()
|
||||
|
||||
def send_presence(self, destination, states):
|
||||
"""As per TransactionQueue"""
|
||||
pos = self._next_pos()
|
||||
|
@ -199,16 +202,20 @@ class FederationRemoteSendQueue(object):
|
|||
(destination, state.user_id) for state in states
|
||||
]
|
||||
|
||||
self.notifier.on_new_replication_data()
|
||||
|
||||
def send_failure(self, failure, destination):
|
||||
"""As per TransactionQueue"""
|
||||
pos = self._next_pos()
|
||||
|
||||
self.failures[pos] = (destination, str(failure))
|
||||
self.notifier.on_new_replication_data()
|
||||
|
||||
def send_device_messages(self, destination):
|
||||
"""As per TransactionQueue"""
|
||||
pos = self._next_pos()
|
||||
self.device_messages[pos] = destination
|
||||
self.notifier.on_new_replication_data()
|
||||
|
||||
def get_current_token(self):
|
||||
return self.pos - 1
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import datetime
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
|
@ -22,9 +22,7 @@ from .units import Transaction, Edu
|
|||
from synapse.api.errors import HttpResponseException
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.util.logcontext import preserve_context_over_fn
|
||||
from synapse.util.retryutils import (
|
||||
get_retry_limiter, NotRetryingDestination,
|
||||
)
|
||||
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
|
||||
from synapse.util.metrics import measure_func
|
||||
from synapse.types import get_domain_from_id
|
||||
from synapse.handlers.presence import format_user_presence_state
|
||||
|
@ -99,7 +97,12 @@ class TransactionQueue(object):
|
|||
# destination -> list of tuple(failure, deferred)
|
||||
self.pending_failures_by_dest = {}
|
||||
|
||||
# destination -> stream_id of last successfully sent to-device message.
|
||||
# NB: may be a long or an int.
|
||||
self.last_device_stream_id_by_dest = {}
|
||||
|
||||
# destination -> stream_id of last successfully sent device list
|
||||
# update.
|
||||
self.last_device_list_stream_id_by_dest = {}
|
||||
|
||||
# HACK to get unique tx id
|
||||
|
@ -300,20 +303,20 @@ class TransactionQueue(object):
|
|||
)
|
||||
return
|
||||
|
||||
pending_pdus = []
|
||||
try:
|
||||
self.pending_transactions[destination] = 1
|
||||
|
||||
# This will throw if we wouldn't retry. We do this here so we fail
|
||||
# quickly, but we will later check this again in the http client,
|
||||
# hence why we throw the result away.
|
||||
yield get_retry_limiter(destination, self.clock, self.store)
|
||||
|
||||
# XXX: what's this for?
|
||||
yield run_on_reactor()
|
||||
|
||||
pending_pdus = []
|
||||
while True:
|
||||
limiter = yield get_retry_limiter(
|
||||
destination,
|
||||
self.clock,
|
||||
self.store,
|
||||
backoff_on_404=True, # If we get a 404 the other side has gone
|
||||
)
|
||||
|
||||
device_message_edus, device_stream_id, dev_list_id = (
|
||||
yield self._get_new_device_messages(destination)
|
||||
)
|
||||
|
@ -369,7 +372,6 @@ class TransactionQueue(object):
|
|||
|
||||
success = yield self._send_new_transaction(
|
||||
destination, pending_pdus, pending_edus, pending_failures,
|
||||
limiter=limiter,
|
||||
)
|
||||
if success:
|
||||
# Remove the acknowledged device messages from the database
|
||||
|
@ -387,12 +389,24 @@ class TransactionQueue(object):
|
|||
self.last_device_list_stream_id_by_dest[destination] = dev_list_id
|
||||
else:
|
||||
break
|
||||
except NotRetryingDestination:
|
||||
except NotRetryingDestination as e:
|
||||
logger.debug(
|
||||
"TX [%s] not ready for retry yet - "
|
||||
"TX [%s] not ready for retry yet (next retry at %s) - "
|
||||
"dropping transaction for now",
|
||||
destination,
|
||||
datetime.datetime.fromtimestamp(
|
||||
(e.retry_last_ts + e.retry_interval) / 1000.0
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warn(
|
||||
"TX [%s] Failed to send transaction: %s",
|
||||
destination,
|
||||
e,
|
||||
)
|
||||
for p, _ in pending_pdus:
|
||||
logger.info("Failed to send event %s to %s", p.event_id,
|
||||
destination)
|
||||
finally:
|
||||
# We want to be *very* sure we delete this after we stop processing
|
||||
self.pending_transactions.pop(destination, None)
|
||||
|
@ -432,7 +446,7 @@ class TransactionQueue(object):
|
|||
@measure_func("_send_new_transaction")
|
||||
@defer.inlineCallbacks
|
||||
def _send_new_transaction(self, destination, pending_pdus, pending_edus,
|
||||
pending_failures, limiter):
|
||||
pending_failures):
|
||||
|
||||
# Sort based on the order field
|
||||
pending_pdus.sort(key=lambda t: t[1])
|
||||
|
@ -442,132 +456,104 @@ class TransactionQueue(object):
|
|||
|
||||
success = True
|
||||
|
||||
logger.debug("TX [%s] _attempt_new_transaction", destination)
|
||||
|
||||
txn_id = str(self._next_txn_id)
|
||||
|
||||
logger.debug(
|
||||
"TX [%s] {%s} Attempting new transaction"
|
||||
" (pdus: %d, edus: %d, failures: %d)",
|
||||
destination, txn_id,
|
||||
len(pdus),
|
||||
len(edus),
|
||||
len(failures)
|
||||
)
|
||||
|
||||
logger.debug("TX [%s] Persisting transaction...", destination)
|
||||
|
||||
transaction = Transaction.create_new(
|
||||
origin_server_ts=int(self.clock.time_msec()),
|
||||
transaction_id=txn_id,
|
||||
origin=self.server_name,
|
||||
destination=destination,
|
||||
pdus=pdus,
|
||||
edus=edus,
|
||||
pdu_failures=failures,
|
||||
)
|
||||
|
||||
self._next_txn_id += 1
|
||||
|
||||
yield self.transaction_actions.prepare_to_send(transaction)
|
||||
|
||||
logger.debug("TX [%s] Persisted transaction", destination)
|
||||
logger.info(
|
||||
"TX [%s] {%s} Sending transaction [%s],"
|
||||
" (PDUs: %d, EDUs: %d, failures: %d)",
|
||||
destination, txn_id,
|
||||
transaction.transaction_id,
|
||||
len(pdus),
|
||||
len(edus),
|
||||
len(failures),
|
||||
)
|
||||
|
||||
# Actually send the transaction
|
||||
|
||||
# FIXME (erikj): This is a bit of a hack to make the Pdu age
|
||||
# keys work
|
||||
def json_data_cb():
|
||||
data = transaction.get_dict()
|
||||
now = int(self.clock.time_msec())
|
||||
if "pdus" in data:
|
||||
for p in data["pdus"]:
|
||||
if "age_ts" in p:
|
||||
unsigned = p.setdefault("unsigned", {})
|
||||
unsigned["age"] = now - int(p["age_ts"])
|
||||
del p["age_ts"]
|
||||
return data
|
||||
|
||||
try:
|
||||
logger.debug("TX [%s] _attempt_new_transaction", destination)
|
||||
|
||||
txn_id = str(self._next_txn_id)
|
||||
|
||||
logger.debug(
|
||||
"TX [%s] {%s} Attempting new transaction"
|
||||
" (pdus: %d, edus: %d, failures: %d)",
|
||||
destination, txn_id,
|
||||
len(pdus),
|
||||
len(edus),
|
||||
len(failures)
|
||||
response = yield self.transport_layer.send_transaction(
|
||||
transaction, json_data_cb
|
||||
)
|
||||
code = 200
|
||||
|
||||
logger.debug("TX [%s] Persisting transaction...", destination)
|
||||
|
||||
transaction = Transaction.create_new(
|
||||
origin_server_ts=int(self.clock.time_msec()),
|
||||
transaction_id=txn_id,
|
||||
origin=self.server_name,
|
||||
destination=destination,
|
||||
pdus=pdus,
|
||||
edus=edus,
|
||||
pdu_failures=failures,
|
||||
)
|
||||
|
||||
self._next_txn_id += 1
|
||||
|
||||
yield self.transaction_actions.prepare_to_send(transaction)
|
||||
|
||||
logger.debug("TX [%s] Persisted transaction", destination)
|
||||
logger.info(
|
||||
"TX [%s] {%s} Sending transaction [%s],"
|
||||
" (PDUs: %d, EDUs: %d, failures: %d)",
|
||||
destination, txn_id,
|
||||
transaction.transaction_id,
|
||||
len(pdus),
|
||||
len(edus),
|
||||
len(failures),
|
||||
)
|
||||
|
||||
with limiter:
|
||||
# Actually send the transaction
|
||||
|
||||
# FIXME (erikj): This is a bit of a hack to make the Pdu age
|
||||
# keys work
|
||||
def json_data_cb():
|
||||
data = transaction.get_dict()
|
||||
now = int(self.clock.time_msec())
|
||||
if "pdus" in data:
|
||||
for p in data["pdus"]:
|
||||
if "age_ts" in p:
|
||||
unsigned = p.setdefault("unsigned", {})
|
||||
unsigned["age"] = now - int(p["age_ts"])
|
||||
del p["age_ts"]
|
||||
return data
|
||||
|
||||
try:
|
||||
response = yield self.transport_layer.send_transaction(
|
||||
transaction, json_data_cb
|
||||
)
|
||||
code = 200
|
||||
|
||||
if response:
|
||||
for e_id, r in response.get("pdus", {}).items():
|
||||
if "error" in r:
|
||||
logger.warn(
|
||||
"Transaction returned error for %s: %s",
|
||||
e_id, r,
|
||||
)
|
||||
except HttpResponseException as e:
|
||||
code = e.code
|
||||
response = e.response
|
||||
|
||||
if e.code in (401, 404, 429) or 500 <= e.code:
|
||||
logger.info(
|
||||
"TX [%s] {%s} got %d response",
|
||||
destination, txn_id, code
|
||||
if response:
|
||||
for e_id, r in response.get("pdus", {}).items():
|
||||
if "error" in r:
|
||||
logger.warn(
|
||||
"Transaction returned error for %s: %s",
|
||||
e_id, r,
|
||||
)
|
||||
raise e
|
||||
except HttpResponseException as e:
|
||||
code = e.code
|
||||
response = e.response
|
||||
|
||||
if e.code in (401, 404, 429) or 500 <= e.code:
|
||||
logger.info(
|
||||
"TX [%s] {%s} got %d response",
|
||||
destination, txn_id, code
|
||||
)
|
||||
raise e
|
||||
|
||||
logger.debug("TX [%s] Sent transaction", destination)
|
||||
logger.debug("TX [%s] Marking as delivered...", destination)
|
||||
logger.info(
|
||||
"TX [%s] {%s} got %d response",
|
||||
destination, txn_id, code
|
||||
)
|
||||
|
||||
yield self.transaction_actions.delivered(
|
||||
transaction, code, response
|
||||
)
|
||||
logger.debug("TX [%s] Sent transaction", destination)
|
||||
logger.debug("TX [%s] Marking as delivered...", destination)
|
||||
|
||||
logger.debug("TX [%s] Marked as delivered", destination)
|
||||
yield self.transaction_actions.delivered(
|
||||
transaction, code, response
|
||||
)
|
||||
|
||||
if code != 200:
|
||||
for p in pdus:
|
||||
logger.info(
|
||||
"Failed to send event %s to %s", p.event_id, destination
|
||||
)
|
||||
success = False
|
||||
except RuntimeError as e:
|
||||
# We capture this here as there as nothing actually listens
|
||||
# for this finishing functions deferred.
|
||||
logger.warn(
|
||||
"TX [%s] Problem in _attempt_transaction: %s",
|
||||
destination,
|
||||
e,
|
||||
)
|
||||
|
||||
success = False
|
||||
logger.debug("TX [%s] Marked as delivered", destination)
|
||||
|
||||
if code != 200:
|
||||
for p in pdus:
|
||||
logger.info("Failed to send event %s to %s", p.event_id, destination)
|
||||
except Exception as e:
|
||||
# We capture this here as there as nothing actually listens
|
||||
# for this finishing functions deferred.
|
||||
logger.warn(
|
||||
"TX [%s] Problem in _attempt_transaction: %s",
|
||||
destination,
|
||||
e,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Failed to send event %s to %s", p.event_id, destination
|
||||
)
|
||||
success = False
|
||||
|
||||
for p in pdus:
|
||||
logger.info("Failed to send event %s to %s", p.event_id, destination)
|
||||
|
||||
defer.returnValue(success)
|
||||
|
|
|
@ -163,6 +163,7 @@ class TransportLayerClient(object):
|
|||
data=json_data,
|
||||
json_data_callback=json_data_callback,
|
||||
long_retries=True,
|
||||
backoff_on_404=True, # If we get a 404 the other side has gone
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
|
@ -174,7 +175,8 @@ class TransportLayerClient(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def make_query(self, destination, query_type, args, retry_on_dns_fail):
|
||||
def make_query(self, destination, query_type, args, retry_on_dns_fail,
|
||||
ignore_backoff=False):
|
||||
path = PREFIX + "/query/%s" % query_type
|
||||
|
||||
content = yield self.client.get_json(
|
||||
|
@ -183,6 +185,7 @@ class TransportLayerClient(object):
|
|||
args=args,
|
||||
retry_on_dns_fail=retry_on_dns_fail,
|
||||
timeout=10000,
|
||||
ignore_backoff=ignore_backoff,
|
||||
)
|
||||
|
||||
defer.returnValue(content)
|
||||
|
@ -242,6 +245,7 @@ class TransportLayerClient(object):
|
|||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
defer.returnValue(response)
|
||||
|
@ -269,6 +273,7 @@ class TransportLayerClient(object):
|
|||
destination=remote_server,
|
||||
path=path,
|
||||
args=args,
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
defer.returnValue(response)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 - 2016 OpenMarket Ltd
|
||||
# Copyright 2017 Vector Creations Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -47,6 +48,7 @@ class AuthHandler(BaseHandler):
|
|||
LoginType.PASSWORD: self._check_password_auth,
|
||||
LoginType.RECAPTCHA: self._check_recaptcha,
|
||||
LoginType.EMAIL_IDENTITY: self._check_email_identity,
|
||||
LoginType.MSISDN: self._check_msisdn,
|
||||
LoginType.DUMMY: self._check_dummy_auth,
|
||||
}
|
||||
self.bcrypt_rounds = hs.config.bcrypt_rounds
|
||||
|
@ -307,31 +309,47 @@ class AuthHandler(BaseHandler):
|
|||
defer.returnValue(True)
|
||||
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_email_identity(self, authdict, _):
|
||||
return self._check_threepid('email', authdict)
|
||||
|
||||
def _check_msisdn(self, authdict, _):
|
||||
return self._check_threepid('msisdn', authdict)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_dummy_auth(self, authdict, _):
|
||||
yield run_on_reactor()
|
||||
defer.returnValue(True)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_threepid(self, medium, authdict):
|
||||
yield run_on_reactor()
|
||||
|
||||
if 'threepid_creds' not in authdict:
|
||||
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
|
||||
|
||||
threepid_creds = authdict['threepid_creds']
|
||||
|
||||
identity_handler = self.hs.get_handlers().identity_handler
|
||||
|
||||
logger.info("Getting validated threepid. threepidcreds: %r" % (threepid_creds,))
|
||||
logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,))
|
||||
threepid = yield identity_handler.threepid_from_creds(threepid_creds)
|
||||
|
||||
if not threepid:
|
||||
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
if threepid['medium'] != medium:
|
||||
raise LoginError(
|
||||
401,
|
||||
"Expecting threepid of type '%s', got '%s'" % (
|
||||
medium, threepid['medium'],
|
||||
),
|
||||
errcode=Codes.UNAUTHORIZED
|
||||
)
|
||||
|
||||
threepid['threepid_creds'] = authdict['threepid_creds']
|
||||
|
||||
defer.returnValue(threepid)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_dummy_auth(self, authdict, _):
|
||||
yield run_on_reactor()
|
||||
defer.returnValue(True)
|
||||
|
||||
def _get_params_recaptcha(self):
|
||||
return {"public_key": self.hs.config.recaptcha_public_key}
|
||||
|
||||
|
|
|
@ -169,6 +169,40 @@ class DeviceHandler(BaseHandler):
|
|||
|
||||
yield self.notify_device_update(user_id, [device_id])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_devices(self, user_id, device_ids):
|
||||
""" Delete several devices
|
||||
|
||||
Args:
|
||||
user_id (str):
|
||||
device_ids (str): The list of device IDs to delete
|
||||
|
||||
Returns:
|
||||
defer.Deferred:
|
||||
"""
|
||||
|
||||
try:
|
||||
yield self.store.delete_devices(user_id, device_ids)
|
||||
except errors.StoreError, e:
|
||||
if e.code == 404:
|
||||
# no match
|
||||
pass
|
||||
else:
|
||||
raise
|
||||
|
||||
# Delete access tokens and e2e keys for each device. Not optimised as it is not
|
||||
# considered as part of a critical path.
|
||||
for device_id in device_ids:
|
||||
yield self.store.user_delete_access_tokens(
|
||||
user_id, device_id=device_id,
|
||||
delete_refresh_tokens=True,
|
||||
)
|
||||
yield self.store.delete_e2e_keys_by_device(
|
||||
user_id=user_id, device_id=device_id
|
||||
)
|
||||
|
||||
yield self.notify_device_update(user_id, device_ids)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_device(self, user_id, device_id, content):
|
||||
""" Update the given device
|
||||
|
@ -214,8 +248,7 @@ class DeviceHandler(BaseHandler):
|
|||
user_id, device_ids, list(hosts)
|
||||
)
|
||||
|
||||
rooms = yield self.store.get_rooms_for_user(user_id)
|
||||
room_ids = [r.room_id for r in rooms]
|
||||
room_ids = yield self.store.get_rooms_for_user(user_id)
|
||||
|
||||
yield self.notifier.on_new_event(
|
||||
"device_list_key", position, rooms=room_ids,
|
||||
|
@ -236,8 +269,7 @@ class DeviceHandler(BaseHandler):
|
|||
user_id (str)
|
||||
from_token (StreamToken)
|
||||
"""
|
||||
rooms = yield self.store.get_rooms_for_user(user_id)
|
||||
room_ids = set(r.room_id for r in rooms)
|
||||
room_ids = yield self.store.get_rooms_for_user(user_id)
|
||||
|
||||
# First we check if any devices have changed
|
||||
changed = yield self.store.get_user_whose_devices_changed(
|
||||
|
@ -262,7 +294,7 @@ class DeviceHandler(BaseHandler):
|
|||
# ordering: treat it the same as a new room
|
||||
event_ids = []
|
||||
|
||||
current_state_ids = yield self.state.get_current_state_ids(room_id)
|
||||
current_state_ids = yield self.store.get_current_state_ids(room_id)
|
||||
|
||||
# special-case for an empty prev state: include all members
|
||||
# in the changed list
|
||||
|
@ -313,8 +345,8 @@ class DeviceHandler(BaseHandler):
|
|||
@defer.inlineCallbacks
|
||||
def user_left_room(self, user, room_id):
|
||||
user_id = user.to_string()
|
||||
rooms = yield self.store.get_rooms_for_user(user_id)
|
||||
if not rooms:
|
||||
room_ids = yield self.store.get_rooms_for_user(user_id)
|
||||
if not room_ids:
|
||||
# We no longer share rooms with this user, so we'll no longer
|
||||
# receive device updates. Mark this in DB.
|
||||
yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
|
||||
|
@ -370,8 +402,8 @@ class DeviceListEduUpdater(object):
|
|||
logger.warning("Got device list update edu for %r from %r", user_id, origin)
|
||||
return
|
||||
|
||||
rooms = yield self.store.get_rooms_for_user(user_id)
|
||||
if not rooms:
|
||||
room_ids = yield self.store.get_rooms_for_user(user_id)
|
||||
if not room_ids:
|
||||
# We don't share any rooms with this user. Ignore update, as we
|
||||
# probably won't get any further updates.
|
||||
return
|
||||
|
|
|
@ -175,6 +175,7 @@ class DirectoryHandler(BaseHandler):
|
|||
"room_alias": room_alias.to_string(),
|
||||
},
|
||||
retry_on_dns_fail=False,
|
||||
ignore_backoff=True,
|
||||
)
|
||||
except CodeMessageException as e:
|
||||
logging.warn("Error retrieving alias")
|
||||
|
|
|
@ -22,7 +22,7 @@ from twisted.internet import defer
|
|||
from synapse.api.errors import SynapseError, CodeMessageException
|
||||
from synapse.types import get_domain_from_id
|
||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -121,15 +121,11 @@ class E2eKeysHandler(object):
|
|||
def do_remote_query(destination):
|
||||
destination_query = remote_queries_not_in_cache[destination]
|
||||
try:
|
||||
limiter = yield get_retry_limiter(
|
||||
destination, self.clock, self.store
|
||||
remote_result = yield self.federation.query_client_keys(
|
||||
destination,
|
||||
{"device_keys": destination_query},
|
||||
timeout=timeout
|
||||
)
|
||||
with limiter:
|
||||
remote_result = yield self.federation.query_client_keys(
|
||||
destination,
|
||||
{"device_keys": destination_query},
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
for user_id, keys in remote_result["device_keys"].items():
|
||||
if user_id in destination_query:
|
||||
|
@ -239,18 +235,14 @@ class E2eKeysHandler(object):
|
|||
def claim_client_keys(destination):
|
||||
device_keys = remote_queries[destination]
|
||||
try:
|
||||
limiter = yield get_retry_limiter(
|
||||
destination, self.clock, self.store
|
||||
remote_result = yield self.federation.claim_client_keys(
|
||||
destination,
|
||||
{"one_time_keys": device_keys},
|
||||
timeout=timeout
|
||||
)
|
||||
with limiter:
|
||||
remote_result = yield self.federation.claim_client_keys(
|
||||
destination,
|
||||
{"one_time_keys": device_keys},
|
||||
timeout=timeout
|
||||
)
|
||||
for user_id, keys in remote_result["one_time_keys"].items():
|
||||
if user_id in device_keys:
|
||||
json_result[user_id] = keys
|
||||
for user_id, keys in remote_result["one_time_keys"].items():
|
||||
if user_id in device_keys:
|
||||
json_result[user_id] = keys
|
||||
except CodeMessageException as e:
|
||||
failures[destination] = {
|
||||
"status": e.code, "message": e.message
|
||||
|
@ -316,7 +308,7 @@ class E2eKeysHandler(object):
|
|||
# old access_token without an associated device_id. Either way, we
|
||||
# need to double-check the device is registered to avoid ending up with
|
||||
# keys without a corresponding device.
|
||||
self.device_handler.check_device_registered(user_id, device_id)
|
||||
yield self.device_handler.check_device_registered(user_id, device_id)
|
||||
|
||||
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
"""Contains handlers for federation events."""
|
||||
import synapse.util.logcontext
|
||||
from signedjson.key import decode_verify_key_bytes
|
||||
from signedjson.sign import verify_signed_json
|
||||
from unpaddedbase64 import decode_base64
|
||||
|
@ -31,7 +32,7 @@ from synapse.util.logcontext import (
|
|||
)
|
||||
from synapse.util.metrics import measure_func
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.util.async import run_on_reactor, Linearizer
|
||||
from synapse.util.frozenutils import unfreeze
|
||||
from synapse.crypto.event_signing import (
|
||||
compute_event_signature, add_hashes_and_signatures,
|
||||
|
@ -79,29 +80,216 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
# When joining a room we need to queue any events for that room up
|
||||
self.room_queues = {}
|
||||
self._room_pdu_linearizer = Linearizer("fed_room_pdu")
|
||||
|
||||
@log_function
|
||||
@defer.inlineCallbacks
|
||||
def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None):
|
||||
""" Called by the ReplicationLayer when we have a new pdu. We need to
|
||||
do auth checks and put it through the StateHandler.
|
||||
@log_function
|
||||
def on_receive_pdu(self, origin, pdu, get_missing=True):
|
||||
""" Process a PDU received via a federation /send/ transaction, or
|
||||
via backfill of missing prev_events
|
||||
|
||||
auth_chain and state are None if we already have the necessary state
|
||||
and prev_events in the db
|
||||
Args:
|
||||
origin (str): server which initiated the /send/ transaction. Will
|
||||
be used to fetch missing events or state.
|
||||
pdu (FrozenEvent): received PDU
|
||||
get_missing (bool): True if we should fetch missing prev_events
|
||||
|
||||
Returns (Deferred): completes with None
|
||||
"""
|
||||
event = pdu
|
||||
|
||||
logger.debug("Got event: %s", event.event_id)
|
||||
# We reprocess pdus when we have seen them only as outliers
|
||||
existing = yield self.get_persisted_pdu(
|
||||
origin, pdu.event_id, do_auth=False
|
||||
)
|
||||
|
||||
# FIXME: Currently we fetch an event again when we already have it
|
||||
# if it has been marked as an outlier.
|
||||
|
||||
already_seen = (
|
||||
existing and (
|
||||
not existing.internal_metadata.is_outlier()
|
||||
or pdu.internal_metadata.is_outlier()
|
||||
)
|
||||
)
|
||||
if already_seen:
|
||||
logger.debug("Already seen pdu %s", pdu.event_id)
|
||||
return
|
||||
|
||||
# If we are currently in the process of joining this room, then we
|
||||
# queue up events for later processing.
|
||||
if event.room_id in self.room_queues:
|
||||
self.room_queues[event.room_id].append((pdu, origin))
|
||||
if pdu.room_id in self.room_queues:
|
||||
logger.info("Ignoring PDU %s for room %s from %s for now; join "
|
||||
"in progress", pdu.event_id, pdu.room_id, origin)
|
||||
self.room_queues[pdu.room_id].append((pdu, origin))
|
||||
return
|
||||
|
||||
logger.debug("Processing event: %s", event.event_id)
|
||||
state = None
|
||||
|
||||
logger.debug("Event: %s", event)
|
||||
auth_chain = []
|
||||
|
||||
have_seen = yield self.store.have_events(
|
||||
[ev for ev, _ in pdu.prev_events]
|
||||
)
|
||||
|
||||
fetch_state = False
|
||||
|
||||
# Get missing pdus if necessary.
|
||||
if not pdu.internal_metadata.is_outlier():
|
||||
# We only backfill backwards to the min depth.
|
||||
min_depth = yield self.get_min_depth_for_context(
|
||||
pdu.room_id
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"_handle_new_pdu min_depth for %s: %d",
|
||||
pdu.room_id, min_depth
|
||||
)
|
||||
|
||||
prevs = {e_id for e_id, _ in pdu.prev_events}
|
||||
seen = set(have_seen.keys())
|
||||
|
||||
if min_depth and pdu.depth < min_depth:
|
||||
# This is so that we don't notify the user about this
|
||||
# message, to work around the fact that some events will
|
||||
# reference really really old events we really don't want to
|
||||
# send to the clients.
|
||||
pdu.internal_metadata.outlier = True
|
||||
elif min_depth and pdu.depth > min_depth:
|
||||
if get_missing and prevs - seen:
|
||||
# If we're missing stuff, ensure we only fetch stuff one
|
||||
# at a time.
|
||||
logger.info(
|
||||
"Acquiring lock for room %r to fetch %d missing events: %r...",
|
||||
pdu.room_id, len(prevs - seen), list(prevs - seen)[:5],
|
||||
)
|
||||
with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
|
||||
logger.info(
|
||||
"Acquired lock for room %r to fetch %d missing events",
|
||||
pdu.room_id, len(prevs - seen),
|
||||
)
|
||||
|
||||
yield self._get_missing_events_for_pdu(
|
||||
origin, pdu, prevs, min_depth
|
||||
)
|
||||
|
||||
prevs = {e_id for e_id, _ in pdu.prev_events}
|
||||
seen = set(have_seen.keys())
|
||||
if prevs - seen:
|
||||
logger.info(
|
||||
"Still missing %d events for room %r: %r...",
|
||||
len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
|
||||
)
|
||||
fetch_state = True
|
||||
|
||||
if fetch_state:
|
||||
# We need to get the state at this event, since we haven't
|
||||
# processed all the prev events.
|
||||
logger.debug(
|
||||
"_handle_new_pdu getting state for %s",
|
||||
pdu.room_id
|
||||
)
|
||||
try:
|
||||
state, auth_chain = yield self.replication_layer.get_state_for_room(
|
||||
origin, pdu.room_id, pdu.event_id,
|
||||
)
|
||||
except:
|
||||
logger.exception("Failed to get state for event: %s", pdu.event_id)
|
||||
|
||||
yield self._process_received_pdu(
|
||||
origin,
|
||||
pdu,
|
||||
state=state,
|
||||
auth_chain=auth_chain,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
|
||||
"""
|
||||
Args:
|
||||
origin (str): Origin of the pdu. Will be called to get the missing events
|
||||
pdu: received pdu
|
||||
prevs (str[]): List of event ids which we are missing
|
||||
min_depth (int): Minimum depth of events to return.
|
||||
|
||||
Returns:
|
||||
Deferred<dict(str, str?)>: updated have_seen dictionary
|
||||
"""
|
||||
# We recalculate seen, since it may have changed.
|
||||
have_seen = yield self.store.have_events(prevs)
|
||||
seen = set(have_seen.keys())
|
||||
|
||||
if not prevs - seen:
|
||||
# nothing left to do
|
||||
defer.returnValue(have_seen)
|
||||
|
||||
latest = yield self.store.get_latest_event_ids_in_room(
|
||||
pdu.room_id
|
||||
)
|
||||
|
||||
# We add the prev events that we have seen to the latest
|
||||
# list to ensure the remote server doesn't give them to us
|
||||
latest = set(latest)
|
||||
latest |= seen
|
||||
|
||||
logger.info(
|
||||
"Missing %d events for room %r: %r...",
|
||||
len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
|
||||
)
|
||||
|
||||
# XXX: we set timeout to 10s to help workaround
|
||||
# https://github.com/matrix-org/synapse/issues/1733.
|
||||
# The reason is to avoid holding the linearizer lock
|
||||
# whilst processing inbound /send transactions, causing
|
||||
# FDs to stack up and block other inbound transactions
|
||||
# which empirically can currently take up to 30 minutes.
|
||||
#
|
||||
# N.B. this explicitly disables retry attempts.
|
||||
#
|
||||
# N.B. this also increases our chances of falling back to
|
||||
# fetching fresh state for the room if the missing event
|
||||
# can't be found, which slightly reduces our security.
|
||||
# it may also increase our DAG extremity count for the room,
|
||||
# causing additional state resolution? See #1760.
|
||||
# However, fetching state doesn't hold the linearizer lock
|
||||
# apparently.
|
||||
#
|
||||
# see https://github.com/matrix-org/synapse/pull/1744
|
||||
|
||||
missing_events = yield self.replication_layer.get_missing_events(
|
||||
origin,
|
||||
pdu.room_id,
|
||||
earliest_events_ids=list(latest),
|
||||
latest_events=[pdu],
|
||||
limit=10,
|
||||
min_depth=min_depth,
|
||||
timeout=10000,
|
||||
)
|
||||
|
||||
# We want to sort these by depth so we process them and
|
||||
# tell clients about them in order.
|
||||
missing_events.sort(key=lambda x: x.depth)
|
||||
|
||||
for e in missing_events:
|
||||
yield self.on_receive_pdu(
|
||||
origin,
|
||||
e,
|
||||
get_missing=False
|
||||
)
|
||||
|
||||
have_seen = yield self.store.have_events(
|
||||
[ev for ev, _ in pdu.prev_events]
|
||||
)
|
||||
defer.returnValue(have_seen)
|
||||
|
||||
@log_function
|
||||
@defer.inlineCallbacks
|
||||
def _process_received_pdu(self, origin, pdu, state, auth_chain):
|
||||
""" Called when we have a new pdu. We need to do auth checks and put it
|
||||
through the StateHandler.
|
||||
"""
|
||||
event = pdu
|
||||
|
||||
logger.debug("Processing event: %s", event)
|
||||
|
||||
# FIXME (erikj): Awful hack to make the case where we are not currently
|
||||
# in the room work
|
||||
|
@ -670,8 +858,6 @@ class FederationHandler(BaseHandler):
|
|||
"""
|
||||
logger.debug("Joining %s to %s", joinee, room_id)
|
||||
|
||||
yield self.store.clean_room_for_join(room_id)
|
||||
|
||||
origin, event = yield self._make_and_verify_event(
|
||||
target_hosts,
|
||||
room_id,
|
||||
|
@ -680,7 +866,15 @@ class FederationHandler(BaseHandler):
|
|||
content,
|
||||
)
|
||||
|
||||
# This shouldn't happen, because the RoomMemberHandler has a
|
||||
# linearizer lock which only allows one operation per user per room
|
||||
# at a time - so this is just paranoia.
|
||||
assert (room_id not in self.room_queues)
|
||||
|
||||
self.room_queues[room_id] = []
|
||||
|
||||
yield self.store.clean_room_for_join(room_id)
|
||||
|
||||
handled_events = set()
|
||||
|
||||
try:
|
||||
|
@ -733,17 +927,36 @@ class FederationHandler(BaseHandler):
|
|||
room_queue = self.room_queues[room_id]
|
||||
del self.room_queues[room_id]
|
||||
|
||||
for p, origin in room_queue:
|
||||
if p.event_id in handled_events:
|
||||
continue
|
||||
# we don't need to wait for the queued events to be processed -
|
||||
# it's just a best-effort thing at this point. We do want to do
|
||||
# them roughly in order, though, otherwise we'll end up making
|
||||
# lots of requests for missing prev_events which we do actually
|
||||
# have. Hence we fire off the deferred, but don't wait for it.
|
||||
|
||||
try:
|
||||
self.on_receive_pdu(origin, p)
|
||||
except:
|
||||
logger.exception("Couldn't handle pdu")
|
||||
synapse.util.logcontext.preserve_fn(self._handle_queued_pdus)(
|
||||
room_queue
|
||||
)
|
||||
|
||||
defer.returnValue(True)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _handle_queued_pdus(self, room_queue):
|
||||
"""Process PDUs which got queued up while we were busy send_joining.
|
||||
|
||||
Args:
|
||||
room_queue (list[FrozenEvent, str]): list of PDUs to be processed
|
||||
and the servers that sent them
|
||||
"""
|
||||
for p, origin in room_queue:
|
||||
try:
|
||||
logger.info("Processing queued PDU %s which was received "
|
||||
"while we were joining %s", p.event_id, p.room_id)
|
||||
yield self.on_receive_pdu(origin, p)
|
||||
except Exception as e:
|
||||
logger.warn(
|
||||
"Error handling queued PDU %s from %s: %s",
|
||||
p.event_id, origin, e)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def on_make_join_request(self, room_id, user_id):
|
||||
|
@ -791,9 +1004,19 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
|
||||
event.internal_metadata.outlier = False
|
||||
# Send this event on behalf of the origin server since they may not
|
||||
# have an up to data view of the state of the room at this event so
|
||||
# will not know which servers to send the event to.
|
||||
# Send this event on behalf of the origin server.
|
||||
#
|
||||
# The reasons we have the destination server rather than the origin
|
||||
# server send it are slightly mysterious: the origin server should have
|
||||
# all the neccessary state once it gets the response to the send_join,
|
||||
# so it could send the event itself if it wanted to. It may be that
|
||||
# doing it this way reduces failure modes, or avoids certain attacks
|
||||
# where a new server selectively tells a subset of the federation that
|
||||
# it has joined.
|
||||
#
|
||||
# The fact is that, as of the current writing, Synapse doesn't send out
|
||||
# the join event over federation after joining, and changing it now
|
||||
# would introduce the danger of backwards-compatibility problems.
|
||||
event.internal_metadata.send_on_behalf_of = origin
|
||||
|
||||
context, event_stream_id, max_stream_id = yield self._handle_new_event(
|
||||
|
@ -878,15 +1101,15 @@ class FederationHandler(BaseHandler):
|
|||
user_id,
|
||||
"leave"
|
||||
)
|
||||
signed_event = self._sign_event(event)
|
||||
event = self._sign_event(event)
|
||||
except SynapseError:
|
||||
raise
|
||||
except CodeMessageException as e:
|
||||
logger.warn("Failed to reject invite: %s", e)
|
||||
raise SynapseError(500, "Failed to reject invite")
|
||||
|
||||
# Try the host we successfully got a response to /make_join/
|
||||
# request first.
|
||||
# Try the host that we succesfully called /make_leave/ on first for
|
||||
# the /send_leave/ request.
|
||||
try:
|
||||
target_hosts.remove(origin)
|
||||
target_hosts.insert(0, origin)
|
||||
|
@ -896,7 +1119,7 @@ class FederationHandler(BaseHandler):
|
|||
try:
|
||||
yield self.replication_layer.send_leave(
|
||||
target_hosts,
|
||||
signed_event
|
||||
event
|
||||
)
|
||||
except SynapseError:
|
||||
raise
|
||||
|
@ -1325,7 +1548,17 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def _prep_event(self, origin, event, state=None, auth_events=None):
|
||||
"""
|
||||
|
||||
Args:
|
||||
origin:
|
||||
event:
|
||||
state:
|
||||
auth_events:
|
||||
|
||||
Returns:
|
||||
Deferred, which resolves to synapse.events.snapshot.EventContext
|
||||
"""
|
||||
context = yield self.state_handler.compute_event_context(
|
||||
event, old_state=state,
|
||||
)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
# Copyright 2017 Vector Creations Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -150,7 +151,7 @@ class IdentityHandler(BaseHandler):
|
|||
params.update(kwargs)
|
||||
|
||||
try:
|
||||
data = yield self.http_client.post_urlencoded_get_json(
|
||||
data = yield self.http_client.post_json_get_json(
|
||||
"https://%s%s" % (
|
||||
id_server,
|
||||
"/_matrix/identity/api/v1/validate/email/requestToken"
|
||||
|
@ -161,3 +162,37 @@ class IdentityHandler(BaseHandler):
|
|||
except CodeMessageException as e:
|
||||
logger.info("Proxied requestToken failed: %r", e)
|
||||
raise e
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def requestMsisdnToken(
|
||||
self, id_server, country, phone_number,
|
||||
client_secret, send_attempt, **kwargs
|
||||
):
|
||||
yield run_on_reactor()
|
||||
|
||||
if not self._should_trust_id_server(id_server):
|
||||
raise SynapseError(
|
||||
400, "Untrusted ID server '%s'" % id_server,
|
||||
Codes.SERVER_NOT_TRUSTED
|
||||
)
|
||||
|
||||
params = {
|
||||
'country': country,
|
||||
'phone_number': phone_number,
|
||||
'client_secret': client_secret,
|
||||
'send_attempt': send_attempt,
|
||||
}
|
||||
params.update(kwargs)
|
||||
|
||||
try:
|
||||
data = yield self.http_client.post_json_get_json(
|
||||
"https://%s%s" % (
|
||||
id_server,
|
||||
"/_matrix/identity/api/v1/validate/msisdn/requestToken"
|
||||
),
|
||||
params
|
||||
)
|
||||
defer.returnValue(data)
|
||||
except CodeMessageException as e:
|
||||
logger.info("Proxied requestToken failed: %r", e)
|
||||
raise e
|
||||
|
|
|
@ -19,6 +19,7 @@ from synapse.api.constants import EventTypes, Membership
|
|||
from synapse.api.errors import AuthError, Codes
|
||||
from synapse.events.utils import serialize_event
|
||||
from synapse.events.validator import EventValidator
|
||||
from synapse.handlers.presence import format_user_presence_state
|
||||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.types import (
|
||||
UserID, StreamToken,
|
||||
|
@ -225,9 +226,17 @@ class InitialSyncHandler(BaseHandler):
|
|||
"content": content,
|
||||
})
|
||||
|
||||
now = self.clock.time_msec()
|
||||
|
||||
ret = {
|
||||
"rooms": rooms_ret,
|
||||
"presence": presence,
|
||||
"presence": [
|
||||
{
|
||||
"type": "m.presence",
|
||||
"content": format_user_presence_state(event, now),
|
||||
}
|
||||
for event in presence
|
||||
],
|
||||
"account_data": account_data_events,
|
||||
"receipts": receipt,
|
||||
"end": now_token.to_string(),
|
||||
|
|
|
@ -29,6 +29,7 @@ from synapse.api.errors import SynapseError
|
|||
from synapse.api.constants import PresenceState
|
||||
from synapse.storage.presence import UserPresenceState
|
||||
|
||||
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||
from synapse.util.logcontext import preserve_fn
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.metrics import Measure
|
||||
|
@ -556,9 +557,9 @@ class PresenceHandler(object):
|
|||
room_ids_to_states = {}
|
||||
users_to_states = {}
|
||||
for state in states:
|
||||
events = yield self.store.get_rooms_for_user(state.user_id)
|
||||
for e in events:
|
||||
room_ids_to_states.setdefault(e.room_id, []).append(state)
|
||||
room_ids = yield self.store.get_rooms_for_user(state.user_id)
|
||||
for room_id in room_ids:
|
||||
room_ids_to_states.setdefault(room_id, []).append(state)
|
||||
|
||||
plist = yield self.store.get_presence_list_observers_accepted(state.user_id)
|
||||
for u in plist:
|
||||
|
@ -574,8 +575,7 @@ class PresenceHandler(object):
|
|||
if not local_states:
|
||||
continue
|
||||
|
||||
users = yield self.store.get_users_in_room(room_id)
|
||||
hosts = set(get_domain_from_id(u) for u in users)
|
||||
hosts = yield self.store.get_hosts_in_room(room_id)
|
||||
|
||||
for host in hosts:
|
||||
hosts_to_states.setdefault(host, []).extend(local_states)
|
||||
|
@ -719,9 +719,7 @@ class PresenceHandler(object):
|
|||
for state in updates
|
||||
])
|
||||
else:
|
||||
defer.returnValue([
|
||||
format_user_presence_state(state, now) for state in updates
|
||||
])
|
||||
defer.returnValue(updates)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_state(self, target_user, state, ignore_status_msg=False):
|
||||
|
@ -795,6 +793,9 @@ class PresenceHandler(object):
|
|||
as_event=False,
|
||||
)
|
||||
|
||||
now = self.clock.time_msec()
|
||||
results[:] = [format_user_presence_state(r, now) for r in results]
|
||||
|
||||
is_accepted = {
|
||||
row["observed_user_id"]: row["accepted"] for row in presence_list
|
||||
}
|
||||
|
@ -847,6 +848,7 @@ class PresenceHandler(object):
|
|||
)
|
||||
|
||||
state_dict = yield self.get_state(observed_user, as_event=False)
|
||||
state_dict = format_user_presence_state(state_dict, self.clock.time_msec())
|
||||
|
||||
self.federation.send_edu(
|
||||
destination=observer_user.domain,
|
||||
|
@ -910,11 +912,12 @@ class PresenceHandler(object):
|
|||
def is_visible(self, observed_user, observer_user):
|
||||
"""Returns whether a user can see another user's presence.
|
||||
"""
|
||||
observer_rooms = yield self.store.get_rooms_for_user(observer_user.to_string())
|
||||
observed_rooms = yield self.store.get_rooms_for_user(observed_user.to_string())
|
||||
|
||||
observer_room_ids = set(r.room_id for r in observer_rooms)
|
||||
observed_room_ids = set(r.room_id for r in observed_rooms)
|
||||
observer_room_ids = yield self.store.get_rooms_for_user(
|
||||
observer_user.to_string()
|
||||
)
|
||||
observed_room_ids = yield self.store.get_rooms_for_user(
|
||||
observed_user.to_string()
|
||||
)
|
||||
|
||||
if observer_room_ids & observed_room_ids:
|
||||
defer.returnValue(True)
|
||||
|
@ -979,14 +982,18 @@ def should_notify(old_state, new_state):
|
|||
return False
|
||||
|
||||
|
||||
def format_user_presence_state(state, now):
|
||||
def format_user_presence_state(state, now, include_user_id=True):
|
||||
"""Convert UserPresenceState to a format that can be sent down to clients
|
||||
and to other servers.
|
||||
|
||||
The "user_id" is optional so that this function can be used to format presence
|
||||
updates for client /sync responses and for federation /send requests.
|
||||
"""
|
||||
content = {
|
||||
"presence": state.state,
|
||||
"user_id": state.user_id,
|
||||
}
|
||||
if include_user_id:
|
||||
content["user_id"] = state.user_id
|
||||
if state.last_active_ts:
|
||||
content["last_active_ago"] = now - state.last_active_ts
|
||||
if state.status_msg and state.state != PresenceState.OFFLINE:
|
||||
|
@ -1025,7 +1032,6 @@ class PresenceEventSource(object):
|
|||
# sending down the rare duplicate is not a concern.
|
||||
|
||||
with Measure(self.clock, "presence.get_new_events"):
|
||||
user_id = user.to_string()
|
||||
if from_key is not None:
|
||||
from_key = int(from_key)
|
||||
|
||||
|
@ -1034,18 +1040,7 @@ class PresenceEventSource(object):
|
|||
|
||||
max_token = self.store.get_current_presence_token()
|
||||
|
||||
plist = yield self.store.get_presence_list_accepted(user.localpart)
|
||||
users_interested_in = set(row["observed_user_id"] for row in plist)
|
||||
users_interested_in.add(user_id) # So that we receive our own presence
|
||||
|
||||
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
|
||||
user_id
|
||||
)
|
||||
users_interested_in.update(users_who_share_room)
|
||||
|
||||
if explicit_room_id:
|
||||
user_ids = yield self.store.get_users_in_room(explicit_room_id)
|
||||
users_interested_in.update(user_ids)
|
||||
users_interested_in = yield self._get_interested_in(user, explicit_room_id)
|
||||
|
||||
user_ids_changed = set()
|
||||
changed = None
|
||||
|
@ -1073,16 +1068,13 @@ class PresenceEventSource(object):
|
|||
|
||||
updates = yield presence.current_state_for_users(user_ids_changed)
|
||||
|
||||
now = self.clock.time_msec()
|
||||
|
||||
defer.returnValue(([
|
||||
{
|
||||
"type": "m.presence",
|
||||
"content": format_user_presence_state(s, now),
|
||||
}
|
||||
for s in updates.values()
|
||||
if include_offline or s.state != PresenceState.OFFLINE
|
||||
], max_token))
|
||||
if include_offline:
|
||||
defer.returnValue((updates.values(), max_token))
|
||||
else:
|
||||
defer.returnValue(([
|
||||
s for s in updates.itervalues()
|
||||
if s.state != PresenceState.OFFLINE
|
||||
], max_token))
|
||||
|
||||
def get_current_key(self):
|
||||
return self.store.get_current_presence_token()
|
||||
|
@ -1090,6 +1082,31 @@ class PresenceEventSource(object):
|
|||
def get_pagination_rows(self, user, pagination_config, key):
|
||||
return self.get_new_events(user, from_key=None, include_offline=False)
|
||||
|
||||
@cachedInlineCallbacks(num_args=2, cache_context=True)
|
||||
def _get_interested_in(self, user, explicit_room_id, cache_context):
|
||||
"""Returns the set of users that the given user should see presence
|
||||
updates for
|
||||
"""
|
||||
user_id = user.to_string()
|
||||
plist = yield self.store.get_presence_list_accepted(
|
||||
user.localpart, on_invalidate=cache_context.invalidate,
|
||||
)
|
||||
users_interested_in = set(row["observed_user_id"] for row in plist)
|
||||
users_interested_in.add(user_id) # So that we receive our own presence
|
||||
|
||||
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
|
||||
user_id, on_invalidate=cache_context.invalidate,
|
||||
)
|
||||
users_interested_in.update(users_who_share_room)
|
||||
|
||||
if explicit_room_id:
|
||||
user_ids = yield self.store.get_users_in_room(
|
||||
explicit_room_id, on_invalidate=cache_context.invalidate,
|
||||
)
|
||||
users_interested_in.update(user_ids)
|
||||
|
||||
defer.returnValue(users_interested_in)
|
||||
|
||||
|
||||
def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now):
|
||||
"""Checks the presence of users that have timed out and updates as
|
||||
|
@ -1157,7 +1174,10 @@ def handle_timeout(state, is_mine, syncing_user_ids, now):
|
|||
# If there are have been no sync for a while (and none ongoing),
|
||||
# set presence to offline
|
||||
if user_id not in syncing_user_ids:
|
||||
if now - state.last_user_sync_ts > SYNC_ONLINE_TIMEOUT:
|
||||
# If the user has done something recently but hasn't synced,
|
||||
# don't set them as offline.
|
||||
sync_or_active = max(state.last_user_sync_ts, state.last_active_ts)
|
||||
if now - sync_or_active > SYNC_ONLINE_TIMEOUT:
|
||||
state = state.copy_and_replace(
|
||||
state=PresenceState.OFFLINE,
|
||||
status_msg=None,
|
||||
|
|
|
@ -52,7 +52,8 @@ class ProfileHandler(BaseHandler):
|
|||
args={
|
||||
"user_id": target_user.to_string(),
|
||||
"field": "displayname",
|
||||
}
|
||||
},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
except CodeMessageException as e:
|
||||
if e.code != 404:
|
||||
|
@ -99,7 +100,8 @@ class ProfileHandler(BaseHandler):
|
|||
args={
|
||||
"user_id": target_user.to_string(),
|
||||
"field": "avatar_url",
|
||||
}
|
||||
},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
except CodeMessageException as e:
|
||||
if e.code != 404:
|
||||
|
@ -156,11 +158,11 @@ class ProfileHandler(BaseHandler):
|
|||
|
||||
self.ratelimit(requester)
|
||||
|
||||
joins = yield self.store.get_rooms_for_user(
|
||||
room_ids = yield self.store.get_rooms_for_user(
|
||||
user.to_string(),
|
||||
)
|
||||
|
||||
for j in joins:
|
||||
for room_id in room_ids:
|
||||
handler = self.hs.get_handlers().room_member_handler
|
||||
try:
|
||||
# Assume the user isn't a guest because we don't let guests set
|
||||
|
@ -171,12 +173,12 @@ class ProfileHandler(BaseHandler):
|
|||
yield handler.update_membership(
|
||||
requester,
|
||||
user,
|
||||
j.room_id,
|
||||
room_id,
|
||||
"join", # We treat a profile update like a join.
|
||||
ratelimit=False, # Try to hide that these events aren't atomic.
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warn(
|
||||
"Failed to update join event for room %s - %s",
|
||||
j.room_id, str(e.message)
|
||||
room_id, str(e.message)
|
||||
)
|
||||
|
|
|
@ -210,10 +210,9 @@ class ReceiptEventSource(object):
|
|||
else:
|
||||
from_key = None
|
||||
|
||||
rooms = yield self.store.get_rooms_for_user(user.to_string())
|
||||
rooms = [room.room_id for room in rooms]
|
||||
room_ids = yield self.store.get_rooms_for_user(user.to_string())
|
||||
events = yield self.store.get_linearized_receipts_for_rooms(
|
||||
rooms,
|
||||
room_ids,
|
||||
from_key=from_key,
|
||||
to_key=to_key,
|
||||
)
|
||||
|
|
|
@ -21,6 +21,7 @@ from synapse.api.constants import (
|
|||
EventTypes, JoinRules,
|
||||
)
|
||||
from synapse.util.async import concurrently_execute
|
||||
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
from synapse.types import ThirdPartyInstanceID
|
||||
|
||||
|
@ -62,6 +63,10 @@ class RoomListHandler(BaseHandler):
|
|||
appservice and network id to use an appservice specific one.
|
||||
Setting to None returns all public rooms across all lists.
|
||||
"""
|
||||
logger.info(
|
||||
"Getting public room list: limit=%r, since=%r, search=%r, network=%r",
|
||||
limit, since_token, bool(search_filter), network_tuple,
|
||||
)
|
||||
if search_filter:
|
||||
# We explicitly don't bother caching searches or requests for
|
||||
# appservice specific lists.
|
||||
|
@ -91,7 +96,6 @@ class RoomListHandler(BaseHandler):
|
|||
|
||||
rooms_to_order_value = {}
|
||||
rooms_to_num_joined = {}
|
||||
rooms_to_latest_event_ids = {}
|
||||
|
||||
newly_visible = []
|
||||
newly_unpublished = []
|
||||
|
@ -116,19 +120,26 @@ class RoomListHandler(BaseHandler):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def get_order_for_room(room_id):
|
||||
latest_event_ids = rooms_to_latest_event_ids.get(room_id, None)
|
||||
if not latest_event_ids:
|
||||
# Most of the rooms won't have changed between the since token and
|
||||
# now (especially if the since token is "now"). So, we can ask what
|
||||
# the current users are in a room (that will hit a cache) and then
|
||||
# check if the room has changed since the since token. (We have to
|
||||
# do it in that order to avoid races).
|
||||
# If things have changed then fall back to getting the current state
|
||||
# at the since token.
|
||||
joined_users = yield self.store.get_users_in_room(room_id)
|
||||
if self.store.has_room_changed_since(room_id, stream_token):
|
||||
latest_event_ids = yield self.store.get_forward_extremeties_for_room(
|
||||
room_id, stream_token
|
||||
)
|
||||
rooms_to_latest_event_ids[room_id] = latest_event_ids
|
||||
|
||||
if not latest_event_ids:
|
||||
return
|
||||
if not latest_event_ids:
|
||||
return
|
||||
|
||||
joined_users = yield self.state_handler.get_current_user_in_room(
|
||||
room_id, latest_event_ids,
|
||||
)
|
||||
|
||||
joined_users = yield self.state_handler.get_current_user_in_room(
|
||||
room_id, latest_event_ids,
|
||||
)
|
||||
num_joined_users = len(joined_users)
|
||||
rooms_to_num_joined[room_id] = num_joined_users
|
||||
|
||||
|
@ -165,19 +176,19 @@ class RoomListHandler(BaseHandler):
|
|||
rooms_to_scan = rooms_to_scan[:since_token.current_limit]
|
||||
rooms_to_scan.reverse()
|
||||
|
||||
# Actually generate the entries. _generate_room_entry will append to
|
||||
# Actually generate the entries. _append_room_entry_to_chunk will append to
|
||||
# chunk but will stop if len(chunk) > limit
|
||||
chunk = []
|
||||
if limit and not search_filter:
|
||||
step = limit + 1
|
||||
for i in xrange(0, len(rooms_to_scan), step):
|
||||
# We iterate here because the vast majority of cases we'll stop
|
||||
# at first iteration, but occaisonally _generate_room_entry
|
||||
# at first iteration, but occaisonally _append_room_entry_to_chunk
|
||||
# won't append to the chunk and so we need to loop again.
|
||||
# We don't want to scan over the entire range either as that
|
||||
# would potentially waste a lot of work.
|
||||
yield concurrently_execute(
|
||||
lambda r: self._generate_room_entry(
|
||||
lambda r: self._append_room_entry_to_chunk(
|
||||
r, rooms_to_num_joined[r],
|
||||
chunk, limit, search_filter
|
||||
),
|
||||
|
@ -187,7 +198,7 @@ class RoomListHandler(BaseHandler):
|
|||
break
|
||||
else:
|
||||
yield concurrently_execute(
|
||||
lambda r: self._generate_room_entry(
|
||||
lambda r: self._append_room_entry_to_chunk(
|
||||
r, rooms_to_num_joined[r],
|
||||
chunk, limit, search_filter
|
||||
),
|
||||
|
@ -256,21 +267,35 @@ class RoomListHandler(BaseHandler):
|
|||
defer.returnValue(results)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _generate_room_entry(self, room_id, num_joined_users, chunk, limit,
|
||||
search_filter):
|
||||
def _append_room_entry_to_chunk(self, room_id, num_joined_users, chunk, limit,
|
||||
search_filter):
|
||||
"""Generate the entry for a room in the public room list and append it
|
||||
to the `chunk` if it matches the search filter
|
||||
"""
|
||||
if limit and len(chunk) > limit + 1:
|
||||
# We've already got enough, so lets just drop it.
|
||||
return
|
||||
|
||||
result = yield self._generate_room_entry(room_id, num_joined_users)
|
||||
|
||||
if result and _matches_room_entry(result, search_filter):
|
||||
chunk.append(result)
|
||||
|
||||
@cachedInlineCallbacks(num_args=1, cache_context=True)
|
||||
def _generate_room_entry(self, room_id, num_joined_users, cache_context):
|
||||
"""Returns the entry for a room
|
||||
"""
|
||||
result = {
|
||||
"room_id": room_id,
|
||||
"num_joined_members": num_joined_users,
|
||||
}
|
||||
|
||||
current_state_ids = yield self.state_handler.get_current_state_ids(room_id)
|
||||
current_state_ids = yield self.store.get_current_state_ids(
|
||||
room_id, on_invalidate=cache_context.invalidate,
|
||||
)
|
||||
|
||||
event_map = yield self.store.get_events([
|
||||
event_id for key, event_id in current_state_ids.items()
|
||||
event_id for key, event_id in current_state_ids.iteritems()
|
||||
if key[0] in (
|
||||
EventTypes.JoinRules,
|
||||
EventTypes.Name,
|
||||
|
@ -294,7 +319,9 @@ class RoomListHandler(BaseHandler):
|
|||
if join_rule and join_rule != JoinRules.PUBLIC:
|
||||
defer.returnValue(None)
|
||||
|
||||
aliases = yield self.store.get_aliases_for_room(room_id)
|
||||
aliases = yield self.store.get_aliases_for_room(
|
||||
room_id, on_invalidate=cache_context.invalidate
|
||||
)
|
||||
if aliases:
|
||||
result["aliases"] = aliases
|
||||
|
||||
|
@ -334,8 +361,7 @@ class RoomListHandler(BaseHandler):
|
|||
if avatar_url:
|
||||
result["avatar_url"] = avatar_url
|
||||
|
||||
if _matches_room_entry(result, search_filter):
|
||||
chunk.append(result)
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_remote_public_room_list(self, server_name, limit=None, since_token=None,
|
||||
|
|
|
@ -20,6 +20,7 @@ from synapse.util.metrics import Measure, measure_func
|
|||
from synapse.util.caches.response_cache import ResponseCache
|
||||
from synapse.push.clientformat import format_push_rules_for_user
|
||||
from synapse.visibility import filter_events_for_client
|
||||
from synapse.types import RoomStreamToken
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
|
@ -225,8 +226,7 @@ class SyncHandler(object):
|
|||
with Measure(self.clock, "ephemeral_by_room"):
|
||||
typing_key = since_token.typing_key if since_token else "0"
|
||||
|
||||
rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string())
|
||||
room_ids = [room.room_id for room in rooms]
|
||||
room_ids = yield self.store.get_rooms_for_user(sync_config.user.to_string())
|
||||
|
||||
typing_source = self.event_sources.sources["typing"]
|
||||
typing, typing_key = yield typing_source.get_new_events(
|
||||
|
@ -568,16 +568,15 @@ class SyncHandler(object):
|
|||
since_token = sync_result_builder.since_token
|
||||
|
||||
if since_token and since_token.device_list_key:
|
||||
rooms = yield self.store.get_rooms_for_user(user_id)
|
||||
room_ids = set(r.room_id for r in rooms)
|
||||
room_ids = yield self.store.get_rooms_for_user(user_id)
|
||||
|
||||
user_ids_changed = set()
|
||||
changed = yield self.store.get_user_whose_devices_changed(
|
||||
since_token.device_list_key
|
||||
)
|
||||
for other_user_id in changed:
|
||||
other_rooms = yield self.store.get_rooms_for_user(other_user_id)
|
||||
if room_ids.intersection(e.room_id for e in other_rooms):
|
||||
other_room_ids = yield self.store.get_rooms_for_user(other_user_id)
|
||||
if room_ids.intersection(other_room_ids):
|
||||
user_ids_changed.add(other_user_id)
|
||||
|
||||
defer.returnValue(user_ids_changed)
|
||||
|
@ -721,14 +720,14 @@ class SyncHandler(object):
|
|||
extra_users_ids.update(users)
|
||||
extra_users_ids.discard(user.to_string())
|
||||
|
||||
states = yield self.presence_handler.get_states(
|
||||
extra_users_ids,
|
||||
as_event=True,
|
||||
)
|
||||
presence.extend(states)
|
||||
if extra_users_ids:
|
||||
states = yield self.presence_handler.get_states(
|
||||
extra_users_ids,
|
||||
)
|
||||
presence.extend(states)
|
||||
|
||||
# Deduplicate the presence entries so that there's at most one per user
|
||||
presence = {p["content"]["user_id"]: p for p in presence}.values()
|
||||
# Deduplicate the presence entries so that there's at most one per user
|
||||
presence = {p.user_id: p for p in presence}.values()
|
||||
|
||||
presence = sync_config.filter_collection.filter_presence(
|
||||
presence
|
||||
|
@ -765,6 +764,21 @@ class SyncHandler(object):
|
|||
)
|
||||
sync_result_builder.now_token = now_token
|
||||
|
||||
# We check up front if anything has changed, if it hasn't then there is
|
||||
# no point in going futher.
|
||||
since_token = sync_result_builder.since_token
|
||||
if not sync_result_builder.full_state:
|
||||
if since_token and not ephemeral_by_room and not account_data_by_room:
|
||||
have_changed = yield self._have_rooms_changed(sync_result_builder)
|
||||
if not have_changed:
|
||||
tags_by_room = yield self.store.get_updated_tags(
|
||||
user_id,
|
||||
since_token.account_data_key,
|
||||
)
|
||||
if not tags_by_room:
|
||||
logger.debug("no-oping sync")
|
||||
defer.returnValue(([], []))
|
||||
|
||||
ignored_account_data = yield self.store.get_global_account_data_by_type_for_user(
|
||||
"m.ignored_user_list", user_id=user_id,
|
||||
)
|
||||
|
@ -774,13 +788,12 @@ class SyncHandler(object):
|
|||
else:
|
||||
ignored_users = frozenset()
|
||||
|
||||
if sync_result_builder.since_token:
|
||||
if since_token:
|
||||
res = yield self._get_rooms_changed(sync_result_builder, ignored_users)
|
||||
room_entries, invited, newly_joined_rooms = res
|
||||
|
||||
tags_by_room = yield self.store.get_updated_tags(
|
||||
user_id,
|
||||
sync_result_builder.since_token.account_data_key,
|
||||
user_id, since_token.account_data_key,
|
||||
)
|
||||
else:
|
||||
res = yield self._get_all_rooms(sync_result_builder, ignored_users)
|
||||
|
@ -805,7 +818,7 @@ class SyncHandler(object):
|
|||
|
||||
# Now we want to get any newly joined users
|
||||
newly_joined_users = set()
|
||||
if sync_result_builder.since_token:
|
||||
if since_token:
|
||||
for joined_sync in sync_result_builder.joined:
|
||||
it = itertools.chain(
|
||||
joined_sync.timeline.events, joined_sync.state.values()
|
||||
|
@ -817,6 +830,38 @@ class SyncHandler(object):
|
|||
|
||||
defer.returnValue((newly_joined_rooms, newly_joined_users))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _have_rooms_changed(self, sync_result_builder):
|
||||
"""Returns whether there may be any new events that should be sent down
|
||||
the sync. Returns True if there are.
|
||||
"""
|
||||
user_id = sync_result_builder.sync_config.user.to_string()
|
||||
since_token = sync_result_builder.since_token
|
||||
now_token = sync_result_builder.now_token
|
||||
|
||||
assert since_token
|
||||
|
||||
# Get a list of membership change events that have happened.
|
||||
rooms_changed = yield self.store.get_membership_changes_for_user(
|
||||
user_id, since_token.room_key, now_token.room_key
|
||||
)
|
||||
|
||||
if rooms_changed:
|
||||
defer.returnValue(True)
|
||||
|
||||
app_service = self.store.get_app_service_by_user_id(user_id)
|
||||
if app_service:
|
||||
rooms = yield self.store.get_app_service_rooms(app_service)
|
||||
joined_room_ids = set(r.room_id for r in rooms)
|
||||
else:
|
||||
joined_room_ids = yield self.store.get_rooms_for_user(user_id)
|
||||
|
||||
stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream
|
||||
for room_id in joined_room_ids:
|
||||
if self.store.has_room_changed_since(room_id, stream_id):
|
||||
defer.returnValue(True)
|
||||
defer.returnValue(False)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_rooms_changed(self, sync_result_builder, ignored_users):
|
||||
"""Gets the the changes that have happened since the last sync.
|
||||
|
@ -841,8 +886,7 @@ class SyncHandler(object):
|
|||
rooms = yield self.store.get_app_service_rooms(app_service)
|
||||
joined_room_ids = set(r.room_id for r in rooms)
|
||||
else:
|
||||
rooms = yield self.store.get_rooms_for_user(user_id)
|
||||
joined_room_ids = set(r.room_id for r in rooms)
|
||||
joined_room_ids = yield self.store.get_rooms_for_user(user_id)
|
||||
|
||||
# Get a list of membership change events that have happened.
|
||||
rooms_changed = yield self.store.get_membership_changes_for_user(
|
||||
|
|
|
@ -12,8 +12,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import synapse.util.retryutils
|
||||
from twisted.internet import defer, reactor, protocol
|
||||
from twisted.internet.error import DNSLookupError
|
||||
from twisted.web.client import readBody, HTTPConnectionPool, Agent
|
||||
|
@ -22,7 +21,7 @@ from twisted.web._newclient import ResponseDone
|
|||
|
||||
from synapse.http.endpoint import matrix_federation_endpoint
|
||||
from synapse.util.async import sleep
|
||||
from synapse.util.logcontext import preserve_context_over_fn
|
||||
from synapse.util import logcontext
|
||||
import synapse.metrics
|
||||
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
@ -94,6 +93,7 @@ class MatrixFederationHttpClient(object):
|
|||
reactor, MatrixFederationEndpointFactory(hs), pool=pool
|
||||
)
|
||||
self.clock = hs.get_clock()
|
||||
self._store = hs.get_datastore()
|
||||
self.version_string = hs.version_string
|
||||
self._next_id = 1
|
||||
|
||||
|
@ -103,123 +103,152 @@ class MatrixFederationHttpClient(object):
|
|||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _create_request(self, destination, method, path_bytes,
|
||||
body_callback, headers_dict={}, param_bytes=b"",
|
||||
query_bytes=b"", retry_on_dns_fail=True,
|
||||
timeout=None, long_retries=False):
|
||||
""" Creates and sends a request to the given url
|
||||
def _request(self, destination, method, path,
|
||||
body_callback, headers_dict={}, param_bytes=b"",
|
||||
query_bytes=b"", retry_on_dns_fail=True,
|
||||
timeout=None, long_retries=False,
|
||||
ignore_backoff=False,
|
||||
backoff_on_404=False):
|
||||
""" Creates and sends a request to the given server
|
||||
Args:
|
||||
destination (str): The remote server to send the HTTP request to.
|
||||
method (str): HTTP method
|
||||
path (str): The HTTP path
|
||||
ignore_backoff (bool): true to ignore the historical backoff data
|
||||
and try the request anyway.
|
||||
backoff_on_404 (bool): Back off if we get a 404
|
||||
|
||||
Returns:
|
||||
Deferred: resolves with the http response object on success.
|
||||
|
||||
Fails with ``HTTPRequestException``: if we get an HTTP response
|
||||
code >= 300.
|
||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||
to retry this server.
|
||||
"""
|
||||
headers_dict[b"User-Agent"] = [self.version_string]
|
||||
headers_dict[b"Host"] = [destination]
|
||||
|
||||
url_bytes = self._create_url(
|
||||
destination, path_bytes, param_bytes, query_bytes
|
||||
limiter = yield synapse.util.retryutils.get_retry_limiter(
|
||||
destination,
|
||||
self.clock,
|
||||
self._store,
|
||||
backoff_on_404=backoff_on_404,
|
||||
ignore_backoff=ignore_backoff,
|
||||
)
|
||||
|
||||
txn_id = "%s-O-%s" % (method, self._next_id)
|
||||
self._next_id = (self._next_id + 1) % (sys.maxint - 1)
|
||||
destination = destination.encode("ascii")
|
||||
path_bytes = path.encode("ascii")
|
||||
with limiter:
|
||||
headers_dict[b"User-Agent"] = [self.version_string]
|
||||
headers_dict[b"Host"] = [destination]
|
||||
|
||||
outbound_logger.info(
|
||||
"{%s} [%s] Sending request: %s %s",
|
||||
txn_id, destination, method, url_bytes
|
||||
)
|
||||
url_bytes = self._create_url(
|
||||
destination, path_bytes, param_bytes, query_bytes
|
||||
)
|
||||
|
||||
# XXX: Would be much nicer to retry only at the transaction-layer
|
||||
# (once we have reliable transactions in place)
|
||||
if long_retries:
|
||||
retries_left = MAX_LONG_RETRIES
|
||||
else:
|
||||
retries_left = MAX_SHORT_RETRIES
|
||||
txn_id = "%s-O-%s" % (method, self._next_id)
|
||||
self._next_id = (self._next_id + 1) % (sys.maxint - 1)
|
||||
|
||||
http_url_bytes = urlparse.urlunparse(
|
||||
("", "", path_bytes, param_bytes, query_bytes, "")
|
||||
)
|
||||
outbound_logger.info(
|
||||
"{%s} [%s] Sending request: %s %s",
|
||||
txn_id, destination, method, url_bytes
|
||||
)
|
||||
|
||||
log_result = None
|
||||
try:
|
||||
while True:
|
||||
producer = None
|
||||
if body_callback:
|
||||
producer = body_callback(method, http_url_bytes, headers_dict)
|
||||
# XXX: Would be much nicer to retry only at the transaction-layer
|
||||
# (once we have reliable transactions in place)
|
||||
if long_retries:
|
||||
retries_left = MAX_LONG_RETRIES
|
||||
else:
|
||||
retries_left = MAX_SHORT_RETRIES
|
||||
|
||||
try:
|
||||
def send_request():
|
||||
request_deferred = preserve_context_over_fn(
|
||||
self.agent.request,
|
||||
http_url_bytes = urlparse.urlunparse(
|
||||
("", "", path_bytes, param_bytes, query_bytes, "")
|
||||
)
|
||||
|
||||
log_result = None
|
||||
try:
|
||||
while True:
|
||||
producer = None
|
||||
if body_callback:
|
||||
producer = body_callback(method, http_url_bytes, headers_dict)
|
||||
|
||||
try:
|
||||
def send_request():
|
||||
request_deferred = self.agent.request(
|
||||
method,
|
||||
url_bytes,
|
||||
Headers(headers_dict),
|
||||
producer
|
||||
)
|
||||
|
||||
return self.clock.time_bound_deferred(
|
||||
request_deferred,
|
||||
time_out=timeout / 1000. if timeout else 60,
|
||||
)
|
||||
|
||||
with logcontext.PreserveLoggingContext():
|
||||
response = yield send_request()
|
||||
|
||||
log_result = "%d %s" % (response.code, response.phrase,)
|
||||
break
|
||||
except Exception as e:
|
||||
if not retry_on_dns_fail and isinstance(e, DNSLookupError):
|
||||
logger.warn(
|
||||
"DNS Lookup failed to %s with %s",
|
||||
destination,
|
||||
e
|
||||
)
|
||||
log_result = "DNS Lookup failed to %s with %s" % (
|
||||
destination, e
|
||||
)
|
||||
raise
|
||||
|
||||
logger.warn(
|
||||
"{%s} Sending request failed to %s: %s %s: %s - %s",
|
||||
txn_id,
|
||||
destination,
|
||||
method,
|
||||
url_bytes,
|
||||
Headers(headers_dict),
|
||||
producer
|
||||
type(e).__name__,
|
||||
_flatten_response_never_received(e),
|
||||
)
|
||||
|
||||
return self.clock.time_bound_deferred(
|
||||
request_deferred,
|
||||
time_out=timeout / 1000. if timeout else 60,
|
||||
log_result = "%s - %s" % (
|
||||
type(e).__name__, _flatten_response_never_received(e),
|
||||
)
|
||||
|
||||
response = yield preserve_context_over_fn(send_request)
|
||||
if retries_left and not timeout:
|
||||
if long_retries:
|
||||
delay = 4 ** (MAX_LONG_RETRIES + 1 - retries_left)
|
||||
delay = min(delay, 60)
|
||||
delay *= random.uniform(0.8, 1.4)
|
||||
else:
|
||||
delay = 0.5 * 2 ** (MAX_SHORT_RETRIES - retries_left)
|
||||
delay = min(delay, 2)
|
||||
delay *= random.uniform(0.8, 1.4)
|
||||
|
||||
log_result = "%d %s" % (response.code, response.phrase,)
|
||||
break
|
||||
except Exception as e:
|
||||
if not retry_on_dns_fail and isinstance(e, DNSLookupError):
|
||||
logger.warn(
|
||||
"DNS Lookup failed to %s with %s",
|
||||
destination,
|
||||
e
|
||||
)
|
||||
log_result = "DNS Lookup failed to %s with %s" % (
|
||||
destination, e
|
||||
)
|
||||
raise
|
||||
|
||||
logger.warn(
|
||||
"{%s} Sending request failed to %s: %s %s: %s - %s",
|
||||
txn_id,
|
||||
destination,
|
||||
method,
|
||||
url_bytes,
|
||||
type(e).__name__,
|
||||
_flatten_response_never_received(e),
|
||||
)
|
||||
|
||||
log_result = "%s - %s" % (
|
||||
type(e).__name__, _flatten_response_never_received(e),
|
||||
)
|
||||
|
||||
if retries_left and not timeout:
|
||||
if long_retries:
|
||||
delay = 4 ** (MAX_LONG_RETRIES + 1 - retries_left)
|
||||
delay = min(delay, 60)
|
||||
delay *= random.uniform(0.8, 1.4)
|
||||
yield sleep(delay)
|
||||
retries_left -= 1
|
||||
else:
|
||||
delay = 0.5 * 2 ** (MAX_SHORT_RETRIES - retries_left)
|
||||
delay = min(delay, 2)
|
||||
delay *= random.uniform(0.8, 1.4)
|
||||
raise
|
||||
finally:
|
||||
outbound_logger.info(
|
||||
"{%s} [%s] Result: %s",
|
||||
txn_id,
|
||||
destination,
|
||||
log_result,
|
||||
)
|
||||
|
||||
yield sleep(delay)
|
||||
retries_left -= 1
|
||||
else:
|
||||
raise
|
||||
finally:
|
||||
outbound_logger.info(
|
||||
"{%s} [%s] Result: %s",
|
||||
txn_id,
|
||||
destination,
|
||||
log_result,
|
||||
)
|
||||
if 200 <= response.code < 300:
|
||||
pass
|
||||
else:
|
||||
# :'(
|
||||
# Update transactions table?
|
||||
with logcontext.PreserveLoggingContext():
|
||||
body = yield readBody(response)
|
||||
raise HttpResponseException(
|
||||
response.code, response.phrase, body
|
||||
)
|
||||
|
||||
if 200 <= response.code < 300:
|
||||
pass
|
||||
else:
|
||||
# :'(
|
||||
# Update transactions table?
|
||||
body = yield preserve_context_over_fn(readBody, response)
|
||||
raise HttpResponseException(
|
||||
response.code, response.phrase, body
|
||||
)
|
||||
|
||||
defer.returnValue(response)
|
||||
defer.returnValue(response)
|
||||
|
||||
def sign_request(self, destination, method, url_bytes, headers_dict,
|
||||
content=None):
|
||||
|
@ -248,7 +277,9 @@ class MatrixFederationHttpClient(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def put_json(self, destination, path, data={}, json_data_callback=None,
|
||||
long_retries=False, timeout=None):
|
||||
long_retries=False, timeout=None,
|
||||
ignore_backoff=False,
|
||||
backoff_on_404=False):
|
||||
""" Sends the specifed json data using PUT
|
||||
|
||||
Args:
|
||||
|
@ -263,11 +294,19 @@ class MatrixFederationHttpClient(object):
|
|||
retry for a short or long time.
|
||||
timeout(int): How long to try (in ms) the destination for before
|
||||
giving up. None indicates no timeout.
|
||||
ignore_backoff (bool): true to ignore the historical backoff data
|
||||
and try the request anyway.
|
||||
backoff_on_404 (bool): True if we should count a 404 response as
|
||||
a failure of the server (and should therefore back off future
|
||||
requests)
|
||||
|
||||
Returns:
|
||||
Deferred: Succeeds when we get a 2xx HTTP response. The result
|
||||
will be the decoded JSON body. On a 4xx or 5xx error response a
|
||||
CodeMessageException is raised.
|
||||
|
||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||
to retry this server.
|
||||
"""
|
||||
|
||||
if not json_data_callback:
|
||||
|
@ -282,26 +321,29 @@ class MatrixFederationHttpClient(object):
|
|||
producer = _JsonProducer(json_data)
|
||||
return producer
|
||||
|
||||
response = yield self._create_request(
|
||||
destination.encode("ascii"),
|
||||
response = yield self._request(
|
||||
destination,
|
||||
"PUT",
|
||||
path.encode("ascii"),
|
||||
path,
|
||||
body_callback=body_callback,
|
||||
headers_dict={"Content-Type": ["application/json"]},
|
||||
long_retries=long_retries,
|
||||
timeout=timeout,
|
||||
ignore_backoff=ignore_backoff,
|
||||
backoff_on_404=backoff_on_404,
|
||||
)
|
||||
|
||||
if 200 <= response.code < 300:
|
||||
# We need to update the transactions table to say it was sent?
|
||||
check_content_type_is_json(response.headers)
|
||||
|
||||
body = yield preserve_context_over_fn(readBody, response)
|
||||
with logcontext.PreserveLoggingContext():
|
||||
body = yield readBody(response)
|
||||
defer.returnValue(json.loads(body))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def post_json(self, destination, path, data={}, long_retries=False,
|
||||
timeout=None):
|
||||
timeout=None, ignore_backoff=False):
|
||||
""" Sends the specifed json data using POST
|
||||
|
||||
Args:
|
||||
|
@ -314,11 +356,15 @@ class MatrixFederationHttpClient(object):
|
|||
retry for a short or long time.
|
||||
timeout(int): How long to try (in ms) the destination for before
|
||||
giving up. None indicates no timeout.
|
||||
|
||||
ignore_backoff (bool): true to ignore the historical backoff data and
|
||||
try the request anyway.
|
||||
Returns:
|
||||
Deferred: Succeeds when we get a 2xx HTTP response. The result
|
||||
will be the decoded JSON body. On a 4xx or 5xx error response a
|
||||
CodeMessageException is raised.
|
||||
|
||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||
to retry this server.
|
||||
"""
|
||||
|
||||
def body_callback(method, url_bytes, headers_dict):
|
||||
|
@ -327,27 +373,29 @@ class MatrixFederationHttpClient(object):
|
|||
)
|
||||
return _JsonProducer(data)
|
||||
|
||||
response = yield self._create_request(
|
||||
destination.encode("ascii"),
|
||||
response = yield self._request(
|
||||
destination,
|
||||
"POST",
|
||||
path.encode("ascii"),
|
||||
path,
|
||||
body_callback=body_callback,
|
||||
headers_dict={"Content-Type": ["application/json"]},
|
||||
long_retries=long_retries,
|
||||
timeout=timeout,
|
||||
ignore_backoff=ignore_backoff,
|
||||
)
|
||||
|
||||
if 200 <= response.code < 300:
|
||||
# We need to update the transactions table to say it was sent?
|
||||
check_content_type_is_json(response.headers)
|
||||
|
||||
body = yield preserve_context_over_fn(readBody, response)
|
||||
with logcontext.PreserveLoggingContext():
|
||||
body = yield readBody(response)
|
||||
|
||||
defer.returnValue(json.loads(body))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_json(self, destination, path, args={}, retry_on_dns_fail=True,
|
||||
timeout=None):
|
||||
timeout=None, ignore_backoff=False):
|
||||
""" GETs some json from the given host homeserver and path
|
||||
|
||||
Args:
|
||||
|
@ -359,11 +407,16 @@ class MatrixFederationHttpClient(object):
|
|||
timeout (int): How long to try (in ms) the destination for before
|
||||
giving up. None indicates no timeout and that the request will
|
||||
be retried.
|
||||
ignore_backoff (bool): true to ignore the historical backoff data
|
||||
and try the request anyway.
|
||||
Returns:
|
||||
Deferred: Succeeds when we get *any* HTTP response.
|
||||
|
||||
The result of the deferred is a tuple of `(code, response)`,
|
||||
where `response` is a dict representing the decoded JSON body.
|
||||
|
||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||
to retry this server.
|
||||
"""
|
||||
logger.debug("get_json args: %s", args)
|
||||
|
||||
|
@ -380,36 +433,47 @@ class MatrixFederationHttpClient(object):
|
|||
self.sign_request(destination, method, url_bytes, headers_dict)
|
||||
return None
|
||||
|
||||
response = yield self._create_request(
|
||||
destination.encode("ascii"),
|
||||
response = yield self._request(
|
||||
destination,
|
||||
"GET",
|
||||
path.encode("ascii"),
|
||||
path,
|
||||
query_bytes=query_bytes,
|
||||
body_callback=body_callback,
|
||||
retry_on_dns_fail=retry_on_dns_fail,
|
||||
timeout=timeout,
|
||||
ignore_backoff=ignore_backoff,
|
||||
)
|
||||
|
||||
if 200 <= response.code < 300:
|
||||
# We need to update the transactions table to say it was sent?
|
||||
check_content_type_is_json(response.headers)
|
||||
|
||||
body = yield preserve_context_over_fn(readBody, response)
|
||||
with logcontext.PreserveLoggingContext():
|
||||
body = yield readBody(response)
|
||||
|
||||
defer.returnValue(json.loads(body))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_file(self, destination, path, output_stream, args={},
|
||||
retry_on_dns_fail=True, max_size=None):
|
||||
retry_on_dns_fail=True, max_size=None,
|
||||
ignore_backoff=False):
|
||||
"""GETs a file from a given homeserver
|
||||
Args:
|
||||
destination (str): The remote server to send the HTTP request to.
|
||||
path (str): The HTTP path to GET.
|
||||
output_stream (file): File to write the response body to.
|
||||
args (dict): Optional dictionary used to create the query string.
|
||||
ignore_backoff (bool): true to ignore the historical backoff data
|
||||
and try the request anyway.
|
||||
Returns:
|
||||
A (int,dict) tuple of the file length and a dict of the response
|
||||
headers.
|
||||
Deferred: resolves with an (int,dict) tuple of the file length and
|
||||
a dict of the response headers.
|
||||
|
||||
Fails with ``HTTPRequestException`` if we get an HTTP response code
|
||||
>= 300
|
||||
|
||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||
to retry this server.
|
||||
"""
|
||||
|
||||
encoded_args = {}
|
||||
|
@ -419,28 +483,29 @@ class MatrixFederationHttpClient(object):
|
|||
encoded_args[k] = [v.encode("UTF-8") for v in vs]
|
||||
|
||||
query_bytes = urllib.urlencode(encoded_args, True)
|
||||
logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
|
||||
logger.debug("Query bytes: %s Retry DNS: %s", query_bytes, retry_on_dns_fail)
|
||||
|
||||
def body_callback(method, url_bytes, headers_dict):
|
||||
self.sign_request(destination, method, url_bytes, headers_dict)
|
||||
return None
|
||||
|
||||
response = yield self._create_request(
|
||||
destination.encode("ascii"),
|
||||
response = yield self._request(
|
||||
destination,
|
||||
"GET",
|
||||
path.encode("ascii"),
|
||||
path,
|
||||
query_bytes=query_bytes,
|
||||
body_callback=body_callback,
|
||||
retry_on_dns_fail=retry_on_dns_fail
|
||||
retry_on_dns_fail=retry_on_dns_fail,
|
||||
ignore_backoff=ignore_backoff,
|
||||
)
|
||||
|
||||
headers = dict(response.headers.getAllRawHeaders())
|
||||
|
||||
try:
|
||||
length = yield preserve_context_over_fn(
|
||||
_readBodyToFile,
|
||||
response, output_stream, max_size
|
||||
)
|
||||
with logcontext.PreserveLoggingContext():
|
||||
length = yield _readBodyToFile(
|
||||
response, output_stream, max_size
|
||||
)
|
||||
except:
|
||||
logger.exception("Failed to download body")
|
||||
raise
|
||||
|
|
|
@ -192,6 +192,16 @@ def parse_json_object_from_request(request):
|
|||
return content
|
||||
|
||||
|
||||
def assert_params_in_request(body, required):
|
||||
absent = []
|
||||
for k in required:
|
||||
if k not in body:
|
||||
absent.append(k)
|
||||
|
||||
if len(absent) > 0:
|
||||
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
|
||||
|
||||
|
||||
class RestServlet(object):
|
||||
|
||||
""" A Synapse REST Servlet.
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
from twisted.internet import defer
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import AuthError
|
||||
from synapse.handlers.presence import format_user_presence_state
|
||||
|
||||
from synapse.util import DeferredTimedOutError
|
||||
from synapse.util.logutils import log_function
|
||||
|
@ -37,6 +38,10 @@ metrics = synapse.metrics.get_metrics_for(__name__)
|
|||
|
||||
notified_events_counter = metrics.register_counter("notified_events")
|
||||
|
||||
users_woken_by_stream_counter = metrics.register_counter(
|
||||
"users_woken_by_stream", labels=["stream"]
|
||||
)
|
||||
|
||||
|
||||
# TODO(paul): Should be shared somewhere
|
||||
def count(func, l):
|
||||
|
@ -73,6 +78,13 @@ class _NotifierUserStream(object):
|
|||
self.user_id = user_id
|
||||
self.rooms = set(rooms)
|
||||
self.current_token = current_token
|
||||
|
||||
# The last token for which we should wake up any streams that have a
|
||||
# token that comes before it. This gets updated everytime we get poked.
|
||||
# We start it at the current token since if we get any streams
|
||||
# that have a token from before we have no idea whether they should be
|
||||
# woken up or not, so lets just wake them up.
|
||||
self.last_notified_token = current_token
|
||||
self.last_notified_ms = time_now_ms
|
||||
|
||||
with PreserveLoggingContext():
|
||||
|
@ -89,9 +101,12 @@ class _NotifierUserStream(object):
|
|||
self.current_token = self.current_token.copy_and_advance(
|
||||
stream_key, stream_id
|
||||
)
|
||||
self.last_notified_token = self.current_token
|
||||
self.last_notified_ms = time_now_ms
|
||||
noify_deferred = self.notify_deferred
|
||||
|
||||
users_woken_by_stream_counter.inc(stream_key)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
self.notify_deferred = ObservableDeferred(defer.Deferred())
|
||||
noify_deferred.callback(self.current_token)
|
||||
|
@ -113,8 +128,14 @@ class _NotifierUserStream(object):
|
|||
def new_listener(self, token):
|
||||
"""Returns a deferred that is resolved when there is a new token
|
||||
greater than the given token.
|
||||
|
||||
Args:
|
||||
token: The token from which we are streaming from, i.e. we shouldn't
|
||||
notify for things that happened before this.
|
||||
"""
|
||||
if self.current_token.is_after(token):
|
||||
# Immediately wake up stream if something has already since happened
|
||||
# since their last token.
|
||||
if self.last_notified_token.is_after(token):
|
||||
return _NotificationListener(defer.succeed(self.current_token))
|
||||
else:
|
||||
return _NotificationListener(self.notify_deferred.observe())
|
||||
|
@ -283,8 +304,7 @@ class Notifier(object):
|
|||
if user_stream is None:
|
||||
current_token = yield self.event_sources.get_current_token()
|
||||
if room_ids is None:
|
||||
rooms = yield self.store.get_rooms_for_user(user_id)
|
||||
room_ids = [room.room_id for room in rooms]
|
||||
room_ids = yield self.store.get_rooms_for_user(user_id)
|
||||
user_stream = _NotifierUserStream(
|
||||
user_id=user_id,
|
||||
rooms=room_ids,
|
||||
|
@ -294,40 +314,44 @@ class Notifier(object):
|
|||
self._register_with_keys(user_stream)
|
||||
|
||||
result = None
|
||||
prev_token = from_token
|
||||
if timeout:
|
||||
end_time = self.clock.time_msec() + timeout
|
||||
|
||||
prev_token = from_token
|
||||
while not result:
|
||||
try:
|
||||
current_token = user_stream.current_token
|
||||
|
||||
result = yield callback(prev_token, current_token)
|
||||
if result:
|
||||
break
|
||||
|
||||
now = self.clock.time_msec()
|
||||
if end_time <= now:
|
||||
break
|
||||
|
||||
# Now we wait for the _NotifierUserStream to be told there
|
||||
# is a new token.
|
||||
# We need to supply the token we supplied to callback so
|
||||
# that we don't miss any current_token updates.
|
||||
prev_token = current_token
|
||||
listener = user_stream.new_listener(prev_token)
|
||||
with PreserveLoggingContext():
|
||||
yield self.clock.time_bound_deferred(
|
||||
listener.deferred,
|
||||
time_out=(end_time - now) / 1000.
|
||||
)
|
||||
|
||||
current_token = user_stream.current_token
|
||||
|
||||
result = yield callback(prev_token, current_token)
|
||||
if result:
|
||||
break
|
||||
|
||||
# Update the prev_token to the current_token since nothing
|
||||
# has happened between the old prev_token and the current_token
|
||||
prev_token = current_token
|
||||
except DeferredTimedOutError:
|
||||
break
|
||||
except defer.CancelledError:
|
||||
break
|
||||
else:
|
||||
|
||||
if result is None:
|
||||
# This happened if there was no timeout or if the timeout had
|
||||
# already expired.
|
||||
current_token = user_stream.current_token
|
||||
result = yield callback(from_token, current_token)
|
||||
result = yield callback(prev_token, current_token)
|
||||
|
||||
defer.returnValue(result)
|
||||
|
||||
|
@ -388,6 +412,15 @@ class Notifier(object):
|
|||
new_events,
|
||||
is_peeking=is_peeking,
|
||||
)
|
||||
elif name == "presence":
|
||||
now = self.clock.time_msec()
|
||||
new_events[:] = [
|
||||
{
|
||||
"type": "m.presence",
|
||||
"content": format_user_presence_state(event, now),
|
||||
}
|
||||
for event in new_events
|
||||
]
|
||||
|
||||
events.extend(new_events)
|
||||
end_token = end_token.copy_and_replace(keyname, new_key)
|
||||
|
@ -420,8 +453,7 @@ class Notifier(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def _get_room_ids(self, user, explicit_room_id):
|
||||
joined_rooms = yield self.store.get_rooms_for_user(user.to_string())
|
||||
joined_room_ids = map(lambda r: r.room_id, joined_rooms)
|
||||
joined_room_ids = yield self.store.get_rooms_for_user(user.to_string())
|
||||
if explicit_room_id:
|
||||
if explicit_room_id in joined_room_ids:
|
||||
defer.returnValue(([explicit_room_id], True))
|
||||
|
|
|
@ -139,7 +139,7 @@ class Mailer(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def _fetch_room_state(room_id):
|
||||
room_state = yield self.state_handler.get_current_state_ids(room_id)
|
||||
room_state = yield self.store.get_current_state_ids(room_id)
|
||||
state_by_room[room_id] = room_state
|
||||
|
||||
# Run at most 3 of these at once: sync does 10 at a time but email
|
||||
|
|
|
@ -17,6 +17,7 @@ import logging
|
|||
import re
|
||||
|
||||
from synapse.types import UserID
|
||||
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -125,6 +126,11 @@ class PushRuleEvaluatorForEvent(object):
|
|||
return self._value_cache.get(dotted_key, None)
|
||||
|
||||
|
||||
# Caches (glob, word_boundary) -> regex for push. See _glob_matches
|
||||
regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR)
|
||||
register_cache("regex_push_cache", regex_cache)
|
||||
|
||||
|
||||
def _glob_matches(glob, value, word_boundary=False):
|
||||
"""Tests if value matches glob.
|
||||
|
||||
|
@ -137,46 +143,63 @@ def _glob_matches(glob, value, word_boundary=False):
|
|||
Returns:
|
||||
bool
|
||||
"""
|
||||
|
||||
try:
|
||||
if IS_GLOB.search(glob):
|
||||
r = re.escape(glob)
|
||||
|
||||
r = r.replace(r'\*', '.*?')
|
||||
r = r.replace(r'\?', '.')
|
||||
|
||||
# handle [abc], [a-z] and [!a-z] style ranges.
|
||||
r = GLOB_REGEX.sub(
|
||||
lambda x: (
|
||||
'[%s%s]' % (
|
||||
x.group(1) and '^' or '',
|
||||
x.group(2).replace(r'\\\-', '-')
|
||||
)
|
||||
),
|
||||
r,
|
||||
)
|
||||
if word_boundary:
|
||||
r = r"\b%s\b" % (r,)
|
||||
r = _compile_regex(r)
|
||||
|
||||
return r.search(value)
|
||||
else:
|
||||
r = r + "$"
|
||||
r = _compile_regex(r)
|
||||
|
||||
return r.match(value)
|
||||
elif word_boundary:
|
||||
r = re.escape(glob)
|
||||
r = r"\b%s\b" % (r,)
|
||||
r = _compile_regex(r)
|
||||
|
||||
return r.search(value)
|
||||
else:
|
||||
return value.lower() == glob.lower()
|
||||
r = regex_cache.get((glob, word_boundary), None)
|
||||
if not r:
|
||||
r = _glob_to_re(glob, word_boundary)
|
||||
regex_cache[(glob, word_boundary)] = r
|
||||
return r.search(value)
|
||||
except re.error:
|
||||
logger.warn("Failed to parse glob to regex: %r", glob)
|
||||
return False
|
||||
|
||||
|
||||
def _glob_to_re(glob, word_boundary):
|
||||
"""Generates regex for a given glob.
|
||||
|
||||
Args:
|
||||
glob (string)
|
||||
word_boundary (bool): Whether to match against word boundaries or entire
|
||||
string. Defaults to False.
|
||||
|
||||
Returns:
|
||||
regex object
|
||||
"""
|
||||
if IS_GLOB.search(glob):
|
||||
r = re.escape(glob)
|
||||
|
||||
r = r.replace(r'\*', '.*?')
|
||||
r = r.replace(r'\?', '.')
|
||||
|
||||
# handle [abc], [a-z] and [!a-z] style ranges.
|
||||
r = GLOB_REGEX.sub(
|
||||
lambda x: (
|
||||
'[%s%s]' % (
|
||||
x.group(1) and '^' or '',
|
||||
x.group(2).replace(r'\\\-', '-')
|
||||
)
|
||||
),
|
||||
r,
|
||||
)
|
||||
if word_boundary:
|
||||
r = r"\b%s\b" % (r,)
|
||||
|
||||
return re.compile(r, flags=re.IGNORECASE)
|
||||
else:
|
||||
r = "^" + r + "$"
|
||||
|
||||
return re.compile(r, flags=re.IGNORECASE)
|
||||
elif word_boundary:
|
||||
r = re.escape(glob)
|
||||
r = r"\b%s\b" % (r,)
|
||||
|
||||
return re.compile(r, flags=re.IGNORECASE)
|
||||
else:
|
||||
r = "^" + re.escape(glob) + "$"
|
||||
return re.compile(r, flags=re.IGNORECASE)
|
||||
|
||||
|
||||
def _flatten_dict(d, prefix=[], result={}):
|
||||
for key, value in d.items():
|
||||
if isinstance(value, basestring):
|
||||
|
@ -185,16 +208,3 @@ def _flatten_dict(d, prefix=[], result={}):
|
|||
_flatten_dict(value, prefix=(prefix + [key]), result=result)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
regex_cache = LruCache(5000)
|
||||
|
||||
|
||||
def _compile_regex(regex_str):
|
||||
r = regex_cache.get(regex_str, None)
|
||||
if r:
|
||||
return r
|
||||
|
||||
r = re.compile(regex_str, flags=re.IGNORECASE)
|
||||
regex_cache[regex_str] = r
|
||||
return r
|
||||
|
|
|
@ -33,13 +33,13 @@ def get_badge_count(store, user_id):
|
|||
|
||||
badge = len(invites)
|
||||
|
||||
for r in joins:
|
||||
if r.room_id in my_receipts_by_room:
|
||||
last_unread_event_id = my_receipts_by_room[r.room_id]
|
||||
for room_id in joins:
|
||||
if room_id in my_receipts_by_room:
|
||||
last_unread_event_id = my_receipts_by_room[room_id]
|
||||
|
||||
notifs = yield (
|
||||
store.get_unread_event_push_actions_by_room_for_user(
|
||||
r.room_id, user_id, last_unread_event_id
|
||||
room_id, user_id, last_unread_event_id
|
||||
)
|
||||
)
|
||||
# return one badge count per conversation, as count per
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
# Copyright 2017 Vector Creations Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -18,6 +19,7 @@ from distutils.version import LooseVersion
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
REQUIREMENTS = {
|
||||
"jsonschema>=2.5.1": ["jsonschema>=2.5.1"],
|
||||
"frozendict>=0.4": ["frozendict"],
|
||||
"unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"],
|
||||
"canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"],
|
||||
|
@ -37,6 +39,7 @@ REQUIREMENTS = {
|
|||
"pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"],
|
||||
"pymacaroons-pynacl": ["pymacaroons"],
|
||||
"msgpack-python>=0.3.0": ["msgpack"],
|
||||
"phonenumbers>=8.2.0": ["phonenumbers"],
|
||||
}
|
||||
CONDITIONAL_REQUIREMENTS = {
|
||||
"web_client": {
|
||||
|
|
|
@ -283,12 +283,12 @@ class ReplicationResource(Resource):
|
|||
|
||||
if request_events != upto_events_token:
|
||||
writer.write_header_and_rows("events", res.new_forward_events, (
|
||||
"position", "internal", "json", "state_group"
|
||||
"position", "event_id", "room_id", "type", "state_key",
|
||||
), position=upto_events_token)
|
||||
|
||||
if request_backfill != upto_backfill_token:
|
||||
writer.write_header_and_rows("backfill", res.new_backfill_events, (
|
||||
"position", "internal", "json", "state_group",
|
||||
"position", "event_id", "room_id", "type", "state_key", "redacts",
|
||||
), position=upto_backfill_token)
|
||||
|
||||
writer.write_header_and_rows(
|
||||
|
|
|
@ -27,4 +27,9 @@ class SlavedIdTracker(object):
|
|||
self._current = (max if self.step > 0 else min)(self._current, new_id)
|
||||
|
||||
def get_current_token(self):
|
||||
"""
|
||||
|
||||
Returns:
|
||||
int
|
||||
"""
|
||||
return self._current
|
||||
|
|
|
@ -16,7 +16,6 @@ from ._base import BaseSlavedStore
|
|||
from ._slaved_id_tracker import SlavedIdTracker
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.events import FrozenEvent
|
||||
from synapse.storage import DataStore
|
||||
from synapse.storage.roommember import RoomMemberStore
|
||||
from synapse.storage.event_federation import EventFederationStore
|
||||
|
@ -25,7 +24,6 @@ from synapse.storage.state import StateStore
|
|||
from synapse.storage.stream import StreamStore
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
import ujson as json
|
||||
import logging
|
||||
|
||||
|
||||
|
@ -109,6 +107,10 @@ class SlavedEventStore(BaseSlavedStore):
|
|||
get_recent_event_ids_for_room = (
|
||||
StreamStore.__dict__["get_recent_event_ids_for_room"]
|
||||
)
|
||||
get_current_state_ids = (
|
||||
StateStore.__dict__["get_current_state_ids"]
|
||||
)
|
||||
has_room_changed_since = DataStore.has_room_changed_since.__func__
|
||||
|
||||
get_unread_push_actions_for_user_in_range_for_http = (
|
||||
DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__
|
||||
|
@ -165,7 +167,6 @@ class SlavedEventStore(BaseSlavedStore):
|
|||
_get_rooms_for_user_where_membership_is_txn = (
|
||||
DataStore._get_rooms_for_user_where_membership_is_txn.__func__
|
||||
)
|
||||
_get_members_rows_txn = DataStore._get_members_rows_txn.__func__
|
||||
_get_state_for_groups = DataStore._get_state_for_groups.__func__
|
||||
_get_all_state_from_cache = DataStore._get_all_state_from_cache.__func__
|
||||
_get_events_around_txn = DataStore._get_events_around_txn.__func__
|
||||
|
@ -238,46 +239,32 @@ class SlavedEventStore(BaseSlavedStore):
|
|||
return super(SlavedEventStore, self).process_replication(result)
|
||||
|
||||
def _process_replication_row(self, row, backfilled):
|
||||
internal = json.loads(row[1])
|
||||
event_json = json.loads(row[2])
|
||||
event = FrozenEvent(event_json, internal_metadata_dict=internal)
|
||||
stream_ordering = row[0] if not backfilled else -row[0]
|
||||
self.invalidate_caches_for_event(
|
||||
event, backfilled,
|
||||
stream_ordering, row[1], row[2], row[3], row[4], row[5],
|
||||
backfilled=backfilled,
|
||||
)
|
||||
|
||||
def invalidate_caches_for_event(self, event, backfilled):
|
||||
self._invalidate_get_event_cache(event.event_id)
|
||||
def invalidate_caches_for_event(self, stream_ordering, event_id, room_id,
|
||||
etype, state_key, redacts, backfilled):
|
||||
self._invalidate_get_event_cache(event_id)
|
||||
|
||||
self.get_latest_event_ids_in_room.invalidate((event.room_id,))
|
||||
self.get_latest_event_ids_in_room.invalidate((room_id,))
|
||||
|
||||
self.get_unread_event_push_actions_by_room_for_user.invalidate_many(
|
||||
(event.room_id,)
|
||||
(room_id,)
|
||||
)
|
||||
|
||||
if not backfilled:
|
||||
self._events_stream_cache.entity_has_changed(
|
||||
event.room_id, event.internal_metadata.stream_ordering
|
||||
room_id, stream_ordering
|
||||
)
|
||||
|
||||
# self.get_unread_event_push_actions_by_room_for_user.invalidate_many(
|
||||
# (event.room_id,)
|
||||
# )
|
||||
if redacts:
|
||||
self._invalidate_get_event_cache(redacts)
|
||||
|
||||
if event.type == EventTypes.Redaction:
|
||||
self._invalidate_get_event_cache(event.redacts)
|
||||
|
||||
if event.type == EventTypes.Member:
|
||||
if etype == EventTypes.Member:
|
||||
self._membership_stream_cache.entity_has_changed(
|
||||
event.state_key, event.internal_metadata.stream_ordering
|
||||
state_key, stream_ordering
|
||||
)
|
||||
self.get_invited_rooms_for_user.invalidate((event.state_key,))
|
||||
|
||||
if not event.is_state():
|
||||
return
|
||||
|
||||
if backfilled:
|
||||
return
|
||||
|
||||
if (not event.internal_metadata.is_invite_from_remote()
|
||||
and event.internal_metadata.is_outlier()):
|
||||
return
|
||||
self.get_invited_rooms_for_user.invalidate((state_key,))
|
||||
|
|
|
@ -57,5 +57,6 @@ class SlavedPresenceStore(BaseSlavedStore):
|
|||
self.presence_stream_cache.entity_has_changed(
|
||||
user_id, position
|
||||
)
|
||||
self._get_presence_for_user.invalidate((user_id,))
|
||||
|
||||
return super(SlavedPresenceStore, self).process_replication(result)
|
||||
|
|
|
@ -19,6 +19,7 @@ from synapse.api.errors import SynapseError, LoginError, Codes
|
|||
from synapse.types import UserID
|
||||
from synapse.http.server import finish_request
|
||||
from synapse.http.servlet import parse_json_object_from_request
|
||||
from synapse.util.msisdn import phone_number_to_msisdn
|
||||
|
||||
from .base import ClientV1RestServlet, client_path_patterns
|
||||
|
||||
|
@ -33,10 +34,55 @@ from saml2.client import Saml2Client
|
|||
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from twisted.web.client import PartialDownloadError
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def login_submission_legacy_convert(submission):
|
||||
"""
|
||||
If the input login submission is an old style object
|
||||
(ie. with top-level user / medium / address) convert it
|
||||
to a typed object.
|
||||
"""
|
||||
if "user" in submission:
|
||||
submission["identifier"] = {
|
||||
"type": "m.id.user",
|
||||
"user": submission["user"],
|
||||
}
|
||||
del submission["user"]
|
||||
|
||||
if "medium" in submission and "address" in submission:
|
||||
submission["identifier"] = {
|
||||
"type": "m.id.thirdparty",
|
||||
"medium": submission["medium"],
|
||||
"address": submission["address"],
|
||||
}
|
||||
del submission["medium"]
|
||||
del submission["address"]
|
||||
|
||||
|
||||
def login_id_thirdparty_from_phone(identifier):
|
||||
"""
|
||||
Convert a phone login identifier type to a generic threepid identifier
|
||||
Args:
|
||||
identifier(dict): Login identifier dict of type 'm.id.phone'
|
||||
|
||||
Returns: Login identifier dict of type 'm.id.threepid'
|
||||
"""
|
||||
if "country" not in identifier or "number" not in identifier:
|
||||
raise SynapseError(400, "Invalid phone-type identifier")
|
||||
|
||||
msisdn = phone_number_to_msisdn(identifier["country"], identifier["number"])
|
||||
|
||||
return {
|
||||
"type": "m.id.thirdparty",
|
||||
"medium": "msisdn",
|
||||
"address": msisdn,
|
||||
}
|
||||
|
||||
|
||||
class LoginRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/login$")
|
||||
PASS_TYPE = "m.login.password"
|
||||
|
@ -117,20 +163,52 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def do_password_login(self, login_submission):
|
||||
if 'medium' in login_submission and 'address' in login_submission:
|
||||
address = login_submission['address']
|
||||
if login_submission['medium'] == 'email':
|
||||
if "password" not in login_submission:
|
||||
raise SynapseError(400, "Missing parameter: password")
|
||||
|
||||
login_submission_legacy_convert(login_submission)
|
||||
|
||||
if "identifier" not in login_submission:
|
||||
raise SynapseError(400, "Missing param: identifier")
|
||||
|
||||
identifier = login_submission["identifier"]
|
||||
if "type" not in identifier:
|
||||
raise SynapseError(400, "Login identifier has no type")
|
||||
|
||||
# convert phone type identifiers to generic threepids
|
||||
if identifier["type"] == "m.id.phone":
|
||||
identifier = login_id_thirdparty_from_phone(identifier)
|
||||
|
||||
# convert threepid identifiers to user IDs
|
||||
if identifier["type"] == "m.id.thirdparty":
|
||||
if 'medium' not in identifier or 'address' not in identifier:
|
||||
raise SynapseError(400, "Invalid thirdparty identifier")
|
||||
|
||||
address = identifier['address']
|
||||
if identifier['medium'] == 'email':
|
||||
# For emails, transform the address to lowercase.
|
||||
# We store all email addreses as lowercase in the DB.
|
||||
# (See add_threepid in synapse/handlers/auth.py)
|
||||
address = address.lower()
|
||||
user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||
login_submission['medium'], address
|
||||
identifier['medium'], address
|
||||
)
|
||||
if not user_id:
|
||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||
else:
|
||||
user_id = login_submission['user']
|
||||
|
||||
identifier = {
|
||||
"type": "m.id.user",
|
||||
"user": user_id,
|
||||
}
|
||||
|
||||
# by this point, the identifier should be an m.id.user: if it's anything
|
||||
# else, we haven't understood it.
|
||||
if identifier["type"] != "m.id.user":
|
||||
raise SynapseError(400, "Unknown login identifier type")
|
||||
if "user" not in identifier:
|
||||
raise SynapseError(400, "User identifier is missing 'user' key")
|
||||
|
||||
user_id = identifier["user"]
|
||||
|
||||
if not user_id.startswith('@'):
|
||||
user_id = UserID.create(
|
||||
|
@ -341,7 +419,12 @@ class CasTicketServlet(ClientV1RestServlet):
|
|||
"ticket": request.args["ticket"],
|
||||
"service": self.cas_service_url
|
||||
}
|
||||
body = yield http_client.get_raw(uri, args)
|
||||
try:
|
||||
body = yield http_client.get_raw(uri, args)
|
||||
except PartialDownloadError as pde:
|
||||
# Twisted raises this error if the connection is closed,
|
||||
# even if that's being used old-http style to signal end-of-data
|
||||
body = pde.response
|
||||
result = yield self.handle_cas_response(request, body, client_redirect_url)
|
||||
defer.returnValue(result)
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ from twisted.internet import defer
|
|||
|
||||
from synapse.api.errors import SynapseError, AuthError
|
||||
from synapse.types import UserID
|
||||
from synapse.handlers.presence import format_user_presence_state
|
||||
from synapse.http.servlet import parse_json_object_from_request
|
||||
from .base import ClientV1RestServlet, client_path_patterns
|
||||
|
||||
|
@ -33,6 +34,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
|
|||
def __init__(self, hs):
|
||||
super(PresenceStatusRestServlet, self).__init__(hs)
|
||||
self.presence_handler = hs.get_presence_handler()
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
|
@ -48,6 +50,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
|
|||
raise AuthError(403, "You are not allowed to see their presence.")
|
||||
|
||||
state = yield self.presence_handler.get_state(target_user=user)
|
||||
state = format_user_presence_state(state, self.clock.time_msec())
|
||||
|
||||
defer.returnValue((200, state))
|
||||
|
||||
|
|
|
@ -748,8 +748,7 @@ class JoinedRoomsRestServlet(ClientV1RestServlet):
|
|||
def on_GET(self, request):
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
rooms = yield self.store.get_rooms_for_user(requester.user.to_string())
|
||||
room_ids = set(r.room_id for r in rooms) # Ensure they're unique.
|
||||
room_ids = yield self.store.get_rooms_for_user(requester.user.to_string())
|
||||
defer.returnValue((200, {"joined_rooms": list(room_ids)}))
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
# Copyright 2017 Vector Creations Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -17,8 +18,11 @@ from twisted.internet import defer
|
|||
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import LoginError, SynapseError, Codes
|
||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||
from synapse.http.servlet import (
|
||||
RestServlet, parse_json_object_from_request, assert_params_in_request
|
||||
)
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.util.msisdn import phone_number_to_msisdn
|
||||
|
||||
from ._base import client_v2_patterns
|
||||
|
||||
|
@ -28,11 +32,11 @@ import logging
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PasswordRequestTokenRestServlet(RestServlet):
|
||||
class EmailPasswordRequestTokenRestServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/account/password/email/requestToken$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(PasswordRequestTokenRestServlet, self).__init__()
|
||||
super(EmailPasswordRequestTokenRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.identity_handler = hs.get_handlers().identity_handler
|
||||
|
||||
|
@ -40,14 +44,9 @@ class PasswordRequestTokenRestServlet(RestServlet):
|
|||
def on_POST(self, request):
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
required = ['id_server', 'client_secret', 'email', 'send_attempt']
|
||||
absent = []
|
||||
for k in required:
|
||||
if k not in body:
|
||||
absent.append(k)
|
||||
|
||||
if absent:
|
||||
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
|
||||
assert_params_in_request(body, [
|
||||
'id_server', 'client_secret', 'email', 'send_attempt'
|
||||
])
|
||||
|
||||
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||
'email', body['email']
|
||||
|
@ -60,6 +59,37 @@ class PasswordRequestTokenRestServlet(RestServlet):
|
|||
defer.returnValue((200, ret))
|
||||
|
||||
|
||||
class MsisdnPasswordRequestTokenRestServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/account/password/msisdn/requestToken$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(MsisdnPasswordRequestTokenRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.datastore = self.hs.get_datastore()
|
||||
self.identity_handler = hs.get_handlers().identity_handler
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
assert_params_in_request(body, [
|
||||
'id_server', 'client_secret',
|
||||
'country', 'phone_number', 'send_attempt',
|
||||
])
|
||||
|
||||
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
|
||||
|
||||
existingUid = yield self.datastore.get_user_id_by_threepid(
|
||||
'msisdn', msisdn
|
||||
)
|
||||
|
||||
if existingUid is None:
|
||||
raise SynapseError(400, "MSISDN not found", Codes.THREEPID_NOT_FOUND)
|
||||
|
||||
ret = yield self.identity_handler.requestMsisdnToken(**body)
|
||||
defer.returnValue((200, ret))
|
||||
|
||||
|
||||
class PasswordRestServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/account/password$")
|
||||
|
||||
|
@ -68,6 +98,7 @@ class PasswordRestServlet(RestServlet):
|
|||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self.datastore = self.hs.get_datastore()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
|
@ -77,7 +108,8 @@ class PasswordRestServlet(RestServlet):
|
|||
|
||||
authed, result, params, _ = yield self.auth_handler.check_auth([
|
||||
[LoginType.PASSWORD],
|
||||
[LoginType.EMAIL_IDENTITY]
|
||||
[LoginType.EMAIL_IDENTITY],
|
||||
[LoginType.MSISDN],
|
||||
], body, self.hs.get_ip_from_request(request))
|
||||
|
||||
if not authed:
|
||||
|
@ -102,7 +134,7 @@ class PasswordRestServlet(RestServlet):
|
|||
# (See add_threepid in synapse/handlers/auth.py)
|
||||
threepid['address'] = threepid['address'].lower()
|
||||
# if using email, we must know about the email they're authing with!
|
||||
threepid_user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||
threepid_user_id = yield self.datastore.get_user_id_by_threepid(
|
||||
threepid['medium'], threepid['address']
|
||||
)
|
||||
if not threepid_user_id:
|
||||
|
@ -169,13 +201,14 @@ class DeactivateAccountRestServlet(RestServlet):
|
|||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
class ThreepidRequestTokenRestServlet(RestServlet):
|
||||
class EmailThreepidRequestTokenRestServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$")
|
||||
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
super(ThreepidRequestTokenRestServlet, self).__init__()
|
||||
super(EmailThreepidRequestTokenRestServlet, self).__init__()
|
||||
self.identity_handler = hs.get_handlers().identity_handler
|
||||
self.datastore = self.hs.get_datastore()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
|
@ -190,7 +223,7 @@ class ThreepidRequestTokenRestServlet(RestServlet):
|
|||
if absent:
|
||||
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
|
||||
|
||||
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||
existingUid = yield self.datastore.get_user_id_by_threepid(
|
||||
'email', body['email']
|
||||
)
|
||||
|
||||
|
@ -201,6 +234,44 @@ class ThreepidRequestTokenRestServlet(RestServlet):
|
|||
defer.returnValue((200, ret))
|
||||
|
||||
|
||||
class MsisdnThreepidRequestTokenRestServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/account/3pid/msisdn/requestToken$")
|
||||
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
super(MsisdnThreepidRequestTokenRestServlet, self).__init__()
|
||||
self.identity_handler = hs.get_handlers().identity_handler
|
||||
self.datastore = self.hs.get_datastore()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
required = [
|
||||
'id_server', 'client_secret',
|
||||
'country', 'phone_number', 'send_attempt',
|
||||
]
|
||||
absent = []
|
||||
for k in required:
|
||||
if k not in body:
|
||||
absent.append(k)
|
||||
|
||||
if absent:
|
||||
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
|
||||
|
||||
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
|
||||
|
||||
existingUid = yield self.datastore.get_user_id_by_threepid(
|
||||
'msisdn', msisdn
|
||||
)
|
||||
|
||||
if existingUid is not None:
|
||||
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
|
||||
|
||||
ret = yield self.identity_handler.requestMsisdnToken(**body)
|
||||
defer.returnValue((200, ret))
|
||||
|
||||
|
||||
class ThreepidRestServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/account/3pid$")
|
||||
|
||||
|
@ -210,6 +281,7 @@ class ThreepidRestServlet(RestServlet):
|
|||
self.identity_handler = hs.get_handlers().identity_handler
|
||||
self.auth = hs.get_auth()
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self.datastore = self.hs.get_datastore()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
|
@ -217,7 +289,7 @@ class ThreepidRestServlet(RestServlet):
|
|||
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
threepids = yield self.hs.get_datastore().user_get_threepids(
|
||||
threepids = yield self.datastore.user_get_threepids(
|
||||
requester.user.to_string()
|
||||
)
|
||||
|
||||
|
@ -258,7 +330,7 @@ class ThreepidRestServlet(RestServlet):
|
|||
|
||||
if 'bind' in body and body['bind']:
|
||||
logger.debug(
|
||||
"Binding emails %s to %s",
|
||||
"Binding threepid %s to %s",
|
||||
threepid, user_id
|
||||
)
|
||||
yield self.identity_handler.bind_threepid(
|
||||
|
@ -302,9 +374,11 @@ class ThreepidDeleteRestServlet(RestServlet):
|
|||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
PasswordRequestTokenRestServlet(hs).register(http_server)
|
||||
EmailPasswordRequestTokenRestServlet(hs).register(http_server)
|
||||
MsisdnPasswordRequestTokenRestServlet(hs).register(http_server)
|
||||
PasswordRestServlet(hs).register(http_server)
|
||||
DeactivateAccountRestServlet(hs).register(http_server)
|
||||
ThreepidRequestTokenRestServlet(hs).register(http_server)
|
||||
EmailThreepidRequestTokenRestServlet(hs).register(http_server)
|
||||
MsisdnThreepidRequestTokenRestServlet(hs).register(http_server)
|
||||
ThreepidRestServlet(hs).register(http_server)
|
||||
ThreepidDeleteRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -46,6 +46,52 @@ class DevicesRestServlet(servlet.RestServlet):
|
|||
defer.returnValue((200, {"devices": devices}))
|
||||
|
||||
|
||||
class DeleteDevicesRestServlet(servlet.RestServlet):
|
||||
"""
|
||||
API for bulk deletion of devices. Accepts a JSON object with a devices
|
||||
key which lists the device_ids to delete. Requires user interactive auth.
|
||||
"""
|
||||
PATTERNS = client_v2_patterns("/delete_devices", releases=[], v2_alpha=False)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(DeleteDevicesRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
try:
|
||||
body = servlet.parse_json_object_from_request(request)
|
||||
except errors.SynapseError as e:
|
||||
if e.errcode == errors.Codes.NOT_JSON:
|
||||
# deal with older clients which didn't pass a J*DELETESON dict
|
||||
# the same as those that pass an empty dict
|
||||
body = {}
|
||||
else:
|
||||
raise e
|
||||
|
||||
if 'devices' not in body:
|
||||
raise errors.SynapseError(
|
||||
400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM
|
||||
)
|
||||
|
||||
authed, result, params, _ = yield self.auth_handler.check_auth([
|
||||
[constants.LoginType.PASSWORD],
|
||||
], body, self.hs.get_ip_from_request(request))
|
||||
|
||||
if not authed:
|
||||
defer.returnValue((401, result))
|
||||
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
yield self.device_handler.delete_devices(
|
||||
requester.user.to_string(),
|
||||
body['devices'],
|
||||
)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
class DeviceRestServlet(servlet.RestServlet):
|
||||
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
|
||||
releases=[], v2_alpha=False)
|
||||
|
@ -111,5 +157,6 @@ class DeviceRestServlet(servlet.RestServlet):
|
|||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
DeleteDevicesRestServlet(hs).register(http_server)
|
||||
DevicesRestServlet(hs).register(http_server)
|
||||
DeviceRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015 - 2016 OpenMarket Ltd
|
||||
# Copyright 2017 Vector Creations Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -19,7 +20,10 @@ import synapse
|
|||
from synapse.api.auth import get_access_token_from_request, has_access_token
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
|
||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||
from synapse.http.servlet import (
|
||||
RestServlet, parse_json_object_from_request, assert_params_in_request
|
||||
)
|
||||
from synapse.util.msisdn import phone_number_to_msisdn
|
||||
|
||||
from ._base import client_v2_patterns
|
||||
|
||||
|
@ -43,7 +47,7 @@ else:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RegisterRequestTokenRestServlet(RestServlet):
|
||||
class EmailRegisterRequestTokenRestServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/register/email/requestToken$")
|
||||
|
||||
def __init__(self, hs):
|
||||
|
@ -51,7 +55,7 @@ class RegisterRequestTokenRestServlet(RestServlet):
|
|||
Args:
|
||||
hs (synapse.server.HomeServer): server
|
||||
"""
|
||||
super(RegisterRequestTokenRestServlet, self).__init__()
|
||||
super(EmailRegisterRequestTokenRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.identity_handler = hs.get_handlers().identity_handler
|
||||
|
||||
|
@ -59,14 +63,9 @@ class RegisterRequestTokenRestServlet(RestServlet):
|
|||
def on_POST(self, request):
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
required = ['id_server', 'client_secret', 'email', 'send_attempt']
|
||||
absent = []
|
||||
for k in required:
|
||||
if k not in body:
|
||||
absent.append(k)
|
||||
|
||||
if len(absent) > 0:
|
||||
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
|
||||
assert_params_in_request(body, [
|
||||
'id_server', 'client_secret', 'email', 'send_attempt'
|
||||
])
|
||||
|
||||
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||
'email', body['email']
|
||||
|
@ -79,6 +78,43 @@ class RegisterRequestTokenRestServlet(RestServlet):
|
|||
defer.returnValue((200, ret))
|
||||
|
||||
|
||||
class MsisdnRegisterRequestTokenRestServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/register/msisdn/requestToken$")
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): server
|
||||
"""
|
||||
super(MsisdnRegisterRequestTokenRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.identity_handler = hs.get_handlers().identity_handler
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
assert_params_in_request(body, [
|
||||
'id_server', 'client_secret',
|
||||
'country', 'phone_number',
|
||||
'send_attempt',
|
||||
])
|
||||
|
||||
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
|
||||
|
||||
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||
'msisdn', msisdn
|
||||
)
|
||||
|
||||
if existingUid is not None:
|
||||
raise SynapseError(
|
||||
400, "Phone number is already in use", Codes.THREEPID_IN_USE
|
||||
)
|
||||
|
||||
ret = yield self.identity_handler.requestMsisdnToken(**body)
|
||||
defer.returnValue((200, ret))
|
||||
|
||||
|
||||
class RegisterRestServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/register$")
|
||||
|
||||
|
@ -200,16 +236,37 @@ class RegisterRestServlet(RestServlet):
|
|||
assigned_user_id=registered_user_id,
|
||||
)
|
||||
|
||||
# Only give msisdn flows if the x_show_msisdn flag is given:
|
||||
# this is a hack to work around the fact that clients were shipped
|
||||
# that use fallback registration if they see any flows that they don't
|
||||
# recognise, which means we break registration for these clients if we
|
||||
# advertise msisdn flows. Once usage of Riot iOS <=0.3.9 and Riot
|
||||
# Android <=0.6.9 have fallen below an acceptable threshold, this
|
||||
# parameter should go away and we should always advertise msisdn flows.
|
||||
show_msisdn = False
|
||||
if 'x_show_msisdn' in body and body['x_show_msisdn']:
|
||||
show_msisdn = True
|
||||
|
||||
if self.hs.config.enable_registration_captcha:
|
||||
flows = [
|
||||
[LoginType.RECAPTCHA],
|
||||
[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]
|
||||
[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
|
||||
]
|
||||
if show_msisdn:
|
||||
flows.extend([
|
||||
[LoginType.MSISDN, LoginType.RECAPTCHA],
|
||||
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
|
||||
])
|
||||
else:
|
||||
flows = [
|
||||
[LoginType.DUMMY],
|
||||
[LoginType.EMAIL_IDENTITY]
|
||||
[LoginType.EMAIL_IDENTITY],
|
||||
]
|
||||
if show_msisdn:
|
||||
flows.extend([
|
||||
[LoginType.MSISDN],
|
||||
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
|
||||
])
|
||||
|
||||
authed, auth_result, params, session_id = yield self.auth_handler.check_auth(
|
||||
flows, body, self.hs.get_ip_from_request(request)
|
||||
|
@ -224,8 +281,9 @@ class RegisterRestServlet(RestServlet):
|
|||
"Already registered user ID %r for this session",
|
||||
registered_user_id
|
||||
)
|
||||
# don't re-register the email address
|
||||
# don't re-register the threepids
|
||||
add_email = False
|
||||
add_msisdn = False
|
||||
else:
|
||||
# NB: This may be from the auth handler and NOT from the POST
|
||||
if 'password' not in params:
|
||||
|
@ -250,6 +308,7 @@ class RegisterRestServlet(RestServlet):
|
|||
)
|
||||
|
||||
add_email = True
|
||||
add_msisdn = True
|
||||
|
||||
return_dict = yield self._create_registration_details(
|
||||
registered_user_id, params
|
||||
|
@ -262,6 +321,13 @@ class RegisterRestServlet(RestServlet):
|
|||
params.get("bind_email")
|
||||
)
|
||||
|
||||
if add_msisdn and auth_result and LoginType.MSISDN in auth_result:
|
||||
threepid = auth_result[LoginType.MSISDN]
|
||||
yield self._register_msisdn_threepid(
|
||||
registered_user_id, threepid, return_dict["access_token"],
|
||||
params.get("bind_msisdn")
|
||||
)
|
||||
|
||||
defer.returnValue((200, return_dict))
|
||||
|
||||
def on_OPTIONS(self, _):
|
||||
|
@ -323,8 +389,9 @@ class RegisterRestServlet(RestServlet):
|
|||
"""
|
||||
reqd = ('medium', 'address', 'validated_at')
|
||||
if any(x not in threepid for x in reqd):
|
||||
# This will only happen if the ID server returns a malformed response
|
||||
logger.info("Can't add incomplete 3pid")
|
||||
defer.returnValue()
|
||||
return
|
||||
|
||||
yield self.auth_handler.add_threepid(
|
||||
user_id,
|
||||
|
@ -371,6 +438,43 @@ class RegisterRestServlet(RestServlet):
|
|||
else:
|
||||
logger.info("bind_email not specified: not binding email")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _register_msisdn_threepid(self, user_id, threepid, token, bind_msisdn):
|
||||
"""Add a phone number as a 3pid identifier
|
||||
|
||||
Also optionally binds msisdn to the given user_id on the identity server
|
||||
|
||||
Args:
|
||||
user_id (str): id of user
|
||||
threepid (object): m.login.msisdn auth response
|
||||
token (str): access_token for the user
|
||||
bind_email (bool): true if the client requested the email to be
|
||||
bound at the identity server
|
||||
Returns:
|
||||
defer.Deferred:
|
||||
"""
|
||||
reqd = ('medium', 'address', 'validated_at')
|
||||
if any(x not in threepid for x in reqd):
|
||||
# This will only happen if the ID server returns a malformed response
|
||||
logger.info("Can't add incomplete 3pid")
|
||||
defer.returnValue()
|
||||
|
||||
yield self.auth_handler.add_threepid(
|
||||
user_id,
|
||||
threepid['medium'],
|
||||
threepid['address'],
|
||||
threepid['validated_at'],
|
||||
)
|
||||
|
||||
if bind_msisdn:
|
||||
logger.info("bind_msisdn specified: binding")
|
||||
logger.debug("Binding msisdn %s to %s", threepid, user_id)
|
||||
yield self.identity_handler.bind_threepid(
|
||||
threepid['threepid_creds'], user_id
|
||||
)
|
||||
else:
|
||||
logger.info("bind_msisdn not specified: not binding msisdn")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _create_registration_details(self, user_id, params):
|
||||
"""Complete registration of newly-registered user
|
||||
|
@ -433,7 +537,7 @@ class RegisterRestServlet(RestServlet):
|
|||
# we have nowhere to store it.
|
||||
device_id = synapse.api.auth.GUEST_DEVICE_ID
|
||||
initial_display_name = params.get("initial_device_display_name")
|
||||
self.device_handler.check_device_registered(
|
||||
yield self.device_handler.check_device_registered(
|
||||
user_id, device_id, initial_display_name
|
||||
)
|
||||
|
||||
|
@ -449,5 +553,6 @@ class RegisterRestServlet(RestServlet):
|
|||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
RegisterRequestTokenRestServlet(hs).register(http_server)
|
||||
EmailRegisterRequestTokenRestServlet(hs).register(http_server)
|
||||
MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
|
||||
RegisterRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -18,6 +18,7 @@ from twisted.internet import defer
|
|||
from synapse.http.servlet import (
|
||||
RestServlet, parse_string, parse_integer, parse_boolean
|
||||
)
|
||||
from synapse.handlers.presence import format_user_presence_state
|
||||
from synapse.handlers.sync import SyncConfig
|
||||
from synapse.types import StreamToken
|
||||
from synapse.events.utils import (
|
||||
|
@ -28,7 +29,6 @@ from synapse.api.errors import SynapseError
|
|||
from synapse.api.constants import PresenceState
|
||||
from ._base import client_v2_patterns
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
import logging
|
||||
|
||||
|
@ -194,12 +194,18 @@ class SyncRestServlet(RestServlet):
|
|||
defer.returnValue((200, response_content))
|
||||
|
||||
def encode_presence(self, events, time_now):
|
||||
formatted = []
|
||||
for event in events:
|
||||
event = copy.deepcopy(event)
|
||||
event['sender'] = event['content'].pop('user_id')
|
||||
formatted.append(event)
|
||||
return {"events": formatted}
|
||||
return {
|
||||
"events": [
|
||||
{
|
||||
"type": "m.presence",
|
||||
"sender": event.user_id,
|
||||
"content": format_user_presence_state(
|
||||
event, time_now, include_user_id=False
|
||||
),
|
||||
}
|
||||
for event in events
|
||||
]
|
||||
}
|
||||
|
||||
def encode_joined(self, rooms, time_now, token_id, event_fields):
|
||||
"""
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import synapse.http.servlet
|
||||
|
||||
from ._base import parse_media_id, respond_with_file, respond_404
|
||||
from twisted.web.resource import Resource
|
||||
|
@ -81,6 +82,17 @@ class DownloadResource(Resource):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def _respond_remote_file(self, request, server_name, media_id, name):
|
||||
# don't forward requests for remote media if allow_remote is false
|
||||
allow_remote = synapse.http.servlet.parse_boolean(
|
||||
request, "allow_remote", default=True)
|
||||
if not allow_remote:
|
||||
logger.info(
|
||||
"Rejecting request for remote media %s/%s due to allow_remote",
|
||||
server_name, media_id,
|
||||
)
|
||||
respond_404(request)
|
||||
return
|
||||
|
||||
media_info = yield self.media_repo.get_remote_media(server_name, media_id)
|
||||
|
||||
media_type = media_info["media_type"]
|
||||
|
|
|
@ -13,22 +13,23 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer, threads
|
||||
import twisted.internet.error
|
||||
import twisted.web.http
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from .upload_resource import UploadResource
|
||||
from .download_resource import DownloadResource
|
||||
from .thumbnail_resource import ThumbnailResource
|
||||
from .identicon_resource import IdenticonResource
|
||||
from .preview_url_resource import PreviewUrlResource
|
||||
from .filepath import MediaFilePaths
|
||||
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from .thumbnailer import Thumbnailer
|
||||
|
||||
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||
from synapse.util.stringutils import random_string
|
||||
from synapse.api.errors import SynapseError
|
||||
|
||||
from twisted.internet import defer, threads
|
||||
from synapse.api.errors import SynapseError, HttpResponseException, \
|
||||
NotFoundError
|
||||
|
||||
from synapse.util.async import Linearizer
|
||||
from synapse.util.stringutils import is_ascii
|
||||
|
@ -157,11 +158,34 @@ class MediaRepository(object):
|
|||
try:
|
||||
length, headers = yield self.client.get_file(
|
||||
server_name, request_path, output_stream=f,
|
||||
max_size=self.max_upload_size,
|
||||
max_size=self.max_upload_size, args={
|
||||
# tell the remote server to 404 if it doesn't
|
||||
# recognise the server_name, to make sure we don't
|
||||
# end up with a routing loop.
|
||||
"allow_remote": "false",
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warn("Failed to fetch remoted media %r", e)
|
||||
raise SynapseError(502, "Failed to fetch remoted media")
|
||||
except twisted.internet.error.DNSLookupError as e:
|
||||
logger.warn("HTTP error fetching remote media %s/%s: %r",
|
||||
server_name, media_id, e)
|
||||
raise NotFoundError()
|
||||
|
||||
except HttpResponseException as e:
|
||||
logger.warn("HTTP error fetching remote media %s/%s: %s",
|
||||
server_name, media_id, e.response)
|
||||
if e.code == twisted.web.http.NOT_FOUND:
|
||||
raise SynapseError.from_http_response_exception(e)
|
||||
raise SynapseError(502, "Failed to fetch remote media")
|
||||
|
||||
except SynapseError:
|
||||
logger.exception("Failed to fetch remote media %s/%s",
|
||||
server_name, media_id)
|
||||
raise
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch remote media %s/%s",
|
||||
server_name, media_id)
|
||||
raise SynapseError(502, "Failed to fetch remote media")
|
||||
|
||||
media_type = headers["Content-Type"][0]
|
||||
time_now_ms = self.clock.time_msec()
|
||||
|
|
|
@ -177,17 +177,12 @@ class StateHandler(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def compute_event_context(self, event, old_state=None):
|
||||
""" Fills out the context with the `current state` of the graph. The
|
||||
`current state` here is defined to be the state of the event graph
|
||||
just before the event - i.e. it never includes `event`
|
||||
|
||||
If `event` has `auth_events` then this will also fill out the
|
||||
`auth_events` field on `context` from the `current_state`.
|
||||
"""Build an EventContext structure for the event.
|
||||
|
||||
Args:
|
||||
event (EventBase)
|
||||
event (synapse.events.EventBase):
|
||||
Returns:
|
||||
an EventContext
|
||||
synapse.events.snapshot.EventContext:
|
||||
"""
|
||||
context = EventContext()
|
||||
|
||||
|
@ -200,11 +195,11 @@ class StateHandler(object):
|
|||
(s.type, s.state_key): s.event_id for s in old_state
|
||||
}
|
||||
if event.is_state():
|
||||
context.current_state_events = dict(context.prev_state_ids)
|
||||
context.current_state_ids = dict(context.prev_state_ids)
|
||||
key = (event.type, event.state_key)
|
||||
context.current_state_events[key] = event.event_id
|
||||
context.current_state_ids[key] = event.event_id
|
||||
else:
|
||||
context.current_state_events = context.prev_state_ids
|
||||
context.current_state_ids = context.prev_state_ids
|
||||
else:
|
||||
context.current_state_ids = {}
|
||||
context.prev_state_ids = {}
|
||||
|
|
|
@ -73,6 +73,9 @@ class LoggingTransaction(object):
|
|||
def __setattr__(self, name, value):
|
||||
setattr(self.txn, name, value)
|
||||
|
||||
def __iter__(self):
|
||||
return self.txn.__iter__()
|
||||
|
||||
def execute(self, sql, *args):
|
||||
self._do_execute(self.txn.execute, sql, *args)
|
||||
|
||||
|
@ -132,7 +135,7 @@ class PerformanceCounters(object):
|
|||
|
||||
def interval(self, interval_duration, limit=3):
|
||||
counters = []
|
||||
for name, (count, cum_time) in self.current_counters.items():
|
||||
for name, (count, cum_time) in self.current_counters.iteritems():
|
||||
prev_count, prev_time = self.previous_counters.get(name, (0, 0))
|
||||
counters.append((
|
||||
(cum_time - prev_time) / interval_duration,
|
||||
|
@ -357,7 +360,7 @@ class SQLBaseStore(object):
|
|||
"""
|
||||
col_headers = list(intern(column[0]) for column in cursor.description)
|
||||
results = list(
|
||||
dict(zip(col_headers, row)) for row in cursor.fetchall()
|
||||
dict(zip(col_headers, row)) for row in cursor
|
||||
)
|
||||
return results
|
||||
|
||||
|
@ -565,7 +568,7 @@ class SQLBaseStore(object):
|
|||
@staticmethod
|
||||
def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
|
||||
if keyvalues:
|
||||
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
|
||||
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
|
||||
else:
|
||||
where = ""
|
||||
|
||||
|
@ -579,7 +582,7 @@ class SQLBaseStore(object):
|
|||
|
||||
txn.execute(sql, keyvalues.values())
|
||||
|
||||
return [r[0] for r in txn.fetchall()]
|
||||
return [r[0] for r in txn]
|
||||
|
||||
def _simple_select_onecol(self, table, keyvalues, retcol,
|
||||
desc="_simple_select_onecol"):
|
||||
|
@ -712,7 +715,7 @@ class SQLBaseStore(object):
|
|||
)
|
||||
values.extend(iterable)
|
||||
|
||||
for key, value in keyvalues.items():
|
||||
for key, value in keyvalues.iteritems():
|
||||
clauses.append("%s = ?" % (key,))
|
||||
values.append(value)
|
||||
|
||||
|
@ -753,7 +756,7 @@ class SQLBaseStore(object):
|
|||
@staticmethod
|
||||
def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
|
||||
if keyvalues:
|
||||
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
|
||||
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
|
||||
else:
|
||||
where = ""
|
||||
|
||||
|
@ -840,6 +843,47 @@ class SQLBaseStore(object):
|
|||
|
||||
return txn.execute(sql, keyvalues.values())
|
||||
|
||||
def _simple_delete_many(self, table, column, iterable, keyvalues, desc):
|
||||
return self.runInteraction(
|
||||
desc, self._simple_delete_many_txn, table, column, iterable, keyvalues
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _simple_delete_many_txn(txn, table, column, iterable, keyvalues):
|
||||
"""Executes a DELETE query on the named table.
|
||||
|
||||
Filters rows by if value of `column` is in `iterable`.
|
||||
|
||||
Args:
|
||||
txn : Transaction object
|
||||
table : string giving the table name
|
||||
column : column name to test for inclusion against `iterable`
|
||||
iterable : list
|
||||
keyvalues : dict of column names and values to select the rows with
|
||||
"""
|
||||
if not iterable:
|
||||
return
|
||||
|
||||
sql = "DELETE FROM %s" % table
|
||||
|
||||
clauses = []
|
||||
values = []
|
||||
clauses.append(
|
||||
"%s IN (%s)" % (column, ",".join("?" for _ in iterable))
|
||||
)
|
||||
values.extend(iterable)
|
||||
|
||||
for key, value in keyvalues.iteritems():
|
||||
clauses.append("%s = ?" % (key,))
|
||||
values.append(value)
|
||||
|
||||
if clauses:
|
||||
sql = "%s WHERE %s" % (
|
||||
sql,
|
||||
" AND ".join(clauses),
|
||||
)
|
||||
return txn.execute(sql, values)
|
||||
|
||||
def _get_cache_dict(self, db_conn, table, entity_column, stream_column,
|
||||
max_value, limit=100000):
|
||||
# Fetch a mapping of room_id -> max stream position for "recent" rooms.
|
||||
|
@ -860,16 +904,16 @@ class SQLBaseStore(object):
|
|||
|
||||
txn = db_conn.cursor()
|
||||
txn.execute(sql, (int(max_value),))
|
||||
rows = txn.fetchall()
|
||||
txn.close()
|
||||
|
||||
cache = {
|
||||
row[0]: int(row[1])
|
||||
for row in rows
|
||||
for row in txn
|
||||
}
|
||||
|
||||
txn.close()
|
||||
|
||||
if cache:
|
||||
min_val = min(cache.values())
|
||||
min_val = min(cache.itervalues())
|
||||
else:
|
||||
min_val = max_value
|
||||
|
||||
|
|
|
@ -182,7 +182,7 @@ class AccountDataStore(SQLBaseStore):
|
|||
txn.execute(sql, (user_id, stream_id))
|
||||
|
||||
global_account_data = {
|
||||
row[0]: json.loads(row[1]) for row in txn.fetchall()
|
||||
row[0]: json.loads(row[1]) for row in txn
|
||||
}
|
||||
|
||||
sql = (
|
||||
|
@ -193,7 +193,7 @@ class AccountDataStore(SQLBaseStore):
|
|||
txn.execute(sql, (user_id, stream_id))
|
||||
|
||||
account_data_by_room = {}
|
||||
for row in txn.fetchall():
|
||||
for row in txn:
|
||||
room_account_data = account_data_by_room.setdefault(row[0], {})
|
||||
room_account_data[row[1]] = json.loads(row[2])
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import synapse.util.async
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
from . import engines
|
||||
|
@ -84,24 +85,14 @@ class BackgroundUpdateStore(SQLBaseStore):
|
|||
self._background_update_performance = {}
|
||||
self._background_update_queue = []
|
||||
self._background_update_handlers = {}
|
||||
self._background_update_timer = None
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def start_doing_background_updates(self):
|
||||
assert self._background_update_timer is None, \
|
||||
"background updates already running"
|
||||
|
||||
logger.info("Starting background schema updates")
|
||||
|
||||
while True:
|
||||
sleep = defer.Deferred()
|
||||
self._background_update_timer = self._clock.call_later(
|
||||
self.BACKGROUND_UPDATE_INTERVAL_MS / 1000., sleep.callback, None
|
||||
)
|
||||
try:
|
||||
yield sleep
|
||||
finally:
|
||||
self._background_update_timer = None
|
||||
yield synapse.util.async.sleep(
|
||||
self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.)
|
||||
|
||||
try:
|
||||
result = yield self.do_next_background_update(
|
||||
|
|
|
@ -178,7 +178,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
|
|||
)
|
||||
txn.execute(sql, (user_id,))
|
||||
message_json = ujson.dumps(messages_by_device["*"])
|
||||
for row in txn.fetchall():
|
||||
for row in txn:
|
||||
# Add the message for all devices for this user on this
|
||||
# server.
|
||||
device = row[0]
|
||||
|
@ -195,7 +195,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
|
|||
# TODO: Maybe this needs to be done in batches if there are
|
||||
# too many local devices for a given user.
|
||||
txn.execute(sql, [user_id] + devices)
|
||||
for row in txn.fetchall():
|
||||
for row in txn:
|
||||
# Only insert into the local inbox if the device exists on
|
||||
# this server
|
||||
device = row[0]
|
||||
|
@ -251,7 +251,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
|
|||
user_id, device_id, last_stream_id, current_stream_id, limit
|
||||
))
|
||||
messages = []
|
||||
for row in txn.fetchall():
|
||||
for row in txn:
|
||||
stream_pos = row[0]
|
||||
messages.append(ujson.loads(row[1]))
|
||||
if len(messages) < limit:
|
||||
|
@ -340,7 +340,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
|
|||
" ORDER BY stream_id ASC"
|
||||
)
|
||||
txn.execute(sql, (last_pos, upper_pos))
|
||||
rows.extend(txn.fetchall())
|
||||
rows.extend(txn)
|
||||
|
||||
return rows
|
||||
|
||||
|
@ -357,12 +357,12 @@ class DeviceInboxStore(BackgroundUpdateStore):
|
|||
"""
|
||||
Args:
|
||||
destination(str): The name of the remote server.
|
||||
last_stream_id(int): The last position of the device message stream
|
||||
last_stream_id(int|long): The last position of the device message stream
|
||||
that the server sent up to.
|
||||
current_stream_id(int): The current position of the device
|
||||
current_stream_id(int|long): The current position of the device
|
||||
message stream.
|
||||
Returns:
|
||||
Deferred ([dict], int): List of messages for the device and where
|
||||
Deferred ([dict], int|long): List of messages for the device and where
|
||||
in the stream the messages got to.
|
||||
"""
|
||||
|
||||
|
@ -384,7 +384,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
|
|||
destination, last_stream_id, current_stream_id, limit
|
||||
))
|
||||
messages = []
|
||||
for row in txn.fetchall():
|
||||
for row in txn:
|
||||
stream_pos = row[0]
|
||||
messages.append(ujson.loads(row[1]))
|
||||
if len(messages) < limit:
|
||||
|
|
|
@ -108,6 +108,23 @@ class DeviceStore(SQLBaseStore):
|
|||
desc="delete_device",
|
||||
)
|
||||
|
||||
def delete_devices(self, user_id, device_ids):
|
||||
"""Deletes several devices.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user which owns the devices
|
||||
device_ids (list): The IDs of the devices to delete
|
||||
Returns:
|
||||
defer.Deferred
|
||||
"""
|
||||
return self._simple_delete_many(
|
||||
table="devices",
|
||||
column="device_id",
|
||||
iterable=device_ids,
|
||||
keyvalues={"user_id": user_id},
|
||||
desc="delete_devices",
|
||||
)
|
||||
|
||||
def update_device(self, user_id, device_id, new_display_name=None):
|
||||
"""Update a device.
|
||||
|
||||
|
@ -291,7 +308,7 @@ class DeviceStore(SQLBaseStore):
|
|||
"""Get stream of updates to send to remote servers
|
||||
|
||||
Returns:
|
||||
(now_stream_id, [ { updates }, .. ])
|
||||
(int, list[dict]): current stream id and list of updates
|
||||
"""
|
||||
now_stream_id = self._device_list_id_gen.get_current_token()
|
||||
|
||||
|
@ -312,17 +329,20 @@ class DeviceStore(SQLBaseStore):
|
|||
SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
|
||||
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
|
||||
GROUP BY user_id, device_id
|
||||
LIMIT 20
|
||||
"""
|
||||
txn.execute(
|
||||
sql, (destination, from_stream_id, now_stream_id, False)
|
||||
)
|
||||
rows = txn.fetchall()
|
||||
|
||||
if not rows:
|
||||
return (now_stream_id, [])
|
||||
|
||||
# maps (user_id, device_id) -> stream_id
|
||||
query_map = {(r[0], r[1]): r[2] for r in rows}
|
||||
query_map = {(r[0], r[1]): r[2] for r in txn}
|
||||
if not query_map:
|
||||
return (now_stream_id, [])
|
||||
|
||||
if len(query_map) >= 20:
|
||||
now_stream_id = max(stream_id for stream_id in query_map.itervalues())
|
||||
|
||||
devices = self._get_e2e_device_keys_txn(
|
||||
txn, query_map.keys(), include_all_devices=True
|
||||
)
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
# limitations under the License.
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
|
||||
from canonicaljson import encode_canonical_json
|
||||
import ujson as json
|
||||
|
||||
|
@ -120,24 +122,63 @@ class EndToEndKeyStore(SQLBaseStore):
|
|||
|
||||
return result
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
|
||||
"""Insert some new one time keys for a device.
|
||||
|
||||
Checks if any of the keys are already inserted, if they are then check
|
||||
if they match. If they don't then we raise an error.
|
||||
"""
|
||||
|
||||
# First we check if we have already persisted any of the keys.
|
||||
rows = yield self._simple_select_many_batch(
|
||||
table="e2e_one_time_keys_json",
|
||||
column="key_id",
|
||||
iterable=[key_id for _, key_id, _ in key_list],
|
||||
retcols=("algorithm", "key_id", "key_json",),
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
},
|
||||
desc="add_e2e_one_time_keys_check",
|
||||
)
|
||||
|
||||
existing_key_map = {
|
||||
(row["algorithm"], row["key_id"]): row["key_json"] for row in rows
|
||||
}
|
||||
|
||||
new_keys = [] # Keys that we need to insert
|
||||
for algorithm, key_id, json_bytes in key_list:
|
||||
ex_bytes = existing_key_map.get((algorithm, key_id), None)
|
||||
if ex_bytes:
|
||||
if json_bytes != ex_bytes:
|
||||
raise SynapseError(
|
||||
400, "One time key with key_id %r already exists" % (key_id,)
|
||||
)
|
||||
else:
|
||||
new_keys.append((algorithm, key_id, json_bytes))
|
||||
|
||||
def _add_e2e_one_time_keys(txn):
|
||||
for (algorithm, key_id, json_bytes) in key_list:
|
||||
self._simple_upsert_txn(
|
||||
txn, table="e2e_one_time_keys_json",
|
||||
keyvalues={
|
||||
# We are protected from race between lookup and insertion due to
|
||||
# a unique constraint. If there is a race of two calls to
|
||||
# `add_e2e_one_time_keys` then they'll conflict and we will only
|
||||
# insert one set.
|
||||
self._simple_insert_many_txn(
|
||||
txn, table="e2e_one_time_keys_json",
|
||||
values=[
|
||||
{
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"algorithm": algorithm,
|
||||
"key_id": key_id,
|
||||
},
|
||||
values={
|
||||
"ts_added_ms": time_now,
|
||||
"key_json": json_bytes,
|
||||
}
|
||||
)
|
||||
return self.runInteraction(
|
||||
"add_e2e_one_time_keys", _add_e2e_one_time_keys
|
||||
for algorithm, key_id, json_bytes in new_keys
|
||||
],
|
||||
)
|
||||
yield self.runInteraction(
|
||||
"add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
|
||||
)
|
||||
|
||||
def count_e2e_one_time_keys(self, user_id, device_id):
|
||||
|
@ -153,7 +194,7 @@ class EndToEndKeyStore(SQLBaseStore):
|
|||
)
|
||||
txn.execute(sql, (user_id, device_id))
|
||||
result = {}
|
||||
for algorithm, key_count in txn.fetchall():
|
||||
for algorithm, key_count in txn:
|
||||
result[algorithm] = key_count
|
||||
return result
|
||||
return self.runInteraction(
|
||||
|
@ -174,7 +215,7 @@ class EndToEndKeyStore(SQLBaseStore):
|
|||
user_result = result.setdefault(user_id, {})
|
||||
device_result = user_result.setdefault(device_id, {})
|
||||
txn.execute(sql, (user_id, device_id, algorithm))
|
||||
for key_id, key_json in txn.fetchall():
|
||||
for key_id, key_json in txn:
|
||||
device_result[algorithm + ":" + key_id] = key_json
|
||||
delete.append((user_id, device_id, algorithm, key_id))
|
||||
sql = (
|
||||
|
|
|
@ -74,7 +74,7 @@ class EventFederationStore(SQLBaseStore):
|
|||
base_sql % (",".join(["?"] * len(chunk)),),
|
||||
chunk
|
||||
)
|
||||
new_front.update([r[0] for r in txn.fetchall()])
|
||||
new_front.update([r[0] for r in txn])
|
||||
|
||||
new_front -= results
|
||||
|
||||
|
@ -110,7 +110,7 @@ class EventFederationStore(SQLBaseStore):
|
|||
|
||||
txn.execute(sql, (room_id, False,))
|
||||
|
||||
return dict(txn.fetchall())
|
||||
return dict(txn)
|
||||
|
||||
def _get_oldest_events_in_room_txn(self, txn, room_id):
|
||||
return self._simple_select_onecol_txn(
|
||||
|
@ -201,19 +201,19 @@ class EventFederationStore(SQLBaseStore):
|
|||
def _update_min_depth_for_room_txn(self, txn, room_id, depth):
|
||||
min_depth = self._get_min_depth_interaction(txn, room_id)
|
||||
|
||||
do_insert = depth < min_depth if min_depth else True
|
||||
if min_depth and depth >= min_depth:
|
||||
return
|
||||
|
||||
if do_insert:
|
||||
self._simple_upsert_txn(
|
||||
txn,
|
||||
table="room_depth",
|
||||
keyvalues={
|
||||
"room_id": room_id,
|
||||
},
|
||||
values={
|
||||
"min_depth": depth,
|
||||
},
|
||||
)
|
||||
self._simple_upsert_txn(
|
||||
txn,
|
||||
table="room_depth",
|
||||
keyvalues={
|
||||
"room_id": room_id,
|
||||
},
|
||||
values={
|
||||
"min_depth": depth,
|
||||
},
|
||||
)
|
||||
|
||||
def _handle_mult_prev_events(self, txn, events):
|
||||
"""
|
||||
|
@ -334,8 +334,7 @@ class EventFederationStore(SQLBaseStore):
|
|||
|
||||
def get_forward_extremeties_for_room_txn(txn):
|
||||
txn.execute(sql, (stream_ordering, room_id))
|
||||
rows = txn.fetchall()
|
||||
return [event_id for event_id, in rows]
|
||||
return [event_id for event_id, in txn]
|
||||
|
||||
return self.runInteraction(
|
||||
"get_forward_extremeties_for_room",
|
||||
|
@ -436,7 +435,7 @@ class EventFederationStore(SQLBaseStore):
|
|||
(room_id, event_id, False, limit - len(event_results))
|
||||
)
|
||||
|
||||
for row in txn.fetchall():
|
||||
for row in txn:
|
||||
if row[1] not in event_results:
|
||||
queue.put((-row[0], row[1]))
|
||||
|
||||
|
@ -482,7 +481,7 @@ class EventFederationStore(SQLBaseStore):
|
|||
(room_id, event_id, False, limit - len(event_results))
|
||||
)
|
||||
|
||||
for e_id, in txn.fetchall():
|
||||
for e_id, in txn:
|
||||
new_front.add(e_id)
|
||||
|
||||
new_front -= earliest_events
|
||||
|
|
|
@ -206,7 +206,7 @@ class EventPushActionsStore(SQLBaseStore):
|
|||
" stream_ordering >= ? AND stream_ordering <= ?"
|
||||
)
|
||||
txn.execute(sql, (min_stream_ordering, max_stream_ordering))
|
||||
return [r[0] for r in txn.fetchall()]
|
||||
return [r[0] for r in txn]
|
||||
ret = yield self.runInteraction("get_push_action_users_in_range", f)
|
||||
defer.returnValue(ret)
|
||||
|
||||
|
|
|
@ -34,14 +34,16 @@ from canonicaljson import encode_canonical_json
|
|||
from collections import deque, namedtuple, OrderedDict
|
||||
from functools import wraps
|
||||
|
||||
import synapse
|
||||
import synapse.metrics
|
||||
|
||||
|
||||
import logging
|
||||
import math
|
||||
import ujson as json
|
||||
|
||||
# these are only included to make the type annotations work
|
||||
from synapse.events import EventBase # noqa: F401
|
||||
from synapse.events.snapshot import EventContext # noqa: F401
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -82,6 +84,11 @@ class _EventPeristenceQueue(object):
|
|||
|
||||
def add_to_queue(self, room_id, events_and_contexts, backfilled):
|
||||
"""Add events to the queue, with the given persist_event options.
|
||||
|
||||
Args:
|
||||
room_id (str):
|
||||
events_and_contexts (list[(EventBase, EventContext)]):
|
||||
backfilled (bool):
|
||||
"""
|
||||
queue = self._event_persist_queues.setdefault(room_id, deque())
|
||||
if queue:
|
||||
|
@ -210,14 +217,14 @@ class EventsStore(SQLBaseStore):
|
|||
partitioned.setdefault(event.room_id, []).append((event, ctx))
|
||||
|
||||
deferreds = []
|
||||
for room_id, evs_ctxs in partitioned.items():
|
||||
for room_id, evs_ctxs in partitioned.iteritems():
|
||||
d = preserve_fn(self._event_persist_queue.add_to_queue)(
|
||||
room_id, evs_ctxs,
|
||||
backfilled=backfilled,
|
||||
)
|
||||
deferreds.append(d)
|
||||
|
||||
for room_id in partitioned.keys():
|
||||
for room_id in partitioned:
|
||||
self._maybe_start_persisting(room_id)
|
||||
|
||||
return preserve_context_over_deferred(
|
||||
|
@ -227,6 +234,17 @@ class EventsStore(SQLBaseStore):
|
|||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def persist_event(self, event, context, backfilled=False):
|
||||
"""
|
||||
|
||||
Args:
|
||||
event (EventBase):
|
||||
context (EventContext):
|
||||
backfilled (bool):
|
||||
|
||||
Returns:
|
||||
Deferred: resolves to (int, int): the stream ordering of ``event``,
|
||||
and the stream ordering of the latest persisted event
|
||||
"""
|
||||
deferred = self._event_persist_queue.add_to_queue(
|
||||
event.room_id, [(event, context)],
|
||||
backfilled=backfilled,
|
||||
|
@ -253,6 +271,16 @@ class EventsStore(SQLBaseStore):
|
|||
@defer.inlineCallbacks
|
||||
def _persist_events(self, events_and_contexts, backfilled=False,
|
||||
delete_existing=False):
|
||||
"""Persist events to db
|
||||
|
||||
Args:
|
||||
events_and_contexts (list[(EventBase, EventContext)]):
|
||||
backfilled (bool):
|
||||
delete_existing (bool):
|
||||
|
||||
Returns:
|
||||
Deferred: resolves when the events have been persisted
|
||||
"""
|
||||
if not events_and_contexts:
|
||||
return
|
||||
|
||||
|
@ -295,7 +323,7 @@ class EventsStore(SQLBaseStore):
|
|||
(event, context)
|
||||
)
|
||||
|
||||
for room_id, ev_ctx_rm in events_by_room.items():
|
||||
for room_id, ev_ctx_rm in events_by_room.iteritems():
|
||||
# Work out new extremities by recursively adding and removing
|
||||
# the new events.
|
||||
latest_event_ids = yield self.get_latest_event_ids_in_room(
|
||||
|
@ -400,6 +428,7 @@ class EventsStore(SQLBaseStore):
|
|||
# Now we need to work out the different state sets for
|
||||
# each state extremities
|
||||
state_sets = []
|
||||
state_groups = set()
|
||||
missing_event_ids = []
|
||||
was_updated = False
|
||||
for event_id in new_latest_event_ids:
|
||||
|
@ -409,9 +438,17 @@ class EventsStore(SQLBaseStore):
|
|||
if event_id == ev.event_id:
|
||||
if ctx.current_state_ids is None:
|
||||
raise Exception("Unknown current state")
|
||||
state_sets.append(ctx.current_state_ids)
|
||||
if ctx.delta_ids or hasattr(ev, "state_key"):
|
||||
was_updated = True
|
||||
|
||||
# If we've already seen the state group don't bother adding
|
||||
# it to the state sets again
|
||||
if ctx.state_group not in state_groups:
|
||||
state_sets.append(ctx.current_state_ids)
|
||||
if ctx.delta_ids or hasattr(ev, "state_key"):
|
||||
was_updated = True
|
||||
if ctx.state_group:
|
||||
# Add this as a seen state group (if it has a state
|
||||
# group)
|
||||
state_groups.add(ctx.state_group)
|
||||
break
|
||||
else:
|
||||
# If we couldn't find it, then we'll need to pull
|
||||
|
@ -425,31 +462,57 @@ class EventsStore(SQLBaseStore):
|
|||
missing_event_ids,
|
||||
)
|
||||
|
||||
groups = set(event_to_groups.values())
|
||||
group_to_state = yield self._get_state_for_groups(groups)
|
||||
groups = set(event_to_groups.itervalues()) - state_groups
|
||||
|
||||
state_sets.extend(group_to_state.values())
|
||||
if groups:
|
||||
group_to_state = yield self._get_state_for_groups(groups)
|
||||
state_sets.extend(group_to_state.itervalues())
|
||||
|
||||
if not new_latest_event_ids:
|
||||
current_state = {}
|
||||
elif was_updated:
|
||||
current_state = yield resolve_events(
|
||||
state_sets,
|
||||
state_map_factory=lambda ev_ids: self.get_events(
|
||||
ev_ids, get_prev_content=False, check_redacted=False,
|
||||
),
|
||||
)
|
||||
if len(state_sets) == 1:
|
||||
# If there is only one state set, then we know what the current
|
||||
# state is.
|
||||
current_state = state_sets[0]
|
||||
else:
|
||||
# We work out the current state by passing the state sets to the
|
||||
# state resolution algorithm. It may ask for some events, including
|
||||
# the events we have yet to persist, so we need a slightly more
|
||||
# complicated event lookup function than simply looking the events
|
||||
# up in the db.
|
||||
events_map = {ev.event_id: ev for ev, _ in events_context}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_events(ev_ids):
|
||||
# We get the events by first looking at the list of events we
|
||||
# are trying to persist, and then fetching the rest from the DB.
|
||||
db = []
|
||||
to_return = {}
|
||||
for ev_id in ev_ids:
|
||||
ev = events_map.get(ev_id, None)
|
||||
if ev:
|
||||
to_return[ev_id] = ev
|
||||
else:
|
||||
db.append(ev_id)
|
||||
|
||||
if db:
|
||||
evs = yield self.get_events(
|
||||
ev_ids, get_prev_content=False, check_redacted=False,
|
||||
)
|
||||
to_return.update(evs)
|
||||
defer.returnValue(to_return)
|
||||
|
||||
current_state = yield resolve_events(
|
||||
state_sets,
|
||||
state_map_factory=get_events,
|
||||
)
|
||||
else:
|
||||
return
|
||||
|
||||
existing_state_rows = yield self._simple_select_list(
|
||||
table="current_state_events",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcols=["event_id", "type", "state_key"],
|
||||
desc="_calculate_state_delta",
|
||||
)
|
||||
existing_state = yield self.get_current_state_ids(room_id)
|
||||
|
||||
existing_events = set(row["event_id"] for row in existing_state_rows)
|
||||
existing_events = set(existing_state.itervalues())
|
||||
new_events = set(ev_id for ev_id in current_state.itervalues())
|
||||
changed_events = existing_events ^ new_events
|
||||
|
||||
|
@ -457,9 +520,8 @@ class EventsStore(SQLBaseStore):
|
|||
return
|
||||
|
||||
to_delete = {
|
||||
(row["type"], row["state_key"]): row["event_id"]
|
||||
for row in existing_state_rows
|
||||
if row["event_id"] in changed_events
|
||||
key: ev_id for key, ev_id in existing_state.iteritems()
|
||||
if ev_id in changed_events
|
||||
}
|
||||
events_to_insert = (new_events - existing_events)
|
||||
to_insert = {
|
||||
|
@ -535,11 +597,91 @@ class EventsStore(SQLBaseStore):
|
|||
and the rejections table. Things reading from those table will need to check
|
||||
whether the event was rejected.
|
||||
|
||||
If delete_existing is True then existing events will be purged from the
|
||||
database before insertion. This is useful when retrying due to IntegrityError.
|
||||
Args:
|
||||
txn (twisted.enterprise.adbapi.Connection): db connection
|
||||
events_and_contexts (list[(EventBase, EventContext)]):
|
||||
events to persist
|
||||
backfilled (bool): True if the events were backfilled
|
||||
delete_existing (bool): True to purge existing table rows for the
|
||||
events from the database. This is useful when retrying due to
|
||||
IntegrityError.
|
||||
current_state_for_room (dict[str, (list[str], list[str])]):
|
||||
The current-state delta for each room. For each room, a tuple
|
||||
(to_delete, to_insert), being a list of event ids to be removed
|
||||
from the current state, and a list of event ids to be added to
|
||||
the current state.
|
||||
new_forward_extremeties (dict[str, list[str]]):
|
||||
The new forward extremities for each room. For each room, a
|
||||
list of the event ids which are the forward extremities.
|
||||
|
||||
"""
|
||||
self._update_current_state_txn(txn, current_state_for_room)
|
||||
|
||||
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
|
||||
for room_id, current_state_tuple in current_state_for_room.iteritems():
|
||||
self._update_forward_extremities_txn(
|
||||
txn,
|
||||
new_forward_extremities=new_forward_extremeties,
|
||||
max_stream_order=max_stream_order,
|
||||
)
|
||||
|
||||
# Ensure that we don't have the same event twice.
|
||||
events_and_contexts = self._filter_events_and_contexts_for_duplicates(
|
||||
events_and_contexts,
|
||||
)
|
||||
|
||||
self._update_room_depths_txn(
|
||||
txn,
|
||||
events_and_contexts=events_and_contexts,
|
||||
backfilled=backfilled,
|
||||
)
|
||||
|
||||
# _update_outliers_txn filters out any events which have already been
|
||||
# persisted, and returns the filtered list.
|
||||
events_and_contexts = self._update_outliers_txn(
|
||||
txn,
|
||||
events_and_contexts=events_and_contexts,
|
||||
)
|
||||
|
||||
# From this point onwards the events are only events that we haven't
|
||||
# seen before.
|
||||
|
||||
if delete_existing:
|
||||
# For paranoia reasons, we go and delete all the existing entries
|
||||
# for these events so we can reinsert them.
|
||||
# This gets around any problems with some tables already having
|
||||
# entries.
|
||||
self._delete_existing_rows_txn(
|
||||
txn,
|
||||
events_and_contexts=events_and_contexts,
|
||||
)
|
||||
|
||||
self._store_event_txn(
|
||||
txn,
|
||||
events_and_contexts=events_and_contexts,
|
||||
)
|
||||
|
||||
# Insert into the state_groups, state_groups_state, and
|
||||
# event_to_state_groups tables.
|
||||
self._store_mult_state_groups_txn(txn, events_and_contexts)
|
||||
|
||||
# _store_rejected_events_txn filters out any events which were
|
||||
# rejected, and returns the filtered list.
|
||||
events_and_contexts = self._store_rejected_events_txn(
|
||||
txn,
|
||||
events_and_contexts=events_and_contexts,
|
||||
)
|
||||
|
||||
# From this point onwards the events are only ones that weren't
|
||||
# rejected.
|
||||
|
||||
self._update_metadata_tables_txn(
|
||||
txn,
|
||||
events_and_contexts=events_and_contexts,
|
||||
backfilled=backfilled,
|
||||
)
|
||||
|
||||
def _update_current_state_txn(self, txn, state_delta_by_room):
|
||||
for room_id, current_state_tuple in state_delta_by_room.iteritems():
|
||||
to_delete, to_insert = current_state_tuple
|
||||
txn.executemany(
|
||||
"DELETE FROM current_state_events WHERE event_id = ?",
|
||||
|
@ -585,7 +727,13 @@ class EventsStore(SQLBaseStore):
|
|||
txn, self.get_users_in_room, (room_id,)
|
||||
)
|
||||
|
||||
for room_id, new_extrem in new_forward_extremeties.items():
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_current_state_ids, (room_id,)
|
||||
)
|
||||
|
||||
def _update_forward_extremities_txn(self, txn, new_forward_extremities,
|
||||
max_stream_order):
|
||||
for room_id, new_extrem in new_forward_extremities.iteritems():
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="event_forward_extremities",
|
||||
|
@ -603,7 +751,7 @@ class EventsStore(SQLBaseStore):
|
|||
"event_id": ev_id,
|
||||
"room_id": room_id,
|
||||
}
|
||||
for room_id, new_extrem in new_forward_extremeties.items()
|
||||
for room_id, new_extrem in new_forward_extremities.iteritems()
|
||||
for ev_id in new_extrem
|
||||
],
|
||||
)
|
||||
|
@ -620,13 +768,22 @@ class EventsStore(SQLBaseStore):
|
|||
"event_id": event_id,
|
||||
"stream_ordering": max_stream_order,
|
||||
}
|
||||
for room_id, new_extrem in new_forward_extremeties.items()
|
||||
for room_id, new_extrem in new_forward_extremities.iteritems()
|
||||
for event_id in new_extrem
|
||||
]
|
||||
)
|
||||
|
||||
# Ensure that we don't have the same event twice.
|
||||
# Pick the earliest non-outlier if there is one, else the earliest one.
|
||||
@classmethod
|
||||
def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts):
|
||||
"""Ensure that we don't have the same event twice.
|
||||
|
||||
Pick the earliest non-outlier if there is one, else the earliest one.
|
||||
|
||||
Args:
|
||||
events_and_contexts (list[(EventBase, EventContext)]):
|
||||
Returns:
|
||||
list[(EventBase, EventContext)]: filtered list
|
||||
"""
|
||||
new_events_and_contexts = OrderedDict()
|
||||
for event, context in events_and_contexts:
|
||||
prev_event_context = new_events_and_contexts.get(event.event_id)
|
||||
|
@ -639,9 +796,17 @@ class EventsStore(SQLBaseStore):
|
|||
new_events_and_contexts[event.event_id] = (event, context)
|
||||
else:
|
||||
new_events_and_contexts[event.event_id] = (event, context)
|
||||
return new_events_and_contexts.values()
|
||||
|
||||
events_and_contexts = new_events_and_contexts.values()
|
||||
def _update_room_depths_txn(self, txn, events_and_contexts, backfilled):
|
||||
"""Update min_depth for each room
|
||||
|
||||
Args:
|
||||
txn (twisted.enterprise.adbapi.Connection): db connection
|
||||
events_and_contexts (list[(EventBase, EventContext)]): events
|
||||
we are persisting
|
||||
backfilled (bool): True if the events were backfilled
|
||||
"""
|
||||
depth_updates = {}
|
||||
for event, context in events_and_contexts:
|
||||
# Remove the any existing cache entries for the event_ids
|
||||
|
@ -657,9 +822,24 @@ class EventsStore(SQLBaseStore):
|
|||
event.depth, depth_updates.get(event.room_id, event.depth)
|
||||
)
|
||||
|
||||
for room_id, depth in depth_updates.items():
|
||||
for room_id, depth in depth_updates.iteritems():
|
||||
self._update_min_depth_for_room_txn(txn, room_id, depth)
|
||||
|
||||
def _update_outliers_txn(self, txn, events_and_contexts):
|
||||
"""Update any outliers with new event info.
|
||||
|
||||
This turns outliers into ex-outliers (unless the new event was
|
||||
rejected).
|
||||
|
||||
Args:
|
||||
txn (twisted.enterprise.adbapi.Connection): db connection
|
||||
events_and_contexts (list[(EventBase, EventContext)]): events
|
||||
we are persisting
|
||||
|
||||
Returns:
|
||||
list[(EventBase, EventContext)] new list, without events which
|
||||
are already in the events table.
|
||||
"""
|
||||
txn.execute(
|
||||
"SELECT event_id, outlier FROM events WHERE event_id in (%s)" % (
|
||||
",".join(["?"] * len(events_and_contexts)),
|
||||
|
@ -669,24 +849,21 @@ class EventsStore(SQLBaseStore):
|
|||
|
||||
have_persisted = {
|
||||
event_id: outlier
|
||||
for event_id, outlier in txn.fetchall()
|
||||
for event_id, outlier in txn
|
||||
}
|
||||
|
||||
to_remove = set()
|
||||
for event, context in events_and_contexts:
|
||||
if context.rejected:
|
||||
# If the event is rejected then we don't care if the event
|
||||
# was an outlier or not.
|
||||
if event.event_id in have_persisted:
|
||||
# If we have already seen the event then ignore it.
|
||||
to_remove.add(event)
|
||||
continue
|
||||
|
||||
if event.event_id not in have_persisted:
|
||||
continue
|
||||
|
||||
to_remove.add(event)
|
||||
|
||||
if context.rejected:
|
||||
# If the event is rejected then we don't care if the event
|
||||
# was an outlier or not.
|
||||
continue
|
||||
|
||||
outlier_persisted = have_persisted[event.event_id]
|
||||
if not event.internal_metadata.is_outlier() and outlier_persisted:
|
||||
# We received a copy of an event that we had already stored as
|
||||
|
@ -741,37 +918,19 @@ class EventsStore(SQLBaseStore):
|
|||
# event isn't an outlier any more.
|
||||
self._update_backward_extremeties(txn, [event])
|
||||
|
||||
events_and_contexts = [
|
||||
return [
|
||||
ec for ec in events_and_contexts if ec[0] not in to_remove
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _delete_existing_rows_txn(cls, txn, events_and_contexts):
|
||||
if not events_and_contexts:
|
||||
# Make sure we don't pass an empty list to functions that expect to
|
||||
# be storing at least one element.
|
||||
# nothing to do here
|
||||
return
|
||||
|
||||
# From this point onwards the events are only events that we haven't
|
||||
# seen before.
|
||||
logger.info("Deleting existing")
|
||||
|
||||
def event_dict(event):
|
||||
return {
|
||||
k: v
|
||||
for k, v in event.get_dict().items()
|
||||
if k not in [
|
||||
"redacted",
|
||||
"redacted_because",
|
||||
]
|
||||
}
|
||||
|
||||
if delete_existing:
|
||||
# For paranoia reasons, we go and delete all the existing entries
|
||||
# for these events so we can reinsert them.
|
||||
# This gets around any problems with some tables already having
|
||||
# entries.
|
||||
|
||||
logger.info("Deleting existing")
|
||||
|
||||
for table in (
|
||||
for table in (
|
||||
"events",
|
||||
"event_auth",
|
||||
"event_json",
|
||||
|
@ -794,11 +953,30 @@ class EventsStore(SQLBaseStore):
|
|||
"redactions",
|
||||
"room_memberships",
|
||||
"topics"
|
||||
):
|
||||
txn.executemany(
|
||||
"DELETE FROM %s WHERE event_id = ?" % (table,),
|
||||
[(ev.event_id,) for ev, _ in events_and_contexts]
|
||||
)
|
||||
):
|
||||
txn.executemany(
|
||||
"DELETE FROM %s WHERE event_id = ?" % (table,),
|
||||
[(ev.event_id,) for ev, _ in events_and_contexts]
|
||||
)
|
||||
|
||||
def _store_event_txn(self, txn, events_and_contexts):
|
||||
"""Insert new events into the event and event_json tables
|
||||
|
||||
Args:
|
||||
txn (twisted.enterprise.adbapi.Connection): db connection
|
||||
events_and_contexts (list[(EventBase, EventContext)]): events
|
||||
we are persisting
|
||||
"""
|
||||
|
||||
if not events_and_contexts:
|
||||
# nothing to do here
|
||||
return
|
||||
|
||||
def event_dict(event):
|
||||
d = event.get_dict()
|
||||
d.pop("redacted", None)
|
||||
d.pop("redacted_because", None)
|
||||
return d
|
||||
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
|
@ -842,6 +1020,19 @@ class EventsStore(SQLBaseStore):
|
|||
],
|
||||
)
|
||||
|
||||
def _store_rejected_events_txn(self, txn, events_and_contexts):
|
||||
"""Add rows to the 'rejections' table for received events which were
|
||||
rejected
|
||||
|
||||
Args:
|
||||
txn (twisted.enterprise.adbapi.Connection): db connection
|
||||
events_and_contexts (list[(EventBase, EventContext)]): events
|
||||
we are persisting
|
||||
|
||||
Returns:
|
||||
list[(EventBase, EventContext)] new list, without the rejected
|
||||
events.
|
||||
"""
|
||||
# Remove the rejected events from the list now that we've added them
|
||||
# to the events table and the events_json table.
|
||||
to_remove = set()
|
||||
|
@ -853,16 +1044,23 @@ class EventsStore(SQLBaseStore):
|
|||
)
|
||||
to_remove.add(event)
|
||||
|
||||
events_and_contexts = [
|
||||
return [
|
||||
ec for ec in events_and_contexts if ec[0] not in to_remove
|
||||
]
|
||||
|
||||
if not events_and_contexts:
|
||||
# Make sure we don't pass an empty list to functions that expect to
|
||||
# be storing at least one element.
|
||||
return
|
||||
def _update_metadata_tables_txn(self, txn, events_and_contexts, backfilled):
|
||||
"""Update all the miscellaneous tables for new events
|
||||
|
||||
# From this point onwards the events are only ones that weren't rejected.
|
||||
Args:
|
||||
txn (twisted.enterprise.adbapi.Connection): db connection
|
||||
events_and_contexts (list[(EventBase, EventContext)]): events
|
||||
we are persisting
|
||||
backfilled (bool): True if the events were backfilled
|
||||
"""
|
||||
|
||||
if not events_and_contexts:
|
||||
# nothing to do here
|
||||
return
|
||||
|
||||
for event, context in events_and_contexts:
|
||||
# Insert all the push actions into the event_push_actions table.
|
||||
|
@ -892,10 +1090,6 @@ class EventsStore(SQLBaseStore):
|
|||
],
|
||||
)
|
||||
|
||||
# Insert into the state_groups, state_groups_state, and
|
||||
# event_to_state_groups tables.
|
||||
self._store_mult_state_groups_txn(txn, events_and_contexts)
|
||||
|
||||
# Update the event_forward_extremities, event_backward_extremities and
|
||||
# event_edges tables.
|
||||
self._handle_mult_prev_events(
|
||||
|
@ -982,13 +1176,6 @@ class EventsStore(SQLBaseStore):
|
|||
# Prefill the event cache
|
||||
self._add_to_cache(txn, events_and_contexts)
|
||||
|
||||
if backfilled:
|
||||
# Backfilled events come before the current state so we don't need
|
||||
# to update the current state table
|
||||
return
|
||||
|
||||
return
|
||||
|
||||
def _add_to_cache(self, txn, events_and_contexts):
|
||||
to_prefill = []
|
||||
|
||||
|
@ -1597,14 +1784,13 @@ class EventsStore(SQLBaseStore):
|
|||
|
||||
def get_all_new_events_txn(txn):
|
||||
sql = (
|
||||
"SELECT e.stream_ordering, ej.internal_metadata, ej.json, eg.state_group"
|
||||
" FROM events as e"
|
||||
" JOIN event_json as ej"
|
||||
" ON e.event_id = ej.event_id AND e.room_id = ej.room_id"
|
||||
" LEFT JOIN event_to_state_groups as eg"
|
||||
" ON e.event_id = eg.event_id"
|
||||
" WHERE ? < e.stream_ordering AND e.stream_ordering <= ?"
|
||||
" ORDER BY e.stream_ordering ASC"
|
||||
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
|
||||
" state_key, redacts"
|
||||
" FROM events AS e"
|
||||
" LEFT JOIN redactions USING (event_id)"
|
||||
" LEFT JOIN state_events USING (event_id)"
|
||||
" WHERE ? < stream_ordering AND stream_ordering <= ?"
|
||||
" ORDER BY stream_ordering ASC"
|
||||
" LIMIT ?"
|
||||
)
|
||||
if have_forward_events:
|
||||
|
@ -1630,15 +1816,13 @@ class EventsStore(SQLBaseStore):
|
|||
forward_ex_outliers = []
|
||||
|
||||
sql = (
|
||||
"SELECT -e.stream_ordering, ej.internal_metadata, ej.json,"
|
||||
" eg.state_group"
|
||||
" FROM events as e"
|
||||
" JOIN event_json as ej"
|
||||
" ON e.event_id = ej.event_id AND e.room_id = ej.room_id"
|
||||
" LEFT JOIN event_to_state_groups as eg"
|
||||
" ON e.event_id = eg.event_id"
|
||||
" WHERE ? > e.stream_ordering AND e.stream_ordering >= ?"
|
||||
" ORDER BY e.stream_ordering DESC"
|
||||
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
|
||||
" state_key, redacts"
|
||||
" FROM events AS e"
|
||||
" LEFT JOIN redactions USING (event_id)"
|
||||
" LEFT JOIN state_events USING (event_id)"
|
||||
" WHERE ? > stream_ordering AND stream_ordering >= ?"
|
||||
" ORDER BY stream_ordering DESC"
|
||||
" LIMIT ?"
|
||||
)
|
||||
if have_backfill_events:
|
||||
|
@ -1825,7 +2009,7 @@ class EventsStore(SQLBaseStore):
|
|||
"state_key": key[1],
|
||||
"event_id": state_id,
|
||||
}
|
||||
for key, state_id in curr_state.items()
|
||||
for key, state_id in curr_state.iteritems()
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -101,9 +101,10 @@ class KeyStore(SQLBaseStore):
|
|||
key_ids
|
||||
Args:
|
||||
server_name (str): The name of the server.
|
||||
key_ids (list of str): List of key_ids to try and look up.
|
||||
key_ids (iterable[str]): key_ids to try and look up.
|
||||
Returns:
|
||||
(list of VerifyKey): The verification keys.
|
||||
Deferred: resolves to dict[str, VerifyKey]: map from
|
||||
key_id to verification key.
|
||||
"""
|
||||
keys = {}
|
||||
for key_id in key_ids:
|
||||
|
|
|
@ -356,7 +356,7 @@ def _get_or_create_schema_state(txn, database_engine):
|
|||
),
|
||||
(current_version,)
|
||||
)
|
||||
applied_deltas = [d for d, in txn.fetchall()]
|
||||
applied_deltas = [d for d, in txn]
|
||||
return current_version, applied_deltas, upgraded
|
||||
|
||||
return None
|
||||
|
|
|
@ -85,8 +85,8 @@ class PresenceStore(SQLBaseStore):
|
|||
self.presence_stream_cache.entity_has_changed,
|
||||
state.user_id, stream_id,
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self._get_presence_for_user, (state.user_id,)
|
||||
txn.call_after(
|
||||
self._get_presence_for_user.invalidate, (state.user_id,)
|
||||
)
|
||||
|
||||
# Actually insert new rows
|
||||
|
|
|
@ -313,10 +313,9 @@ class ReceiptsStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
txn.execute(sql, (room_id, receipt_type, user_id))
|
||||
results = txn.fetchall()
|
||||
|
||||
if results and topological_ordering:
|
||||
for to, so, _ in results:
|
||||
if topological_ordering:
|
||||
for to, so, _ in txn:
|
||||
if int(to) > topological_ordering:
|
||||
return False
|
||||
elif int(to) == topological_ordering and int(so) >= stream_ordering:
|
||||
|
|
|
@ -209,7 +209,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||
" WHERE lower(name) = lower(?)"
|
||||
)
|
||||
txn.execute(sql, (user_id,))
|
||||
return dict(txn.fetchall())
|
||||
return dict(txn)
|
||||
|
||||
return self.runInteraction("get_users_by_id_case_insensitive", f)
|
||||
|
||||
|
|
|
@ -396,7 +396,7 @@ class RoomStore(SQLBaseStore):
|
|||
sql % ("AND appservice_id IS NULL",),
|
||||
(stream_id,)
|
||||
)
|
||||
return dict(txn.fetchall())
|
||||
return dict(txn)
|
||||
else:
|
||||
# We want to get from all lists, so we need to aggregate the results
|
||||
|
||||
|
@ -422,7 +422,7 @@ class RoomStore(SQLBaseStore):
|
|||
|
||||
results = {}
|
||||
# A room is visible if its visible on any list.
|
||||
for room_id, visibility in txn.fetchall():
|
||||
for room_id, visibility in txn:
|
||||
results[room_id] = bool(visibility) or results.get(room_id, False)
|
||||
|
||||
return results
|
||||
|
|
|
@ -129,17 +129,30 @@ class RoomMemberStore(SQLBaseStore):
|
|||
with self._stream_id_gen.get_next() as stream_ordering:
|
||||
yield self.runInteraction("locally_reject_invite", f, stream_ordering)
|
||||
|
||||
@cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True)
|
||||
def get_hosts_in_room(self, room_id, cache_context):
|
||||
"""Returns the set of all hosts currently in the room
|
||||
"""
|
||||
user_ids = yield self.get_users_in_room(
|
||||
room_id, on_invalidate=cache_context.invalidate,
|
||||
)
|
||||
hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids)
|
||||
defer.returnValue(hosts)
|
||||
|
||||
@cached(max_entries=500000, iterable=True)
|
||||
def get_users_in_room(self, room_id):
|
||||
def f(txn):
|
||||
|
||||
rows = self._get_members_rows_txn(
|
||||
txn,
|
||||
room_id=room_id,
|
||||
membership=Membership.JOIN,
|
||||
sql = (
|
||||
"SELECT m.user_id FROM room_memberships as m"
|
||||
" INNER JOIN current_state_events as c"
|
||||
" ON m.event_id = c.event_id "
|
||||
" AND m.room_id = c.room_id "
|
||||
" AND m.user_id = c.state_key"
|
||||
" WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?"
|
||||
)
|
||||
|
||||
return [r["user_id"] for r in rows]
|
||||
txn.execute(sql, (room_id, Membership.JOIN,))
|
||||
return [r[0] for r in txn]
|
||||
return self.runInteraction("get_users_in_room", f)
|
||||
|
||||
@cached()
|
||||
|
@ -246,52 +259,27 @@ class RoomMemberStore(SQLBaseStore):
|
|||
|
||||
return results
|
||||
|
||||
def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None):
|
||||
where_clause = "c.room_id = ?"
|
||||
where_values = [room_id]
|
||||
|
||||
if membership:
|
||||
where_clause += " AND m.membership = ?"
|
||||
where_values.append(membership)
|
||||
|
||||
if user_id:
|
||||
where_clause += " AND m.user_id = ?"
|
||||
where_values.append(user_id)
|
||||
|
||||
sql = (
|
||||
"SELECT m.* FROM room_memberships as m"
|
||||
" INNER JOIN current_state_events as c"
|
||||
" ON m.event_id = c.event_id "
|
||||
" AND m.room_id = c.room_id "
|
||||
" AND m.user_id = c.state_key"
|
||||
" WHERE c.type = 'm.room.member' AND %(where)s"
|
||||
) % {
|
||||
"where": where_clause,
|
||||
}
|
||||
|
||||
txn.execute(sql, where_values)
|
||||
rows = self.cursor_to_dict(txn)
|
||||
|
||||
return rows
|
||||
|
||||
@cached(max_entries=500000, iterable=True)
|
||||
@cachedInlineCallbacks(max_entries=500000, iterable=True)
|
||||
def get_rooms_for_user(self, user_id):
|
||||
return self.get_rooms_for_user_where_membership_is(
|
||||
"""Returns a set of room_ids the user is currently joined to
|
||||
"""
|
||||
rooms = yield self.get_rooms_for_user_where_membership_is(
|
||||
user_id, membership_list=[Membership.JOIN],
|
||||
)
|
||||
defer.returnValue(frozenset(r.room_id for r in rooms))
|
||||
|
||||
@cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
|
||||
def get_users_who_share_room_with_user(self, user_id, cache_context):
|
||||
"""Returns the set of users who share a room with `user_id`
|
||||
"""
|
||||
rooms = yield self.get_rooms_for_user(
|
||||
room_ids = yield self.get_rooms_for_user(
|
||||
user_id, on_invalidate=cache_context.invalidate,
|
||||
)
|
||||
|
||||
user_who_share_room = set()
|
||||
for room in rooms:
|
||||
for room_id in room_ids:
|
||||
user_ids = yield self.get_users_in_room(
|
||||
room.room_id, on_invalidate=cache_context.invalidate,
|
||||
room_id, on_invalidate=cache_context.invalidate,
|
||||
)
|
||||
user_who_share_room.update(user_ids)
|
||||
|
||||
|
|
|
@ -72,7 +72,7 @@ class SignatureStore(SQLBaseStore):
|
|||
" WHERE event_id = ?"
|
||||
)
|
||||
txn.execute(query, (event_id, ))
|
||||
return {k: v for k, v in txn.fetchall()}
|
||||
return {k: v for k, v in txn}
|
||||
|
||||
def _store_event_reference_hashes_txn(self, txn, events):
|
||||
"""Store a hash for a PDU
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
|
||||
from synapse.util.caches import intern_string
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
|
||||
|
@ -69,6 +69,18 @@ class StateStore(SQLBaseStore):
|
|||
where_clause="type='m.room.member'",
|
||||
)
|
||||
|
||||
@cachedInlineCallbacks(max_entries=100000, iterable=True)
|
||||
def get_current_state_ids(self, room_id):
|
||||
rows = yield self._simple_select_list(
|
||||
table="current_state_events",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcols=["event_id", "type", "state_key"],
|
||||
desc="_calculate_state_delta",
|
||||
)
|
||||
defer.returnValue({
|
||||
(r["type"], r["state_key"]): r["event_id"] for r in rows
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_groups_ids(self, room_id, event_ids):
|
||||
if not event_ids:
|
||||
|
@ -78,7 +90,7 @@ class StateStore(SQLBaseStore):
|
|||
event_ids,
|
||||
)
|
||||
|
||||
groups = set(event_to_groups.values())
|
||||
groups = set(event_to_groups.itervalues())
|
||||
group_to_state = yield self._get_state_for_groups(groups)
|
||||
|
||||
defer.returnValue(group_to_state)
|
||||
|
@ -96,17 +108,18 @@ class StateStore(SQLBaseStore):
|
|||
|
||||
state_event_map = yield self.get_events(
|
||||
[
|
||||
ev_id for group_ids in group_to_ids.values()
|
||||
for ev_id in group_ids.values()
|
||||
ev_id for group_ids in group_to_ids.itervalues()
|
||||
for ev_id in group_ids.itervalues()
|
||||
],
|
||||
get_prev_content=False
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
group: [
|
||||
state_event_map[v] for v in event_id_map.values() if v in state_event_map
|
||||
state_event_map[v] for v in event_id_map.itervalues()
|
||||
if v in state_event_map
|
||||
]
|
||||
for group, event_id_map in group_to_ids.items()
|
||||
for group, event_id_map in group_to_ids.iteritems()
|
||||
})
|
||||
|
||||
def _have_persisted_state_group_txn(self, txn, state_group):
|
||||
|
@ -124,6 +137,16 @@ class StateStore(SQLBaseStore):
|
|||
continue
|
||||
|
||||
if context.current_state_ids is None:
|
||||
# AFAIK, this can never happen
|
||||
logger.error(
|
||||
"Non-outlier event %s had current_state_ids==None",
|
||||
event.event_id)
|
||||
continue
|
||||
|
||||
# if the event was rejected, just give it the same state as its
|
||||
# predecessor.
|
||||
if context.rejected:
|
||||
state_groups[event.event_id] = context.prev_group
|
||||
continue
|
||||
|
||||
state_groups[event.event_id] = context.state_group
|
||||
|
@ -168,7 +191,7 @@ class StateStore(SQLBaseStore):
|
|||
"state_key": key[1],
|
||||
"event_id": state_id,
|
||||
}
|
||||
for key, state_id in context.delta_ids.items()
|
||||
for key, state_id in context.delta_ids.iteritems()
|
||||
],
|
||||
)
|
||||
else:
|
||||
|
@ -183,7 +206,7 @@ class StateStore(SQLBaseStore):
|
|||
"state_key": key[1],
|
||||
"event_id": state_id,
|
||||
}
|
||||
for key, state_id in context.current_state_ids.items()
|
||||
for key, state_id in context.current_state_ids.iteritems()
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -195,7 +218,7 @@ class StateStore(SQLBaseStore):
|
|||
"state_group": state_group_id,
|
||||
"event_id": event_id,
|
||||
}
|
||||
for event_id, state_group_id in state_groups.items()
|
||||
for event_id, state_group_id in state_groups.iteritems()
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -319,10 +342,10 @@ class StateStore(SQLBaseStore):
|
|||
args.extend(where_args)
|
||||
|
||||
txn.execute(sql % (where_clause,), args)
|
||||
rows = self.cursor_to_dict(txn)
|
||||
for row in rows:
|
||||
key = (row["type"], row["state_key"])
|
||||
results[group][key] = row["event_id"]
|
||||
for row in txn:
|
||||
typ, state_key, event_id = row
|
||||
key = (typ, state_key)
|
||||
results[group][key] = event_id
|
||||
else:
|
||||
if types is not None:
|
||||
where_clause = "AND (%s)" % (
|
||||
|
@ -351,12 +374,11 @@ class StateStore(SQLBaseStore):
|
|||
" WHERE state_group = ? %s" % (where_clause,),
|
||||
args
|
||||
)
|
||||
rows = txn.fetchall()
|
||||
results[group].update({
|
||||
(typ, state_key): event_id
|
||||
for typ, state_key, event_id in rows
|
||||
results[group].update(
|
||||
((typ, state_key), event_id)
|
||||
for typ, state_key, event_id in txn
|
||||
if (typ, state_key) not in results[group]
|
||||
})
|
||||
)
|
||||
|
||||
# If the lengths match then we must have all the types,
|
||||
# so no need to go walk further down the tree.
|
||||
|
@ -393,21 +415,21 @@ class StateStore(SQLBaseStore):
|
|||
event_ids,
|
||||
)
|
||||
|
||||
groups = set(event_to_groups.values())
|
||||
groups = set(event_to_groups.itervalues())
|
||||
group_to_state = yield self._get_state_for_groups(groups, types)
|
||||
|
||||
state_event_map = yield self.get_events(
|
||||
[ev_id for sd in group_to_state.values() for ev_id in sd.values()],
|
||||
[ev_id for sd in group_to_state.itervalues() for ev_id in sd.itervalues()],
|
||||
get_prev_content=False
|
||||
)
|
||||
|
||||
event_to_state = {
|
||||
event_id: {
|
||||
k: state_event_map[v]
|
||||
for k, v in group_to_state[group].items()
|
||||
for k, v in group_to_state[group].iteritems()
|
||||
if v in state_event_map
|
||||
}
|
||||
for event_id, group in event_to_groups.items()
|
||||
for event_id, group in event_to_groups.iteritems()
|
||||
}
|
||||
|
||||
defer.returnValue({event: event_to_state[event] for event in event_ids})
|
||||
|
@ -430,12 +452,12 @@ class StateStore(SQLBaseStore):
|
|||
event_ids,
|
||||
)
|
||||
|
||||
groups = set(event_to_groups.values())
|
||||
groups = set(event_to_groups.itervalues())
|
||||
group_to_state = yield self._get_state_for_groups(groups, types)
|
||||
|
||||
event_to_state = {
|
||||
event_id: group_to_state[group]
|
||||
for event_id, group in event_to_groups.items()
|
||||
for event_id, group in event_to_groups.iteritems()
|
||||
}
|
||||
|
||||
defer.returnValue({event: event_to_state[event] for event in event_ids})
|
||||
|
@ -474,7 +496,7 @@ class StateStore(SQLBaseStore):
|
|||
state_map = yield self.get_state_ids_for_events([event_id], types)
|
||||
defer.returnValue(state_map[event_id])
|
||||
|
||||
@cached(num_args=2, max_entries=10000)
|
||||
@cached(num_args=2, max_entries=100000)
|
||||
def _get_state_group_for_event(self, room_id, event_id):
|
||||
return self._simple_select_one_onecol(
|
||||
table="event_to_state_groups",
|
||||
|
@ -547,7 +569,7 @@ class StateStore(SQLBaseStore):
|
|||
got_all = not (missing_types or types is None)
|
||||
|
||||
return {
|
||||
k: v for k, v in state_dict_ids.items()
|
||||
k: v for k, v in state_dict_ids.iteritems()
|
||||
if include(k[0], k[1])
|
||||
}, missing_types, got_all
|
||||
|
||||
|
@ -606,7 +628,7 @@ class StateStore(SQLBaseStore):
|
|||
|
||||
# Now we want to update the cache with all the things we fetched
|
||||
# from the database.
|
||||
for group, group_state_dict in group_to_state_dict.items():
|
||||
for group, group_state_dict in group_to_state_dict.iteritems():
|
||||
if types:
|
||||
# We delibrately put key -> None mappings into the cache to
|
||||
# cache absence of the key, on the assumption that if we've
|
||||
|
@ -621,10 +643,10 @@ class StateStore(SQLBaseStore):
|
|||
else:
|
||||
state_dict = results[group]
|
||||
|
||||
state_dict.update({
|
||||
(intern_string(k[0]), intern_string(k[1])): v
|
||||
for k, v in group_state_dict.items()
|
||||
})
|
||||
state_dict.update(
|
||||
((intern_string(k[0]), intern_string(k[1])), v)
|
||||
for k, v in group_state_dict.iteritems()
|
||||
)
|
||||
|
||||
self._state_group_cache.update(
|
||||
cache_seq_num,
|
||||
|
@ -635,10 +657,10 @@ class StateStore(SQLBaseStore):
|
|||
|
||||
# Remove all the entries with None values. The None values were just
|
||||
# used for bookkeeping in the cache.
|
||||
for group, state_dict in results.items():
|
||||
for group, state_dict in results.iteritems():
|
||||
results[group] = {
|
||||
key: event_id
|
||||
for key, event_id in state_dict.items()
|
||||
for key, event_id in state_dict.iteritems()
|
||||
if event_id
|
||||
}
|
||||
|
||||
|
@ -727,7 +749,7 @@ class StateStore(SQLBaseStore):
|
|||
# of keys
|
||||
|
||||
delta_state = {
|
||||
key: value for key, value in curr_state.items()
|
||||
key: value for key, value in curr_state.iteritems()
|
||||
if prev_state.get(key, None) != value
|
||||
}
|
||||
|
||||
|
@ -767,7 +789,7 @@ class StateStore(SQLBaseStore):
|
|||
"state_key": key[1],
|
||||
"event_id": state_id,
|
||||
}
|
||||
for key, state_id in delta_state.items()
|
||||
for key, state_id in delta_state.iteritems()
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -829,3 +829,6 @@ class StreamStore(SQLBaseStore):
|
|||
updatevalues={"stream_id": stream_id},
|
||||
desc="update_federation_out_pos",
|
||||
)
|
||||
|
||||
def has_room_changed_since(self, room_id, stream_id):
|
||||
return self._events_stream_cache.has_entity_changed(room_id, stream_id)
|
||||
|
|
|
@ -95,7 +95,7 @@ class TagsStore(SQLBaseStore):
|
|||
for stream_id, user_id, room_id in tag_ids:
|
||||
txn.execute(sql, (user_id, room_id))
|
||||
tags = []
|
||||
for tag, content in txn.fetchall():
|
||||
for tag, content in txn:
|
||||
tags.append(json.dumps(tag) + ":" + content)
|
||||
tag_json = "{" + ",".join(tags) + "}"
|
||||
results.append((stream_id, user_id, room_id, tag_json))
|
||||
|
@ -132,7 +132,7 @@ class TagsStore(SQLBaseStore):
|
|||
" WHERE user_id = ? AND stream_id > ?"
|
||||
)
|
||||
txn.execute(sql, (user_id, stream_id))
|
||||
room_ids = [row[0] for row in txn.fetchall()]
|
||||
room_ids = [row[0] for row in txn]
|
||||
return room_ids
|
||||
|
||||
changed = self._account_data_stream_cache.has_entity_changed(
|
||||
|
|
|
@ -30,6 +30,17 @@ class IdGenerator(object):
|
|||
|
||||
|
||||
def _load_current_id(db_conn, table, column, step=1):
|
||||
"""
|
||||
|
||||
Args:
|
||||
db_conn (object):
|
||||
table (str):
|
||||
column (str):
|
||||
step (int):
|
||||
|
||||
Returns:
|
||||
int
|
||||
"""
|
||||
cur = db_conn.cursor()
|
||||
if step == 1:
|
||||
cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
|
||||
|
@ -131,6 +142,9 @@ class StreamIdGenerator(object):
|
|||
def get_current_token(self):
|
||||
"""Returns the maximum stream id such that all stream ids less than or
|
||||
equal to it have been successfully persisted.
|
||||
|
||||
Returns:
|
||||
int
|
||||
"""
|
||||
with self._lock:
|
||||
if self._unfinished_ids:
|
||||
|
|
|
@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class DeferredTimedOutError(SynapseError):
|
||||
def __init__(self):
|
||||
super(SynapseError).__init__(504, "Timed out")
|
||||
super(SynapseError, self).__init__(504, "Timed out")
|
||||
|
||||
|
||||
def unwrapFirstError(failure):
|
||||
|
@ -93,8 +93,10 @@ class Clock(object):
|
|||
ret_deferred = defer.Deferred()
|
||||
|
||||
def timed_out_fn():
|
||||
e = DeferredTimedOutError()
|
||||
|
||||
try:
|
||||
ret_deferred.errback(DeferredTimedOutError())
|
||||
ret_deferred.errback(e)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
@ -114,7 +116,7 @@ class Clock(object):
|
|||
|
||||
ret_deferred.addBoth(cancel)
|
||||
|
||||
def sucess(res):
|
||||
def success(res):
|
||||
try:
|
||||
ret_deferred.callback(res)
|
||||
except:
|
||||
|
@ -128,7 +130,7 @@ class Clock(object):
|
|||
except:
|
||||
pass
|
||||
|
||||
given_deferred.addCallbacks(callback=sucess, errback=err)
|
||||
given_deferred.addCallbacks(callback=success, errback=err)
|
||||
|
||||
timer = self.call_later(time_out, timed_out_fn)
|
||||
|
||||
|
|
|
@ -15,12 +15,9 @@
|
|||
import logging
|
||||
|
||||
from synapse.util.async import ObservableDeferred
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util import unwrapFirstError, logcontext
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
|
||||
from synapse.util.logcontext import (
|
||||
PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
|
||||
)
|
||||
|
||||
from . import DEBUG_CACHES, register_cache
|
||||
|
||||
|
@ -189,7 +186,55 @@ class Cache(object):
|
|||
self.cache.clear()
|
||||
|
||||
|
||||
class CacheDescriptor(object):
|
||||
class _CacheDescriptorBase(object):
|
||||
def __init__(self, orig, num_args, inlineCallbacks, cache_context=False):
|
||||
self.orig = orig
|
||||
|
||||
if inlineCallbacks:
|
||||
self.function_to_call = defer.inlineCallbacks(orig)
|
||||
else:
|
||||
self.function_to_call = orig
|
||||
|
||||
arg_spec = inspect.getargspec(orig)
|
||||
all_args = arg_spec.args
|
||||
|
||||
if "cache_context" in all_args:
|
||||
if not cache_context:
|
||||
raise ValueError(
|
||||
"Cannot have a 'cache_context' arg without setting"
|
||||
" cache_context=True"
|
||||
)
|
||||
elif cache_context:
|
||||
raise ValueError(
|
||||
"Cannot have cache_context=True without having an arg"
|
||||
" named `cache_context`"
|
||||
)
|
||||
|
||||
if num_args is None:
|
||||
num_args = len(all_args) - 1
|
||||
if cache_context:
|
||||
num_args -= 1
|
||||
|
||||
if len(all_args) < num_args + 1:
|
||||
raise Exception(
|
||||
"Not enough explicit positional arguments to key off for %r: "
|
||||
"got %i args, but wanted %i. (@cached cannot key off *args or "
|
||||
"**kwargs)"
|
||||
% (orig.__name__, len(all_args), num_args)
|
||||
)
|
||||
|
||||
self.num_args = num_args
|
||||
self.arg_names = all_args[1:num_args + 1]
|
||||
|
||||
if "cache_context" in self.arg_names:
|
||||
raise Exception(
|
||||
"cache_context arg cannot be included among the cache keys"
|
||||
)
|
||||
|
||||
self.add_cache_context = cache_context
|
||||
|
||||
|
||||
class CacheDescriptor(_CacheDescriptorBase):
|
||||
""" A method decorator that applies a memoizing cache around the function.
|
||||
|
||||
This caches deferreds, rather than the results themselves. Deferreds that
|
||||
|
@ -217,52 +262,24 @@ class CacheDescriptor(object):
|
|||
r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
|
||||
defer.returnValue(r1 + r2)
|
||||
|
||||
Args:
|
||||
num_args (int): number of positional arguments (excluding ``self`` and
|
||||
``cache_context``) to use as cache keys. Defaults to all named
|
||||
args of the function.
|
||||
"""
|
||||
def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
|
||||
def __init__(self, orig, max_entries=1000, num_args=None, tree=False,
|
||||
inlineCallbacks=False, cache_context=False, iterable=False):
|
||||
|
||||
super(CacheDescriptor, self).__init__(
|
||||
orig, num_args=num_args, inlineCallbacks=inlineCallbacks,
|
||||
cache_context=cache_context)
|
||||
|
||||
max_entries = int(max_entries * CACHE_SIZE_FACTOR)
|
||||
|
||||
self.orig = orig
|
||||
|
||||
if inlineCallbacks:
|
||||
self.function_to_call = defer.inlineCallbacks(orig)
|
||||
else:
|
||||
self.function_to_call = orig
|
||||
|
||||
self.max_entries = max_entries
|
||||
self.num_args = num_args
|
||||
self.tree = tree
|
||||
|
||||
self.iterable = iterable
|
||||
|
||||
all_args = inspect.getargspec(orig)
|
||||
self.arg_names = all_args.args[1:num_args + 1]
|
||||
|
||||
if "cache_context" in all_args.args:
|
||||
if not cache_context:
|
||||
raise ValueError(
|
||||
"Cannot have a 'cache_context' arg without setting"
|
||||
" cache_context=True"
|
||||
)
|
||||
try:
|
||||
self.arg_names.remove("cache_context")
|
||||
except ValueError:
|
||||
pass
|
||||
elif cache_context:
|
||||
raise ValueError(
|
||||
"Cannot have cache_context=True without having an arg"
|
||||
" named `cache_context`"
|
||||
)
|
||||
|
||||
self.add_cache_context = cache_context
|
||||
|
||||
if len(self.arg_names) < self.num_args:
|
||||
raise Exception(
|
||||
"Not enough explicit positional arguments to key off of for %r."
|
||||
" (@cached cannot key off of *args or **kwargs)"
|
||||
% (orig.__name__,)
|
||||
)
|
||||
|
||||
def __get__(self, obj, objtype=None):
|
||||
cache = Cache(
|
||||
name=self.orig.__name__,
|
||||
|
@ -308,11 +325,9 @@ class CacheDescriptor(object):
|
|||
defer.returnValue(cached_result)
|
||||
observer.addCallback(check_result)
|
||||
|
||||
return preserve_context_over_deferred(observer)
|
||||
except KeyError:
|
||||
ret = defer.maybeDeferred(
|
||||
preserve_context_over_fn,
|
||||
self.function_to_call,
|
||||
logcontext.preserve_fn(self.function_to_call),
|
||||
obj, *args, **kwargs
|
||||
)
|
||||
|
||||
|
@ -322,10 +337,11 @@ class CacheDescriptor(object):
|
|||
|
||||
ret.addErrback(onErr)
|
||||
|
||||
ret = ObservableDeferred(ret, consumeErrors=True)
|
||||
cache.set(cache_key, ret, callback=invalidate_callback)
|
||||
result_d = ObservableDeferred(ret, consumeErrors=True)
|
||||
cache.set(cache_key, result_d, callback=invalidate_callback)
|
||||
observer = result_d.observe()
|
||||
|
||||
return preserve_context_over_deferred(ret.observe())
|
||||
return logcontext.make_deferred_yieldable(observer)
|
||||
|
||||
wrapped.invalidate = cache.invalidate
|
||||
wrapped.invalidate_all = cache.invalidate_all
|
||||
|
@ -338,48 +354,40 @@ class CacheDescriptor(object):
|
|||
return wrapped
|
||||
|
||||
|
||||
class CacheListDescriptor(object):
|
||||
class CacheListDescriptor(_CacheDescriptorBase):
|
||||
"""Wraps an existing cache to support bulk fetching of keys.
|
||||
|
||||
Given a list of keys it looks in the cache to find any hits, then passes
|
||||
the list of missing keys to the wrapped fucntion.
|
||||
the list of missing keys to the wrapped function.
|
||||
|
||||
Once wrapped, the function returns either a Deferred which resolves to
|
||||
the list of results, or (if all results were cached), just the list of
|
||||
results.
|
||||
"""
|
||||
|
||||
def __init__(self, orig, cached_method_name, list_name, num_args=1,
|
||||
def __init__(self, orig, cached_method_name, list_name, num_args=None,
|
||||
inlineCallbacks=False):
|
||||
"""
|
||||
Args:
|
||||
orig (function)
|
||||
method_name (str); The name of the chached method.
|
||||
cached_method_name (str): The name of the chached method.
|
||||
list_name (str): Name of the argument which is the bulk lookup list
|
||||
num_args (int)
|
||||
num_args (int): number of positional arguments (excluding ``self``,
|
||||
but including list_name) to use as cache keys. Defaults to all
|
||||
named args of the function.
|
||||
inlineCallbacks (bool): Whether orig is a generator that should
|
||||
be wrapped by defer.inlineCallbacks
|
||||
"""
|
||||
self.orig = orig
|
||||
super(CacheListDescriptor, self).__init__(
|
||||
orig, num_args=num_args, inlineCallbacks=inlineCallbacks)
|
||||
|
||||
if inlineCallbacks:
|
||||
self.function_to_call = defer.inlineCallbacks(orig)
|
||||
else:
|
||||
self.function_to_call = orig
|
||||
|
||||
self.num_args = num_args
|
||||
self.list_name = list_name
|
||||
|
||||
self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
|
||||
self.list_pos = self.arg_names.index(self.list_name)
|
||||
|
||||
self.cached_method_name = cached_method_name
|
||||
|
||||
self.sentinel = object()
|
||||
|
||||
if len(self.arg_names) < self.num_args:
|
||||
raise Exception(
|
||||
"Not enough explicit positional arguments to key off of for %r."
|
||||
" (@cached cannot key off of *args or **kwars)"
|
||||
% (orig.__name__,)
|
||||
)
|
||||
|
||||
if self.list_name not in self.arg_names:
|
||||
raise Exception(
|
||||
"Couldn't see arguments %r for %r."
|
||||
|
@ -425,8 +433,7 @@ class CacheListDescriptor(object):
|
|||
args_to_call[self.list_name] = missing
|
||||
|
||||
ret_d = defer.maybeDeferred(
|
||||
preserve_context_over_fn,
|
||||
self.function_to_call,
|
||||
logcontext.preserve_fn(self.function_to_call),
|
||||
**args_to_call
|
||||
)
|
||||
|
||||
|
@ -435,8 +442,7 @@ class CacheListDescriptor(object):
|
|||
# We need to create deferreds for each arg in the list so that
|
||||
# we can insert the new deferred into the cache.
|
||||
for arg in missing:
|
||||
with PreserveLoggingContext():
|
||||
observer = ret_d.observe()
|
||||
observer = ret_d.observe()
|
||||
observer.addCallback(lambda r, arg: r.get(arg, None), arg)
|
||||
|
||||
observer = ObservableDeferred(observer)
|
||||
|
@ -463,7 +469,7 @@ class CacheListDescriptor(object):
|
|||
results.update(res)
|
||||
return results
|
||||
|
||||
return preserve_context_over_deferred(defer.gatherResults(
|
||||
return logcontext.make_deferred_yieldable(defer.gatherResults(
|
||||
cached_defers.values(),
|
||||
consumeErrors=True,
|
||||
).addCallback(update_results_dict).addErrback(
|
||||
|
@ -487,7 +493,7 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
|
|||
self.cache.invalidate(self.key)
|
||||
|
||||
|
||||
def cached(max_entries=1000, num_args=1, tree=False, cache_context=False,
|
||||
def cached(max_entries=1000, num_args=None, tree=False, cache_context=False,
|
||||
iterable=False):
|
||||
return lambda orig: CacheDescriptor(
|
||||
orig,
|
||||
|
@ -499,8 +505,8 @@ def cached(max_entries=1000, num_args=1, tree=False, cache_context=False,
|
|||
)
|
||||
|
||||
|
||||
def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False,
|
||||
iterable=False):
|
||||
def cachedInlineCallbacks(max_entries=1000, num_args=None, tree=False,
|
||||
cache_context=False, iterable=False):
|
||||
return lambda orig: CacheDescriptor(
|
||||
orig,
|
||||
max_entries=max_entries,
|
||||
|
@ -512,7 +518,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex
|
|||
)
|
||||
|
||||
|
||||
def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False):
|
||||
def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=False):
|
||||
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`.
|
||||
|
||||
Used to do batch lookups for an already created cache. A single argument
|
||||
|
@ -525,7 +531,8 @@ def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False)
|
|||
cache (Cache): The underlying cache to use.
|
||||
list_name (str): The name of the argument that is the list to use to
|
||||
do batch lookups in the cache.
|
||||
num_args (int): Number of arguments to use as the key in the cache.
|
||||
num_args (int): Number of arguments to use as the key in the cache
|
||||
(including list_name). Defaults to all named parameters.
|
||||
inlineCallbacks (bool): Should the function be wrapped in an
|
||||
`defer.inlineCallbacks`?
|
||||
|
||||
|
|
|
@ -50,7 +50,7 @@ class StreamChangeCache(object):
|
|||
def has_entity_changed(self, entity, stream_pos):
|
||||
"""Returns True if the entity may have been updated since stream_pos
|
||||
"""
|
||||
assert type(stream_pos) is int
|
||||
assert type(stream_pos) is int or type(stream_pos) is long
|
||||
|
||||
if stream_pos < self._earliest_known_stream_pos:
|
||||
self.metrics.inc_misses()
|
||||
|
|
|
@ -12,6 +12,16 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
""" Thread-local-alike tracking of log contexts within synapse
|
||||
|
||||
This module provides objects and utilities for tracking contexts through
|
||||
synapse code, so that log lines can include a request identifier, and so that
|
||||
CPU and database activity can be accounted for against the request that caused
|
||||
them.
|
||||
|
||||
See doc/log_contexts.rst for details on how this works.
|
||||
"""
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
import threading
|
||||
|
@ -300,6 +310,10 @@ def preserve_context_over_fn(fn, *args, **kwargs):
|
|||
def preserve_context_over_deferred(deferred, context=None):
|
||||
"""Given a deferred wrap it such that any callbacks added later to it will
|
||||
be invoked with the current context.
|
||||
|
||||
Deprecated: this almost certainly doesn't do want you want, ie make
|
||||
the deferred follow the synapse logcontext rules: try
|
||||
``make_deferred_yieldable`` instead.
|
||||
"""
|
||||
if context is None:
|
||||
context = LoggingContext.current_context()
|
||||
|
@ -309,24 +323,65 @@ def preserve_context_over_deferred(deferred, context=None):
|
|||
|
||||
|
||||
def preserve_fn(f):
|
||||
"""Ensures that function is called with correct context and that context is
|
||||
restored after return. Useful for wrapping functions that return a deferred
|
||||
which you don't yield on.
|
||||
"""Wraps a function, to ensure that the current context is restored after
|
||||
return from the function, and that the sentinel context is set once the
|
||||
deferred returned by the funtion completes.
|
||||
|
||||
Useful for wrapping functions that return a deferred which you don't yield
|
||||
on.
|
||||
"""
|
||||
def reset_context(result):
|
||||
LoggingContext.set_current_context(LoggingContext.sentinel)
|
||||
return result
|
||||
|
||||
# XXX: why is this here rather than inside g? surely we want to preserve
|
||||
# the context from the time the function was called, not when it was
|
||||
# wrapped?
|
||||
current = LoggingContext.current_context()
|
||||
|
||||
def g(*args, **kwargs):
|
||||
with PreserveLoggingContext(current):
|
||||
res = f(*args, **kwargs)
|
||||
if isinstance(res, defer.Deferred):
|
||||
return preserve_context_over_deferred(
|
||||
res, context=LoggingContext.sentinel
|
||||
)
|
||||
else:
|
||||
return res
|
||||
res = f(*args, **kwargs)
|
||||
if isinstance(res, defer.Deferred) and not res.called:
|
||||
# The function will have reset the context before returning, so
|
||||
# we need to restore it now.
|
||||
LoggingContext.set_current_context(current)
|
||||
|
||||
# The original context will be restored when the deferred
|
||||
# completes, but there is nothing waiting for it, so it will
|
||||
# get leaked into the reactor or some other function which
|
||||
# wasn't expecting it. We therefore need to reset the context
|
||||
# here.
|
||||
#
|
||||
# (If this feels asymmetric, consider it this way: we are
|
||||
# effectively forking a new thread of execution. We are
|
||||
# probably currently within a ``with LoggingContext()`` block,
|
||||
# which is supposed to have a single entry and exit point. But
|
||||
# by spawning off another deferred, we are effectively
|
||||
# adding a new exit point.)
|
||||
res.addBoth(reset_context)
|
||||
return res
|
||||
return g
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def make_deferred_yieldable(deferred):
|
||||
"""Given a deferred, make it follow the Synapse logcontext rules:
|
||||
|
||||
If the deferred has completed (or is not actually a Deferred), essentially
|
||||
does nothing (just returns another completed deferred with the
|
||||
result/failure).
|
||||
|
||||
If the deferred has not yet completed, resets the logcontext before
|
||||
returning a deferred. Then, when the deferred completes, restores the
|
||||
current logcontext before running callbacks/errbacks.
|
||||
|
||||
(This is more-or-less the opposite operation to preserve_fn.)
|
||||
"""
|
||||
with PreserveLoggingContext():
|
||||
r = yield deferred
|
||||
defer.returnValue(r)
|
||||
|
||||
|
||||
# modules to ignore in `logcontext_tracer`
|
||||
_to_ignore = [
|
||||
"synapse.util.logcontext",
|
||||
|
|
40
synapse/util/msisdn.py
Normal file
40
synapse/util/msisdn.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2017 Vector Creations Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import phonenumbers
|
||||
from synapse.api.errors import SynapseError
|
||||
|
||||
|
||||
def phone_number_to_msisdn(country, number):
|
||||
"""
|
||||
Takes an ISO-3166-1 2 letter country code and phone number and
|
||||
returns an msisdn representing the canonical version of that
|
||||
phone number.
|
||||
Args:
|
||||
country (str): ISO-3166-1 2 letter country code
|
||||
number (str): Phone number in a national or international format
|
||||
|
||||
Returns:
|
||||
(str) The canonical form of the phone number, as an msisdn
|
||||
Raises:
|
||||
SynapseError if the number could not be parsed.
|
||||
"""
|
||||
try:
|
||||
phoneNumber = phonenumbers.parse(number, country)
|
||||
except phonenumbers.NumberParseException:
|
||||
raise SynapseError(400, "Unable to parse phone number")
|
||||
return phonenumbers.format_number(
|
||||
phoneNumber, phonenumbers.PhoneNumberFormat.E164
|
||||
)[1:]
|
|
@ -12,7 +12,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import synapse.util.logcontext
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import CodeMessageException
|
||||
|
@ -35,7 +35,8 @@ class NotRetryingDestination(Exception):
|
|||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_retry_limiter(destination, clock, store, **kwargs):
|
||||
def get_retry_limiter(destination, clock, store, ignore_backoff=False,
|
||||
**kwargs):
|
||||
"""For a given destination check if we have previously failed to
|
||||
send a request there and are waiting before retrying the destination.
|
||||
If we are not ready to retry the destination, this will raise a
|
||||
|
@ -43,6 +44,14 @@ def get_retry_limiter(destination, clock, store, **kwargs):
|
|||
that will mark the destination as down if an exception is thrown (excluding
|
||||
CodeMessageException with code < 500)
|
||||
|
||||
Args:
|
||||
destination (str): name of homeserver
|
||||
clock (synapse.util.clock): timing source
|
||||
store (synapse.storage.transactions.TransactionStore): datastore
|
||||
ignore_backoff (bool): true to ignore the historical backoff data and
|
||||
try the request anyway. We will still update the next
|
||||
retry_interval on success/failure.
|
||||
|
||||
Example usage:
|
||||
|
||||
try:
|
||||
|
@ -66,7 +75,7 @@ def get_retry_limiter(destination, clock, store, **kwargs):
|
|||
|
||||
now = int(clock.time_msec())
|
||||
|
||||
if retry_last_ts + retry_interval > now:
|
||||
if not ignore_backoff and retry_last_ts + retry_interval > now:
|
||||
raise NotRetryingDestination(
|
||||
retry_last_ts=retry_last_ts,
|
||||
retry_interval=retry_interval,
|
||||
|
@ -124,7 +133,13 @@ class RetryDestinationLimiter(object):
|
|||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
valid_err_code = False
|
||||
if exc_type is not None and issubclass(exc_type, CodeMessageException):
|
||||
if exc_type is None:
|
||||
valid_err_code = True
|
||||
elif not issubclass(exc_type, Exception):
|
||||
# avoid treating exceptions which don't derive from Exception as
|
||||
# failures; this is mostly so as not to catch defer._DefGen.
|
||||
valid_err_code = True
|
||||
elif issubclass(exc_type, CodeMessageException):
|
||||
# Some error codes are perfectly fine for some APIs, whereas other
|
||||
# APIs may expect to never received e.g. a 404. It's important to
|
||||
# handle 404 as some remote servers will return a 404 when the HS
|
||||
|
@ -142,11 +157,13 @@ class RetryDestinationLimiter(object):
|
|||
else:
|
||||
valid_err_code = False
|
||||
|
||||
if exc_type is None or valid_err_code:
|
||||
if valid_err_code:
|
||||
# We connected successfully.
|
||||
if not self.retry_interval:
|
||||
return
|
||||
|
||||
logger.debug("Connection to %s was successful; clearing backoff",
|
||||
self.destination)
|
||||
retry_last_ts = 0
|
||||
self.retry_interval = 0
|
||||
else:
|
||||
|
@ -160,6 +177,10 @@ class RetryDestinationLimiter(object):
|
|||
else:
|
||||
self.retry_interval = self.min_retry_interval
|
||||
|
||||
logger.debug(
|
||||
"Connection to %s was unsuccessful (%s(%s)); backoff now %i",
|
||||
self.destination, exc_type, exc_val, self.retry_interval
|
||||
)
|
||||
retry_last_ts = int(self.clock.time_msec())
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -173,4 +194,5 @@ class RetryDestinationLimiter(object):
|
|||
"Failed to store set_destination_retry_timings",
|
||||
)
|
||||
|
||||
store_retry_timings()
|
||||
# we deliberately do this in the background.
|
||||
synapse.util.logcontext.preserve_fn(store_retry_timings)()
|
||||
|
|
|
@ -134,6 +134,13 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
|
|||
if prev_membership not in MEMBERSHIP_PRIORITY:
|
||||
prev_membership = "leave"
|
||||
|
||||
# Always allow the user to see their own leave events, otherwise
|
||||
# they won't see the room disappear if they reject the invite
|
||||
if membership == "leave" and (
|
||||
prev_membership == "join" or prev_membership == "invite"
|
||||
):
|
||||
return True
|
||||
|
||||
new_priority = MEMBERSHIP_PRIORITY.index(membership)
|
||||
old_priority = MEMBERSHIP_PRIORITY.index(prev_membership)
|
||||
if old_priority < new_priority:
|
||||
|
|
|
@ -23,6 +23,9 @@ from tests.utils import (
|
|||
|
||||
from synapse.api.filtering import Filter
|
||||
from synapse.events import FrozenEvent
|
||||
from synapse.api.errors import SynapseError
|
||||
|
||||
import jsonschema
|
||||
|
||||
user_localpart = "test_user"
|
||||
|
||||
|
@ -54,6 +57,70 @@ class FilteringTestCase(unittest.TestCase):
|
|||
|
||||
self.datastore = hs.get_datastore()
|
||||
|
||||
def test_errors_on_invalid_filters(self):
|
||||
invalid_filters = [
|
||||
{"boom": {}},
|
||||
{"account_data": "Hello World"},
|
||||
{"event_fields": ["\\foo"]},
|
||||
{"room": {"timeline": {"limit": 0}, "state": {"not_bars": ["*"]}}},
|
||||
{"event_format": "other"},
|
||||
{"room": {"not_rooms": ["#foo:pik-test"]}},
|
||||
{"presence": {"senders": ["@bar;pik.test.com"]}}
|
||||
]
|
||||
for filter in invalid_filters:
|
||||
with self.assertRaises(SynapseError) as check_filter_error:
|
||||
self.filtering.check_valid_filter(filter)
|
||||
self.assertIsInstance(check_filter_error.exception, SynapseError)
|
||||
|
||||
def test_valid_filters(self):
|
||||
valid_filters = [
|
||||
{
|
||||
"room": {
|
||||
"timeline": {"limit": 20},
|
||||
"state": {"not_types": ["m.room.member"]},
|
||||
"ephemeral": {"limit": 0, "not_types": ["*"]},
|
||||
"include_leave": False,
|
||||
"rooms": ["!dee:pik-test"],
|
||||
"not_rooms": ["!gee:pik-test"],
|
||||
"account_data": {"limit": 0, "types": ["*"]}
|
||||
}
|
||||
},
|
||||
{
|
||||
"room": {
|
||||
"state": {
|
||||
"types": ["m.room.*"],
|
||||
"not_rooms": ["!726s6s6q:example.com"]
|
||||
},
|
||||
"timeline": {
|
||||
"limit": 10,
|
||||
"types": ["m.room.message"],
|
||||
"not_rooms": ["!726s6s6q:example.com"],
|
||||
"not_senders": ["@spam:example.com"]
|
||||
},
|
||||
"ephemeral": {
|
||||
"types": ["m.receipt", "m.typing"],
|
||||
"not_rooms": ["!726s6s6q:example.com"],
|
||||
"not_senders": ["@spam:example.com"]
|
||||
}
|
||||
},
|
||||
"presence": {
|
||||
"types": ["m.presence"],
|
||||
"not_senders": ["@alice:example.com"]
|
||||
},
|
||||
"event_format": "client",
|
||||
"event_fields": ["type", "content", "sender"]
|
||||
}
|
||||
]
|
||||
for filter in valid_filters:
|
||||
try:
|
||||
self.filtering.check_valid_filter(filter)
|
||||
except jsonschema.ValidationError as e:
|
||||
self.fail(e)
|
||||
|
||||
def test_limits_are_applied(self):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
def test_definition_types_works_with_literals(self):
|
||||
definition = {
|
||||
"types": ["m.room.message", "org.matrix.foo.bar"]
|
||||
|
|
|
@ -93,6 +93,7 @@ class DirectoryTestCase(unittest.TestCase):
|
|||
"room_alias": "#another:remote",
|
||||
},
|
||||
retry_on_dns_fail=False,
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
|
@ -324,7 +324,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
|||
state = UserPresenceState.default(user_id)
|
||||
state = state.copy_and_replace(
|
||||
state=PresenceState.ONLINE,
|
||||
last_active_ts=now,
|
||||
last_active_ts=0,
|
||||
last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1,
|
||||
)
|
||||
|
||||
|
|
|
@ -119,7 +119,8 @@ class ProfileTestCase(unittest.TestCase):
|
|||
self.mock_federation.make_query.assert_called_with(
|
||||
destination="remote",
|
||||
query_type="profile",
|
||||
args={"user_id": "@alice:remote", "field": "displayname"}
|
||||
args={"user_id": "@alice:remote", "field": "displayname"},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
|
@ -192,6 +192,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
|||
),
|
||||
json_data_callback=ANY,
|
||||
long_retries=True,
|
||||
backoff_on_404=True,
|
||||
),
|
||||
defer.succeed((200, "OK"))
|
||||
)
|
||||
|
@ -263,6 +264,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
|||
),
|
||||
json_data_callback=ANY,
|
||||
long_retries=True,
|
||||
backoff_on_404=True,
|
||||
),
|
||||
defer.succeed((200, "OK"))
|
||||
)
|
||||
|
|
|
@ -68,7 +68,7 @@ class ReplicationResourceCase(unittest.TestCase):
|
|||
code, body = yield get
|
||||
self.assertEquals(code, 200)
|
||||
self.assertEquals(body["events"]["field_names"], [
|
||||
"position", "internal", "json", "state_group"
|
||||
"position", "event_id", "room_id", "type", "state_key",
|
||||
])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
|
@ -33,8 +33,8 @@ PATH_PREFIX = "/_matrix/client/v2_alpha"
|
|||
class FilterTestCase(unittest.TestCase):
|
||||
|
||||
USER_ID = "@apple:test"
|
||||
EXAMPLE_FILTER = {"type": ["m.*"]}
|
||||
EXAMPLE_FILTER_JSON = '{"type": ["m.*"]}'
|
||||
EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}}
|
||||
EXAMPLE_FILTER_JSON = '{"room": {"timeline": {"types": ["m.room.message"]}}}'
|
||||
TO_REGISTER = [filter]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
|
@ -89,7 +89,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
|||
@defer.inlineCallbacks
|
||||
def test_select_one_1col(self):
|
||||
self.mock_txn.rowcount = 1
|
||||
self.mock_txn.fetchall.return_value = [("Value",)]
|
||||
self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
|
||||
|
||||
value = yield self.datastore._simple_select_one_onecol(
|
||||
table="tablename",
|
||||
|
@ -136,7 +136,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
|||
@defer.inlineCallbacks
|
||||
def test_select_list(self):
|
||||
self.mock_txn.rowcount = 3
|
||||
self.mock_txn.fetchall.return_value = ((1,), (2,), (3,))
|
||||
self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
|
||||
self.mock_txn.description = (
|
||||
("colA", None, None, None, None, None, None),
|
||||
)
|
||||
|
|
53
tests/storage/test_keys.py
Normal file
53
tests/storage/test_keys.py
Normal file
|
@ -0,0 +1,53 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2017 Vector Creations Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import signedjson.key
|
||||
from twisted.internet import defer
|
||||
|
||||
import tests.unittest
|
||||
import tests.utils
|
||||
|
||||
|
||||
class KeyStoreTestCase(tests.unittest.TestCase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(KeyStoreTestCase, self).__init__(*args, **kwargs)
|
||||
self.store = None # type: synapse.storage.keys.KeyStore
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
hs = yield tests.utils.setup_test_homeserver()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_server_verify_keys(self):
|
||||
key1 = signedjson.key.decode_verify_key_base64(
|
||||
"ed25519", "key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw"
|
||||
)
|
||||
key2 = signedjson.key.decode_verify_key_base64(
|
||||
"ed25519", "key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
|
||||
)
|
||||
yield self.store.store_server_verify_key(
|
||||
"server1", "from_server", 0, key1
|
||||
)
|
||||
yield self.store.store_server_verify_key(
|
||||
"server1", "from_server", 0, key2
|
||||
)
|
||||
|
||||
res = yield self.store.get_server_verify_keys(
|
||||
"server1", ["ed25519:key1", "ed25519:key2", "ed25519:key3"])
|
||||
|
||||
self.assertEqual(len(res.keys()), 2)
|
||||
self.assertEqual(res["ed25519:key1"].version, "key1")
|
||||
self.assertEqual(res["ed25519:key2"].version, "key2")
|
14
tests/util/caches/__init__.py
Normal file
14
tests/util/caches/__init__.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2017 Vector Creations Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
177
tests/util/caches/test_descriptors.py
Normal file
177
tests/util/caches/test_descriptors.py
Normal file
|
@ -0,0 +1,177 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
import mock
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.util import async
|
||||
from synapse.util import logcontext
|
||||
from twisted.internet import defer
|
||||
from synapse.util.caches import descriptors
|
||||
from tests import unittest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DescriptorTestCase(unittest.TestCase):
|
||||
@defer.inlineCallbacks
|
||||
def test_cache(self):
|
||||
class Cls(object):
|
||||
def __init__(self):
|
||||
self.mock = mock.Mock()
|
||||
|
||||
@descriptors.cached()
|
||||
def fn(self, arg1, arg2):
|
||||
return self.mock(arg1, arg2)
|
||||
|
||||
obj = Cls()
|
||||
|
||||
obj.mock.return_value = 'fish'
|
||||
r = yield obj.fn(1, 2)
|
||||
self.assertEqual(r, 'fish')
|
||||
obj.mock.assert_called_once_with(1, 2)
|
||||
obj.mock.reset_mock()
|
||||
|
||||
# a call with different params should call the mock again
|
||||
obj.mock.return_value = 'chips'
|
||||
r = yield obj.fn(1, 3)
|
||||
self.assertEqual(r, 'chips')
|
||||
obj.mock.assert_called_once_with(1, 3)
|
||||
obj.mock.reset_mock()
|
||||
|
||||
# the two values should now be cached
|
||||
r = yield obj.fn(1, 2)
|
||||
self.assertEqual(r, 'fish')
|
||||
r = yield obj.fn(1, 3)
|
||||
self.assertEqual(r, 'chips')
|
||||
obj.mock.assert_not_called()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_cache_num_args(self):
|
||||
"""Only the first num_args arguments should matter to the cache"""
|
||||
|
||||
class Cls(object):
|
||||
def __init__(self):
|
||||
self.mock = mock.Mock()
|
||||
|
||||
@descriptors.cached(num_args=1)
|
||||
def fn(self, arg1, arg2):
|
||||
return self.mock(arg1, arg2)
|
||||
|
||||
obj = Cls()
|
||||
obj.mock.return_value = 'fish'
|
||||
r = yield obj.fn(1, 2)
|
||||
self.assertEqual(r, 'fish')
|
||||
obj.mock.assert_called_once_with(1, 2)
|
||||
obj.mock.reset_mock()
|
||||
|
||||
# a call with different params should call the mock again
|
||||
obj.mock.return_value = 'chips'
|
||||
r = yield obj.fn(2, 3)
|
||||
self.assertEqual(r, 'chips')
|
||||
obj.mock.assert_called_once_with(2, 3)
|
||||
obj.mock.reset_mock()
|
||||
|
||||
# the two values should now be cached; we should be able to vary
|
||||
# the second argument and still get the cached result.
|
||||
r = yield obj.fn(1, 4)
|
||||
self.assertEqual(r, 'fish')
|
||||
r = yield obj.fn(2, 5)
|
||||
self.assertEqual(r, 'chips')
|
||||
obj.mock.assert_not_called()
|
||||
|
||||
def test_cache_logcontexts(self):
|
||||
"""Check that logcontexts are set and restored correctly when
|
||||
using the cache."""
|
||||
|
||||
complete_lookup = defer.Deferred()
|
||||
|
||||
class Cls(object):
|
||||
@descriptors.cached()
|
||||
def fn(self, arg1):
|
||||
@defer.inlineCallbacks
|
||||
def inner_fn():
|
||||
with logcontext.PreserveLoggingContext():
|
||||
yield complete_lookup
|
||||
defer.returnValue(1)
|
||||
|
||||
return inner_fn()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def do_lookup():
|
||||
with logcontext.LoggingContext() as c1:
|
||||
c1.name = "c1"
|
||||
r = yield obj.fn(1)
|
||||
self.assertEqual(logcontext.LoggingContext.current_context(),
|
||||
c1)
|
||||
defer.returnValue(r)
|
||||
|
||||
def check_result(r):
|
||||
self.assertEqual(r, 1)
|
||||
|
||||
obj = Cls()
|
||||
|
||||
# set off a deferred which will do a cache lookup
|
||||
d1 = do_lookup()
|
||||
self.assertEqual(logcontext.LoggingContext.current_context(),
|
||||
logcontext.LoggingContext.sentinel)
|
||||
d1.addCallback(check_result)
|
||||
|
||||
# and another
|
||||
d2 = do_lookup()
|
||||
self.assertEqual(logcontext.LoggingContext.current_context(),
|
||||
logcontext.LoggingContext.sentinel)
|
||||
d2.addCallback(check_result)
|
||||
|
||||
# let the lookup complete
|
||||
complete_lookup.callback(None)
|
||||
|
||||
return defer.gatherResults([d1, d2])
|
||||
|
||||
def test_cache_logcontexts_with_exception(self):
|
||||
"""Check that the cache sets and restores logcontexts correctly when
|
||||
the lookup function throws an exception"""
|
||||
|
||||
class Cls(object):
|
||||
@descriptors.cached()
|
||||
def fn(self, arg1):
|
||||
@defer.inlineCallbacks
|
||||
def inner_fn():
|
||||
yield async.run_on_reactor()
|
||||
raise SynapseError(400, "blah")
|
||||
|
||||
return inner_fn()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def do_lookup():
|
||||
with logcontext.LoggingContext() as c1:
|
||||
c1.name = "c1"
|
||||
try:
|
||||
yield obj.fn(1)
|
||||
self.fail("No exception thrown")
|
||||
except SynapseError:
|
||||
pass
|
||||
|
||||
self.assertEqual(logcontext.LoggingContext.current_context(),
|
||||
c1)
|
||||
|
||||
obj = Cls()
|
||||
|
||||
# set off a deferred which will do a cache lookup
|
||||
d1 = do_lookup()
|
||||
self.assertEqual(logcontext.LoggingContext.current_context(),
|
||||
logcontext.LoggingContext.sentinel)
|
||||
|
||||
return d1
|
33
tests/util/test_clock.py
Normal file
33
tests/util/test_clock.py
Normal file
|
@ -0,0 +1,33 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2017 Vector Creations Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from synapse import util
|
||||
from twisted.internet import defer
|
||||
from tests import unittest
|
||||
|
||||
|
||||
class ClockTestCase(unittest.TestCase):
|
||||
@defer.inlineCallbacks
|
||||
def test_time_bound_deferred(self):
|
||||
# just a deferred which never resolves
|
||||
slow_deferred = defer.Deferred()
|
||||
|
||||
clock = util.Clock()
|
||||
time_bound = clock.time_bound_deferred(slow_deferred, 0.001)
|
||||
|
||||
try:
|
||||
yield time_bound
|
||||
self.fail("Expected timedout error, but got nothing")
|
||||
except util.DeferredTimedOutError:
|
||||
pass
|
|
@ -1,8 +1,10 @@
|
|||
import twisted.python.failure
|
||||
from twisted.internet import defer
|
||||
from twisted.internet import reactor
|
||||
from .. import unittest
|
||||
|
||||
from synapse.util.async import sleep
|
||||
from synapse.util import logcontext
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
|
||||
|
||||
|
@ -33,3 +35,62 @@ class LoggingContextTestCase(unittest.TestCase):
|
|||
context_one.test_key = "one"
|
||||
yield sleep(0)
|
||||
self._check_test_key("one")
|
||||
|
||||
def _test_preserve_fn(self, function):
|
||||
sentinel_context = LoggingContext.current_context()
|
||||
|
||||
callback_completed = [False]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def cb():
|
||||
context_one.test_key = "one"
|
||||
yield function()
|
||||
self._check_test_key("one")
|
||||
|
||||
callback_completed[0] = True
|
||||
|
||||
with LoggingContext() as context_one:
|
||||
context_one.test_key = "one"
|
||||
|
||||
# fire off function, but don't wait on it.
|
||||
logcontext.preserve_fn(cb)()
|
||||
|
||||
self._check_test_key("one")
|
||||
|
||||
# now wait for the function under test to have run, and check that
|
||||
# the logcontext is left in a sane state.
|
||||
d2 = defer.Deferred()
|
||||
|
||||
def check_logcontext():
|
||||
if not callback_completed[0]:
|
||||
reactor.callLater(0.01, check_logcontext)
|
||||
return
|
||||
|
||||
# make sure that the context was reset before it got thrown back
|
||||
# into the reactor
|
||||
try:
|
||||
self.assertIs(LoggingContext.current_context(),
|
||||
sentinel_context)
|
||||
d2.callback(None)
|
||||
except BaseException:
|
||||
d2.errback(twisted.python.failure.Failure())
|
||||
|
||||
reactor.callLater(0.01, check_logcontext)
|
||||
|
||||
# test is done once d2 finishes
|
||||
return d2
|
||||
|
||||
def test_preserve_fn_with_blocking_fn(self):
|
||||
@defer.inlineCallbacks
|
||||
def blocking_function():
|
||||
yield sleep(0)
|
||||
|
||||
return self._test_preserve_fn(blocking_function)
|
||||
|
||||
def test_preserve_fn_with_non_blocking_fn(self):
|
||||
@defer.inlineCallbacks
|
||||
def nonblocking_function():
|
||||
with logcontext.PreserveLoggingContext():
|
||||
yield defer.succeed(None)
|
||||
|
||||
return self._test_preserve_fn(nonblocking_function)
|
||||
|
|
Loading…
Reference in a new issue