mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-16 23:11:34 +01:00
Merge branch 'develop' of github.com:matrix-org/synapse into erikj/fed_reader
This commit is contained in:
commit
5aa024e501
54 changed files with 1256 additions and 351 deletions
64
CHANGES.rst
64
CHANGES.rst
|
@ -1,3 +1,67 @@
|
|||
Changes in synapse v0.17.0-rc1 (2016-07-28)
|
||||
===========================================
|
||||
|
||||
This release changes the LDAP configuration format in a backwards incompatible
|
||||
way, see PR #843 for details.
|
||||
|
||||
This release contains significant security bug fixes regarding authenticating
|
||||
events received over federation. Please upgrade.
|
||||
|
||||
|
||||
Features:
|
||||
|
||||
* Add purge_media_cache admin API (PR #902)
|
||||
* Add deactivate account admin API (PR #903)
|
||||
* Add optional pepper to password hashing (PR #907, #910 by KentShikama)
|
||||
* Add an admin option to shared secret registration (breaks backwards compat)
|
||||
(PR #909)
|
||||
* Add purge local room history API (PR #911, #923, #924)
|
||||
* Add requestToken endpoints (PR #915)
|
||||
* Add an /account/deactivate endpoint (PR #921)
|
||||
* Add filter param to /messages. Add 'contains_url' to filter. (PR #922)
|
||||
* Add device_id support to /login (PR #929)
|
||||
* Add device_id support to /v2/register flow. (PR #937, #942)
|
||||
* Add GET /devices endpoint (PR #939, #944)
|
||||
* Add GET /device/{deviceId} (PR #943)
|
||||
* Add update and delete APIs for devices (PR #949)
|
||||
|
||||
|
||||
Changes:
|
||||
|
||||
* Rewrite LDAP Authentication against ldap3 (PR #843 by mweinelt)
|
||||
* Linearize some federation endpoints based on (origin, room_id) (PR #879)
|
||||
* Remove the legacy v0 content upload API. (PR #888)
|
||||
* Use similar naming we use in email notifs for push (PR #894)
|
||||
* Optionally include password hash in createUser endpoint (PR #905 by
|
||||
KentShikama)
|
||||
* Use a query that postgresql optimises better for get_events_around (PR #906)
|
||||
* Fall back to 'username' if 'user' is not given for appservice registration.
|
||||
(PR #927 by Half-Shot)
|
||||
* Add metrics for psutil derived memory usage (PR #936)
|
||||
* Record device_id in client_ips (PR #938)
|
||||
* Send the correct host header when fetching keys (PR #941)
|
||||
* Log the hostname the reCAPTCHA was completed on (PR #946)
|
||||
* Make the device id on e2e key upload optional (PR #956)
|
||||
* Add r0.2.0 to the "supported versions" list (PR #960)
|
||||
* Don't include name of room for invites in push (PR #961)
|
||||
|
||||
|
||||
Bug fixes:
|
||||
|
||||
* Fix substitution failure in mail template (PR #887)
|
||||
* Put most recent 20 messages in email notif (PR #892)
|
||||
* Ensure that the guest user is in the database when upgrading accounts
|
||||
(PR #914)
|
||||
* Fix various edge cases in auth handling (PR #919)
|
||||
* Fix 500 ISE when sending alias event without a state_key (PR #925)
|
||||
* Fix bug where we stored rejections in the state_group, persist all
|
||||
rejections (PR #948)
|
||||
* Fix lack of check of if the user is banned when handling 3pid invites
|
||||
(PR #952)
|
||||
* Fix a couple of bugs in the transaction and keyring code (PR #954, #955)
|
||||
|
||||
|
||||
|
||||
Changes in synapse v0.16.1-r1 (2016-07-08)
|
||||
==========================================
|
||||
|
||||
|
|
|
@ -445,7 +445,7 @@ You have two choices here, which will influence the form of your Matrix user
|
|||
IDs:
|
||||
|
||||
1) Use the machine's own hostname as available on public DNS in the form of
|
||||
its A or AAAA records. This is easier to set up initially, perhaps for
|
||||
its A records. This is easier to set up initially, perhaps for
|
||||
testing, but lacks the flexibility of SRV.
|
||||
|
||||
2) Set up a SRV record for your domain name. This requires you create a SRV
|
||||
|
|
12
docs/admin_api/README.rst
Normal file
12
docs/admin_api/README.rst
Normal file
|
@ -0,0 +1,12 @@
|
|||
Admin APIs
|
||||
==========
|
||||
|
||||
This directory includes documentation for the various synapse specific admin
|
||||
APIs available.
|
||||
|
||||
Only users that are server admins can use these APIs. A user can be marked as a
|
||||
server admin by updating the database directly, e.g.:
|
||||
|
||||
``UPDATE users SET admin = 1 WHERE name = '@foo:bar.com'``
|
||||
|
||||
Restarting may be required for the changes to register.
|
15
docs/admin_api/purge_history_api.rst
Normal file
15
docs/admin_api/purge_history_api.rst
Normal file
|
@ -0,0 +1,15 @@
|
|||
Purge History API
|
||||
=================
|
||||
|
||||
The purge history API allows server admins to purge historic events from their
|
||||
database, reclaiming disk space.
|
||||
|
||||
Depending on the amount of history being purged a call to the API may take
|
||||
several minutes or longer. During this period users will not be able to
|
||||
paginate further back in the room from the point being purged from.
|
||||
|
||||
The API is simply:
|
||||
|
||||
``POST /_matrix/client/r0/admin/purge_history/<room_id>/<event_id>``
|
||||
|
||||
including an ``access_token`` of a server admin.
|
19
docs/admin_api/purge_remote_media.rst
Normal file
19
docs/admin_api/purge_remote_media.rst
Normal file
|
@ -0,0 +1,19 @@
|
|||
Purge Remote Media API
|
||||
======================
|
||||
|
||||
The purge remote media API allows server admins to purge old cached remote
|
||||
media.
|
||||
|
||||
The API is::
|
||||
|
||||
POST /_matrix/client/r0/admin/purge_media_cache
|
||||
|
||||
{
|
||||
"before_ts": <unix_timestamp_in_ms>
|
||||
}
|
||||
|
||||
Which will remove all cached media that was last accessed before
|
||||
``<unix_timestamp_in_ms>``.
|
||||
|
||||
If the user re-requests purged remote media, synapse will re-request the media
|
||||
from the originating server.
|
|
@ -16,7 +16,5 @@ ignore =
|
|||
|
||||
[flake8]
|
||||
max-line-length = 90
|
||||
ignore = W503 ; W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.
|
||||
|
||||
[pep8]
|
||||
max-line-length = 90
|
||||
# W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.
|
||||
ignore = W503
|
||||
|
|
|
@ -16,4 +16,4 @@
|
|||
""" This is a reference implementation of a Matrix home server.
|
||||
"""
|
||||
|
||||
__version__ = "0.16.1-r1"
|
||||
__version__ = "0.17.0-rc1"
|
||||
|
|
|
@ -13,22 +13,22 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
import pymacaroons
|
||||
from canonicaljson import encode_canonical_json
|
||||
from signedjson.key import decode_verify_key_bytes
|
||||
from signedjson.sign import verify_signed_json, SignatureVerifyException
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership, JoinRules
|
||||
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
|
||||
from synapse.types import Requester, UserID, get_domain_from_id
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.logcontext import preserve_context_over_fn
|
||||
from synapse.util.metrics import Measure
|
||||
from unpaddedbase64 import decode_base64
|
||||
|
||||
import logging
|
||||
import pymacaroons
|
||||
import synapse.types
|
||||
from synapse.api.constants import EventTypes, Membership, JoinRules
|
||||
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
|
||||
from synapse.types import UserID, get_domain_from_id
|
||||
from synapse.util.logcontext import preserve_context_over_fn
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -376,6 +376,10 @@ class Auth(object):
|
|||
if Membership.INVITE == membership and "third_party_invite" in event.content:
|
||||
if not self._verify_third_party_invite(event, auth_events):
|
||||
raise AuthError(403, "You are not invited to this room.")
|
||||
if target_banned:
|
||||
raise AuthError(
|
||||
403, "%s is banned from the room" % (target_user_id,)
|
||||
)
|
||||
return True
|
||||
|
||||
if Membership.JOIN != membership:
|
||||
|
@ -566,8 +570,7 @@ class Auth(object):
|
|||
Args:
|
||||
request - An HTTP request with an access_token query parameter.
|
||||
Returns:
|
||||
defer.Deferred: resolves to a namedtuple including "user" (UserID)
|
||||
"access_token_id" (int), "is_guest" (bool)
|
||||
defer.Deferred: resolves to a ``synapse.types.Requester`` object
|
||||
Raises:
|
||||
AuthError if no user by that token exists or the token is invalid.
|
||||
"""
|
||||
|
@ -576,9 +579,7 @@ class Auth(object):
|
|||
user_id = yield self._get_appservice_user_id(request.args)
|
||||
if user_id:
|
||||
request.authenticated_entity = user_id
|
||||
defer.returnValue(
|
||||
Requester(UserID.from_string(user_id), "", False)
|
||||
)
|
||||
defer.returnValue(synapse.types.create_requester(user_id))
|
||||
|
||||
access_token = request.args["access_token"][0]
|
||||
user_info = yield self.get_user_by_access_token(access_token, rights)
|
||||
|
@ -612,7 +613,8 @@ class Auth(object):
|
|||
|
||||
request.authenticated_entity = user.to_string()
|
||||
|
||||
defer.returnValue(Requester(user, token_id, is_guest))
|
||||
defer.returnValue(synapse.types.create_requester(
|
||||
user, token_id, is_guest, device_id))
|
||||
except KeyError:
|
||||
raise AuthError(
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
|
||||
|
|
|
@ -16,13 +16,11 @@
|
|||
import sys
|
||||
sys.dont_write_bytecode = True
|
||||
|
||||
from synapse.python_dependencies import (
|
||||
check_requirements, MissingRequirementError
|
||||
) # NOQA
|
||||
from synapse import python_dependencies # noqa: E402
|
||||
|
||||
try:
|
||||
check_requirements()
|
||||
except MissingRequirementError as e:
|
||||
python_dependencies.check_requirements()
|
||||
except python_dependencies.MissingRequirementError as e:
|
||||
message = "\n".join([
|
||||
"Missing Requirement: %s" % (e.message,),
|
||||
"To install run:",
|
||||
|
|
|
@ -77,10 +77,12 @@ class SynapseKeyClientProtocol(HTTPClient):
|
|||
def __init__(self):
|
||||
self.remote_key = defer.Deferred()
|
||||
self.host = None
|
||||
self._peer = None
|
||||
|
||||
def connectionMade(self):
|
||||
self.host = self.transport.getHost()
|
||||
logger.debug("Connected to %s", self.host)
|
||||
self._peer = self.transport.getPeer()
|
||||
logger.debug("Connected to %s", self._peer)
|
||||
|
||||
self.sendCommand(b"GET", self.path)
|
||||
if self.host:
|
||||
self.sendHeader(b"Host", self.host)
|
||||
|
@ -124,7 +126,10 @@ class SynapseKeyClientProtocol(HTTPClient):
|
|||
self.timer.cancel()
|
||||
|
||||
def on_timeout(self):
|
||||
logger.debug("Timeout waiting for response from %s", self.host)
|
||||
logger.debug(
|
||||
"Timeout waiting for response from %s: %s",
|
||||
self.host, self._peer,
|
||||
)
|
||||
self.errback(IOError("Timeout waiting for response"))
|
||||
self.transport.abortConnection()
|
||||
|
||||
|
@ -133,4 +138,5 @@ class SynapseKeyClientFactory(Factory):
|
|||
def protocol(self):
|
||||
protocol = SynapseKeyClientProtocol()
|
||||
protocol.path = self.path
|
||||
protocol.host = self.host
|
||||
return protocol
|
||||
|
|
|
@ -44,7 +44,21 @@ import logging
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids"))
|
||||
VerifyKeyRequest = namedtuple("VerifyRequest", (
|
||||
"server_name", "key_ids", "json_object", "deferred"
|
||||
))
|
||||
"""
|
||||
A request for a verify key to verify a JSON object.
|
||||
|
||||
Attributes:
|
||||
server_name(str): The name of the server to verify against.
|
||||
key_ids(set(str)): The set of key_ids to that could be used to verify the
|
||||
JSON object
|
||||
json_object(dict): The JSON object to verify.
|
||||
deferred(twisted.internet.defer.Deferred):
|
||||
A deferred (server_name, key_id, verify_key) tuple that resolves when
|
||||
a verify key has been fetched
|
||||
"""
|
||||
|
||||
|
||||
class Keyring(object):
|
||||
|
@ -74,39 +88,32 @@ class Keyring(object):
|
|||
list of deferreds indicating success or failure to verify each
|
||||
json object's signature for the given server_name.
|
||||
"""
|
||||
group_id_to_json = {}
|
||||
group_id_to_group = {}
|
||||
group_ids = []
|
||||
|
||||
next_group_id = 0
|
||||
deferreds = {}
|
||||
verify_requests = []
|
||||
|
||||
for server_name, json_object in server_and_json:
|
||||
logger.debug("Verifying for %s", server_name)
|
||||
group_id = next_group_id
|
||||
next_group_id += 1
|
||||
group_ids.append(group_id)
|
||||
|
||||
key_ids = signature_ids(json_object, server_name)
|
||||
if not key_ids:
|
||||
deferreds[group_id] = defer.fail(SynapseError(
|
||||
deferred = defer.fail(SynapseError(
|
||||
400,
|
||||
"Not signed with a supported algorithm",
|
||||
Codes.UNAUTHORIZED,
|
||||
))
|
||||
else:
|
||||
deferreds[group_id] = defer.Deferred()
|
||||
deferred = defer.Deferred()
|
||||
|
||||
group = KeyGroup(server_name, group_id, key_ids)
|
||||
verify_request = VerifyKeyRequest(
|
||||
server_name, key_ids, json_object, deferred
|
||||
)
|
||||
|
||||
group_id_to_group[group_id] = group
|
||||
group_id_to_json[group_id] = json_object
|
||||
verify_requests.append(verify_request)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_key_deferred(group, deferred):
|
||||
server_name = group.server_name
|
||||
def handle_key_deferred(verify_request):
|
||||
server_name = verify_request.server_name
|
||||
try:
|
||||
_, _, key_id, verify_key = yield deferred
|
||||
_, key_id, verify_key = yield verify_request.deferred
|
||||
except IOError as e:
|
||||
logger.warn(
|
||||
"Got IOError when downloading keys for %s: %s %s",
|
||||
|
@ -128,7 +135,7 @@ class Keyring(object):
|
|||
Codes.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
json_object = group_id_to_json[group.group_id]
|
||||
json_object = verify_request.json_object
|
||||
|
||||
try:
|
||||
verify_signed_json(json_object, server_name, verify_key)
|
||||
|
@ -157,36 +164,34 @@ class Keyring(object):
|
|||
|
||||
# Actually start fetching keys.
|
||||
wait_on_deferred.addBoth(
|
||||
lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
|
||||
lambda _: self.get_server_verify_keys(verify_requests)
|
||||
)
|
||||
|
||||
# When we've finished fetching all the keys for a given server_name,
|
||||
# resolve the deferred passed to `wait_for_previous_lookups` so that
|
||||
# any lookups waiting will proceed.
|
||||
server_to_gids = {}
|
||||
server_to_request_ids = {}
|
||||
|
||||
def remove_deferreds(res, server_name, group_id):
|
||||
server_to_gids[server_name].discard(group_id)
|
||||
if not server_to_gids[server_name]:
|
||||
def remove_deferreds(res, server_name, verify_request):
|
||||
request_id = id(verify_request)
|
||||
server_to_request_ids[server_name].discard(request_id)
|
||||
if not server_to_request_ids[server_name]:
|
||||
d = server_to_deferred.pop(server_name, None)
|
||||
if d:
|
||||
d.callback(None)
|
||||
return res
|
||||
|
||||
for g_id, deferred in deferreds.items():
|
||||
server_name = group_id_to_group[g_id].server_name
|
||||
server_to_gids.setdefault(server_name, set()).add(g_id)
|
||||
deferred.addBoth(remove_deferreds, server_name, g_id)
|
||||
for verify_request in verify_requests:
|
||||
server_name = verify_request.server_name
|
||||
request_id = id(verify_request)
|
||||
server_to_request_ids.setdefault(server_name, set()).add(request_id)
|
||||
deferred.addBoth(remove_deferreds, server_name, verify_request)
|
||||
|
||||
# Pass those keys to handle_key_deferred so that the json object
|
||||
# signatures can be verified
|
||||
return [
|
||||
preserve_context_over_fn(
|
||||
handle_key_deferred,
|
||||
group_id_to_group[g_id],
|
||||
deferreds[g_id],
|
||||
)
|
||||
for g_id in group_ids
|
||||
preserve_context_over_fn(handle_key_deferred, verify_request)
|
||||
for verify_request in verify_requests
|
||||
]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -220,7 +225,7 @@ class Keyring(object):
|
|||
|
||||
d.addBoth(rm, server_name)
|
||||
|
||||
def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred):
|
||||
def get_server_verify_keys(self, verify_requests):
|
||||
"""Takes a dict of KeyGroups and tries to find at least one key for
|
||||
each group.
|
||||
"""
|
||||
|
@ -237,62 +242,64 @@ class Keyring(object):
|
|||
merged_results = {}
|
||||
|
||||
missing_keys = {}
|
||||
for group in group_id_to_group.values():
|
||||
missing_keys.setdefault(group.server_name, set()).update(
|
||||
group.key_ids
|
||||
for verify_request in verify_requests:
|
||||
missing_keys.setdefault(verify_request.server_name, set()).update(
|
||||
verify_request.key_ids
|
||||
)
|
||||
|
||||
for fn in key_fetch_fns:
|
||||
results = yield fn(missing_keys.items())
|
||||
merged_results.update(results)
|
||||
|
||||
# We now need to figure out which groups we have keys for
|
||||
# and which we don't
|
||||
missing_groups = {}
|
||||
for group in group_id_to_group.values():
|
||||
for key_id in group.key_ids:
|
||||
if key_id in merged_results[group.server_name]:
|
||||
# We now need to figure out which verify requests we have keys
|
||||
# for and which we don't
|
||||
missing_keys = {}
|
||||
requests_missing_keys = []
|
||||
for verify_request in verify_requests:
|
||||
server_name = verify_request.server_name
|
||||
result_keys = merged_results[server_name]
|
||||
|
||||
if verify_request.deferred.called:
|
||||
# We've already called this deferred, which probably
|
||||
# means that we've already found a key for it.
|
||||
continue
|
||||
|
||||
for key_id in verify_request.key_ids:
|
||||
if key_id in result_keys:
|
||||
with PreserveLoggingContext():
|
||||
group_id_to_deferred[group.group_id].callback((
|
||||
group.group_id,
|
||||
group.server_name,
|
||||
verify_request.deferred.callback((
|
||||
server_name,
|
||||
key_id,
|
||||
merged_results[group.server_name][key_id],
|
||||
result_keys[key_id],
|
||||
))
|
||||
break
|
||||
else:
|
||||
missing_groups.setdefault(
|
||||
group.server_name, []
|
||||
).append(group)
|
||||
# The else block is only reached if the loop above
|
||||
# doesn't break.
|
||||
missing_keys.setdefault(server_name, set()).update(
|
||||
verify_request.key_ids
|
||||
)
|
||||
requests_missing_keys.append(verify_request)
|
||||
|
||||
if not missing_groups:
|
||||
if not missing_keys:
|
||||
break
|
||||
|
||||
missing_keys = {
|
||||
server_name: set(
|
||||
key_id for group in groups for key_id in group.key_ids
|
||||
)
|
||||
for server_name, groups in missing_groups.items()
|
||||
}
|
||||
|
||||
for group in missing_groups.values():
|
||||
group_id_to_deferred[group.group_id].errback(SynapseError(
|
||||
for verify_request in requests_missing_keys.values():
|
||||
verify_request.deferred.errback(SynapseError(
|
||||
401,
|
||||
"No key for %s with id %s" % (
|
||||
group.server_name, group.key_ids,
|
||||
verify_request.server_name, verify_request.key_ids,
|
||||
),
|
||||
Codes.UNAUTHORIZED,
|
||||
))
|
||||
|
||||
def on_err(err):
|
||||
for deferred in group_id_to_deferred.values():
|
||||
if not deferred.called:
|
||||
deferred.errback(err)
|
||||
for verify_request in verify_requests:
|
||||
if not verify_request.deferred.called:
|
||||
verify_request.deferred.errback(err)
|
||||
|
||||
do_iterations().addErrback(on_err)
|
||||
|
||||
return group_id_to_deferred
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_keys_from_store(self, server_name_and_key_ids):
|
||||
res = yield defer.gatherResults(
|
||||
|
@ -447,7 +454,7 @@ class Keyring(object):
|
|||
)
|
||||
|
||||
processed_response = yield self.process_v2_response(
|
||||
perspective_name, response
|
||||
perspective_name, response, only_from_server=False
|
||||
)
|
||||
|
||||
for server_name, response_keys in processed_response.items():
|
||||
|
@ -527,7 +534,7 @@ class Keyring(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def process_v2_response(self, from_server, response_json,
|
||||
requested_ids=[]):
|
||||
requested_ids=[], only_from_server=True):
|
||||
time_now_ms = self.clock.time_msec()
|
||||
response_keys = {}
|
||||
verify_keys = {}
|
||||
|
@ -551,6 +558,13 @@ class Keyring(object):
|
|||
|
||||
results = {}
|
||||
server_name = response_json["server_name"]
|
||||
if only_from_server:
|
||||
if server_name != from_server:
|
||||
raise ValueError(
|
||||
"Expected a response for server %r not %r" % (
|
||||
from_server, server_name
|
||||
)
|
||||
)
|
||||
for key_id in response_json["signatures"].get(server_name, {}):
|
||||
if key_id not in response_json["verify_keys"]:
|
||||
raise ValueError(
|
||||
|
|
|
@ -13,14 +13,14 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import LimitExceededError
|
||||
import synapse.types
|
||||
from synapse.api.constants import Membership, EventTypes
|
||||
from synapse.types import UserID, Requester
|
||||
|
||||
|
||||
import logging
|
||||
from synapse.api.errors import LimitExceededError
|
||||
from synapse.types import UserID
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -124,7 +124,8 @@ class BaseHandler(object):
|
|||
# and having homeservers have their own users leave keeps more
|
||||
# of that decision-making and control local to the guest-having
|
||||
# homeserver.
|
||||
requester = Requester(target_user, "", True)
|
||||
requester = synapse.types.create_requester(
|
||||
target_user, is_guest=True)
|
||||
handler = self.hs.get_handlers().room_member_handler
|
||||
yield handler.update_membership(
|
||||
requester,
|
||||
|
|
|
@ -77,6 +77,7 @@ class AuthHandler(BaseHandler):
|
|||
self.ldap_bind_password = hs.config.ldap_bind_password
|
||||
|
||||
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
||||
self.device_handler = hs.get_device_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_auth(self, flows, clientdict, clientip):
|
||||
|
@ -279,7 +280,16 @@ class AuthHandler(BaseHandler):
|
|||
data = pde.response
|
||||
resp_body = simplejson.loads(data)
|
||||
|
||||
if 'success' in resp_body and resp_body['success']:
|
||||
if 'success' in resp_body:
|
||||
# Note that we do NOT check the hostname here: we explicitly
|
||||
# intend the CAPTCHA to be presented by whatever client the
|
||||
# user is using, we just care that they have completed a CAPTCHA.
|
||||
logger.info(
|
||||
"%s reCAPTCHA from hostname %s",
|
||||
"Successful" if resp_body['success'] else "Failed",
|
||||
resp_body.get('hostname')
|
||||
)
|
||||
if resp_body['success']:
|
||||
defer.returnValue(True)
|
||||
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
|
@ -365,7 +375,8 @@ class AuthHandler(BaseHandler):
|
|||
return self._check_password(user_id, password)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_login_tuple_for_user_id(self, user_id, device_id=None):
|
||||
def get_login_tuple_for_user_id(self, user_id, device_id=None,
|
||||
initial_display_name=None):
|
||||
"""
|
||||
Gets login tuple for the user with the given user ID.
|
||||
|
||||
|
@ -374,9 +385,15 @@ class AuthHandler(BaseHandler):
|
|||
The user is assumed to have been authenticated by some other
|
||||
machanism (e.g. CAS), and the user_id converted to the canonical case.
|
||||
|
||||
The device will be recorded in the table if it is not there already.
|
||||
|
||||
Args:
|
||||
user_id (str): canonical User ID
|
||||
device_id (str): the device ID to associate with the access token
|
||||
device_id (str|None): the device ID to associate with the tokens.
|
||||
None to leave the tokens unassociated with a device (deprecated:
|
||||
we should always have a device ID)
|
||||
initial_display_name (str): display name to associate with the
|
||||
device if it needs re-registering
|
||||
Returns:
|
||||
A tuple of:
|
||||
The access token for the user's session.
|
||||
|
@ -388,6 +405,16 @@ class AuthHandler(BaseHandler):
|
|||
logger.info("Logging in user %s on device %s", user_id, device_id)
|
||||
access_token = yield self.issue_access_token(user_id, device_id)
|
||||
refresh_token = yield self.issue_refresh_token(user_id, device_id)
|
||||
|
||||
# the device *should* have been registered before we got here; however,
|
||||
# it's possible we raced against a DELETE operation. The thing we
|
||||
# really don't want is active access_tokens without a record of the
|
||||
# device, so we double-check it here.
|
||||
if device_id is not None:
|
||||
yield self.device_handler.check_device_registered(
|
||||
user_id, device_id, initial_display_name
|
||||
)
|
||||
|
||||
defer.returnValue((access_token, refresh_token))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
|
@ -79,17 +79,17 @@ class DeviceHandler(BaseHandler):
|
|||
Args:
|
||||
user_id (str):
|
||||
Returns:
|
||||
defer.Deferred: dict[str, dict[str, X]]: map from device_id to
|
||||
info on the device
|
||||
defer.Deferred: list[dict[str, X]]: info on each device
|
||||
"""
|
||||
|
||||
devices = yield self.store.get_devices_by_user(user_id)
|
||||
device_map = yield self.store.get_devices_by_user(user_id)
|
||||
|
||||
ips = yield self.store.get_last_client_ip_by_device(
|
||||
devices=((user_id, device_id) for device_id in devices.keys())
|
||||
devices=((user_id, device_id) for device_id in device_map.keys())
|
||||
)
|
||||
|
||||
for device in devices.values():
|
||||
devices = device_map.values()
|
||||
for device in devices:
|
||||
_update_device_from_client_ips(device, ips)
|
||||
|
||||
defer.returnValue(devices)
|
||||
|
@ -100,7 +100,7 @@ class DeviceHandler(BaseHandler):
|
|||
|
||||
Args:
|
||||
user_id (str):
|
||||
device_id (str)
|
||||
device_id (str):
|
||||
|
||||
Returns:
|
||||
defer.Deferred: dict[str, X]: info on the device
|
||||
|
@ -117,6 +117,61 @@ class DeviceHandler(BaseHandler):
|
|||
_update_device_from_client_ips(device, ips)
|
||||
defer.returnValue(device)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_device(self, user_id, device_id):
|
||||
""" Delete the given device
|
||||
|
||||
Args:
|
||||
user_id (str):
|
||||
device_id (str):
|
||||
|
||||
Returns:
|
||||
defer.Deferred:
|
||||
"""
|
||||
|
||||
try:
|
||||
yield self.store.delete_device(user_id, device_id)
|
||||
except errors.StoreError, e:
|
||||
if e.code == 404:
|
||||
# no match
|
||||
pass
|
||||
else:
|
||||
raise
|
||||
|
||||
yield self.store.user_delete_access_tokens(
|
||||
user_id, device_id=device_id,
|
||||
delete_refresh_tokens=True,
|
||||
)
|
||||
|
||||
yield self.store.delete_e2e_keys_by_device(
|
||||
user_id=user_id, device_id=device_id
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_device(self, user_id, device_id, content):
|
||||
""" Update the given device
|
||||
|
||||
Args:
|
||||
user_id (str):
|
||||
device_id (str):
|
||||
content (dict): body of update request
|
||||
|
||||
Returns:
|
||||
defer.Deferred:
|
||||
"""
|
||||
|
||||
try:
|
||||
yield self.store.update_device(
|
||||
user_id,
|
||||
device_id,
|
||||
new_display_name=content.get("display_name")
|
||||
)
|
||||
except errors.StoreError, e:
|
||||
if e.code == 404:
|
||||
raise errors.NotFoundError()
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def _update_device_from_client_ips(device, client_ips):
|
||||
ip = client_ips.get((device["user_id"], device["device_id"]), {})
|
||||
|
|
|
@ -13,15 +13,15 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse.types
|
||||
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
|
||||
from synapse.types import UserID, Requester
|
||||
|
||||
from synapse.types import UserID
|
||||
from ._base import BaseHandler
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -165,7 +165,9 @@ class ProfileHandler(BaseHandler):
|
|||
try:
|
||||
# Assume the user isn't a guest because we don't let guests set
|
||||
# profile or avatar data.
|
||||
requester = Requester(user, "", False)
|
||||
# XXX why are we recreating `requester` here for each room?
|
||||
# what was wrong with the `requester` we were passed?
|
||||
requester = synapse.types.create_requester(user)
|
||||
yield handler.update_membership(
|
||||
requester,
|
||||
user,
|
||||
|
|
|
@ -14,18 +14,19 @@
|
|||
# limitations under the License.
|
||||
|
||||
"""Contains functions for registering clients."""
|
||||
import logging
|
||||
import urllib
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.types import UserID, Requester
|
||||
import synapse.types
|
||||
from synapse.api.errors import (
|
||||
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
|
||||
)
|
||||
from ._base import BaseHandler
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.http.client import CaptchaServerHttpClient
|
||||
|
||||
import logging
|
||||
import urllib
|
||||
from synapse.types import UserID
|
||||
from synapse.util.async import run_on_reactor
|
||||
from ._base import BaseHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -410,8 +411,9 @@ class RegistrationHandler(BaseHandler):
|
|||
if displayname is not None:
|
||||
logger.info("setting user display name: %s -> %s", user_id, displayname)
|
||||
profile_handler = self.hs.get_handlers().profile_handler
|
||||
requester = synapse.types.create_requester(user)
|
||||
yield profile_handler.set_displayname(
|
||||
user, Requester(user, token, False), displayname
|
||||
user, requester, displayname
|
||||
)
|
||||
|
||||
defer.returnValue((user_id, token))
|
||||
|
|
|
@ -14,24 +14,22 @@
|
|||
# limitations under the License.
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
from signedjson.key import decode_verify_key_bytes
|
||||
from signedjson.sign import verify_signed_json
|
||||
from twisted.internet import defer
|
||||
from unpaddedbase64 import decode_base64
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
from synapse.types import UserID, RoomID, Requester
|
||||
import synapse.types
|
||||
from synapse.api.constants import (
|
||||
EventTypes, Membership,
|
||||
)
|
||||
from synapse.api.errors import AuthError, SynapseError, Codes
|
||||
from synapse.types import UserID, RoomID
|
||||
from synapse.util.async import Linearizer
|
||||
from synapse.util.distributor import user_left_room, user_joined_room
|
||||
|
||||
from signedjson.sign import verify_signed_json
|
||||
from signedjson.key import decode_verify_key_bytes
|
||||
|
||||
from unpaddedbase64 import decode_base64
|
||||
|
||||
import logging
|
||||
from ._base import BaseHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -315,7 +313,7 @@ class RoomMemberHandler(BaseHandler):
|
|||
)
|
||||
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
|
||||
else:
|
||||
requester = Requester(target_user, None, False)
|
||||
requester = synapse.types.create_requester(target_user)
|
||||
|
||||
message_handler = self.hs.get_handlers().message_handler
|
||||
prev_event = message_handler.deduplicate_state_event(event, context)
|
||||
|
|
|
@ -205,6 +205,7 @@ class JsonResource(HttpServer, resource.Resource):
|
|||
|
||||
def register_paths(self, method, path_patterns, callback):
|
||||
for path_pattern in path_patterns:
|
||||
logger.debug("Registering for %s %s", method, path_pattern.pattern)
|
||||
self.path_regexs.setdefault(method, []).append(
|
||||
self._PathEntry(path_pattern, callback)
|
||||
)
|
||||
|
|
|
@ -140,9 +140,8 @@ class EmailPusher(object):
|
|||
being run.
|
||||
"""
|
||||
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
|
||||
unprocessed = yield self.store.get_unread_push_actions_for_user_in_range(
|
||||
self.user_id, start, self.max_stream_ordering
|
||||
)
|
||||
fn = self.store.get_unread_push_actions_for_user_in_range_for_email
|
||||
unprocessed = yield fn(self.user_id, start, self.max_stream_ordering)
|
||||
|
||||
soonest_due_at = None
|
||||
|
||||
|
|
|
@ -141,7 +141,8 @@ class HttpPusher(object):
|
|||
run once per pusher.
|
||||
"""
|
||||
|
||||
unprocessed = yield self.store.get_unread_push_actions_for_user_in_range(
|
||||
fn = self.store.get_unread_push_actions_for_user_in_range_for_http
|
||||
unprocessed = yield fn(
|
||||
self.user_id, self.last_stream_ordering, self.max_stream_ordering
|
||||
)
|
||||
|
||||
|
|
|
@ -54,7 +54,7 @@ def get_context_for_event(state_handler, ev, user_id):
|
|||
room_state = yield state_handler.get_current_state(ev.room_id)
|
||||
|
||||
# we no longer bother setting room_alias, and make room_name the
|
||||
# human-readable name instead, be that m.room.namer, an alias or
|
||||
# human-readable name instead, be that m.room.name, an alias or
|
||||
# a list of people in the room
|
||||
name = calculate_room_name(
|
||||
room_state, user_id, fallback_to_single_member=False
|
||||
|
|
|
@ -93,8 +93,11 @@ class SlavedEventStore(BaseSlavedStore):
|
|||
StreamStore.__dict__["get_recent_event_ids_for_room"]
|
||||
)
|
||||
|
||||
get_unread_push_actions_for_user_in_range = (
|
||||
DataStore.get_unread_push_actions_for_user_in_range.__func__
|
||||
get_unread_push_actions_for_user_in_range_for_http = (
|
||||
DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__
|
||||
)
|
||||
get_unread_push_actions_for_user_in_range_for_email = (
|
||||
DataStore.get_unread_push_actions_for_user_in_range_for_email.__func__
|
||||
)
|
||||
get_push_action_users_in_range = (
|
||||
DataStore.get_push_action_users_in_range.__func__
|
||||
|
|
|
@ -152,7 +152,10 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
)
|
||||
device_id = yield self._register_device(user_id, login_submission)
|
||||
access_token, refresh_token = (
|
||||
yield auth_handler.get_login_tuple_for_user_id(user_id, device_id)
|
||||
yield auth_handler.get_login_tuple_for_user_id(
|
||||
user_id, device_id,
|
||||
login_submission.get("initial_device_display_name")
|
||||
)
|
||||
)
|
||||
result = {
|
||||
"user_id": user_id, # may have changed
|
||||
|
@ -173,7 +176,10 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
)
|
||||
device_id = yield self._register_device(user_id, login_submission)
|
||||
access_token, refresh_token = (
|
||||
yield auth_handler.get_login_tuple_for_user_id(user_id, device_id)
|
||||
yield auth_handler.get_login_tuple_for_user_id(
|
||||
user_id, device_id,
|
||||
login_submission.get("initial_device_display_name")
|
||||
)
|
||||
)
|
||||
result = {
|
||||
"user_id": user_id, # may have changed
|
||||
|
@ -262,7 +268,8 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
)
|
||||
access_token, refresh_token = (
|
||||
yield auth_handler.get_login_tuple_for_user_id(
|
||||
registered_user_id, device_id
|
||||
registered_user_id, device_id,
|
||||
login_submission.get("initial_device_display_name")
|
||||
)
|
||||
)
|
||||
result = {
|
||||
|
|
|
@ -13,19 +13,17 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.http.servlet import RestServlet
|
||||
|
||||
from ._base import client_v2_patterns
|
||||
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.http import servlet
|
||||
from ._base import client_v2_patterns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DevicesRestServlet(RestServlet):
|
||||
class DevicesRestServlet(servlet.RestServlet):
|
||||
PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False)
|
||||
|
||||
def __init__(self, hs):
|
||||
|
@ -47,7 +45,7 @@ class DevicesRestServlet(RestServlet):
|
|||
defer.returnValue((200, {"devices": devices}))
|
||||
|
||||
|
||||
class DeviceRestServlet(RestServlet):
|
||||
class DeviceRestServlet(servlet.RestServlet):
|
||||
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
|
||||
releases=[], v2_alpha=False)
|
||||
|
||||
|
@ -70,6 +68,32 @@ class DeviceRestServlet(RestServlet):
|
|||
)
|
||||
defer.returnValue((200, device))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_DELETE(self, request, device_id):
|
||||
# XXX: it's not completely obvious we want to expose this endpoint.
|
||||
# It allows the client to delete access tokens, which feels like a
|
||||
# thing which merits extra auth. But if we want to do the interactive-
|
||||
# auth dance, we should really make it possible to delete more than one
|
||||
# device at a time.
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
yield self.device_handler.delete_device(
|
||||
requester.user.to_string(),
|
||||
device_id,
|
||||
)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, device_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
body = servlet.parse_json_object_from_request(request)
|
||||
yield self.device_handler.update_device(
|
||||
requester.user.to_string(),
|
||||
device_id,
|
||||
body
|
||||
)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
DevicesRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -13,24 +13,25 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
import simplejson as json
|
||||
from canonicaljson import encode_canonical_json
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse.api.errors
|
||||
import synapse.server
|
||||
import synapse.types
|
||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||
from synapse.types import UserID
|
||||
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
||||
from ._base import client_v2_patterns
|
||||
|
||||
import logging
|
||||
import simplejson as json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KeyUploadServlet(RestServlet):
|
||||
"""
|
||||
POST /keys/upload/<device_id> HTTP/1.1
|
||||
POST /keys/upload HTTP/1.1
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
|
@ -53,23 +54,45 @@ class KeyUploadServlet(RestServlet):
|
|||
},
|
||||
}
|
||||
"""
|
||||
PATTERNS = client_v2_patterns("/keys/upload/(?P<device_id>[^/]*)", releases=())
|
||||
PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$",
|
||||
releases=())
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): server
|
||||
"""
|
||||
super(KeyUploadServlet, self).__init__()
|
||||
self.store = hs.get_datastore()
|
||||
self.clock = hs.get_clock()
|
||||
self.auth = hs.get_auth()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, device_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
user_id = requester.user.to_string()
|
||||
# TODO: Check that the device_id matches that in the authentication
|
||||
# or derive the device_id from the authentication instead.
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
if device_id is not None:
|
||||
# passing the device_id here is deprecated; however, we allow it
|
||||
# for now for compatibility with older clients.
|
||||
if (requester.device_id is not None and
|
||||
device_id != requester.device_id):
|
||||
logger.warning("Client uploading keys for a different device "
|
||||
"(logged in as %s, uploading for %s)",
|
||||
requester.device_id, device_id)
|
||||
else:
|
||||
device_id = requester.device_id
|
||||
|
||||
if device_id is None:
|
||||
raise synapse.api.errors.SynapseError(
|
||||
400,
|
||||
"To upload keys, you must pass device_id when authenticating"
|
||||
)
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
|
||||
# TODO: Validate the JSON to make sure it has the right keys.
|
||||
|
@ -102,13 +125,14 @@ class KeyUploadServlet(RestServlet):
|
|||
user_id, device_id, time_now, key_list
|
||||
)
|
||||
|
||||
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
|
||||
defer.returnValue((200, {"one_time_key_counts": result}))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, device_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
# the device should have been registered already, but it may have been
|
||||
# deleted due to a race with a DELETE request. Or we may be using an
|
||||
# old access_token without an associated device_id. Either way, we
|
||||
# need to double-check the device is registered to avoid ending up with
|
||||
# keys without a corresponding device.
|
||||
self.device_handler.check_device_registered(
|
||||
user_id, device_id, "unknown device"
|
||||
)
|
||||
|
||||
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
|
||||
defer.returnValue((200, {"one_time_key_counts": result}))
|
||||
|
|
|
@ -374,13 +374,13 @@ class RegisterRestServlet(RestServlet):
|
|||
"""
|
||||
device_id = yield self._register_device(user_id, params)
|
||||
|
||||
access_token = yield self.auth_handler.issue_access_token(
|
||||
user_id, device_id=device_id
|
||||
access_token, refresh_token = (
|
||||
yield self.auth_handler.get_login_tuple_for_user_id(
|
||||
user_id, device_id=device_id,
|
||||
initial_display_name=params.get("initial_device_display_name")
|
||||
)
|
||||
)
|
||||
|
||||
refresh_token = yield self.auth_handler.issue_refresh_token(
|
||||
user_id, device_id=device_id
|
||||
)
|
||||
defer.returnValue({
|
||||
"user_id": user_id,
|
||||
"access_token": access_token,
|
||||
|
|
|
@ -26,7 +26,11 @@ class VersionsRestServlet(RestServlet):
|
|||
|
||||
def on_GET(self, request):
|
||||
return (200, {
|
||||
"versions": ["r0.0.1"]
|
||||
"versions": [
|
||||
"r0.0.1",
|
||||
"r0.1.0",
|
||||
"r0.2.0",
|
||||
]
|
||||
})
|
||||
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
from . import engines
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
|
@ -87,10 +88,12 @@ class BackgroundUpdateStore(SQLBaseStore):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def start_doing_background_updates(self):
|
||||
while True:
|
||||
if self._background_update_timer is not None:
|
||||
return
|
||||
assert self._background_update_timer is None, \
|
||||
"background updates already running"
|
||||
|
||||
logger.info("Starting background schema updates")
|
||||
|
||||
while True:
|
||||
sleep = defer.Deferred()
|
||||
self._background_update_timer = self._clock.call_later(
|
||||
self.BACKGROUND_UPDATE_INTERVAL_MS / 1000., sleep.callback, None
|
||||
|
@ -101,22 +104,23 @@ class BackgroundUpdateStore(SQLBaseStore):
|
|||
self._background_update_timer = None
|
||||
|
||||
try:
|
||||
result = yield self.do_background_update(
|
||||
result = yield self.do_next_background_update(
|
||||
self.BACKGROUND_UPDATE_DURATION_MS
|
||||
)
|
||||
except:
|
||||
logger.exception("Error doing update")
|
||||
|
||||
else:
|
||||
if result is None:
|
||||
logger.info(
|
||||
"No more background updates to do."
|
||||
" Unscheduling background update task."
|
||||
)
|
||||
return
|
||||
defer.returnValue(None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def do_background_update(self, desired_duration_ms):
|
||||
"""Does some amount of work on a background update
|
||||
def do_next_background_update(self, desired_duration_ms):
|
||||
"""Does some amount of work on the next queued background update
|
||||
|
||||
Args:
|
||||
desired_duration_ms(float): How long we want to spend
|
||||
updating.
|
||||
|
@ -135,11 +139,21 @@ class BackgroundUpdateStore(SQLBaseStore):
|
|||
self._background_update_queue.append(update['update_name'])
|
||||
|
||||
if not self._background_update_queue:
|
||||
# no work left to do
|
||||
defer.returnValue(None)
|
||||
|
||||
# pop from the front, and add back to the back
|
||||
update_name = self._background_update_queue.pop(0)
|
||||
self._background_update_queue.append(update_name)
|
||||
|
||||
res = yield self._do_background_update(update_name, desired_duration_ms)
|
||||
defer.returnValue(res)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_background_update(self, update_name, desired_duration_ms):
|
||||
logger.info("Starting update batch on background update '%s'",
|
||||
update_name)
|
||||
|
||||
update_handler = self._background_update_handlers[update_name]
|
||||
|
||||
performance = self._background_update_performance.get(update_name)
|
||||
|
@ -202,6 +216,64 @@ class BackgroundUpdateStore(SQLBaseStore):
|
|||
"""
|
||||
self._background_update_handlers[update_name] = update_handler
|
||||
|
||||
def register_background_index_update(self, update_name, index_name,
|
||||
table, columns):
|
||||
"""Helper for store classes to do a background index addition
|
||||
|
||||
To use:
|
||||
|
||||
1. use a schema delta file to add a background update. Example:
|
||||
INSERT INTO background_updates (update_name, progress_json) VALUES
|
||||
('my_new_index', '{}');
|
||||
|
||||
2. In the Store constructor, call this method
|
||||
|
||||
Args:
|
||||
update_name (str): update_name to register for
|
||||
index_name (str): name of index to add
|
||||
table (str): table to add index to
|
||||
columns (list[str]): columns/expressions to include in index
|
||||
"""
|
||||
|
||||
# if this is postgres, we add the indexes concurrently. Otherwise
|
||||
# we fall back to doing it inline
|
||||
if isinstance(self.database_engine, engines.PostgresEngine):
|
||||
conc = True
|
||||
else:
|
||||
conc = False
|
||||
|
||||
sql = "CREATE INDEX %(conc)s %(name)s ON %(table)s (%(columns)s)" \
|
||||
% {
|
||||
"conc": "CONCURRENTLY" if conc else "",
|
||||
"name": index_name,
|
||||
"table": table,
|
||||
"columns": ", ".join(columns),
|
||||
}
|
||||
|
||||
def create_index_concurrently(conn):
|
||||
conn.rollback()
|
||||
# postgres insists on autocommit for the index
|
||||
conn.set_session(autocommit=True)
|
||||
c = conn.cursor()
|
||||
c.execute(sql)
|
||||
conn.set_session(autocommit=False)
|
||||
|
||||
def create_index(conn):
|
||||
c = conn.cursor()
|
||||
c.execute(sql)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def updater(progress, batch_size):
|
||||
logger.info("Adding index %s to %s", index_name, table)
|
||||
if conc:
|
||||
yield self.runWithConnection(create_index_concurrently)
|
||||
else:
|
||||
yield self.runWithConnection(create_index)
|
||||
yield self._end_background_update(update_name)
|
||||
defer.returnValue(1)
|
||||
|
||||
self.register_background_update_handler(update_name, updater)
|
||||
|
||||
def start_background_update(self, update_name, progress):
|
||||
"""Starts a background update running.
|
||||
|
||||
|
|
|
@ -15,10 +15,11 @@
|
|||
|
||||
import logging
|
||||
|
||||
from ._base import SQLBaseStore, Cache
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from ._base import Cache
|
||||
from . import background_updates
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Number of msec of granularity to store the user IP 'last seen' time. Smaller
|
||||
|
@ -27,8 +28,7 @@ logger = logging.getLogger(__name__)
|
|||
LAST_SEEN_GRANULARITY = 120 * 1000
|
||||
|
||||
|
||||
class ClientIpStore(SQLBaseStore):
|
||||
|
||||
class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
def __init__(self, hs):
|
||||
self.client_ip_last_seen = Cache(
|
||||
name="client_ip_last_seen",
|
||||
|
@ -37,6 +37,13 @@ class ClientIpStore(SQLBaseStore):
|
|||
|
||||
super(ClientIpStore, self).__init__(hs)
|
||||
|
||||
self.register_background_index_update(
|
||||
"user_ips_device_index",
|
||||
index_name="user_ips_device_id",
|
||||
table="user_ips",
|
||||
columns=["user_id", "device_id", "last_seen"],
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def insert_client_ip(self, user, access_token, ip, user_agent, device_id):
|
||||
now = int(self._clock.time_msec())
|
||||
|
|
|
@ -76,6 +76,46 @@ class DeviceStore(SQLBaseStore):
|
|||
desc="get_device",
|
||||
)
|
||||
|
||||
def delete_device(self, user_id, device_id):
|
||||
"""Delete a device.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user which owns the device
|
||||
device_id (str): The ID of the device to delete
|
||||
Returns:
|
||||
defer.Deferred
|
||||
"""
|
||||
return self._simple_delete_one(
|
||||
table="devices",
|
||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||
desc="delete_device",
|
||||
)
|
||||
|
||||
def update_device(self, user_id, device_id, new_display_name=None):
|
||||
"""Update a device.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user which owns the device
|
||||
device_id (str): The ID of the device to update
|
||||
new_display_name (str|None): new displayname for device; None
|
||||
to leave unchanged
|
||||
Raises:
|
||||
StoreError: if the device is not found
|
||||
Returns:
|
||||
defer.Deferred
|
||||
"""
|
||||
updates = {}
|
||||
if new_display_name is not None:
|
||||
updates["display_name"] = new_display_name
|
||||
if not updates:
|
||||
return defer.succeed(None)
|
||||
return self._simple_update_one(
|
||||
table="devices",
|
||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||
updatevalues=updates,
|
||||
desc="update_device",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_devices_by_user(self, user_id):
|
||||
"""Retrieve all of a user's registered devices.
|
||||
|
|
|
@ -13,6 +13,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import twisted.internet.defer
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
|
||||
|
||||
|
@ -123,3 +125,16 @@ class EndToEndKeyStore(SQLBaseStore):
|
|||
return self.runInteraction(
|
||||
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
|
||||
)
|
||||
|
||||
@twisted.internet.defer.inlineCallbacks
|
||||
def delete_e2e_keys_by_device(self, user_id, device_id):
|
||||
yield self._simple_delete(
|
||||
table="e2e_device_keys_json",
|
||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||
desc="delete_e2e_device_keys_by_device"
|
||||
)
|
||||
yield self._simple_delete(
|
||||
table="e2e_one_time_keys_json",
|
||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||
desc="delete_e2e_one_time_keys_by_device"
|
||||
)
|
||||
|
|
|
@ -117,21 +117,149 @@ class EventPushActionsStore(SQLBaseStore):
|
|||
defer.returnValue(ret)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_unread_push_actions_for_user_in_range(self, user_id,
|
||||
min_stream_ordering,
|
||||
max_stream_ordering=None,
|
||||
limit=20):
|
||||
def get_unread_push_actions_for_user_in_range_for_http(
|
||||
self, user_id, min_stream_ordering, max_stream_ordering, limit=20
|
||||
):
|
||||
"""Get a list of the most recent unread push actions for a given user,
|
||||
within the given stream ordering range. Called by the httppusher.
|
||||
|
||||
Args:
|
||||
user_id (str): The user to fetch push actions for.
|
||||
min_stream_ordering(int): The exclusive lower bound on the
|
||||
stream ordering of event push actions to fetch.
|
||||
max_stream_ordering(int): The inclusive upper bound on the
|
||||
stream ordering of event push actions to fetch.
|
||||
limit (int): The maximum number of rows to return.
|
||||
Returns:
|
||||
A promise which resolves to a list of dicts with the keys "event_id",
|
||||
"room_id", "stream_ordering", "actions".
|
||||
The list will be ordered by ascending stream_ordering.
|
||||
The list will have between 0~limit entries.
|
||||
"""
|
||||
# find rooms that have a read receipt in them and return the next
|
||||
# push actions
|
||||
def get_after_receipt(txn):
|
||||
# find rooms that have a read receipt in them and return the next
|
||||
# push actions
|
||||
sql = (
|
||||
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions"
|
||||
" FROM ("
|
||||
" SELECT room_id,"
|
||||
" MAX(topological_ordering) as topological_ordering,"
|
||||
" MAX(stream_ordering) as stream_ordering"
|
||||
" FROM events"
|
||||
" INNER JOIN receipts_linearized USING (room_id, event_id)"
|
||||
" WHERE receipt_type = 'm.read' AND user_id = ?"
|
||||
" GROUP BY room_id"
|
||||
") AS rl,"
|
||||
" event_push_actions AS ep"
|
||||
" WHERE"
|
||||
" ep.room_id = rl.room_id"
|
||||
" AND ("
|
||||
" ep.topological_ordering > rl.topological_ordering"
|
||||
" OR ("
|
||||
" ep.topological_ordering = rl.topological_ordering"
|
||||
" AND ep.stream_ordering > rl.stream_ordering"
|
||||
" )"
|
||||
" )"
|
||||
" AND ep.user_id = ?"
|
||||
" AND ep.stream_ordering > ?"
|
||||
" AND ep.stream_ordering <= ?"
|
||||
" ORDER BY ep.stream_ordering ASC LIMIT ?"
|
||||
)
|
||||
args = [
|
||||
user_id, user_id,
|
||||
min_stream_ordering, max_stream_ordering, limit,
|
||||
]
|
||||
txn.execute(sql, args)
|
||||
return txn.fetchall()
|
||||
after_read_receipt = yield self.runInteraction(
|
||||
"get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
|
||||
)
|
||||
|
||||
# There are rooms with push actions in them but you don't have a read receipt in
|
||||
# them e.g. rooms you've been invited to, so get push actions for rooms which do
|
||||
# not have read receipts in them too.
|
||||
def get_no_receipt(txn):
|
||||
sql = (
|
||||
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
|
||||
" e.received_ts"
|
||||
" FROM event_push_actions AS ep"
|
||||
" INNER JOIN events AS e USING (room_id, event_id)"
|
||||
" WHERE"
|
||||
" ep.room_id NOT IN ("
|
||||
" SELECT room_id FROM receipts_linearized"
|
||||
" WHERE receipt_type = 'm.read' AND user_id = ?"
|
||||
" GROUP BY room_id"
|
||||
" )"
|
||||
" AND ep.user_id = ?"
|
||||
" AND ep.stream_ordering > ?"
|
||||
" AND ep.stream_ordering <= ?"
|
||||
" ORDER BY ep.stream_ordering ASC LIMIT ?"
|
||||
)
|
||||
args = [
|
||||
user_id, user_id,
|
||||
min_stream_ordering, max_stream_ordering, limit,
|
||||
]
|
||||
txn.execute(sql, args)
|
||||
return txn.fetchall()
|
||||
no_read_receipt = yield self.runInteraction(
|
||||
"get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
|
||||
)
|
||||
|
||||
notifs = [
|
||||
{
|
||||
"event_id": row[0],
|
||||
"room_id": row[1],
|
||||
"stream_ordering": row[2],
|
||||
"actions": json.loads(row[3]),
|
||||
} for row in after_read_receipt + no_read_receipt
|
||||
]
|
||||
|
||||
# Now sort it so it's ordered correctly, since currently it will
|
||||
# contain results from the first query, correctly ordered, followed
|
||||
# by results from the second query, but we want them all ordered
|
||||
# by stream_ordering, oldest first.
|
||||
notifs.sort(key=lambda r: r['stream_ordering'])
|
||||
|
||||
# Take only up to the limit. We have to stop at the limit because
|
||||
# one of the subqueries may have hit the limit.
|
||||
defer.returnValue(notifs[:limit])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_unread_push_actions_for_user_in_range_for_email(
|
||||
self, user_id, min_stream_ordering, max_stream_ordering, limit=20
|
||||
):
|
||||
"""Get a list of the most recent unread push actions for a given user,
|
||||
within the given stream ordering range. Called by the emailpusher
|
||||
|
||||
Args:
|
||||
user_id (str): The user to fetch push actions for.
|
||||
min_stream_ordering(int): The exclusive lower bound on the
|
||||
stream ordering of event push actions to fetch.
|
||||
max_stream_ordering(int): The inclusive upper bound on the
|
||||
stream ordering of event push actions to fetch.
|
||||
limit (int): The maximum number of rows to return.
|
||||
Returns:
|
||||
A promise which resolves to a list of dicts with the keys "event_id",
|
||||
"room_id", "stream_ordering", "actions", "received_ts".
|
||||
The list will be ordered by descending received_ts.
|
||||
The list will have between 0~limit entries.
|
||||
"""
|
||||
# find rooms that have a read receipt in them and return the most recent
|
||||
# push actions
|
||||
def get_after_receipt(txn):
|
||||
sql = (
|
||||
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
|
||||
" e.received_ts"
|
||||
" FROM ("
|
||||
" SELECT room_id, user_id, "
|
||||
" max(topological_ordering) as topological_ordering, "
|
||||
" max(stream_ordering) as stream_ordering "
|
||||
" SELECT room_id,"
|
||||
" MAX(topological_ordering) as topological_ordering,"
|
||||
" MAX(stream_ordering) as stream_ordering"
|
||||
" FROM events"
|
||||
" NATURAL JOIN receipts_linearized WHERE receipt_type = 'm.read'"
|
||||
" GROUP BY room_id, user_id"
|
||||
" INNER JOIN receipts_linearized USING (room_id, event_id)"
|
||||
" WHERE receipt_type = 'm.read' AND user_id = ?"
|
||||
" GROUP BY room_id"
|
||||
") AS rl,"
|
||||
" event_push_actions AS ep"
|
||||
" INNER JOIN events AS e USING (room_id, event_id)"
|
||||
|
@ -144,44 +272,49 @@ class EventPushActionsStore(SQLBaseStore):
|
|||
" AND ep.stream_ordering > rl.stream_ordering"
|
||||
" )"
|
||||
" )"
|
||||
" AND ep.stream_ordering > ?"
|
||||
" AND ep.user_id = ?"
|
||||
" AND ep.user_id = rl.user_id"
|
||||
" AND ep.stream_ordering > ?"
|
||||
" AND ep.stream_ordering <= ?"
|
||||
" ORDER BY ep.stream_ordering DESC LIMIT ?"
|
||||
)
|
||||
args = [min_stream_ordering, user_id]
|
||||
if max_stream_ordering is not None:
|
||||
sql += " AND ep.stream_ordering <= ?"
|
||||
args.append(max_stream_ordering)
|
||||
sql += " ORDER BY ep.stream_ordering DESC LIMIT ?"
|
||||
args.append(limit)
|
||||
args = [
|
||||
user_id, user_id,
|
||||
min_stream_ordering, max_stream_ordering, limit,
|
||||
]
|
||||
txn.execute(sql, args)
|
||||
return txn.fetchall()
|
||||
after_read_receipt = yield self.runInteraction(
|
||||
"get_unread_push_actions_for_user_in_range", get_after_receipt
|
||||
"get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
|
||||
)
|
||||
|
||||
# There are rooms with push actions in them but you don't have a read receipt in
|
||||
# them e.g. rooms you've been invited to, so get push actions for rooms which do
|
||||
# not have read receipts in them too.
|
||||
def get_no_receipt(txn):
|
||||
sql = (
|
||||
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
|
||||
" e.received_ts"
|
||||
" FROM event_push_actions AS ep"
|
||||
" JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
|
||||
" WHERE ep.room_id not in ("
|
||||
" SELECT room_id FROM events NATURAL JOIN receipts_linearized"
|
||||
" INNER JOIN events AS e USING (room_id, event_id)"
|
||||
" WHERE"
|
||||
" ep.room_id NOT IN ("
|
||||
" SELECT room_id FROM receipts_linearized"
|
||||
" WHERE receipt_type = 'm.read' AND user_id = ?"
|
||||
" GROUP BY room_id"
|
||||
") AND ep.user_id = ? AND ep.stream_ordering > ?"
|
||||
" )"
|
||||
" AND ep.user_id = ?"
|
||||
" AND ep.stream_ordering > ?"
|
||||
" AND ep.stream_ordering <= ?"
|
||||
" ORDER BY ep.stream_ordering DESC LIMIT ?"
|
||||
)
|
||||
args = [user_id, user_id, min_stream_ordering]
|
||||
if max_stream_ordering is not None:
|
||||
sql += " AND ep.stream_ordering <= ?"
|
||||
args.append(max_stream_ordering)
|
||||
sql += " ORDER BY ep.stream_ordering DESC LIMIT ?"
|
||||
args.append(limit)
|
||||
args = [
|
||||
user_id, user_id,
|
||||
min_stream_ordering, max_stream_ordering, limit,
|
||||
]
|
||||
txn.execute(sql, args)
|
||||
return txn.fetchall()
|
||||
no_read_receipt = yield self.runInteraction(
|
||||
"get_unread_push_actions_for_user_in_range", get_no_receipt
|
||||
"get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
|
||||
)
|
||||
|
||||
# Make a list of dicts from the two sets of results.
|
||||
|
@ -198,7 +331,7 @@ class EventPushActionsStore(SQLBaseStore):
|
|||
# Now sort it so it's ordered correctly, since currently it will
|
||||
# contain results from the first query, correctly ordered, followed
|
||||
# by results from the second query, but we want them all ordered
|
||||
# by received_ts
|
||||
# by received_ts (most recent first)
|
||||
notifs.sort(key=lambda r: -(r['received_ts'] or 0))
|
||||
|
||||
# Now return the first `limit`
|
||||
|
|
|
@ -397,6 +397,12 @@ class EventsStore(SQLBaseStore):
|
|||
|
||||
@log_function
|
||||
def _persist_events_txn(self, txn, events_and_contexts, backfilled):
|
||||
"""Insert some number of room events into the necessary database tables.
|
||||
|
||||
Rejected events are only inserted into the events table, the events_json table,
|
||||
and the rejections table. Things reading from those table will need to check
|
||||
whether the event was rejected.
|
||||
"""
|
||||
depth_updates = {}
|
||||
for event, context in events_and_contexts:
|
||||
# Remove the any existing cache entries for the event_ids
|
||||
|
@ -407,21 +413,11 @@ class EventsStore(SQLBaseStore):
|
|||
event.room_id, event.internal_metadata.stream_ordering,
|
||||
)
|
||||
|
||||
if not event.internal_metadata.is_outlier():
|
||||
if not event.internal_metadata.is_outlier() and not context.rejected:
|
||||
depth_updates[event.room_id] = max(
|
||||
event.depth, depth_updates.get(event.room_id, event.depth)
|
||||
)
|
||||
|
||||
if context.push_actions:
|
||||
self._set_push_actions_for_event_and_users_txn(
|
||||
txn, event, context.push_actions
|
||||
)
|
||||
|
||||
if event.type == EventTypes.Redaction and event.redacts is not None:
|
||||
self._remove_push_actions_for_event_id_txn(
|
||||
txn, event.room_id, event.redacts
|
||||
)
|
||||
|
||||
for room_id, depth in depth_updates.items():
|
||||
self._update_min_depth_for_room_txn(txn, room_id, depth)
|
||||
|
||||
|
@ -431,14 +427,24 @@ class EventsStore(SQLBaseStore):
|
|||
),
|
||||
[event.event_id for event, _ in events_and_contexts]
|
||||
)
|
||||
|
||||
have_persisted = {
|
||||
event_id: outlier
|
||||
for event_id, outlier in txn.fetchall()
|
||||
}
|
||||
|
||||
# Remove the events that we've seen before.
|
||||
event_map = {}
|
||||
to_remove = set()
|
||||
for event, context in events_and_contexts:
|
||||
if context.rejected:
|
||||
# If the event is rejected then we don't care if the event
|
||||
# was an outlier or not.
|
||||
if event.event_id in have_persisted:
|
||||
# If we have already seen the event then ignore it.
|
||||
to_remove.add(event)
|
||||
continue
|
||||
|
||||
# Handle the case of the list including the same event multiple
|
||||
# times. The tricky thing here is when they differ by whether
|
||||
# they are an outlier.
|
||||
|
@ -463,6 +469,12 @@ class EventsStore(SQLBaseStore):
|
|||
|
||||
outlier_persisted = have_persisted[event.event_id]
|
||||
if not event.internal_metadata.is_outlier() and outlier_persisted:
|
||||
# We received a copy of an event that we had already stored as
|
||||
# an outlier in the database. We now have some state at that
|
||||
# so we need to update the state_groups table with that state.
|
||||
|
||||
# insert into the state_group, state_groups_state and
|
||||
# event_to_state_groups tables.
|
||||
self._store_mult_state_groups_txn(txn, ((event, context),))
|
||||
|
||||
metadata_json = encode_json(
|
||||
|
@ -478,6 +490,8 @@ class EventsStore(SQLBaseStore):
|
|||
(metadata_json, event.event_id,)
|
||||
)
|
||||
|
||||
# Add an entry to the ex_outlier_stream table to replicate the
|
||||
# change in outlier status to our workers.
|
||||
stream_order = event.internal_metadata.stream_ordering
|
||||
state_group_id = context.state_group or context.new_state_group_id
|
||||
self._simple_insert_txn(
|
||||
|
@ -499,6 +513,8 @@ class EventsStore(SQLBaseStore):
|
|||
(False, event.event_id,)
|
||||
)
|
||||
|
||||
# Update the event_backward_extremities table now that this
|
||||
# event isn't an outlier any more.
|
||||
self._update_extremeties(txn, [event])
|
||||
|
||||
events_and_contexts = [
|
||||
|
@ -506,38 +522,12 @@ class EventsStore(SQLBaseStore):
|
|||
]
|
||||
|
||||
if not events_and_contexts:
|
||||
# Make sure we don't pass an empty list to functions that expect to
|
||||
# be storing at least one element.
|
||||
return
|
||||
|
||||
self._store_mult_state_groups_txn(txn, events_and_contexts)
|
||||
|
||||
self._handle_mult_prev_events(
|
||||
txn,
|
||||
events=[event for event, _ in events_and_contexts],
|
||||
)
|
||||
|
||||
for event, _ in events_and_contexts:
|
||||
if event.type == EventTypes.Name:
|
||||
self._store_room_name_txn(txn, event)
|
||||
elif event.type == EventTypes.Topic:
|
||||
self._store_room_topic_txn(txn, event)
|
||||
elif event.type == EventTypes.Message:
|
||||
self._store_room_message_txn(txn, event)
|
||||
elif event.type == EventTypes.Redaction:
|
||||
self._store_redaction(txn, event)
|
||||
elif event.type == EventTypes.RoomHistoryVisibility:
|
||||
self._store_history_visibility_txn(txn, event)
|
||||
elif event.type == EventTypes.GuestAccess:
|
||||
self._store_guest_access_txn(txn, event)
|
||||
|
||||
self._store_room_members_txn(
|
||||
txn,
|
||||
[
|
||||
event
|
||||
for event, _ in events_and_contexts
|
||||
if event.type == EventTypes.Member
|
||||
],
|
||||
backfilled=backfilled,
|
||||
)
|
||||
# From this point onwards the events are only events that we haven't
|
||||
# seen before.
|
||||
|
||||
def event_dict(event):
|
||||
return {
|
||||
|
@ -591,10 +581,41 @@ class EventsStore(SQLBaseStore):
|
|||
],
|
||||
)
|
||||
|
||||
# Remove the rejected events from the list now that we've added them
|
||||
# to the events table and the events_json table.
|
||||
to_remove = set()
|
||||
for event, context in events_and_contexts:
|
||||
if context.rejected:
|
||||
# Insert the event_id into the rejections table
|
||||
self._store_rejections_txn(
|
||||
txn, event.event_id, context.rejected
|
||||
)
|
||||
to_remove.add(event)
|
||||
|
||||
events_and_contexts = [
|
||||
ec for ec in events_and_contexts if ec[0] not in to_remove
|
||||
]
|
||||
|
||||
if not events_and_contexts:
|
||||
# Make sure we don't pass an empty list to functions that expect to
|
||||
# be storing at least one element.
|
||||
return
|
||||
|
||||
# From this point onwards the events are only ones that weren't rejected.
|
||||
|
||||
for event, context in events_and_contexts:
|
||||
# Insert all the push actions into the event_push_actions table.
|
||||
if context.push_actions:
|
||||
self._set_push_actions_for_event_and_users_txn(
|
||||
txn, event, context.push_actions
|
||||
)
|
||||
|
||||
if event.type == EventTypes.Redaction and event.redacts is not None:
|
||||
# Remove the entries in the event_push_actions table for the
|
||||
# redacted event.
|
||||
self._remove_push_actions_for_event_id_txn(
|
||||
txn, event.room_id, event.redacts
|
||||
)
|
||||
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
|
@ -610,6 +631,49 @@ class EventsStore(SQLBaseStore):
|
|||
],
|
||||
)
|
||||
|
||||
# Insert into the state_groups, state_groups_state, and
|
||||
# event_to_state_groups tables.
|
||||
self._store_mult_state_groups_txn(txn, events_and_contexts)
|
||||
|
||||
# Update the event_forward_extremities, event_backward_extremities and
|
||||
# event_edges tables.
|
||||
self._handle_mult_prev_events(
|
||||
txn,
|
||||
events=[event for event, _ in events_and_contexts],
|
||||
)
|
||||
|
||||
for event, _ in events_and_contexts:
|
||||
if event.type == EventTypes.Name:
|
||||
# Insert into the room_names and event_search tables.
|
||||
self._store_room_name_txn(txn, event)
|
||||
elif event.type == EventTypes.Topic:
|
||||
# Insert into the topics table and event_search table.
|
||||
self._store_room_topic_txn(txn, event)
|
||||
elif event.type == EventTypes.Message:
|
||||
# Insert into the event_search table.
|
||||
self._store_room_message_txn(txn, event)
|
||||
elif event.type == EventTypes.Redaction:
|
||||
# Insert into the redactions table.
|
||||
self._store_redaction(txn, event)
|
||||
elif event.type == EventTypes.RoomHistoryVisibility:
|
||||
# Insert into the event_search table.
|
||||
self._store_history_visibility_txn(txn, event)
|
||||
elif event.type == EventTypes.GuestAccess:
|
||||
# Insert into the event_search table.
|
||||
self._store_guest_access_txn(txn, event)
|
||||
|
||||
# Insert into the room_memberships table.
|
||||
self._store_room_members_txn(
|
||||
txn,
|
||||
[
|
||||
event
|
||||
for event, _ in events_and_contexts
|
||||
if event.type == EventTypes.Member
|
||||
],
|
||||
backfilled=backfilled,
|
||||
)
|
||||
|
||||
# Insert event_reference_hashes table.
|
||||
self._store_event_reference_hashes_txn(
|
||||
txn, [event for event, _ in events_and_contexts]
|
||||
)
|
||||
|
@ -654,6 +718,7 @@ class EventsStore(SQLBaseStore):
|
|||
],
|
||||
)
|
||||
|
||||
# Prefill the event cache
|
||||
self._add_to_cache(txn, events_and_contexts)
|
||||
|
||||
if backfilled:
|
||||
|
@ -666,11 +731,6 @@ class EventsStore(SQLBaseStore):
|
|||
# Outlier events shouldn't clobber the current state.
|
||||
continue
|
||||
|
||||
if context.rejected:
|
||||
# If the event failed it's auth checks then it shouldn't
|
||||
# clobbler the current state.
|
||||
continue
|
||||
|
||||
txn.call_after(
|
||||
self._get_current_state_for_key.invalidate,
|
||||
(event.room_id, event.type, event.state_key,)
|
||||
|
|
|
@ -18,18 +18,31 @@ import re
|
|||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import StoreError, Codes
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
from synapse.storage import background_updates
|
||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
||||
|
||||
|
||||
class RegistrationStore(SQLBaseStore):
|
||||
class RegistrationStore(background_updates.BackgroundUpdateStore):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RegistrationStore, self).__init__(hs)
|
||||
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
self.register_background_index_update(
|
||||
"access_tokens_device_index",
|
||||
index_name="access_tokens_device_id",
|
||||
table="access_tokens",
|
||||
columns=["user_id", "device_id"],
|
||||
)
|
||||
|
||||
self.register_background_index_update(
|
||||
"refresh_tokens_device_index",
|
||||
index_name="refresh_tokens_device_id",
|
||||
table="refresh_tokens",
|
||||
columns=["user_id", "device_id"],
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_access_token_to_user(self, user_id, token, device_id=None):
|
||||
"""Adds an access token for the given user.
|
||||
|
@ -238,16 +251,37 @@ class RegistrationStore(SQLBaseStore):
|
|||
self.get_user_by_id.invalidate((user_id,))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def user_delete_access_tokens(self, user_id, except_token_ids=[]):
|
||||
def f(txn):
|
||||
sql = "SELECT token FROM access_tokens WHERE user_id = ?"
|
||||
def user_delete_access_tokens(self, user_id, except_token_ids=[],
|
||||
device_id=None,
|
||||
delete_refresh_tokens=False):
|
||||
"""
|
||||
Invalidate access/refresh tokens belonging to a user
|
||||
|
||||
Args:
|
||||
user_id (str): ID of user the tokens belong to
|
||||
except_token_ids (list[str]): list of access_tokens which should
|
||||
*not* be deleted
|
||||
device_id (str|None): ID of device the tokens are associated with.
|
||||
If None, tokens associated with any device (or no device) will
|
||||
be deleted
|
||||
delete_refresh_tokens (bool): True to delete refresh tokens as
|
||||
well as access tokens.
|
||||
Returns:
|
||||
defer.Deferred:
|
||||
"""
|
||||
def f(txn, table, except_tokens, call_after_delete):
|
||||
sql = "SELECT token FROM %s WHERE user_id = ?" % table
|
||||
clauses = [user_id]
|
||||
|
||||
if except_token_ids:
|
||||
if device_id is not None:
|
||||
sql += " AND device_id = ?"
|
||||
clauses.append(device_id)
|
||||
|
||||
if except_tokens:
|
||||
sql += " AND id NOT IN (%s)" % (
|
||||
",".join(["?" for _ in except_token_ids]),
|
||||
",".join(["?" for _ in except_tokens]),
|
||||
)
|
||||
clauses += except_token_ids
|
||||
clauses += except_tokens
|
||||
|
||||
txn.execute(sql, clauses)
|
||||
|
||||
|
@ -256,16 +290,33 @@ class RegistrationStore(SQLBaseStore):
|
|||
n = 100
|
||||
chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)]
|
||||
for chunk in chunks:
|
||||
if call_after_delete:
|
||||
for row in chunk:
|
||||
txn.call_after(self.get_user_by_access_token.invalidate, (row[0],))
|
||||
txn.call_after(call_after_delete, (row[0],))
|
||||
|
||||
txn.execute(
|
||||
"DELETE FROM access_tokens WHERE token in (%s)" % (
|
||||
"DELETE FROM %s WHERE token in (%s)" % (
|
||||
table,
|
||||
",".join(["?" for _ in chunk]),
|
||||
), [r[0] for r in chunk]
|
||||
)
|
||||
|
||||
yield self.runInteraction("user_delete_access_tokens", f)
|
||||
# delete refresh tokens first, to stop new access tokens being
|
||||
# allocated while our backs are turned
|
||||
if delete_refresh_tokens:
|
||||
yield self.runInteraction(
|
||||
"user_delete_access_tokens", f,
|
||||
table="refresh_tokens",
|
||||
except_tokens=[],
|
||||
call_after_delete=None,
|
||||
)
|
||||
|
||||
yield self.runInteraction(
|
||||
"user_delete_access_tokens", f,
|
||||
table="access_tokens",
|
||||
except_tokens=except_token_ids,
|
||||
call_after_delete=self.get_user_by_access_token.invalidate,
|
||||
)
|
||||
|
||||
def delete_access_token(self, access_token):
|
||||
def f(txn):
|
||||
|
@ -288,9 +339,8 @@ class RegistrationStore(SQLBaseStore):
|
|||
Args:
|
||||
token (str): The access token of a user.
|
||||
Returns:
|
||||
dict: Including the name (user_id) and the ID of their access token.
|
||||
Raises:
|
||||
StoreError if no user was found.
|
||||
defer.Deferred: None, if the token did not match, otherwise dict
|
||||
including the keys `name`, `is_guest`, `device_id`, `token_id`.
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"get_user_by_access_token",
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
/* Copyright 2016 OpenMarket Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
INSERT INTO background_updates (update_name, progress_json) VALUES
|
||||
('access_tokens_device_index', '{}');
|
19
synapse/storage/schema/delta/33/devices_for_e2e_keys.sql
Normal file
19
synapse/storage/schema/delta/33/devices_for_e2e_keys.sql
Normal file
|
@ -0,0 +1,19 @@
|
|||
/* Copyright 2016 OpenMarket Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- make sure that we have a device record for each set of E2E keys, so that the
|
||||
-- user can delete them if they like.
|
||||
INSERT INTO devices
|
||||
SELECT user_id, device_id, 'unknown device' FROM e2e_device_keys_json;
|
|
@ -0,0 +1,17 @@
|
|||
/* Copyright 2016 OpenMarket Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
INSERT INTO background_updates (update_name, progress_json) VALUES
|
||||
('refresh_tokens_device_index', '{}');
|
|
@ -13,4 +13,5 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
CREATE INDEX user_ips_device_id ON user_ips(user_id, device_id, last_seen);
|
||||
INSERT INTO background_updates (update_name, progress_json) VALUES
|
||||
('user_ips_device_index', '{}');
|
||||
|
|
|
@ -24,6 +24,7 @@ from collections import namedtuple
|
|||
|
||||
import itertools
|
||||
import logging
|
||||
import ujson as json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -101,7 +102,7 @@ class TransactionStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
if result and result["response_code"]:
|
||||
return result["response_code"], result["response_json"]
|
||||
return result["response_code"], json.loads(str(result["response_json"]))
|
||||
else:
|
||||
return None
|
||||
|
||||
|
|
|
@ -18,7 +18,38 @@ from synapse.api.errors import SynapseError
|
|||
from collections import namedtuple
|
||||
|
||||
|
||||
Requester = namedtuple("Requester", ["user", "access_token_id", "is_guest"])
|
||||
Requester = namedtuple("Requester",
|
||||
["user", "access_token_id", "is_guest", "device_id"])
|
||||
"""
|
||||
Represents the user making a request
|
||||
|
||||
Attributes:
|
||||
user (UserID): id of the user making the request
|
||||
access_token_id (int|None): *ID* of the access token used for this
|
||||
request, or None if it came via the appservice API or similar
|
||||
is_guest (bool): True if the user making this request is a guest user
|
||||
device_id (str|None): device_id which was set at authentication time
|
||||
"""
|
||||
|
||||
|
||||
def create_requester(user_id, access_token_id=None, is_guest=False,
|
||||
device_id=None):
|
||||
"""
|
||||
Create a new ``Requester`` object
|
||||
|
||||
Args:
|
||||
user_id (str|UserID): id of the user making the request
|
||||
access_token_id (int|None): *ID* of the access token used for this
|
||||
request, or None if it came via the appservice API or similar
|
||||
is_guest (bool): True if the user making this request is a guest user
|
||||
device_id (str|None): device_id which was set at authentication time
|
||||
|
||||
Returns:
|
||||
Requester
|
||||
"""
|
||||
if not isinstance(user_id, UserID):
|
||||
user_id = UserID.from_string(user_id)
|
||||
return Requester(user_id, access_token_id, is_guest, device_id)
|
||||
|
||||
|
||||
def get_domain_from_id(string):
|
||||
|
|
|
@ -84,7 +84,7 @@ class Measure(object):
|
|||
|
||||
if context != self.start_context:
|
||||
logger.warn(
|
||||
"Context have unexpectedly changed from '%s' to '%s'. (%r)",
|
||||
"Context has unexpectedly changed from '%s' to '%s'. (%r)",
|
||||
context, self.start_context, self.name
|
||||
)
|
||||
return
|
||||
|
|
|
@ -83,7 +83,10 @@ def calculate_room_name(room_state, user_id, fallback_to_members=True,
|
|||
):
|
||||
if ("m.room.member", my_member_event.sender) in room_state:
|
||||
inviter_member_event = room_state[("m.room.member", my_member_event.sender)]
|
||||
if fallback_to_single_member:
|
||||
return "Invite from %s" % (name_from_member_event(inviter_member_event),)
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return "Room Invite"
|
||||
|
||||
|
|
|
@ -128,7 +128,7 @@ class RetryDestinationLimiter(object):
|
|||
)
|
||||
|
||||
valid_err_code = False
|
||||
if exc_type is CodeMessageException:
|
||||
if exc_type is not None and issubclass(exc_type, CodeMessageException):
|
||||
valid_err_code = 0 <= exc_val.code < 500
|
||||
|
||||
if exc_type is None or valid_err_code:
|
||||
|
|
|
@ -12,11 +12,14 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from synapse import types
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse.api.errors
|
||||
import synapse.handlers.device
|
||||
|
||||
import synapse.storage
|
||||
from synapse import types
|
||||
from tests import unittest, utils
|
||||
|
||||
user1 = "@boris:aaa"
|
||||
|
@ -27,7 +30,7 @@ class DeviceTestCase(unittest.TestCase):
|
|||
def __init__(self, *args, **kwargs):
|
||||
super(DeviceTestCase, self).__init__(*args, **kwargs)
|
||||
self.store = None # type: synapse.storage.DataStore
|
||||
self.handler = None # type: device.DeviceHandler
|
||||
self.handler = None # type: synapse.handlers.device.DeviceHandler
|
||||
self.clock = None # type: utils.MockClock
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -84,28 +87,31 @@ class DeviceTestCase(unittest.TestCase):
|
|||
yield self._record_users()
|
||||
|
||||
res = yield self.handler.get_devices_by_user(user1)
|
||||
self.assertEqual(3, len(res.keys()))
|
||||
self.assertEqual(3, len(res))
|
||||
device_map = {
|
||||
d["device_id"]: d for d in res
|
||||
}
|
||||
self.assertDictContainsSubset({
|
||||
"user_id": user1,
|
||||
"device_id": "xyz",
|
||||
"display_name": "display 0",
|
||||
"last_seen_ip": None,
|
||||
"last_seen_ts": None,
|
||||
}, res["xyz"])
|
||||
}, device_map["xyz"])
|
||||
self.assertDictContainsSubset({
|
||||
"user_id": user1,
|
||||
"device_id": "fco",
|
||||
"display_name": "display 1",
|
||||
"last_seen_ip": "ip1",
|
||||
"last_seen_ts": 1000000,
|
||||
}, res["fco"])
|
||||
}, device_map["fco"])
|
||||
self.assertDictContainsSubset({
|
||||
"user_id": user1,
|
||||
"device_id": "abc",
|
||||
"display_name": "display 2",
|
||||
"last_seen_ip": "ip3",
|
||||
"last_seen_ts": 3000000,
|
||||
}, res["abc"])
|
||||
}, device_map["abc"])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_device(self):
|
||||
|
@ -120,6 +126,37 @@ class DeviceTestCase(unittest.TestCase):
|
|||
"last_seen_ts": 3000000,
|
||||
}, res)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_delete_device(self):
|
||||
yield self._record_users()
|
||||
|
||||
# delete the device
|
||||
yield self.handler.delete_device(user1, "abc")
|
||||
|
||||
# check the device was deleted
|
||||
with self.assertRaises(synapse.api.errors.NotFoundError):
|
||||
yield self.handler.get_device(user1, "abc")
|
||||
|
||||
# we'd like to check the access token was invalidated, but that's a
|
||||
# bit of a PITA.
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_update_device(self):
|
||||
yield self._record_users()
|
||||
|
||||
update = {"display_name": "new display"}
|
||||
yield self.handler.update_device(user1, "abc", update)
|
||||
|
||||
res = yield self.handler.get_device(user1, "abc")
|
||||
self.assertEqual(res["display_name"], "new display")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_update_unknown_device(self):
|
||||
update = {"display_name": "new_display"}
|
||||
with self.assertRaises(synapse.api.errors.NotFoundError):
|
||||
yield self.handler.update_device("user_id", "unknown_device_id",
|
||||
update)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _record_users(self):
|
||||
# check this works for both devices which have a recorded client_ip,
|
||||
|
|
|
@ -19,11 +19,12 @@ from twisted.internet import defer
|
|||
|
||||
from mock import Mock, NonCallableMock
|
||||
|
||||
import synapse.types
|
||||
from synapse.api.errors import AuthError
|
||||
from synapse.handlers.profile import ProfileHandler
|
||||
from synapse.types import UserID
|
||||
|
||||
from tests.utils import setup_test_homeserver, requester_for_user
|
||||
from tests.utils import setup_test_homeserver
|
||||
|
||||
|
||||
class ProfileHandlers(object):
|
||||
|
@ -86,7 +87,7 @@ class ProfileTestCase(unittest.TestCase):
|
|||
def test_set_my_name(self):
|
||||
yield self.handler.set_displayname(
|
||||
self.frank,
|
||||
requester_for_user(self.frank),
|
||||
synapse.types.create_requester(self.frank),
|
||||
"Frank Jr."
|
||||
)
|
||||
|
||||
|
@ -99,7 +100,7 @@ class ProfileTestCase(unittest.TestCase):
|
|||
def test_set_my_name_noauth(self):
|
||||
d = self.handler.set_displayname(
|
||||
self.frank,
|
||||
requester_for_user(self.bob),
|
||||
synapse.types.create_requester(self.bob),
|
||||
"Frank Jr."
|
||||
)
|
||||
|
||||
|
@ -144,7 +145,8 @@ class ProfileTestCase(unittest.TestCase):
|
|||
@defer.inlineCallbacks
|
||||
def test_set_my_avatar(self):
|
||||
yield self.handler.set_avatar_url(
|
||||
self.frank, requester_for_user(self.frank), "http://my.server/pic.gif"
|
||||
self.frank, synapse.types.create_requester(self.frank),
|
||||
"http://my.server/pic.gif"
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
|
|
|
@ -13,15 +13,17 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.replication.resource import ReplicationResource
|
||||
from synapse.types import Requester, UserID
|
||||
|
||||
from twisted.internet import defer
|
||||
from tests import unittest
|
||||
from tests.utils import setup_test_homeserver, requester_for_user
|
||||
from mock import Mock, NonCallableMock
|
||||
import json
|
||||
import contextlib
|
||||
import json
|
||||
|
||||
from mock import Mock, NonCallableMock
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse.types
|
||||
from synapse.replication.resource import ReplicationResource
|
||||
from synapse.types import UserID
|
||||
from tests import unittest
|
||||
from tests.utils import setup_test_homeserver
|
||||
|
||||
|
||||
class ReplicationResourceCase(unittest.TestCase):
|
||||
|
@ -61,7 +63,7 @@ class ReplicationResourceCase(unittest.TestCase):
|
|||
def test_events_and_state(self):
|
||||
get = self.get(events="-1", state="-1", timeout="0")
|
||||
yield self.hs.get_handlers().room_creation_handler.create_room(
|
||||
Requester(self.user, "", False), {}
|
||||
synapse.types.create_requester(self.user), {}
|
||||
)
|
||||
code, body = yield get
|
||||
self.assertEquals(code, 200)
|
||||
|
@ -144,7 +146,7 @@ class ReplicationResourceCase(unittest.TestCase):
|
|||
def send_text_message(self, room_id, message):
|
||||
handler = self.hs.get_handlers().message_handler
|
||||
event = yield handler.create_and_send_nonmember_event(
|
||||
requester_for_user(self.user),
|
||||
synapse.types.create_requester(self.user),
|
||||
{
|
||||
"type": "m.room.message",
|
||||
"content": {"body": "message", "msgtype": "m.text"},
|
||||
|
@ -157,7 +159,7 @@ class ReplicationResourceCase(unittest.TestCase):
|
|||
@defer.inlineCallbacks
|
||||
def create_room(self):
|
||||
result = yield self.hs.get_handlers().room_creation_handler.create_room(
|
||||
Requester(self.user, "", False), {}
|
||||
synapse.types.create_requester(self.user), {}
|
||||
)
|
||||
defer.returnValue(result["room_id"])
|
||||
|
||||
|
|
|
@ -14,17 +14,14 @@
|
|||
# limitations under the License.
|
||||
|
||||
"""Tests REST events for /profile paths."""
|
||||
from tests import unittest
|
||||
from mock import Mock
|
||||
from twisted.internet import defer
|
||||
|
||||
from mock import Mock
|
||||
|
||||
from ....utils import MockHttpResource, setup_test_homeserver
|
||||
|
||||
import synapse.types
|
||||
from synapse.api.errors import SynapseError, AuthError
|
||||
from synapse.types import Requester, UserID
|
||||
|
||||
from synapse.rest.client.v1 import profile
|
||||
from tests import unittest
|
||||
from ....utils import MockHttpResource, setup_test_homeserver
|
||||
|
||||
myid = "@1234ABCD:test"
|
||||
PATH_PREFIX = "/_matrix/client/api/v1"
|
||||
|
@ -52,7 +49,7 @@ class ProfileTestCase(unittest.TestCase):
|
|||
)
|
||||
|
||||
def _get_user_by_req(request=None, allow_guest=False):
|
||||
return Requester(UserID.from_string(myid), "", False)
|
||||
return synapse.types.create_requester(myid)
|
||||
|
||||
hs.get_v1auth().get_user_by_req = _get_user_by_req
|
||||
|
||||
|
|
|
@ -65,13 +65,16 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||
self.registration_handler.appservice_register = Mock(
|
||||
return_value=user_id
|
||||
)
|
||||
self.auth_handler.issue_access_token = Mock(return_value=token)
|
||||
self.auth_handler.get_login_tuple_for_user_id = Mock(
|
||||
return_value=(token, "kermits_refresh_token")
|
||||
)
|
||||
|
||||
(code, result) = yield self.servlet.on_POST(self.request)
|
||||
self.assertEquals(code, 200)
|
||||
det_data = {
|
||||
"user_id": user_id,
|
||||
"access_token": token,
|
||||
"refresh_token": "kermits_refresh_token",
|
||||
"home_server": self.hs.hostname
|
||||
}
|
||||
self.assertDictContainsSubset(det_data, result)
|
||||
|
@ -121,7 +124,9 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||
"password": "monkey"
|
||||
}, None)
|
||||
self.registration_handler.register = Mock(return_value=(user_id, None))
|
||||
self.auth_handler.issue_access_token = Mock(return_value=token)
|
||||
self.auth_handler.get_login_tuple_for_user_id = Mock(
|
||||
return_value=(token, "kermits_refresh_token")
|
||||
)
|
||||
self.device_handler.check_device_registered = \
|
||||
Mock(return_value=device_id)
|
||||
|
||||
|
@ -130,13 +135,14 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||
det_data = {
|
||||
"user_id": user_id,
|
||||
"access_token": token,
|
||||
"refresh_token": "kermits_refresh_token",
|
||||
"home_server": self.hs.hostname,
|
||||
"device_id": device_id,
|
||||
}
|
||||
self.assertDictContainsSubset(det_data, result)
|
||||
self.assertIn("refresh_token", result)
|
||||
self.auth_handler.issue_access_token.assert_called_once_with(
|
||||
user_id, device_id=device_id)
|
||||
self.auth_handler.get_login_tuple_for_user_id(
|
||||
user_id, device_id=device_id, initial_device_display_name=None)
|
||||
|
||||
def test_POST_disabled_registration(self):
|
||||
self.hs.config.enable_registration = False
|
||||
|
|
|
@ -10,7 +10,7 @@ class BackgroundUpdateTestCase(unittest.TestCase):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
hs = yield setup_test_homeserver()
|
||||
hs = yield setup_test_homeserver() # type: synapse.server.HomeServer
|
||||
self.store = hs.get_datastore()
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
|
@ -20,11 +20,20 @@ class BackgroundUpdateTestCase(unittest.TestCase):
|
|||
"test_update", self.update_handler
|
||||
)
|
||||
|
||||
# run the real background updates, to get them out the way
|
||||
# (perhaps we should run them as part of the test HS setup, since we
|
||||
# run all of the other schema setup stuff there?)
|
||||
while True:
|
||||
res = yield self.store.do_next_background_update(1000)
|
||||
if res is None:
|
||||
break
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_do_background_update(self):
|
||||
desired_count = 1000
|
||||
duration_ms = 42
|
||||
|
||||
# first step: make a bit of progress
|
||||
@defer.inlineCallbacks
|
||||
def update(progress, count):
|
||||
self.clock.advance_time_msec(count * duration_ms)
|
||||
|
@ -42,7 +51,7 @@ class BackgroundUpdateTestCase(unittest.TestCase):
|
|||
yield self.store.start_background_update("test_update", {"my_key": 1})
|
||||
|
||||
self.update_handler.reset_mock()
|
||||
result = yield self.store.do_background_update(
|
||||
result = yield self.store.do_next_background_update(
|
||||
duration_ms * desired_count
|
||||
)
|
||||
self.assertIsNotNone(result)
|
||||
|
@ -50,15 +59,15 @@ class BackgroundUpdateTestCase(unittest.TestCase):
|
|||
{"my_key": 1}, self.store.DEFAULT_BACKGROUND_BATCH_SIZE
|
||||
)
|
||||
|
||||
# second step: complete the update
|
||||
@defer.inlineCallbacks
|
||||
def update(progress, count):
|
||||
yield self.store._end_background_update("test_update")
|
||||
defer.returnValue(count)
|
||||
|
||||
self.update_handler.side_effect = update
|
||||
|
||||
self.update_handler.reset_mock()
|
||||
result = yield self.store.do_background_update(
|
||||
result = yield self.store.do_next_background_update(
|
||||
duration_ms * desired_count
|
||||
)
|
||||
self.assertIsNotNone(result)
|
||||
|
@ -66,8 +75,9 @@ class BackgroundUpdateTestCase(unittest.TestCase):
|
|||
{"my_key": 2}, desired_count
|
||||
)
|
||||
|
||||
# third step: we don't expect to be called any more
|
||||
self.update_handler.reset_mock()
|
||||
result = yield self.store.do_background_update(
|
||||
result = yield self.store.do_next_background_update(
|
||||
duration_ms * desired_count
|
||||
)
|
||||
self.assertIsNone(result)
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse.api.errors
|
||||
import tests.unittest
|
||||
import tests.utils
|
||||
|
||||
|
@ -67,3 +68,38 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
|||
"device_id": "device2",
|
||||
"display_name": "display_name 2",
|
||||
}, res["device2"])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_update_device(self):
|
||||
yield self.store.store_device(
|
||||
"user_id", "device_id", "display_name 1"
|
||||
)
|
||||
|
||||
res = yield self.store.get_device("user_id", "device_id")
|
||||
self.assertEqual("display_name 1", res["display_name"])
|
||||
|
||||
# do a no-op first
|
||||
yield self.store.update_device(
|
||||
"user_id", "device_id",
|
||||
)
|
||||
res = yield self.store.get_device("user_id", "device_id")
|
||||
self.assertEqual("display_name 1", res["display_name"])
|
||||
|
||||
# do the update
|
||||
yield self.store.update_device(
|
||||
"user_id", "device_id",
|
||||
new_display_name="display_name 2",
|
||||
)
|
||||
|
||||
# check it worked
|
||||
res = yield self.store.get_device("user_id", "device_id")
|
||||
self.assertEqual("display_name 2", res["display_name"])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_update_unknown_device(self):
|
||||
with self.assertRaises(synapse.api.errors.StoreError) as cm:
|
||||
yield self.store.update_device(
|
||||
"user_id", "unknown_device_id",
|
||||
new_display_name="display_name 2",
|
||||
)
|
||||
self.assertEqual(404, cm.exception.code)
|
||||
|
|
41
tests/storage/test_event_push_actions.py
Normal file
41
tests/storage/test_event_push_actions.py
Normal file
|
@ -0,0 +1,41 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
import tests.unittest
|
||||
import tests.utils
|
||||
|
||||
USER_ID = "@user:example.com"
|
||||
|
||||
|
||||
class EventPushActionsStoreTestCase(tests.unittest.TestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
hs = yield tests.utils.setup_test_homeserver()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_unread_push_actions_for_user_in_range_for_http(self):
|
||||
yield self.store.get_unread_push_actions_for_user_in_range_for_http(
|
||||
USER_ID, 0, 1000, 20
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_unread_push_actions_for_user_in_range_for_email(self):
|
||||
yield self.store.get_unread_push_actions_for_user_in_range_for_email(
|
||||
USER_ID, 0, 1000, 20
|
||||
)
|
|
@ -128,6 +128,40 @@ class RegistrationStoreTestCase(unittest.TestCase):
|
|||
with self.assertRaises(StoreError):
|
||||
yield self.store.exchange_refresh_token(last_token, generator.generate)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_user_delete_access_tokens(self):
|
||||
# add some tokens
|
||||
generator = TokenGenerator()
|
||||
refresh_token = generator.generate(self.user_id)
|
||||
yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
|
||||
yield self.store.add_access_token_to_user(self.user_id, self.tokens[1],
|
||||
self.device_id)
|
||||
yield self.store.add_refresh_token_to_user(self.user_id, refresh_token,
|
||||
self.device_id)
|
||||
|
||||
# now delete some
|
||||
yield self.store.user_delete_access_tokens(
|
||||
self.user_id, device_id=self.device_id, delete_refresh_tokens=True)
|
||||
|
||||
# check they were deleted
|
||||
user = yield self.store.get_user_by_access_token(self.tokens[1])
|
||||
self.assertIsNone(user, "access token was not deleted by device_id")
|
||||
with self.assertRaises(StoreError):
|
||||
yield self.store.exchange_refresh_token(refresh_token,
|
||||
generator.generate)
|
||||
|
||||
# check the one not associated with the device was not deleted
|
||||
user = yield self.store.get_user_by_access_token(self.tokens[0])
|
||||
self.assertEqual(self.user_id, user["name"])
|
||||
|
||||
# now delete the rest
|
||||
yield self.store.user_delete_access_tokens(
|
||||
self.user_id, delete_refresh_tokens=True)
|
||||
|
||||
user = yield self.store.get_user_by_access_token(self.tokens[0])
|
||||
self.assertIsNone(user,
|
||||
"access token was not deleted without device_id")
|
||||
|
||||
|
||||
class TokenGenerator:
|
||||
def __init__(self):
|
||||
|
|
|
@ -17,13 +17,18 @@ from twisted.trial import unittest
|
|||
|
||||
import logging
|
||||
|
||||
|
||||
# logging doesn't have a "don't log anything at all EVARRRR setting,
|
||||
# but since the highest value is 50, 1000000 should do ;)
|
||||
NEVER = 1000000
|
||||
|
||||
logging.getLogger().addHandler(logging.StreamHandler())
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(logging.Formatter(
|
||||
"%(levelname)s:%(name)s:%(message)s [%(pathname)s:%(lineno)d]"
|
||||
))
|
||||
logging.getLogger().addHandler(handler)
|
||||
logging.getLogger().setLevel(NEVER)
|
||||
logging.getLogger("synapse.storage.SQL").setLevel(NEVER)
|
||||
logging.getLogger("synapse.storage.txn").setLevel(NEVER)
|
||||
|
||||
|
||||
def around(target):
|
||||
|
@ -70,8 +75,6 @@ class TestCase(unittest.TestCase):
|
|||
return ret
|
||||
|
||||
logging.getLogger().setLevel(level)
|
||||
# Don't set SQL logging
|
||||
logging.getLogger("synapse.storage").setLevel(old_level)
|
||||
return orig()
|
||||
|
||||
def assertObjectHasAttributes(self, attrs, obj):
|
||||
|
|
|
@ -20,7 +20,6 @@ from synapse.storage.prepare_database import prepare_database
|
|||
from synapse.storage.engines import create_engine
|
||||
from synapse.server import HomeServer
|
||||
from synapse.federation.transport import server
|
||||
from synapse.types import Requester
|
||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
|
@ -512,7 +511,3 @@ class DeferredMockCallable(object):
|
|||
"call(%s)" % _format_call(c[0], c[1]) for c in calls
|
||||
])
|
||||
)
|
||||
|
||||
|
||||
def requester_for_user(user):
|
||||
return Requester(user, None, False)
|
||||
|
|
Loading…
Reference in a new issue