0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2025-01-21 00:51:58 +01:00

Merge branch 'release-v0.19.0' of github.com:matrix-org/synapse

This commit is contained in:
Erik Johnston 2017-02-04 08:28:56 +00:00
commit fad3a84335
102 changed files with 3614 additions and 1783 deletions

View file

@ -1,3 +1,78 @@
Changes in synapse v0.19.0 (2017-02-04)
=======================================
No changes since RC 4.
Changes in synapse v0.19.0-rc4 (2017-02-02)
===========================================
* Bump cache sizes for common membership queries (PR #1879)
Changes in synapse v0.19.0-rc3 (2017-02-02)
===========================================
* Fix email push in pusher worker (PR #1875)
* Make presence.get_new_events a bit faster (PR #1876)
* Make /keys/changes a bit more performant (PR #1877)
Changes in synapse v0.19.0-rc2 (2017-02-02)
===========================================
* Include newly joined users in /keys/changes API (PR #1872)
Changes in synapse v0.19.0-rc1 (2017-02-02)
===========================================
Features:
* Add support for specifying multiple bind addresses (PR #1709, #1712, #1795,
#1835). Thanks to @kyrias!
* Add /account/3pid/delete endpoint (PR #1714)
* Add config option to configure the Riot URL used in notification emails (PR
#1811). Thanks to @aperezdc!
* Add username and password config options for turn server (PR #1832). Thanks
to @xsteadfastx!
* Implement device lists updates over federation (PR #1857, #1861, #1864)
* Implement /keys/changes (PR #1869, #1872)
Changes:
* Improve IPv6 support (PR #1696). Thanks to @kyrias and @glyph!
* Log which files we saved attachments to in the media_repository (PR #1791)
* Linearize updates to membership via PUT /state/ to better handle multiple
joins (PR #1787)
* Limit number of entries to prefill from cache on startup (PR #1792)
* Remove full_twisted_stacktraces option (PR #1802)
* Measure size of some caches by sum of the size of cached values (PR #1815)
* Measure metrics of string_cache (PR #1821)
* Reduce logging verbosity (PR #1822, #1823, #1824)
* Don't clobber a displayname or avatar_url if provided by an m.room.member
event (PR #1852)
* Better handle 401/404 response for federation /send/ (PR #1866, #1871)
Fixes:
* Fix ability to change password to a non-ascii one (PR #1711)
* Fix push getting stuck due to looking at the wrong view of state (PR #1820)
* Fix email address comparison to be case insensitive (PR #1827)
* Fix occasional inconsistencies of room membership (PR #1836, #1840)
Performance:
* Don't block messages sending on bumping presence (PR #1789)
* Change device_inbox stream index to include user (PR #1793)
* Optimise state resolution (PR #1818)
* Use DB cache of joined users for presence (PR #1862)
* Add an index to make membership queries faster (PR #1867)
Changes in synapse v0.18.7 (2017-01-09) Changes in synapse v0.18.7 (2017-01-09)
======================================= =======================================
@ -30,6 +105,7 @@ Changes in synapse v0.18.6 (2017-01-06)
Bug fixes: Bug fixes:
* Fix bug when checking if a guest user is allowed to join a room (PR #1772) * Fix bug when checking if a guest user is allowed to join a room (PR #1772)
Thanks to Patrik Oldsberg for diagnosing and the fix!
Changes in synapse v0.18.6-rc3 (2017-01-05) Changes in synapse v0.18.6-rc3 (2017-01-05)

View file

@ -138,6 +138,7 @@ Installing prerequisites on openSUSE::
python-devel libffi-devel libopenssl-devel libjpeg62-devel python-devel libffi-devel libopenssl-devel libjpeg62-devel
Installing prerequisites on OpenBSD:: Installing prerequisites on OpenBSD::
doas pkg_add python libffi py-pip py-setuptools sqlite3 py-virtualenv \ doas pkg_add python libffi py-pip py-setuptools sqlite3 py-virtualenv \
libxslt libxslt
@ -658,7 +659,7 @@ configuration might look like::
} }
} }
You will also want to set ``bind_address: 127.0.0.1`` and ``x_forwarded: true`` You will also want to set ``bind_addresses: ['127.0.0.1']`` and ``x_forwarded: true``
for port 8008 in ``homeserver.yaml`` to ensure that client IP addresses are for port 8008 in ``homeserver.yaml`` to ensure that client IP addresses are
recorded correctly. recorded correctly.

View file

@ -6,11 +6,9 @@ media.
The API is:: The API is::
POST /_matrix/client/r0/admin/purge_media_cache POST /_matrix/client/r0/admin/purge_media_cache?before_ts=<unix_timestamp_in_ms>&access_token=<access_token>
{ {}
"before_ts": <unix_timestamp_in_ms>
}
Which will remove all cached media that was last accessed before Which will remove all cached media that was last accessed before
``<unix_timestamp_in_ms>``. ``<unix_timestamp_in_ms>``.

View file

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.18.7" __version__ = "0.19.0"

View file

@ -16,18 +16,14 @@
import logging import logging
import pymacaroons import pymacaroons
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json, SignatureVerifyException
from twisted.internet import defer from twisted.internet import defer
from unpaddedbase64 import decode_base64
import synapse.types import synapse.types
from synapse import event_auth
from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError from synapse.api.errors import AuthError, Codes
from synapse.types import UserID, get_domain_from_id from synapse.types import UserID
from synapse.util.logcontext import preserve_context_over_fn from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.logutils import log_function
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -78,147 +74,7 @@ class Auth(object):
True if the auth checks pass. True if the auth checks pass.
""" """
with Measure(self.clock, "auth.check"): with Measure(self.clock, "auth.check"):
self.check_size_limits(event) event_auth.check(event, auth_events, do_sig_check=do_sig_check)
if not hasattr(event, "room_id"):
raise AuthError(500, "Event has no room_id: %s" % event)
if do_sig_check:
sender_domain = get_domain_from_id(event.sender)
event_id_domain = get_domain_from_id(event.event_id)
is_invite_via_3pid = (
event.type == EventTypes.Member
and event.membership == Membership.INVITE
and "third_party_invite" in event.content
)
# Check the sender's domain has signed the event
if not event.signatures.get(sender_domain):
# We allow invites via 3pid to have a sender from a different
# HS, as the sender must match the sender of the original
# 3pid invite. This is checked further down with the
# other dedicated membership checks.
if not is_invite_via_3pid:
raise AuthError(403, "Event not signed by sender's server")
# Check the event_id's domain has signed the event
if not event.signatures.get(event_id_domain):
raise AuthError(403, "Event not signed by sending server")
if auth_events is None:
# Oh, we don't know what the state of the room was, so we
# are trusting that this is allowed (at least for now)
logger.warn("Trusting event: %s", event.event_id)
return True
if event.type == EventTypes.Create:
room_id_domain = get_domain_from_id(event.room_id)
if room_id_domain != sender_domain:
raise AuthError(
403,
"Creation event's room_id domain does not match sender's"
)
# FIXME
return True
creation_event = auth_events.get((EventTypes.Create, ""), None)
if not creation_event:
raise SynapseError(
403,
"Room %r does not exist" % (event.room_id,)
)
creating_domain = get_domain_from_id(event.room_id)
originating_domain = get_domain_from_id(event.sender)
if creating_domain != originating_domain:
if not self.can_federate(event, auth_events):
raise AuthError(
403,
"This room has been marked as unfederatable."
)
# FIXME: Temp hack
if event.type == EventTypes.Aliases:
if not event.is_state():
raise AuthError(
403,
"Alias event must be a state event",
)
if not event.state_key:
raise AuthError(
403,
"Alias event must have non-empty state_key"
)
sender_domain = get_domain_from_id(event.sender)
if event.state_key != sender_domain:
raise AuthError(
403,
"Alias event's state_key does not match sender's domain"
)
return True
logger.debug(
"Auth events: %s",
[a.event_id for a in auth_events.values()]
)
if event.type == EventTypes.Member:
allowed = self.is_membership_change_allowed(
event, auth_events
)
if allowed:
logger.debug("Allowing! %s", event)
else:
logger.debug("Denying! %s", event)
return allowed
self.check_event_sender_in_room(event, auth_events)
# Special case to allow m.room.third_party_invite events wherever
# a user is allowed to issue invites. Fixes
# https://github.com/vector-im/vector-web/issues/1208 hopefully
if event.type == EventTypes.ThirdPartyInvite:
user_level = self._get_user_power_level(event.user_id, auth_events)
invite_level = self._get_named_level(auth_events, "invite", 0)
if user_level < invite_level:
raise AuthError(
403, (
"You cannot issue a third party invite for %s." %
(event.content.display_name,)
)
)
else:
return True
self._can_send_event(event, auth_events)
if event.type == EventTypes.PowerLevels:
self._check_power_levels(event, auth_events)
if event.type == EventTypes.Redaction:
self.check_redaction(event, auth_events)
logger.debug("Allowing! %s", event)
def check_size_limits(self, event):
def too_big(field):
raise EventSizeError("%s too large" % (field,))
if len(event.user_id) > 255:
too_big("user_id")
if len(event.room_id) > 255:
too_big("room_id")
if event.is_state() and len(event.state_key) > 255:
too_big("state_key")
if len(event.type) > 255:
too_big("type")
if len(event.event_id) > 255:
too_big("event_id")
if len(encode_canonical_json(event.get_pdu_json())) > 65536:
too_big("event")
@defer.inlineCallbacks @defer.inlineCallbacks
def check_joined_room(self, room_id, user_id, current_state=None): def check_joined_room(self, room_id, user_id, current_state=None):
@ -290,7 +146,7 @@ class Auth(object):
with Measure(self.clock, "check_host_in_room"): with Measure(self.clock, "check_host_in_room"):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.info("calling resolve_state_groups from check_host_in_room") logger.debug("calling resolve_state_groups from check_host_in_room")
entry = yield self.state.resolve_state_groups( entry = yield self.state.resolve_state_groups(
room_id, latest_event_ids room_id, latest_event_ids
) )
@ -300,16 +156,6 @@ class Auth(object):
) )
defer.returnValue(ret) defer.returnValue(ret)
def check_event_sender_in_room(self, event, auth_events):
key = (EventTypes.Member, event.user_id, )
member_event = auth_events.get(key)
return self._check_joined_room(
member_event,
event.user_id,
event.room_id
)
def _check_joined_room(self, member, user_id, room_id): def _check_joined_room(self, member, user_id, room_id):
if not member or member.membership != Membership.JOIN: if not member or member.membership != Membership.JOIN:
raise AuthError(403, "User %s not in room %s (%s)" % ( raise AuthError(403, "User %s not in room %s (%s)" % (
@ -321,267 +167,8 @@ class Auth(object):
return creation_event.content.get("m.federate", True) is True return creation_event.content.get("m.federate", True) is True
@log_function
def is_membership_change_allowed(self, event, auth_events):
membership = event.content["membership"]
# Check if this is the room creator joining:
if len(event.prev_events) == 1 and Membership.JOIN == membership:
# Get room creation event:
key = (EventTypes.Create, "", )
create = auth_events.get(key)
if create and event.prev_events[0][0] == create.event_id:
if create.content["creator"] == event.state_key:
return True
target_user_id = event.state_key
creating_domain = get_domain_from_id(event.room_id)
target_domain = get_domain_from_id(target_user_id)
if creating_domain != target_domain:
if not self.can_federate(event, auth_events):
raise AuthError(
403,
"This room has been marked as unfederatable."
)
# get info about the caller
key = (EventTypes.Member, event.user_id, )
caller = auth_events.get(key)
caller_in_room = caller and caller.membership == Membership.JOIN
caller_invited = caller and caller.membership == Membership.INVITE
# get info about the target
key = (EventTypes.Member, target_user_id, )
target = auth_events.get(key)
target_in_room = target and target.membership == Membership.JOIN
target_banned = target and target.membership == Membership.BAN
key = (EventTypes.JoinRules, "", )
join_rule_event = auth_events.get(key)
if join_rule_event:
join_rule = join_rule_event.content.get(
"join_rule", JoinRules.INVITE
)
else:
join_rule = JoinRules.INVITE
user_level = self._get_user_power_level(event.user_id, auth_events)
target_level = self._get_user_power_level(
target_user_id, auth_events
)
# FIXME (erikj): What should we do here as the default?
ban_level = self._get_named_level(auth_events, "ban", 50)
logger.debug(
"is_membership_change_allowed: %s",
{
"caller_in_room": caller_in_room,
"caller_invited": caller_invited,
"target_banned": target_banned,
"target_in_room": target_in_room,
"membership": membership,
"join_rule": join_rule,
"target_user_id": target_user_id,
"event.user_id": event.user_id,
}
)
if Membership.INVITE == membership and "third_party_invite" in event.content:
if not self._verify_third_party_invite(event, auth_events):
raise AuthError(403, "You are not invited to this room.")
if target_banned:
raise AuthError(
403, "%s is banned from the room" % (target_user_id,)
)
return True
if Membership.JOIN != membership:
if (caller_invited
and Membership.LEAVE == membership
and target_user_id == event.user_id):
return True
if not caller_in_room: # caller isn't joined
raise AuthError(
403,
"%s not in room %s." % (event.user_id, event.room_id,)
)
if Membership.INVITE == membership:
# TODO (erikj): We should probably handle this more intelligently
# PRIVATE join rules.
# Invites are valid iff caller is in the room and target isn't.
if target_banned:
raise AuthError(
403, "%s is banned from the room" % (target_user_id,)
)
elif target_in_room: # the target is already in the room.
raise AuthError(403, "%s is already in the room." %
target_user_id)
else:
invite_level = self._get_named_level(auth_events, "invite", 0)
if user_level < invite_level:
raise AuthError(
403, "You cannot invite user %s." % target_user_id
)
elif Membership.JOIN == membership:
# Joins are valid iff caller == target and they were:
# invited: They are accepting the invitation
# joined: It's a NOOP
if event.user_id != target_user_id:
raise AuthError(403, "Cannot force another user to join.")
elif target_banned:
raise AuthError(403, "You are banned from this room")
elif join_rule == JoinRules.PUBLIC:
pass
elif join_rule == JoinRules.INVITE:
if not caller_in_room and not caller_invited:
raise AuthError(403, "You are not invited to this room.")
else:
# TODO (erikj): may_join list
# TODO (erikj): private rooms
raise AuthError(403, "You are not allowed to join this room")
elif Membership.LEAVE == membership:
# TODO (erikj): Implement kicks.
if target_banned and user_level < ban_level:
raise AuthError(
403, "You cannot unban user &s." % (target_user_id,)
)
elif target_user_id != event.user_id:
kick_level = self._get_named_level(auth_events, "kick", 50)
if user_level < kick_level or user_level <= target_level:
raise AuthError(
403, "You cannot kick user %s." % target_user_id
)
elif Membership.BAN == membership:
if user_level < ban_level or user_level <= target_level:
raise AuthError(403, "You don't have permission to ban")
else:
raise AuthError(500, "Unknown membership %s" % membership)
return True
def _verify_third_party_invite(self, event, auth_events):
"""
Validates that the invite event is authorized by a previous third-party invite.
Checks that the public key, and keyserver, match those in the third party invite,
and that the invite event has a signature issued using that public key.
Args:
event: The m.room.member join event being validated.
auth_events: All relevant previous context events which may be used
for authorization decisions.
Return:
True if the event fulfills the expectations of a previous third party
invite event.
"""
if "third_party_invite" not in event.content:
return False
if "signed" not in event.content["third_party_invite"]:
return False
signed = event.content["third_party_invite"]["signed"]
for key in {"mxid", "token"}:
if key not in signed:
return False
token = signed["token"]
invite_event = auth_events.get(
(EventTypes.ThirdPartyInvite, token,)
)
if not invite_event:
return False
if invite_event.sender != event.sender:
return False
if event.user_id != invite_event.user_id:
return False
if signed["mxid"] != event.state_key:
return False
if signed["token"] != token:
return False
for public_key_object in self.get_public_keys(invite_event):
public_key = public_key_object["public_key"]
try:
for server, signature_block in signed["signatures"].items():
for key_name, encoded_signature in signature_block.items():
if not key_name.startswith("ed25519:"):
continue
verify_key = decode_verify_key_bytes(
key_name,
decode_base64(public_key)
)
verify_signed_json(signed, server, verify_key)
# We got the public key from the invite, so we know that the
# correct server signed the signed bundle.
# The caller is responsible for checking that the signing
# server has not revoked that public key.
return True
except (KeyError, SignatureVerifyException,):
continue
return False
def get_public_keys(self, invite_event): def get_public_keys(self, invite_event):
public_keys = [] return event_auth.get_public_keys(invite_event)
if "public_key" in invite_event.content:
o = {
"public_key": invite_event.content["public_key"],
}
if "key_validity_url" in invite_event.content:
o["key_validity_url"] = invite_event.content["key_validity_url"]
public_keys.append(o)
public_keys.extend(invite_event.content.get("public_keys", []))
return public_keys
def _get_power_level_event(self, auth_events):
key = (EventTypes.PowerLevels, "", )
return auth_events.get(key)
def _get_user_power_level(self, user_id, auth_events):
power_level_event = self._get_power_level_event(auth_events)
if power_level_event:
level = power_level_event.content.get("users", {}).get(user_id)
if not level:
level = power_level_event.content.get("users_default", 0)
if level is None:
return 0
else:
return int(level)
else:
key = (EventTypes.Create, "", )
create_event = auth_events.get(key)
if (create_event is not None and
create_event.content["creator"] == user_id):
return 100
else:
return 0
def _get_named_level(self, auth_events, name, default):
power_level_event = self._get_power_level_event(auth_events)
if not power_level_event:
return default
level = power_level_event.content.get(name, None)
if level is not None:
return int(level)
else:
return default
@defer.inlineCallbacks @defer.inlineCallbacks
def get_user_by_req(self, request, allow_guest=False, rights="access"): def get_user_by_req(self, request, allow_guest=False, rights="access"):
@ -974,56 +561,6 @@ class Auth(object):
defer.returnValue(auth_ids) defer.returnValue(auth_ids)
def _get_send_level(self, etype, state_key, auth_events):
key = (EventTypes.PowerLevels, "", )
send_level_event = auth_events.get(key)
send_level = None
if send_level_event:
send_level = send_level_event.content.get("events", {}).get(
etype
)
if send_level is None:
if state_key is not None:
send_level = send_level_event.content.get(
"state_default", 50
)
else:
send_level = send_level_event.content.get(
"events_default", 0
)
if send_level:
send_level = int(send_level)
else:
send_level = 0
return send_level
@log_function
def _can_send_event(self, event, auth_events):
send_level = self._get_send_level(
event.type, event.get("state_key", None), auth_events
)
user_level = self._get_user_power_level(event.user_id, auth_events)
if user_level < send_level:
raise AuthError(
403,
"You don't have permission to post that to the room. " +
"user_level (%d) < send_level (%d)" % (user_level, send_level)
)
# Check state_key
if hasattr(event, "state_key"):
if event.state_key.startswith("@"):
if event.state_key != event.user_id:
raise AuthError(
403,
"You are not allowed to set others state"
)
return True
def check_redaction(self, event, auth_events): def check_redaction(self, event, auth_events):
"""Check whether the event sender is allowed to redact the target event. """Check whether the event sender is allowed to redact the target event.
@ -1037,107 +574,7 @@ class Auth(object):
AuthError if the event sender is definitely not allowed to redact AuthError if the event sender is definitely not allowed to redact
the target event. the target event.
""" """
user_level = self._get_user_power_level(event.user_id, auth_events) return event_auth.check_redaction(event, auth_events)
redact_level = self._get_named_level(auth_events, "redact", 50)
if user_level >= redact_level:
return False
redacter_domain = get_domain_from_id(event.event_id)
redactee_domain = get_domain_from_id(event.redacts)
if redacter_domain == redactee_domain:
return True
raise AuthError(
403,
"You don't have permission to redact events"
)
def _check_power_levels(self, event, auth_events):
user_list = event.content.get("users", {})
# Validate users
for k, v in user_list.items():
try:
UserID.from_string(k)
except:
raise SynapseError(400, "Not a valid user_id: %s" % (k,))
try:
int(v)
except:
raise SynapseError(400, "Not a valid power level: %s" % (v,))
key = (event.type, event.state_key, )
current_state = auth_events.get(key)
if not current_state:
return
user_level = self._get_user_power_level(event.user_id, auth_events)
# Check other levels:
levels_to_check = [
("users_default", None),
("events_default", None),
("state_default", None),
("ban", None),
("redact", None),
("kick", None),
("invite", None),
]
old_list = current_state.content.get("users")
for user in set(old_list.keys() + user_list.keys()):
levels_to_check.append(
(user, "users")
)
old_list = current_state.content.get("events")
new_list = event.content.get("events")
for ev_id in set(old_list.keys() + new_list.keys()):
levels_to_check.append(
(ev_id, "events")
)
old_state = current_state.content
new_state = event.content
for level_to_check, dir in levels_to_check:
old_loc = old_state
new_loc = new_state
if dir:
old_loc = old_loc.get(dir, {})
new_loc = new_loc.get(dir, {})
if level_to_check in old_loc:
old_level = int(old_loc[level_to_check])
else:
old_level = None
if level_to_check in new_loc:
new_level = int(new_loc[level_to_check])
else:
new_level = None
if new_level is not None and old_level is not None:
if new_level == old_level:
continue
if dir == "users" and level_to_check != event.user_id:
if old_level == user_level:
raise AuthError(
403,
"You don't have permission to remove ops level equal "
"to your own"
)
if old_level > user_level or new_level > user_level:
raise AuthError(
403,
"You don't have permission to add ops level greater "
"than your own"
)
@defer.inlineCallbacks @defer.inlineCallbacks
def check_can_change_room_list(self, room_id, user): def check_can_change_room_list(self, room_id, user):
@ -1167,10 +604,10 @@ class Auth(object):
if power_level_event: if power_level_event:
auth_events[(EventTypes.PowerLevels, "")] = power_level_event auth_events[(EventTypes.PowerLevels, "")] = power_level_event
send_level = self._get_send_level( send_level = event_auth.get_send_level(
EventTypes.Aliases, "", auth_events EventTypes.Aliases, "", auth_events
) )
user_level = self._get_user_power_level(user_id, auth_events) user_level = event_auth.get_user_power_level(user_id, auth_events)
if user_level < send_level: if user_level < send_level:
raise AuthError( raise AuthError(

View file

@ -76,7 +76,7 @@ class AppserviceServer(HomeServer):
def _listen_http(self, listener_config): def _listen_http(self, listener_config):
port = listener_config["port"] port = listener_config["port"]
bind_address = listener_config.get("bind_address", "") bind_addresses = listener_config["bind_addresses"]
site_tag = listener_config.get("tag", port) site_tag = listener_config.get("tag", port)
resources = {} resources = {}
for res in listener_config["resources"]: for res in listener_config["resources"]:
@ -85,6 +85,8 @@ class AppserviceServer(HomeServer):
resources[METRICS_PREFIX] = MetricsResource(self) resources[METRICS_PREFIX] = MetricsResource(self)
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, Resource())
for address in bind_addresses:
reactor.listenTCP( reactor.listenTCP(
port, port,
SynapseSite( SynapseSite(
@ -93,8 +95,9 @@ class AppserviceServer(HomeServer):
listener_config, listener_config,
root_resource, root_resource,
), ),
interface=bind_address interface=address
) )
logger.info("Synapse appservice now listening on port %d", port) logger.info("Synapse appservice now listening on port %d", port)
def start_listening(self, listeners): def start_listening(self, listeners):
@ -102,6 +105,9 @@ class AppserviceServer(HomeServer):
if listener["type"] == "http": if listener["type"] == "http":
self._listen_http(listener) self._listen_http(listener)
elif listener["type"] == "manhole": elif listener["type"] == "manhole":
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP( reactor.listenTCP(
listener["port"], listener["port"],
manhole( manhole(
@ -109,7 +115,7 @@ class AppserviceServer(HomeServer):
password="rabbithole", password="rabbithole",
globals={"hs": self}, globals={"hs": self},
), ),
interface=listener.get("bind_address", '127.0.0.1') interface=address
) )
else: else:
logger.warn("Unrecognized listener type: %s", listener["type"]) logger.warn("Unrecognized listener type: %s", listener["type"])

View file

@ -90,7 +90,7 @@ class ClientReaderServer(HomeServer):
def _listen_http(self, listener_config): def _listen_http(self, listener_config):
port = listener_config["port"] port = listener_config["port"]
bind_address = listener_config.get("bind_address", "") bind_addresses = listener_config["bind_addresses"]
site_tag = listener_config.get("tag", port) site_tag = listener_config.get("tag", port)
resources = {} resources = {}
for res in listener_config["resources"]: for res in listener_config["resources"]:
@ -108,6 +108,8 @@ class ClientReaderServer(HomeServer):
}) })
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, Resource())
for address in bind_addresses:
reactor.listenTCP( reactor.listenTCP(
port, port,
SynapseSite( SynapseSite(
@ -116,8 +118,9 @@ class ClientReaderServer(HomeServer):
listener_config, listener_config,
root_resource, root_resource,
), ),
interface=bind_address interface=address
) )
logger.info("Synapse client reader now listening on port %d", port) logger.info("Synapse client reader now listening on port %d", port)
def start_listening(self, listeners): def start_listening(self, listeners):
@ -125,6 +128,9 @@ class ClientReaderServer(HomeServer):
if listener["type"] == "http": if listener["type"] == "http":
self._listen_http(listener) self._listen_http(listener)
elif listener["type"] == "manhole": elif listener["type"] == "manhole":
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP( reactor.listenTCP(
listener["port"], listener["port"],
manhole( manhole(
@ -132,7 +138,7 @@ class ClientReaderServer(HomeServer):
password="rabbithole", password="rabbithole",
globals={"hs": self}, globals={"hs": self},
), ),
interface=listener.get("bind_address", '127.0.0.1') interface=address
) )
else: else:
logger.warn("Unrecognized listener type: %s", listener["type"]) logger.warn("Unrecognized listener type: %s", listener["type"])

View file

@ -86,7 +86,7 @@ class FederationReaderServer(HomeServer):
def _listen_http(self, listener_config): def _listen_http(self, listener_config):
port = listener_config["port"] port = listener_config["port"]
bind_address = listener_config.get("bind_address", "") bind_addresses = listener_config["bind_addresses"]
site_tag = listener_config.get("tag", port) site_tag = listener_config.get("tag", port)
resources = {} resources = {}
for res in listener_config["resources"]: for res in listener_config["resources"]:
@ -99,6 +99,8 @@ class FederationReaderServer(HomeServer):
}) })
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, Resource())
for address in bind_addresses:
reactor.listenTCP( reactor.listenTCP(
port, port,
SynapseSite( SynapseSite(
@ -107,8 +109,9 @@ class FederationReaderServer(HomeServer):
listener_config, listener_config,
root_resource, root_resource,
), ),
interface=bind_address interface=address
) )
logger.info("Synapse federation reader now listening on port %d", port) logger.info("Synapse federation reader now listening on port %d", port)
def start_listening(self, listeners): def start_listening(self, listeners):
@ -116,6 +119,9 @@ class FederationReaderServer(HomeServer):
if listener["type"] == "http": if listener["type"] == "http":
self._listen_http(listener) self._listen_http(listener)
elif listener["type"] == "manhole": elif listener["type"] == "manhole":
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP( reactor.listenTCP(
listener["port"], listener["port"],
manhole( manhole(
@ -123,7 +129,7 @@ class FederationReaderServer(HomeServer):
password="rabbithole", password="rabbithole",
globals={"hs": self}, globals={"hs": self},
), ),
interface=listener.get("bind_address", '127.0.0.1') interface=address
) )
else: else:
logger.warn("Unrecognized listener type: %s", listener["type"]) logger.warn("Unrecognized listener type: %s", listener["type"])

View file

@ -30,6 +30,7 @@ from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.transactions import TransactionStore from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.storage.presence import UserPresenceState from synapse.storage.presence import UserPresenceState
from synapse.util.async import sleep from synapse.util.async import sleep
@ -56,7 +57,7 @@ logger = logging.getLogger("synapse.app.appservice")
class FederationSenderSlaveStore( class FederationSenderSlaveStore(
SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore, SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore,
SlavedRegistrationStore, SlavedRegistrationStore, SlavedDeviceStore,
): ):
pass pass
@ -82,7 +83,7 @@ class FederationSenderServer(HomeServer):
def _listen_http(self, listener_config): def _listen_http(self, listener_config):
port = listener_config["port"] port = listener_config["port"]
bind_address = listener_config.get("bind_address", "") bind_addresses = listener_config["bind_addresses"]
site_tag = listener_config.get("tag", port) site_tag = listener_config.get("tag", port)
resources = {} resources = {}
for res in listener_config["resources"]: for res in listener_config["resources"]:
@ -91,6 +92,8 @@ class FederationSenderServer(HomeServer):
resources[METRICS_PREFIX] = MetricsResource(self) resources[METRICS_PREFIX] = MetricsResource(self)
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, Resource())
for address in bind_addresses:
reactor.listenTCP( reactor.listenTCP(
port, port,
SynapseSite( SynapseSite(
@ -99,8 +102,9 @@ class FederationSenderServer(HomeServer):
listener_config, listener_config,
root_resource, root_resource,
), ),
interface=bind_address interface=address
) )
logger.info("Synapse federation_sender now listening on port %d", port) logger.info("Synapse federation_sender now listening on port %d", port)
def start_listening(self, listeners): def start_listening(self, listeners):
@ -108,6 +112,9 @@ class FederationSenderServer(HomeServer):
if listener["type"] == "http": if listener["type"] == "http":
self._listen_http(listener) self._listen_http(listener)
elif listener["type"] == "manhole": elif listener["type"] == "manhole":
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP( reactor.listenTCP(
listener["port"], listener["port"],
manhole( manhole(
@ -115,7 +122,7 @@ class FederationSenderServer(HomeServer):
password="rabbithole", password="rabbithole",
globals={"hs": self}, globals={"hs": self},
), ),
interface=listener.get("bind_address", '127.0.0.1') interface=address
) )
else: else:
logger.warn("Unrecognized listener type: %s", listener["type"]) logger.warn("Unrecognized listener type: %s", listener["type"])

View file

@ -107,7 +107,7 @@ def build_resource_for_web_client(hs):
class SynapseHomeServer(HomeServer): class SynapseHomeServer(HomeServer):
def _listener_http(self, config, listener_config): def _listener_http(self, config, listener_config):
port = listener_config["port"] port = listener_config["port"]
bind_address = listener_config.get("bind_address", "") bind_addresses = listener_config["bind_addresses"]
tls = listener_config.get("tls", False) tls = listener_config.get("tls", False)
site_tag = listener_config.get("tag", port) site_tag = listener_config.get("tag", port)
@ -173,7 +173,9 @@ class SynapseHomeServer(HomeServer):
root_resource = Resource() root_resource = Resource()
root_resource = create_resource_tree(resources, root_resource) root_resource = create_resource_tree(resources, root_resource)
if tls: if tls:
for address in bind_addresses:
reactor.listenSSL( reactor.listenSSL(
port, port,
SynapseSite( SynapseSite(
@ -183,9 +185,10 @@ class SynapseHomeServer(HomeServer):
root_resource, root_resource,
), ),
self.tls_server_context_factory, self.tls_server_context_factory,
interface=bind_address interface=address
) )
else: else:
for address in bind_addresses:
reactor.listenTCP( reactor.listenTCP(
port, port,
SynapseSite( SynapseSite(
@ -194,7 +197,7 @@ class SynapseHomeServer(HomeServer):
listener_config, listener_config,
root_resource, root_resource,
), ),
interface=bind_address interface=address
) )
logger.info("Synapse now listening on port %d", port) logger.info("Synapse now listening on port %d", port)
@ -205,6 +208,9 @@ class SynapseHomeServer(HomeServer):
if listener["type"] == "http": if listener["type"] == "http":
self._listener_http(config, listener) self._listener_http(config, listener)
elif listener["type"] == "manhole": elif listener["type"] == "manhole":
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP( reactor.listenTCP(
listener["port"], listener["port"],
manhole( manhole(
@ -212,7 +218,7 @@ class SynapseHomeServer(HomeServer):
password="rabbithole", password="rabbithole",
globals={"hs": self}, globals={"hs": self},
), ),
interface=listener.get("bind_address", '127.0.0.1') interface=address
) )
else: else:
logger.warn("Unrecognized listener type: %s", listener["type"]) logger.warn("Unrecognized listener type: %s", listener["type"])

View file

@ -87,7 +87,7 @@ class MediaRepositoryServer(HomeServer):
def _listen_http(self, listener_config): def _listen_http(self, listener_config):
port = listener_config["port"] port = listener_config["port"]
bind_address = listener_config.get("bind_address", "") bind_addresses = listener_config["bind_addresses"]
site_tag = listener_config.get("tag", port) site_tag = listener_config.get("tag", port)
resources = {} resources = {}
for res in listener_config["resources"]: for res in listener_config["resources"]:
@ -105,6 +105,8 @@ class MediaRepositoryServer(HomeServer):
}) })
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, Resource())
for address in bind_addresses:
reactor.listenTCP( reactor.listenTCP(
port, port,
SynapseSite( SynapseSite(
@ -113,8 +115,9 @@ class MediaRepositoryServer(HomeServer):
listener_config, listener_config,
root_resource, root_resource,
), ),
interface=bind_address interface=address
) )
logger.info("Synapse media repository now listening on port %d", port) logger.info("Synapse media repository now listening on port %d", port)
def start_listening(self, listeners): def start_listening(self, listeners):
@ -122,6 +125,9 @@ class MediaRepositoryServer(HomeServer):
if listener["type"] == "http": if listener["type"] == "http":
self._listen_http(listener) self._listen_http(listener)
elif listener["type"] == "manhole": elif listener["type"] == "manhole":
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP( reactor.listenTCP(
listener["port"], listener["port"],
manhole( manhole(
@ -129,7 +135,7 @@ class MediaRepositoryServer(HomeServer):
password="rabbithole", password="rabbithole",
globals={"hs": self}, globals={"hs": self},
), ),
interface=listener.get("bind_address", '127.0.0.1') interface=address
) )
else: else:
logger.warn("Unrecognized listener type: %s", listener["type"]) logger.warn("Unrecognized listener type: %s", listener["type"])

View file

@ -121,7 +121,7 @@ class PusherServer(HomeServer):
def _listen_http(self, listener_config): def _listen_http(self, listener_config):
port = listener_config["port"] port = listener_config["port"]
bind_address = listener_config.get("bind_address", "") bind_addresses = listener_config["bind_addresses"]
site_tag = listener_config.get("tag", port) site_tag = listener_config.get("tag", port)
resources = {} resources = {}
for res in listener_config["resources"]: for res in listener_config["resources"]:
@ -130,6 +130,8 @@ class PusherServer(HomeServer):
resources[METRICS_PREFIX] = MetricsResource(self) resources[METRICS_PREFIX] = MetricsResource(self)
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, Resource())
for address in bind_addresses:
reactor.listenTCP( reactor.listenTCP(
port, port,
SynapseSite( SynapseSite(
@ -138,8 +140,9 @@ class PusherServer(HomeServer):
listener_config, listener_config,
root_resource, root_resource,
), ),
interface=bind_address interface=address
) )
logger.info("Synapse pusher now listening on port %d", port) logger.info("Synapse pusher now listening on port %d", port)
def start_listening(self, listeners): def start_listening(self, listeners):
@ -147,6 +150,9 @@ class PusherServer(HomeServer):
if listener["type"] == "http": if listener["type"] == "http":
self._listen_http(listener) self._listen_http(listener)
elif listener["type"] == "manhole": elif listener["type"] == "manhole":
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP( reactor.listenTCP(
listener["port"], listener["port"],
manhole( manhole(
@ -154,7 +160,7 @@ class PusherServer(HomeServer):
password="rabbithole", password="rabbithole",
globals={"hs": self}, globals={"hs": self},
), ),
interface=listener.get("bind_address", '127.0.0.1') interface=address
) )
else: else:
logger.warn("Unrecognized listener type: %s", listener["type"]) logger.warn("Unrecognized listener type: %s", listener["type"])

View file

@ -39,6 +39,7 @@ from synapse.replication.slave.storage.filtering import SlavedFilteringStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.presence import SlavedPresenceStore from synapse.replication.slave.storage.presence import SlavedPresenceStore
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.room import RoomStore
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore from synapse.storage.client_ips import ClientIpStore
@ -77,6 +78,7 @@ class SynchrotronSlavedStore(
SlavedFilteringStore, SlavedFilteringStore,
SlavedPresenceStore, SlavedPresenceStore,
SlavedDeviceInboxStore, SlavedDeviceInboxStore,
SlavedDeviceStore,
RoomStore, RoomStore,
BaseSlavedStore, BaseSlavedStore,
ClientIpStore, # After BaseSlavedStore because the constructor is different ClientIpStore, # After BaseSlavedStore because the constructor is different
@ -289,7 +291,7 @@ class SynchrotronServer(HomeServer):
def _listen_http(self, listener_config): def _listen_http(self, listener_config):
port = listener_config["port"] port = listener_config["port"]
bind_address = listener_config.get("bind_address", "") bind_addresses = listener_config["bind_addresses"]
site_tag = listener_config.get("tag", port) site_tag = listener_config.get("tag", port)
resources = {} resources = {}
for res in listener_config["resources"]: for res in listener_config["resources"]:
@ -310,6 +312,8 @@ class SynchrotronServer(HomeServer):
}) })
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, Resource())
for address in bind_addresses:
reactor.listenTCP( reactor.listenTCP(
port, port,
SynapseSite( SynapseSite(
@ -318,8 +322,9 @@ class SynchrotronServer(HomeServer):
listener_config, listener_config,
root_resource, root_resource,
), ),
interface=bind_address interface=address
) )
logger.info("Synapse synchrotron now listening on port %d", port) logger.info("Synapse synchrotron now listening on port %d", port)
def start_listening(self, listeners): def start_listening(self, listeners):
@ -327,6 +332,9 @@ class SynchrotronServer(HomeServer):
if listener["type"] == "http": if listener["type"] == "http":
self._listen_http(listener) self._listen_http(listener)
elif listener["type"] == "manhole": elif listener["type"] == "manhole":
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP( reactor.listenTCP(
listener["port"], listener["port"],
manhole( manhole(
@ -334,7 +342,7 @@ class SynchrotronServer(HomeServer):
password="rabbithole", password="rabbithole",
globals={"hs": self}, globals={"hs": self},
), ),
interface=listener.get("bind_address", '127.0.0.1') interface=address
) )
else: else:
logger.warn("Unrecognized listener type: %s", listener["type"]) logger.warn("Unrecognized listener type: %s", listener["type"])
@ -374,6 +382,27 @@ class SynchrotronServer(HomeServer):
stream_key, position, users=users, rooms=rooms stream_key, position, users=users, rooms=rooms
) )
@defer.inlineCallbacks
def notify_device_list_update(result):
stream = result.get("device_lists")
if not stream:
return
position_index = stream["field_names"].index("position")
user_index = stream["field_names"].index("user_id")
for row in stream["rows"]:
position = row[position_index]
user_id = row[user_index]
rooms = yield store.get_rooms_for_user(user_id)
room_ids = [r.room_id for r in rooms]
notifier.on_new_event(
"device_list_key", position, rooms=room_ids,
)
@defer.inlineCallbacks
def notify(result): def notify(result):
stream = result.get("events") stream = result.get("events")
if stream: if stream:
@ -411,6 +440,7 @@ class SynchrotronServer(HomeServer):
notify_from_stream( notify_from_stream(
result, "to_device", "to_device_key", user="user_id" result, "to_device", "to_device_key", user="user_id"
) )
yield notify_device_list_update(result)
while True: while True:
try: try:
@ -421,7 +451,7 @@ class SynchrotronServer(HomeServer):
yield store.process_replication(result) yield store.process_replication(result)
typing_handler.process_replication(result) typing_handler.process_replication(result)
yield presence_handler.process_replication(result) yield presence_handler.process_replication(result)
notify(result) yield notify(result)
except: except:
logger.exception("Error replicating from %r", replication_url) logger.exception("Error replicating from %r", replication_url)
yield sleep(5) yield sleep(5)

View file

@ -68,6 +68,9 @@ class EmailConfig(Config):
self.email_notif_for_new_users = email_config.get( self.email_notif_for_new_users = email_config.get(
"notif_for_new_users", True "notif_for_new_users", True
) )
self.email_riot_base_url = email_config.get(
"riot_base_url", None
)
if "app_name" in email_config: if "app_name" in email_config:
self.email_app_name = email_config["app_name"] self.email_app_name = email_config["app_name"]
else: else:
@ -85,6 +88,9 @@ class EmailConfig(Config):
def default_config(self, config_dir_path, server_name, **kwargs): def default_config(self, config_dir_path, server_name, **kwargs):
return """ return """
# Enable sending emails for notification events # Enable sending emails for notification events
# Defining a custom URL for Riot is only needed if email notifications
# should contain links to a self-hosted installation of Riot; when set
# the "app_name" setting is ignored.
#email: #email:
# enable_notifs: false # enable_notifs: false
# smtp_host: "localhost" # smtp_host: "localhost"
@ -95,4 +101,5 @@ class EmailConfig(Config):
# notif_template_html: notif_mail.html # notif_template_html: notif_mail.html
# notif_template_text: notif_mail.txt # notif_template_text: notif_mail.txt
# notif_for_new_users: True # notif_for_new_users: True
# riot_base_url: "http://localhost/riot"
""" """

View file

@ -22,7 +22,6 @@ import yaml
from string import Template from string import Template
import os import os
import signal import signal
from synapse.util.debug import debug_deferreds
DEFAULT_LOG_CONFIG = Template(""" DEFAULT_LOG_CONFIG = Template("""
@ -71,8 +70,6 @@ class LoggingConfig(Config):
self.verbosity = config.get("verbose", 0) self.verbosity = config.get("verbose", 0)
self.log_config = self.abspath(config.get("log_config")) self.log_config = self.abspath(config.get("log_config"))
self.log_file = self.abspath(config.get("log_file")) self.log_file = self.abspath(config.get("log_file"))
if config.get("full_twisted_stacktraces"):
debug_deferreds()
def default_config(self, config_dir_path, server_name, **kwargs): def default_config(self, config_dir_path, server_name, **kwargs):
log_file = self.abspath("homeserver.log") log_file = self.abspath("homeserver.log")
@ -88,11 +85,6 @@ class LoggingConfig(Config):
# A yaml python logging config file # A yaml python logging config file
log_config: "%(log_config)s" log_config: "%(log_config)s"
# Stop twisted from discarding the stack traces of exceptions in
# deferreds by waiting a reactor tick before running a deferred's
# callbacks.
# full_twisted_stacktraces: true
""" % locals() """ % locals()
def read_arguments(self, args): def read_arguments(self, args):

View file

@ -42,6 +42,15 @@ class ServerConfig(Config):
self.listeners = config.get("listeners", []) self.listeners = config.get("listeners", [])
for listener in self.listeners:
bind_address = listener.pop("bind_address", None)
bind_addresses = listener.setdefault("bind_addresses", [])
if bind_address:
bind_addresses.append(bind_address)
elif not bind_addresses:
bind_addresses.append('')
self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None)) self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None))
bind_port = config.get("bind_port") bind_port = config.get("bind_port")
@ -54,7 +63,7 @@ class ServerConfig(Config):
self.listeners.append({ self.listeners.append({
"port": bind_port, "port": bind_port,
"bind_address": bind_host, "bind_addresses": [bind_host],
"tls": True, "tls": True,
"type": "http", "type": "http",
"resources": [ "resources": [
@ -73,7 +82,7 @@ class ServerConfig(Config):
if unsecure_port: if unsecure_port:
self.listeners.append({ self.listeners.append({
"port": unsecure_port, "port": unsecure_port,
"bind_address": bind_host, "bind_addresses": [bind_host],
"tls": False, "tls": False,
"type": "http", "type": "http",
"resources": [ "resources": [
@ -92,7 +101,7 @@ class ServerConfig(Config):
if manhole: if manhole:
self.listeners.append({ self.listeners.append({
"port": manhole, "port": manhole,
"bind_address": "127.0.0.1", "bind_addresses": ["127.0.0.1"],
"type": "manhole", "type": "manhole",
}) })
@ -100,7 +109,7 @@ class ServerConfig(Config):
if metrics_port: if metrics_port:
self.listeners.append({ self.listeners.append({
"port": metrics_port, "port": metrics_port,
"bind_address": config.get("metrics_bind_host", "127.0.0.1"), "bind_addresses": [config.get("metrics_bind_host", "127.0.0.1")],
"tls": False, "tls": False,
"type": "http", "type": "http",
"resources": [ "resources": [
@ -155,9 +164,14 @@ class ServerConfig(Config):
# The port to listen for HTTPS requests on. # The port to listen for HTTPS requests on.
port: %(bind_port)s port: %(bind_port)s
# Local interface to listen on. # Local addresses to listen on.
# The empty string will cause synapse to listen on all interfaces. # This will listen on all IPv4 addresses by default.
bind_address: '' bind_addresses:
- '0.0.0.0'
# Uncomment to listen on all IPv6 interfaces
# N.B: On at least Linux this will also listen on all IPv4
# addresses, so you will need to comment out the line above.
# - '::'
# This is a 'http' listener, allows us to specify 'resources'. # This is a 'http' listener, allows us to specify 'resources'.
type: http type: http
@ -188,7 +202,7 @@ class ServerConfig(Config):
# For when matrix traffic passes through loadbalancer that unwraps TLS. # For when matrix traffic passes through loadbalancer that unwraps TLS.
- port: %(unsecure_port)s - port: %(unsecure_port)s
tls: false tls: false
bind_address: '' bind_addresses: ['0.0.0.0']
type: http type: http
x_forwarded: false x_forwarded: false

View file

@ -19,7 +19,9 @@ class VoipConfig(Config):
def read_config(self, config): def read_config(self, config):
self.turn_uris = config.get("turn_uris", []) self.turn_uris = config.get("turn_uris", [])
self.turn_shared_secret = config["turn_shared_secret"] self.turn_shared_secret = config.get("turn_shared_secret")
self.turn_username = config.get("turn_username")
self.turn_password = config.get("turn_password")
self.turn_user_lifetime = self.parse_duration(config["turn_user_lifetime"]) self.turn_user_lifetime = self.parse_duration(config["turn_user_lifetime"])
def default_config(self, **kwargs): def default_config(self, **kwargs):
@ -32,6 +34,11 @@ class VoipConfig(Config):
# The shared secret used to compute passwords for the TURN server # The shared secret used to compute passwords for the TURN server
turn_shared_secret: "YOUR_SHARED_SECRET" turn_shared_secret: "YOUR_SHARED_SECRET"
# The Username and password if the TURN server needs them and
# does not use a token
#turn_username: "TURNSERVER_USERNAME"
#turn_password: "TURNSERVER_PASSWORD"
# How long generated TURN credentials last # How long generated TURN credentials last
turn_user_lifetime: "1h" turn_user_lifetime: "1h"
""" """

View file

@ -29,3 +29,13 @@ class WorkerConfig(Config):
self.worker_log_file = config.get("worker_log_file") self.worker_log_file = config.get("worker_log_file")
self.worker_log_config = config.get("worker_log_config") self.worker_log_config = config.get("worker_log_config")
self.worker_replication_url = config.get("worker_replication_url") self.worker_replication_url = config.get("worker_replication_url")
if self.worker_listeners:
for listener in self.worker_listeners:
bind_address = listener.pop("bind_address", None)
bind_addresses = listener.setdefault("bind_addresses", [])
if bind_address:
bind_addresses.append(bind_address)
elif not bind_addresses:
bind_addresses.append('')

678
synapse/event_auth.py Normal file
View file

@ -0,0 +1,678 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json, SignatureVerifyException
from unpaddedbase64 import decode_base64
from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, SynapseError, EventSizeError
from synapse.types import UserID, get_domain_from_id
logger = logging.getLogger(__name__)
def check(event, auth_events, do_sig_check=True, do_size_check=True):
""" Checks if this event is correctly authed.
Args:
event: the event being checked.
auth_events (dict: event-key -> event): the existing room state.
Returns:
True if the auth checks pass.
"""
if do_size_check:
_check_size_limits(event)
if not hasattr(event, "room_id"):
raise AuthError(500, "Event has no room_id: %s" % event)
if do_sig_check:
sender_domain = get_domain_from_id(event.sender)
event_id_domain = get_domain_from_id(event.event_id)
is_invite_via_3pid = (
event.type == EventTypes.Member
and event.membership == Membership.INVITE
and "third_party_invite" in event.content
)
# Check the sender's domain has signed the event
if not event.signatures.get(sender_domain):
# We allow invites via 3pid to have a sender from a different
# HS, as the sender must match the sender of the original
# 3pid invite. This is checked further down with the
# other dedicated membership checks.
if not is_invite_via_3pid:
raise AuthError(403, "Event not signed by sender's server")
# Check the event_id's domain has signed the event
if not event.signatures.get(event_id_domain):
raise AuthError(403, "Event not signed by sending server")
if auth_events is None:
# Oh, we don't know what the state of the room was, so we
# are trusting that this is allowed (at least for now)
logger.warn("Trusting event: %s", event.event_id)
return True
if event.type == EventTypes.Create:
room_id_domain = get_domain_from_id(event.room_id)
if room_id_domain != sender_domain:
raise AuthError(
403,
"Creation event's room_id domain does not match sender's"
)
# FIXME
return True
creation_event = auth_events.get((EventTypes.Create, ""), None)
if not creation_event:
raise SynapseError(
403,
"Room %r does not exist" % (event.room_id,)
)
creating_domain = get_domain_from_id(event.room_id)
originating_domain = get_domain_from_id(event.sender)
if creating_domain != originating_domain:
if not _can_federate(event, auth_events):
raise AuthError(
403,
"This room has been marked as unfederatable."
)
# FIXME: Temp hack
if event.type == EventTypes.Aliases:
if not event.is_state():
raise AuthError(
403,
"Alias event must be a state event",
)
if not event.state_key:
raise AuthError(
403,
"Alias event must have non-empty state_key"
)
sender_domain = get_domain_from_id(event.sender)
if event.state_key != sender_domain:
raise AuthError(
403,
"Alias event's state_key does not match sender's domain"
)
return True
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"Auth events: %s",
[a.event_id for a in auth_events.values()]
)
if event.type == EventTypes.Member:
allowed = _is_membership_change_allowed(
event, auth_events
)
if allowed:
logger.debug("Allowing! %s", event)
else:
logger.debug("Denying! %s", event)
return allowed
_check_event_sender_in_room(event, auth_events)
# Special case to allow m.room.third_party_invite events wherever
# a user is allowed to issue invites. Fixes
# https://github.com/vector-im/vector-web/issues/1208 hopefully
if event.type == EventTypes.ThirdPartyInvite:
user_level = get_user_power_level(event.user_id, auth_events)
invite_level = _get_named_level(auth_events, "invite", 0)
if user_level < invite_level:
raise AuthError(
403, (
"You cannot issue a third party invite for %s." %
(event.content.display_name,)
)
)
else:
return True
_can_send_event(event, auth_events)
if event.type == EventTypes.PowerLevels:
_check_power_levels(event, auth_events)
if event.type == EventTypes.Redaction:
check_redaction(event, auth_events)
logger.debug("Allowing! %s", event)
def _check_size_limits(event):
def too_big(field):
raise EventSizeError("%s too large" % (field,))
if len(event.user_id) > 255:
too_big("user_id")
if len(event.room_id) > 255:
too_big("room_id")
if event.is_state() and len(event.state_key) > 255:
too_big("state_key")
if len(event.type) > 255:
too_big("type")
if len(event.event_id) > 255:
too_big("event_id")
if len(encode_canonical_json(event.get_pdu_json())) > 65536:
too_big("event")
def _can_federate(event, auth_events):
creation_event = auth_events.get((EventTypes.Create, ""))
return creation_event.content.get("m.federate", True) is True
def _is_membership_change_allowed(event, auth_events):
membership = event.content["membership"]
# Check if this is the room creator joining:
if len(event.prev_events) == 1 and Membership.JOIN == membership:
# Get room creation event:
key = (EventTypes.Create, "", )
create = auth_events.get(key)
if create and event.prev_events[0][0] == create.event_id:
if create.content["creator"] == event.state_key:
return True
target_user_id = event.state_key
creating_domain = get_domain_from_id(event.room_id)
target_domain = get_domain_from_id(target_user_id)
if creating_domain != target_domain:
if not _can_federate(event, auth_events):
raise AuthError(
403,
"This room has been marked as unfederatable."
)
# get info about the caller
key = (EventTypes.Member, event.user_id, )
caller = auth_events.get(key)
caller_in_room = caller and caller.membership == Membership.JOIN
caller_invited = caller and caller.membership == Membership.INVITE
# get info about the target
key = (EventTypes.Member, target_user_id, )
target = auth_events.get(key)
target_in_room = target and target.membership == Membership.JOIN
target_banned = target and target.membership == Membership.BAN
key = (EventTypes.JoinRules, "", )
join_rule_event = auth_events.get(key)
if join_rule_event:
join_rule = join_rule_event.content.get(
"join_rule", JoinRules.INVITE
)
else:
join_rule = JoinRules.INVITE
user_level = get_user_power_level(event.user_id, auth_events)
target_level = get_user_power_level(
target_user_id, auth_events
)
# FIXME (erikj): What should we do here as the default?
ban_level = _get_named_level(auth_events, "ban", 50)
logger.debug(
"_is_membership_change_allowed: %s",
{
"caller_in_room": caller_in_room,
"caller_invited": caller_invited,
"target_banned": target_banned,
"target_in_room": target_in_room,
"membership": membership,
"join_rule": join_rule,
"target_user_id": target_user_id,
"event.user_id": event.user_id,
}
)
if Membership.INVITE == membership and "third_party_invite" in event.content:
if not _verify_third_party_invite(event, auth_events):
raise AuthError(403, "You are not invited to this room.")
if target_banned:
raise AuthError(
403, "%s is banned from the room" % (target_user_id,)
)
return True
if Membership.JOIN != membership:
if (caller_invited
and Membership.LEAVE == membership
and target_user_id == event.user_id):
return True
if not caller_in_room: # caller isn't joined
raise AuthError(
403,
"%s not in room %s." % (event.user_id, event.room_id,)
)
if Membership.INVITE == membership:
# TODO (erikj): We should probably handle this more intelligently
# PRIVATE join rules.
# Invites are valid iff caller is in the room and target isn't.
if target_banned:
raise AuthError(
403, "%s is banned from the room" % (target_user_id,)
)
elif target_in_room: # the target is already in the room.
raise AuthError(403, "%s is already in the room." %
target_user_id)
else:
invite_level = _get_named_level(auth_events, "invite", 0)
if user_level < invite_level:
raise AuthError(
403, "You cannot invite user %s." % target_user_id
)
elif Membership.JOIN == membership:
# Joins are valid iff caller == target and they were:
# invited: They are accepting the invitation
# joined: It's a NOOP
if event.user_id != target_user_id:
raise AuthError(403, "Cannot force another user to join.")
elif target_banned:
raise AuthError(403, "You are banned from this room")
elif join_rule == JoinRules.PUBLIC:
pass
elif join_rule == JoinRules.INVITE:
if not caller_in_room and not caller_invited:
raise AuthError(403, "You are not invited to this room.")
else:
# TODO (erikj): may_join list
# TODO (erikj): private rooms
raise AuthError(403, "You are not allowed to join this room")
elif Membership.LEAVE == membership:
# TODO (erikj): Implement kicks.
if target_banned and user_level < ban_level:
raise AuthError(
403, "You cannot unban user &s." % (target_user_id,)
)
elif target_user_id != event.user_id:
kick_level = _get_named_level(auth_events, "kick", 50)
if user_level < kick_level or user_level <= target_level:
raise AuthError(
403, "You cannot kick user %s." % target_user_id
)
elif Membership.BAN == membership:
if user_level < ban_level or user_level <= target_level:
raise AuthError(403, "You don't have permission to ban")
else:
raise AuthError(500, "Unknown membership %s" % membership)
return True
def _check_event_sender_in_room(event, auth_events):
key = (EventTypes.Member, event.user_id, )
member_event = auth_events.get(key)
return _check_joined_room(
member_event,
event.user_id,
event.room_id
)
def _check_joined_room(member, user_id, room_id):
if not member or member.membership != Membership.JOIN:
raise AuthError(403, "User %s not in room %s (%s)" % (
user_id, room_id, repr(member)
))
def get_send_level(etype, state_key, auth_events):
key = (EventTypes.PowerLevels, "", )
send_level_event = auth_events.get(key)
send_level = None
if send_level_event:
send_level = send_level_event.content.get("events", {}).get(
etype
)
if send_level is None:
if state_key is not None:
send_level = send_level_event.content.get(
"state_default", 50
)
else:
send_level = send_level_event.content.get(
"events_default", 0
)
if send_level:
send_level = int(send_level)
else:
send_level = 0
return send_level
def _can_send_event(event, auth_events):
send_level = get_send_level(
event.type, event.get("state_key", None), auth_events
)
user_level = get_user_power_level(event.user_id, auth_events)
if user_level < send_level:
raise AuthError(
403,
"You don't have permission to post that to the room. " +
"user_level (%d) < send_level (%d)" % (user_level, send_level)
)
# Check state_key
if hasattr(event, "state_key"):
if event.state_key.startswith("@"):
if event.state_key != event.user_id:
raise AuthError(
403,
"You are not allowed to set others state"
)
return True
def check_redaction(event, auth_events):
"""Check whether the event sender is allowed to redact the target event.
Returns:
True if the the sender is allowed to redact the target event if the
target event was created by them.
False if the sender is allowed to redact the target event with no
further checks.
Raises:
AuthError if the event sender is definitely not allowed to redact
the target event.
"""
user_level = get_user_power_level(event.user_id, auth_events)
redact_level = _get_named_level(auth_events, "redact", 50)
if user_level >= redact_level:
return False
redacter_domain = get_domain_from_id(event.event_id)
redactee_domain = get_domain_from_id(event.redacts)
if redacter_domain == redactee_domain:
return True
raise AuthError(
403,
"You don't have permission to redact events"
)
def _check_power_levels(event, auth_events):
user_list = event.content.get("users", {})
# Validate users
for k, v in user_list.items():
try:
UserID.from_string(k)
except:
raise SynapseError(400, "Not a valid user_id: %s" % (k,))
try:
int(v)
except:
raise SynapseError(400, "Not a valid power level: %s" % (v,))
key = (event.type, event.state_key, )
current_state = auth_events.get(key)
if not current_state:
return
user_level = get_user_power_level(event.user_id, auth_events)
# Check other levels:
levels_to_check = [
("users_default", None),
("events_default", None),
("state_default", None),
("ban", None),
("redact", None),
("kick", None),
("invite", None),
]
old_list = current_state.content.get("users")
for user in set(old_list.keys() + user_list.keys()):
levels_to_check.append(
(user, "users")
)
old_list = current_state.content.get("events")
new_list = event.content.get("events")
for ev_id in set(old_list.keys() + new_list.keys()):
levels_to_check.append(
(ev_id, "events")
)
old_state = current_state.content
new_state = event.content
for level_to_check, dir in levels_to_check:
old_loc = old_state
new_loc = new_state
if dir:
old_loc = old_loc.get(dir, {})
new_loc = new_loc.get(dir, {})
if level_to_check in old_loc:
old_level = int(old_loc[level_to_check])
else:
old_level = None
if level_to_check in new_loc:
new_level = int(new_loc[level_to_check])
else:
new_level = None
if new_level is not None and old_level is not None:
if new_level == old_level:
continue
if dir == "users" and level_to_check != event.user_id:
if old_level == user_level:
raise AuthError(
403,
"You don't have permission to remove ops level equal "
"to your own"
)
if old_level > user_level or new_level > user_level:
raise AuthError(
403,
"You don't have permission to add ops level greater "
"than your own"
)
def _get_power_level_event(auth_events):
key = (EventTypes.PowerLevels, "", )
return auth_events.get(key)
def get_user_power_level(user_id, auth_events):
power_level_event = _get_power_level_event(auth_events)
if power_level_event:
level = power_level_event.content.get("users", {}).get(user_id)
if not level:
level = power_level_event.content.get("users_default", 0)
if level is None:
return 0
else:
return int(level)
else:
key = (EventTypes.Create, "", )
create_event = auth_events.get(key)
if (create_event is not None and
create_event.content["creator"] == user_id):
return 100
else:
return 0
def _get_named_level(auth_events, name, default):
power_level_event = _get_power_level_event(auth_events)
if not power_level_event:
return default
level = power_level_event.content.get(name, None)
if level is not None:
return int(level)
else:
return default
def _verify_third_party_invite(event, auth_events):
"""
Validates that the invite event is authorized by a previous third-party invite.
Checks that the public key, and keyserver, match those in the third party invite,
and that the invite event has a signature issued using that public key.
Args:
event: The m.room.member join event being validated.
auth_events: All relevant previous context events which may be used
for authorization decisions.
Return:
True if the event fulfills the expectations of a previous third party
invite event.
"""
if "third_party_invite" not in event.content:
return False
if "signed" not in event.content["third_party_invite"]:
return False
signed = event.content["third_party_invite"]["signed"]
for key in {"mxid", "token"}:
if key not in signed:
return False
token = signed["token"]
invite_event = auth_events.get(
(EventTypes.ThirdPartyInvite, token,)
)
if not invite_event:
return False
if invite_event.sender != event.sender:
return False
if event.user_id != invite_event.user_id:
return False
if signed["mxid"] != event.state_key:
return False
if signed["token"] != token:
return False
for public_key_object in get_public_keys(invite_event):
public_key = public_key_object["public_key"]
try:
for server, signature_block in signed["signatures"].items():
for key_name, encoded_signature in signature_block.items():
if not key_name.startswith("ed25519:"):
continue
verify_key = decode_verify_key_bytes(
key_name,
decode_base64(public_key)
)
verify_signed_json(signed, server, verify_key)
# We got the public key from the invite, so we know that the
# correct server signed the signed bundle.
# The caller is responsible for checking that the signing
# server has not revoked that public key.
return True
except (KeyError, SignatureVerifyException,):
continue
return False
def get_public_keys(invite_event):
public_keys = []
if "public_key" in invite_event.content:
o = {
"public_key": invite_event.content["public_key"],
}
if "key_validity_url" in invite_event.content:
o["key_validity_url"] = invite_event.content["key_validity_url"]
public_keys.append(o)
public_keys.extend(invite_event.content.get("public_keys", []))
return public_keys
def auth_types_for_event(event):
"""Given an event, return a list of (EventType, StateKey) that may be
needed to auth the event. The returned list may be a superset of what
would actually be required depending on the full state of the room.
Used to limit the number of events to fetch from the database to
actually auth the event.
"""
if event.type == EventTypes.Create:
return []
auth_types = []
auth_types.append((EventTypes.PowerLevels, "", ))
auth_types.append((EventTypes.Member, event.user_id, ))
auth_types.append((EventTypes.Create, "", ))
if event.type == EventTypes.Member:
membership = event.content["membership"]
if membership in [Membership.JOIN, Membership.INVITE]:
auth_types.append((EventTypes.JoinRules, "", ))
auth_types.append((EventTypes.Member, event.state_key, ))
if membership == Membership.INVITE:
if "third_party_invite" in event.content:
key = (
EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"]
)
auth_types.append(key)
return auth_types

View file

@ -79,7 +79,6 @@ class EventBase(object):
auth_events = _event_dict_property("auth_events") auth_events = _event_dict_property("auth_events")
depth = _event_dict_property("depth") depth = _event_dict_property("depth")
content = _event_dict_property("content") content = _event_dict_property("content")
event_id = _event_dict_property("event_id")
hashes = _event_dict_property("hashes") hashes = _event_dict_property("hashes")
origin = _event_dict_property("origin") origin = _event_dict_property("origin")
origin_server_ts = _event_dict_property("origin_server_ts") origin_server_ts = _event_dict_property("origin_server_ts")
@ -88,8 +87,6 @@ class EventBase(object):
redacts = _event_dict_property("redacts") redacts = _event_dict_property("redacts")
room_id = _event_dict_property("room_id") room_id = _event_dict_property("room_id")
sender = _event_dict_property("sender") sender = _event_dict_property("sender")
state_key = _event_dict_property("state_key")
type = _event_dict_property("type")
user_id = _event_dict_property("sender") user_id = _event_dict_property("sender")
@property @property
@ -162,6 +159,11 @@ class FrozenEvent(EventBase):
else: else:
frozen_dict = event_dict frozen_dict = event_dict
self.event_id = event_dict["event_id"]
self.type = event_dict["type"]
if "state_key" in event_dict:
self.state_key = event_dict["state_key"]
super(FrozenEvent, self).__init__( super(FrozenEvent, self).__init__(
frozen_dict, frozen_dict,
signatures=signatures, signatures=signatures,

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from . import EventBase, FrozenEvent from . import EventBase, FrozenEvent, _event_dict_property
from synapse.types import EventID from synapse.types import EventID
@ -34,6 +34,10 @@ class EventBuilder(EventBase):
internal_metadata_dict=internal_metadata_dict, internal_metadata_dict=internal_metadata_dict,
) )
event_id = _event_dict_property("event_id")
state_key = _event_dict_property("state_key")
type = _event_dict_property("type")
def build(self): def build(self):
return FrozenEvent.from_event(self) return FrozenEvent.from_event(self)

View file

@ -26,7 +26,7 @@ from synapse.util import unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.events import FrozenEvent from synapse.events import FrozenEvent, builder
import synapse.metrics import synapse.metrics
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
@ -126,6 +126,16 @@ class FederationClient(FederationBase):
destination, content, timeout destination, content, timeout
) )
@log_function
def query_user_devices(self, destination, user_id, timeout=30000):
"""Query the device keys for a list of user ids hosted on a remote
server.
"""
sent_queries_counter.inc("user_devices")
return self.transport_layer.query_user_devices(
destination, user_id, timeout
)
@log_function @log_function
def claim_client_keys(self, destination, content, timeout): def claim_client_keys(self, destination, content, timeout):
"""Claims one-time keys for a device hosted on a remote server. """Claims one-time keys for a device hosted on a remote server.
@ -499,8 +509,10 @@ class FederationClient(FederationBase):
if "prev_state" not in pdu_dict: if "prev_state" not in pdu_dict:
pdu_dict["prev_state"] = [] pdu_dict["prev_state"] = []
ev = builder.EventBuilder(pdu_dict)
defer.returnValue( defer.returnValue(
(destination, self.event_from_pdu_json(pdu_dict)) (destination, ev)
) )
break break
except CodeMessageException as e: except CodeMessageException as e:

View file

@ -52,8 +52,8 @@ class FederationServer(FederationBase):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self._room_pdu_linearizer = Linearizer() self._room_pdu_linearizer = Linearizer("fed_room_pdu")
self._server_linearizer = Linearizer() self._server_linearizer = Linearizer("fed_server")
# We cache responses to state queries, as they take a while and often # We cache responses to state queries, as they take a while and often
# come in waves. # come in waves.
@ -416,6 +416,9 @@ class FederationServer(FederationBase):
def on_query_client_keys(self, origin, content): def on_query_client_keys(self, origin, content):
return self.on_query_request("client_keys", content) return self.on_query_request("client_keys", content)
def on_query_user_devices(self, origin, user_id):
return self.on_query_request("user_devices", user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_claim_client_keys(self, origin, content): def on_claim_client_keys(self, origin, content):

View file

@ -100,6 +100,7 @@ class TransactionQueue(object):
self.pending_failures_by_dest = {} self.pending_failures_by_dest = {}
self.last_device_stream_id_by_dest = {} self.last_device_stream_id_by_dest = {}
self.last_device_list_stream_id_by_dest = {}
# HACK to get unique tx id # HACK to get unique tx id
self._next_txn_id = int(self.clock.time_msec()) self._next_txn_id = int(self.clock.time_msec())
@ -318,9 +319,10 @@ class TransactionQueue(object):
destination, destination,
self.clock, self.clock,
self.store, self.store,
backoff_on_404=True, # If we get a 404 the other side has gone
) )
device_message_edus, device_stream_id = ( device_message_edus, device_stream_id, dev_list_id = (
yield self._get_new_device_messages(destination) yield self._get_new_device_messages(destination)
) )
@ -355,14 +357,26 @@ class TransactionQueue(object):
success = yield self._send_new_transaction( success = yield self._send_new_transaction(
destination, pending_pdus, pending_edus, pending_failures, destination, pending_pdus, pending_edus, pending_failures,
device_stream_id,
should_delete_from_device_stream=bool(device_message_edus),
limiter=limiter, limiter=limiter,
) )
if not success: if success:
# Remove the acknowledged device messages from the database
# Only bother if we actually sent some device messages
if device_message_edus:
yield self.store.delete_device_msgs_for_remote(
destination, device_stream_id
)
logger.info("Marking as sent %r %r", destination, dev_list_id)
yield self.store.mark_as_sent_devices_by_remote(
destination, dev_list_id
)
self.last_device_stream_id_by_dest[destination] = device_stream_id
self.last_device_list_stream_id_by_dest[destination] = dev_list_id
else:
break break
except NotRetryingDestination: except NotRetryingDestination:
logger.info( logger.debug(
"TX [%s] not ready for retry yet - " "TX [%s] not ready for retry yet - "
"dropping transaction for now", "dropping transaction for now",
destination, destination,
@ -387,13 +401,26 @@ class TransactionQueue(object):
) )
for content in contents for content in contents
] ]
defer.returnValue((edus, stream_id))
last_device_list = self.last_device_list_stream_id_by_dest.get(destination, 0)
now_stream_id, results = yield self.store.get_devices_by_remote(
destination, last_device_list
)
edus.extend(
Edu(
origin=self.server_name,
destination=destination,
edu_type="m.device_list_update",
content=content,
)
for content in results
)
defer.returnValue((edus, stream_id, now_stream_id))
@measure_func("_send_new_transaction") @measure_func("_send_new_transaction")
@defer.inlineCallbacks @defer.inlineCallbacks
def _send_new_transaction(self, destination, pending_pdus, pending_edus, def _send_new_transaction(self, destination, pending_pdus, pending_edus,
pending_failures, device_stream_id, pending_failures, limiter):
should_delete_from_device_stream, limiter):
# Sort based on the order field # Sort based on the order field
pending_pdus.sort(key=lambda t: t[1]) pending_pdus.sort(key=lambda t: t[1])
@ -477,7 +504,7 @@ class TransactionQueue(object):
code = e.code code = e.code
response = e.response response = e.response
if e.code == 429 or 500 <= e.code: if e.code in (401, 404, 429) or 500 <= e.code:
logger.info( logger.info(
"TX [%s] {%s} got %d response", "TX [%s] {%s} got %d response",
destination, txn_id, code destination, txn_id, code
@ -504,13 +531,6 @@ class TransactionQueue(object):
"Failed to send event %s to %s", p.event_id, destination "Failed to send event %s to %s", p.event_id, destination
) )
success = False success = False
else:
# Remove the acknowledged device messages from the database
if should_delete_from_device_stream:
yield self.store.delete_device_msgs_for_remote(
destination, device_stream_id
)
self.last_device_stream_id_by_dest[destination] = device_stream_id
except RuntimeError as e: except RuntimeError as e:
# We capture this here as there as nothing actually listens # We capture this here as there as nothing actually listens
# for this finishing functions deferred. # for this finishing functions deferred.

View file

@ -346,6 +346,32 @@ class TransportLayerClient(object):
) )
defer.returnValue(content) defer.returnValue(content)
@defer.inlineCallbacks
@log_function
def query_user_devices(self, destination, user_id, timeout):
"""Query the devices for a user id hosted on a remote server.
Response:
{
"stream_id": "...",
"devices": [ { ... } ]
}
Args:
destination(str): The server to query.
query_content(dict): The user ids to query.
Returns:
A dict containg the device keys.
"""
path = PREFIX + "/user/devices/" + user_id
content = yield self.client.get_json(
destination=destination,
path=path,
timeout=timeout,
)
defer.returnValue(content)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def claim_client_keys(self, destination, query_content, timeout): def claim_client_keys(self, destination, query_content, timeout):

View file

@ -409,6 +409,13 @@ class FederationClientKeysQueryServlet(BaseFederationServlet):
return self.handler.on_query_client_keys(origin, content) return self.handler.on_query_client_keys(origin, content)
class FederationUserDevicesQueryServlet(BaseFederationServlet):
PATH = "/user/devices/(?P<user_id>[^/]*)"
def on_GET(self, origin, content, query, user_id):
return self.handler.on_query_user_devices(origin, user_id)
class FederationClientKeysClaimServlet(BaseFederationServlet): class FederationClientKeysClaimServlet(BaseFederationServlet):
PATH = "/user/keys/claim" PATH = "/user/keys/claim"
@ -613,6 +620,7 @@ SERVLET_CLASSES = (
FederationGetMissingEventsServlet, FederationGetMissingEventsServlet,
FederationEventAuthServlet, FederationEventAuthServlet,
FederationClientKeysQueryServlet, FederationClientKeysQueryServlet,
FederationUserDevicesQueryServlet,
FederationClientKeysClaimServlet, FederationClientKeysClaimServlet,
FederationThirdPartyInviteExchangeServlet, FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet, On3pidBindServlet,

View file

@ -88,9 +88,13 @@ class BaseHandler(object):
current_state = yield self.store.get_events( current_state = yield self.store.get_events(
context.current_state_ids.values() context.current_state_ids.values()
) )
current_state = current_state.values()
else: else:
current_state = yield self.store.get_current_state(event.room_id) current_state = yield self.state_handler.get_current_state(
event.room_id
)
current_state = current_state.values()
logger.info("maybe_kick_guest_users %r", current_state) logger.info("maybe_kick_guest_users %r", current_state)
yield self.kick_guest_users(current_state) yield self.kick_guest_users(current_state)

View file

@ -65,6 +65,7 @@ class AuthHandler(BaseHandler):
self.hs = hs # FIXME better possibility to access registrationHandler later? self.hs = hs # FIXME better possibility to access registrationHandler later?
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator()
@defer.inlineCallbacks @defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip): def check_auth(self, flows, clientdict, clientip):
@ -529,37 +530,11 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def issue_access_token(self, user_id, device_id=None): def issue_access_token(self, user_id, device_id=None):
access_token = self.generate_access_token(user_id) access_token = self.macaroon_gen.generate_access_token(user_id)
yield self.store.add_access_token_to_user(user_id, access_token, yield self.store.add_access_token_to_user(user_id, access_token,
device_id) device_id)
defer.returnValue(access_token) defer.returnValue(access_token)
def generate_access_token(self, user_id, extra_caveats=None):
extra_caveats = extra_caveats or []
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = access")
# Include a nonce, to make sure that each login gets a different
# access token.
macaroon.add_first_party_caveat("nonce = %s" % (
stringutils.random_string_with_symbols(16),
))
for caveat in extra_caveats:
macaroon.add_first_party_caveat(caveat)
return macaroon.serialize()
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login")
now = self.hs.get_clock().time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
def generate_delete_pusher_token(self, user_id):
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = delete_pusher")
return macaroon.serialize()
def validate_short_term_login_token_and_get_user_id(self, login_token): def validate_short_term_login_token_and_get_user_id(self, login_token):
auth_api = self.hs.get_auth() auth_api = self.hs.get_auth()
try: try:
@ -570,15 +545,6 @@ class AuthHandler(BaseHandler):
except Exception: except Exception:
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
def _generate_base_macaroon(self, user_id):
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
key=self.hs.config.macaroon_secret_key)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon
@defer.inlineCallbacks @defer.inlineCallbacks
def set_password(self, user_id, newpassword, requester=None): def set_password(self, user_id, newpassword, requester=None):
password_hash = self.hash(newpassword) password_hash = self.hash(newpassword)
@ -617,6 +583,17 @@ class AuthHandler(BaseHandler):
self.hs.get_clock().time_msec() self.hs.get_clock().time_msec()
) )
@defer.inlineCallbacks
def delete_threepid(self, user_id, medium, address):
# 'Canonicalise' email addresses as per above
if medium == 'email':
address = address.lower()
ret = yield self.store.user_delete_threepid(
user_id, medium, address,
)
defer.returnValue(ret)
def _save_session(self, session): def _save_session(self, session):
# TODO: Persistent storage # TODO: Persistent storage
logger.debug("Saving session %s", session) logger.debug("Saving session %s", session)
@ -656,12 +633,54 @@ class AuthHandler(BaseHandler):
Whether self.hash(password) == stored_hash (bool). Whether self.hash(password) == stored_hash (bool).
""" """
if stored_hash: if stored_hash:
return bcrypt.hashpw(password + self.hs.config.password_pepper, return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
stored_hash.encode('utf-8')) == stored_hash stored_hash.encode('utf8')) == stored_hash
else: else:
return False return False
class MacaroonGeneartor(object):
def __init__(self, hs):
self.clock = hs.get_clock()
self.server_name = hs.config.server_name
self.macaroon_secret_key = hs.config.macaroon_secret_key
def generate_access_token(self, user_id, extra_caveats=None):
extra_caveats = extra_caveats or []
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = access")
# Include a nonce, to make sure that each login gets a different
# access token.
macaroon.add_first_party_caveat("nonce = %s" % (
stringutils.random_string_with_symbols(16),
))
for caveat in extra_caveats:
macaroon.add_first_party_caveat(caveat)
return macaroon.serialize()
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login")
now = self.clock.time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
def generate_delete_pusher_token(self, user_id):
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = delete_pusher")
return macaroon.serialize()
def _generate_base_macaroon(self, user_id):
macaroon = pymacaroons.Macaroon(
location=self.server_name,
identifier="key",
key=self.macaroon_secret_key)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon
class _AccountHandler(object): class _AccountHandler(object):
"""A proxy object that gets passed to password auth providers so they """A proxy object that gets passed to password auth providers so they
can register new users etc if necessary. can register new users etc if necessary.

View file

@ -14,7 +14,11 @@
# limitations under the License. # limitations under the License.
from synapse.api import errors from synapse.api import errors
from synapse.api.constants import EventTypes
from synapse.util import stringutils from synapse.util import stringutils
from synapse.util.async import Linearizer
from synapse.util.metrics import measure_func
from synapse.types import get_domain_from_id, RoomStreamToken
from twisted.internet import defer from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
@ -27,6 +31,21 @@ class DeviceHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(DeviceHandler, self).__init__(hs) super(DeviceHandler, self).__init__(hs)
self.hs = hs
self.state = hs.get_state_handler()
self.federation_sender = hs.get_federation_sender()
self.federation = hs.get_replication_layer()
self._remote_edue_linearizer = Linearizer(name="remote_device_list")
self.federation.register_edu_handler(
"m.device_list_update", self._incoming_device_list_update,
)
self.federation.register_query_handler(
"user_devices", self.on_federation_query_user_devices,
)
hs.get_distributor().observe("user_left_room", self.user_left_room)
@defer.inlineCallbacks @defer.inlineCallbacks
def check_device_registered(self, user_id, device_id, def check_device_registered(self, user_id, device_id,
initial_device_display_name=None): initial_device_display_name=None):
@ -45,28 +64,28 @@ class DeviceHandler(BaseHandler):
str: device id (generated if none was supplied) str: device id (generated if none was supplied)
""" """
if device_id is not None: if device_id is not None:
yield self.store.store_device( new_device = yield self.store.store_device(
user_id=user_id, user_id=user_id,
device_id=device_id, device_id=device_id,
initial_device_display_name=initial_device_display_name, initial_device_display_name=initial_device_display_name,
ignore_if_known=True,
) )
if new_device:
yield self.notify_device_update(user_id, [device_id])
defer.returnValue(device_id) defer.returnValue(device_id)
# if the device id is not specified, we'll autogen one, but loop a few # if the device id is not specified, we'll autogen one, but loop a few
# times in case of a clash. # times in case of a clash.
attempts = 0 attempts = 0
while attempts < 5: while attempts < 5:
try:
device_id = stringutils.random_string(10).upper() device_id = stringutils.random_string(10).upper()
yield self.store.store_device( new_device = yield self.store.store_device(
user_id=user_id, user_id=user_id,
device_id=device_id, device_id=device_id,
initial_device_display_name=initial_device_display_name, initial_device_display_name=initial_device_display_name,
ignore_if_known=False,
) )
if new_device:
yield self.notify_device_update(user_id, [device_id])
defer.returnValue(device_id) defer.returnValue(device_id)
except errors.StoreError:
attempts += 1 attempts += 1
raise errors.StoreError(500, "Couldn't generate a device ID.") raise errors.StoreError(500, "Couldn't generate a device ID.")
@ -147,6 +166,8 @@ class DeviceHandler(BaseHandler):
user_id=user_id, device_id=device_id user_id=user_id, device_id=device_id
) )
yield self.notify_device_update(user_id, [device_id])
@defer.inlineCallbacks @defer.inlineCallbacks
def update_device(self, user_id, device_id, content): def update_device(self, user_id, device_id, content):
""" Update the given device """ Update the given device
@ -166,12 +187,168 @@ class DeviceHandler(BaseHandler):
device_id, device_id,
new_display_name=content.get("display_name") new_display_name=content.get("display_name")
) )
yield self.notify_device_update(user_id, [device_id])
except errors.StoreError, e: except errors.StoreError, e:
if e.code == 404: if e.code == 404:
raise errors.NotFoundError() raise errors.NotFoundError()
else: else:
raise raise
@measure_func("notify_device_update")
@defer.inlineCallbacks
def notify_device_update(self, user_id, device_ids):
"""Notify that a user's device(s) has changed. Pokes the notifier, and
remote servers if the user is local.
"""
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
user_id
)
hosts = set()
if self.hs.is_mine_id(user_id):
hosts.update(get_domain_from_id(u) for u in users_who_share_room)
hosts.discard(self.server_name)
position = yield self.store.add_device_change_to_streams(
user_id, device_ids, list(hosts)
)
rooms = yield self.store.get_rooms_for_user(user_id)
room_ids = [r.room_id for r in rooms]
yield self.notifier.on_new_event(
"device_list_key", position, rooms=room_ids,
)
if hosts:
logger.info("Sending device list update notif to: %r", hosts)
for host in hosts:
self.federation_sender.send_device_messages(host)
@measure_func("device.get_user_ids_changed")
@defer.inlineCallbacks
def get_user_ids_changed(self, user_id, from_token):
"""Get list of users that have had the devices updated, or have newly
joined a room, that `user_id` may be interested in.
Args:
user_id (str)
from_token (StreamToken)
"""
rooms = yield self.store.get_rooms_for_user(user_id)
room_ids = set(r.room_id for r in rooms)
# First we check if any devices have changed
changed = yield self.store.get_user_whose_devices_changed(
from_token.device_list_key
)
# Then work out if any users have since joined
rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key)
possibly_changed = set(changed)
for room_id in rooms_changed:
# Fetch the current state at the time.
stream_ordering = RoomStreamToken.parse_stream_token(from_token.room_key)
try:
event_ids = yield self.store.get_forward_extremeties_for_room(
room_id, stream_ordering=stream_ordering
)
prev_state_ids = yield self.store.get_state_ids_for_events(event_ids)
except:
prev_state_ids = {}
current_state_ids = yield self.state.get_current_state_ids(room_id)
# If there has been any change in membership, include them in the
# possibly changed list. We'll check if they are joined below,
# and we're not toooo worried about spuriously adding users.
for key, event_id in current_state_ids.iteritems():
etype, state_key = key
if etype == EventTypes.Member:
prev_event_id = prev_state_ids.get(key, None)
if not prev_event_id or prev_event_id != event_id:
possibly_changed.add(state_key)
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
user_id
)
# Take the intersection of the users whose devices may have changed
# and those that actually still share a room with the user
defer.returnValue(users_who_share_room & possibly_changed)
@measure_func("_incoming_device_list_update")
@defer.inlineCallbacks
def _incoming_device_list_update(self, origin, edu_content):
user_id = edu_content["user_id"]
device_id = edu_content["device_id"]
stream_id = edu_content["stream_id"]
prev_ids = edu_content.get("prev_id", [])
if get_domain_from_id(user_id) != origin:
# TODO: Raise?
logger.warning("Got device list update edu for %r from %r", user_id, origin)
return
rooms = yield self.store.get_rooms_for_user(user_id)
if not rooms:
# We don't share any rooms with this user. Ignore update, as we
# probably won't get any further updates.
return
with (yield self._remote_edue_linearizer.queue(user_id)):
# If the prev id matches whats in our cache table, then we don't need
# to resync the users device list, otherwise we do.
resync = True
if len(prev_ids) == 1:
extremity = yield self.store.get_device_list_last_stream_id_for_remote(
user_id
)
logger.info("Extrem: %r, prev_ids: %r", extremity, prev_ids)
if str(extremity) == str(prev_ids[0]):
resync = False
if resync:
# Fetch all devices for the user.
result = yield self.federation.query_user_devices(origin, user_id)
stream_id = result["stream_id"]
devices = result["devices"]
yield self.store.update_remote_device_list_cache(
user_id, devices, stream_id,
)
device_ids = [device["device_id"] for device in devices]
yield self.notify_device_update(user_id, device_ids)
else:
# Simply update the single device, since we know that is the only
# change (becuase of the single prev_id matching the current cache)
content = dict(edu_content)
for key in ("user_id", "device_id", "stream_id", "prev_ids"):
content.pop(key, None)
yield self.store.update_remote_device_list_cache_entry(
user_id, device_id, content, stream_id,
)
yield self.notify_device_update(user_id, [device_id])
@defer.inlineCallbacks
def on_federation_query_user_devices(self, user_id):
stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
defer.returnValue({
"user_id": user_id,
"stream_id": stream_id,
"devices": devices,
})
@defer.inlineCallbacks
def user_left_room(self, user, room_id):
user_id = user.to_string()
rooms = yield self.store.get_rooms_for_user(user_id)
if not rooms:
# We no longer share rooms with this user, so we'll no longer
# receive device updates. Mark this in DB.
yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
def _update_device_from_client_ips(device, client_ips): def _update_device_from_client_ips(device, client_ips):
ip = client_ips.get((device["user_id"], device["device_id"]), {}) ip = client_ips.get((device["user_id"], device["device_id"]), {})

View file

@ -73,10 +73,9 @@ class E2eKeysHandler(object):
if self.is_mine_id(user_id): if self.is_mine_id(user_id):
local_query[user_id] = device_ids local_query[user_id] = device_ids
else: else:
domain = get_domain_from_id(user_id) remote_queries[user_id] = device_ids
remote_queries.setdefault(domain, {})[user_id] = device_ids
# do the queries # Firt get local devices.
failures = {} failures = {}
results = {} results = {}
if local_query: if local_query:
@ -85,9 +84,42 @@ class E2eKeysHandler(object):
if user_id in local_query: if user_id in local_query:
results[user_id] = keys results[user_id] = keys
# Now attempt to get any remote devices from our local cache.
remote_queries_not_in_cache = {}
if remote_queries:
query_list = []
for user_id, device_ids in remote_queries.iteritems():
if device_ids:
query_list.extend((user_id, device_id) for device_id in device_ids)
else:
query_list.append((user_id, None))
user_ids_not_in_cache, remote_results = (
yield self.store.get_user_devices_from_cache(
query_list
)
)
for user_id, devices in remote_results.iteritems():
user_devices = results.setdefault(user_id, {})
for device_id, device in devices.iteritems():
keys = device.get("keys", None)
device_display_name = device.get("device_display_name", None)
if keys:
result = dict(keys)
unsigned = result.setdefault("unsigned", {})
if device_display_name:
unsigned["device_display_name"] = device_display_name
user_devices[device_id] = result
for user_id in user_ids_not_in_cache:
domain = get_domain_from_id(user_id)
r = remote_queries_not_in_cache.setdefault(domain, {})
r[user_id] = remote_queries[user_id]
# Now fetch any devices that we don't have in our cache
@defer.inlineCallbacks @defer.inlineCallbacks
def do_remote_query(destination): def do_remote_query(destination):
destination_query = remote_queries[destination] destination_query = remote_queries_not_in_cache[destination]
try: try:
limiter = yield get_retry_limiter( limiter = yield get_retry_limiter(
destination, self.clock, self.store destination, self.clock, self.store
@ -119,7 +151,7 @@ class E2eKeysHandler(object):
yield preserve_context_over_deferred(defer.gatherResults([ yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(do_remote_query)(destination) preserve_fn(do_remote_query)(destination)
for destination in remote_queries for destination in remote_queries_not_in_cache
])) ]))
defer.returnValue({ defer.returnValue({
@ -162,7 +194,7 @@ class E2eKeysHandler(object):
# "unsigned" section # "unsigned" section
for user_id, device_keys in results.items(): for user_id, device_keys in results.items():
for device_id, device_info in device_keys.items(): for device_id, device_info in device_keys.items():
r = json.loads(device_info["key_json"]) r = dict(device_info["keys"])
r["unsigned"] = {} r["unsigned"] = {}
display_name = device_info["device_display_name"] display_name = device_info["device_display_name"]
if display_name is not None: if display_name is not None:
@ -255,10 +287,12 @@ class E2eKeysHandler(object):
device_id, user_id, time_now device_id, user_id, time_now
) )
# TODO: Sign the JSON with the server key # TODO: Sign the JSON with the server key
yield self.store.set_e2e_device_keys( changed = yield self.store.set_e2e_device_keys(
user_id, device_id, time_now, user_id, device_id, time_now, device_keys,
encode_canonical_json(device_keys)
) )
if changed:
# Only notify about device updates *if* the keys actually changed
yield self.device_handler.notify_device_update(user_id, [device_id])
one_time_keys = keys.get("one_time_keys", None) one_time_keys = keys.get("one_time_keys", None)
if one_time_keys: if one_time_keys:

View file

@ -591,12 +591,12 @@ class FederationHandler(BaseHandler):
event_ids = list(extremities.keys()) event_ids = list(extremities.keys())
logger.info("calling resolve_state_groups in _maybe_backfill") logger.debug("calling resolve_state_groups in _maybe_backfill")
states = yield preserve_context_over_deferred(defer.gatherResults([ states = yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e]) preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e])
for e in event_ids for e in event_ids
])) ]))
states = dict(zip(event_ids, [s[1] for s in states])) states = dict(zip(event_ids, [s.state for s in states]))
state_map = yield self.store.get_events( state_map = yield self.store.get_events(
[e_id for ids in states.values() for e_id in ids], [e_id for ids in states.values() for e_id in ids],
@ -1319,7 +1319,6 @@ class FederationHandler(BaseHandler):
event_stream_id, max_stream_id = yield self.store.persist_event( event_stream_id, max_stream_id = yield self.store.persist_event(
event, new_event_context, event, new_event_context,
current_state=state,
) )
defer.returnValue((event_stream_id, max_stream_id)) defer.returnValue((event_stream_id, max_stream_id))
@ -1530,7 +1529,7 @@ class FederationHandler(BaseHandler):
(d.type, d.state_key): d for d in different_events if d (d.type, d.state_key): d for d in different_events if d
}) })
new_state, prev_state = self.state_handler.resolve_events( new_state = self.state_handler.resolve_events(
[local_view.values(), remote_view.values()], [local_view.values(), remote_view.values()],
event event
) )

View file

@ -208,7 +208,9 @@ class MessageHandler(BaseHandler):
content = builder.content content = builder.content
try: try:
if "displayname" not in content:
content["displayname"] = yield profile.get_displayname(target) content["displayname"] = yield profile.get_displayname(target)
if "avatar_url" not in content:
content["avatar_url"] = yield profile.get_avatar_url(target) content["avatar_url"] = yield profile.get_avatar_url(target)
except Exception as e: except Exception as e:
logger.info( logger.info(
@ -279,7 +281,9 @@ class MessageHandler(BaseHandler):
if event.type == EventTypes.Message: if event.type == EventTypes.Message:
presence = self.hs.get_presence_handler() presence = self.hs.get_presence_handler()
yield presence.bump_presence_active_time(user) # We don't want to block sending messages on any presence code. This
# matters as sometimes presence code can take a while.
preserve_fn(presence.bump_presence_active_time)(user)
@defer.inlineCallbacks @defer.inlineCallbacks
def deduplicate_state_event(self, event, context): def deduplicate_state_event(self, event, context):

View file

@ -574,7 +574,7 @@ class PresenceHandler(object):
if not local_states: if not local_states:
continue continue
users = yield self.state.get_current_user_in_room(room_id) users = yield self.store.get_users_in_room(room_id)
hosts = set(get_domain_from_id(u) for u in users) hosts = set(get_domain_from_id(u) for u in users)
for host in hosts: for host in hosts:
@ -766,7 +766,7 @@ class PresenceHandler(object):
# don't need to send to local clients here, as that is done as part # don't need to send to local clients here, as that is done as part
# of the event stream/sync. # of the event stream/sync.
# TODO: Only send to servers not already in the room. # TODO: Only send to servers not already in the room.
user_ids = yield self.state.get_current_user_in_room(room_id) user_ids = yield self.store.get_users_in_room(room_id)
if self.is_mine(user): if self.is_mine(user):
state = yield self.current_state_for_user(user.to_string()) state = yield self.current_state_for_user(user.to_string())
@ -1011,7 +1011,7 @@ class PresenceEventSource(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_new_events(self, user, from_key, room_ids=None, include_offline=True, def get_new_events(self, user, from_key, room_ids=None, include_offline=True,
**kwargs): explicit_room_id=None, **kwargs):
# The process for getting presence events are: # The process for getting presence events are:
# 1. Get the rooms the user is in. # 1. Get the rooms the user is in.
# 2. Get the list of user in the rooms. # 2. Get the list of user in the rooms.
@ -1028,22 +1028,24 @@ class PresenceEventSource(object):
user_id = user.to_string() user_id = user.to_string()
if from_key is not None: if from_key is not None:
from_key = int(from_key) from_key = int(from_key)
room_ids = room_ids or []
presence = self.get_presence_handler() presence = self.get_presence_handler()
stream_change_cache = self.store.presence_stream_cache stream_change_cache = self.store.presence_stream_cache
if not room_ids:
rooms = yield self.store.get_rooms_for_user(user_id)
room_ids = set(e.room_id for e in rooms)
else:
room_ids = set(room_ids)
max_token = self.store.get_current_presence_token() max_token = self.store.get_current_presence_token()
plist = yield self.store.get_presence_list_accepted(user.localpart) plist = yield self.store.get_presence_list_accepted(user.localpart)
friends = set(row["observed_user_id"] for row in plist) users_interested_in = set(row["observed_user_id"] for row in plist)
friends.add(user_id) # So that we receive our own presence users_interested_in.add(user_id) # So that we receive our own presence
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
user_id
)
users_interested_in.update(users_who_share_room)
if explicit_room_id:
user_ids = yield self.store.get_users_in_room(explicit_room_id)
users_interested_in.update(user_ids)
user_ids_changed = set() user_ids_changed = set()
changed = None changed = None
@ -1055,35 +1057,19 @@ class PresenceEventSource(object):
# work out if we share a room or they're in our presence list # work out if we share a room or they're in our presence list
get_updates_counter.inc("stream") get_updates_counter.inc("stream")
for other_user_id in changed: for other_user_id in changed:
if other_user_id in friends: if other_user_id in users_interested_in:
user_ids_changed.add(other_user_id) user_ids_changed.add(other_user_id)
continue
other_rooms = yield self.store.get_rooms_for_user(other_user_id)
if room_ids.intersection(e.room_id for e in other_rooms):
user_ids_changed.add(other_user_id)
continue
else: else:
# Too many possible updates. Find all users we can see and check # Too many possible updates. Find all users we can see and check
# if any of them have changed. # if any of them have changed.
get_updates_counter.inc("full") get_updates_counter.inc("full")
user_ids_to_check = set()
for room_id in room_ids:
users = yield self.state.get_current_user_in_room(room_id)
user_ids_to_check.update(users)
user_ids_to_check.update(friends)
# Always include yourself. Only really matters for when the user is
# not in any rooms, but still.
user_ids_to_check.add(user_id)
if from_key: if from_key:
user_ids_changed = stream_change_cache.get_entities_changed( user_ids_changed = stream_change_cache.get_entities_changed(
user_ids_to_check, from_key, users_interested_in, from_key,
) )
else: else:
user_ids_changed = user_ids_to_check user_ids_changed = users_interested_in
updates = yield presence.current_state_for_users(user_ids_changed) updates = yield presence.current_state_for_users(user_ids_changed)

View file

@ -40,6 +40,8 @@ class RegistrationHandler(BaseHandler):
self._next_generated_user_id = None self._next_generated_user_id = None
self.macaroon_gen = hs.get_macaroon_generator()
@defer.inlineCallbacks @defer.inlineCallbacks
def check_username(self, localpart, guest_access_token=None, def check_username(self, localpart, guest_access_token=None,
assigned_user_id=None): assigned_user_id=None):
@ -143,7 +145,7 @@ class RegistrationHandler(BaseHandler):
token = None token = None
if generate_token: if generate_token:
token = self.auth_handler().generate_access_token(user_id) token = self.macaroon_gen.generate_access_token(user_id)
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
token=token, token=token,
@ -167,7 +169,7 @@ class RegistrationHandler(BaseHandler):
user_id = user.to_string() user_id = user.to_string()
yield self.check_user_id_not_appservice_exclusive(user_id) yield self.check_user_id_not_appservice_exclusive(user_id)
if generate_token: if generate_token:
token = self.auth_handler().generate_access_token(user_id) token = self.macaroon_gen.generate_access_token(user_id)
try: try:
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
@ -254,7 +256,7 @@ class RegistrationHandler(BaseHandler):
user_id = user.to_string() user_id = user.to_string()
yield self.check_user_id_not_appservice_exclusive(user_id) yield self.check_user_id_not_appservice_exclusive(user_id)
token = self.auth_handler().generate_access_token(user_id) token = self.macaroon_gen.generate_access_token(user_id)
try: try:
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
@ -399,7 +401,7 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
token = self.auth_handler().generate_access_token(user_id) token = self.macaroon_gen.generate_access_token(user_id)
if need_register: if need_register:
yield self.store.register( yield self.store.register(

View file

@ -437,6 +437,7 @@ class RoomEventSource(object):
limit, limit,
room_ids, room_ids,
is_guest, is_guest,
explicit_room_id=None,
): ):
# We just ignore the key for now. # We just ignore the key for now.

View file

@ -62,17 +62,18 @@ class RoomListHandler(BaseHandler):
appservice and network id to use an appservice specific one. appservice and network id to use an appservice specific one.
Setting to None returns all public rooms across all lists. Setting to None returns all public rooms across all lists.
""" """
if search_filter or (network_tuple and network_tuple.appservice_id is not None): if search_filter:
# We explicitly don't bother caching searches or requests for # We explicitly don't bother caching searches or requests for
# appservice specific lists. # appservice specific lists.
return self._get_public_room_list( return self._get_public_room_list(
limit, since_token, search_filter, network_tuple=network_tuple, limit, since_token, search_filter, network_tuple=network_tuple,
) )
result = self.response_cache.get((limit, since_token)) key = (limit, since_token, network_tuple)
result = self.response_cache.get(key)
if not result: if not result:
result = self.response_cache.set( result = self.response_cache.set(
(limit, since_token), key,
self._get_public_room_list( self._get_public_room_list(
limit, since_token, network_tuple=network_tuple limit, since_token, network_tuple=network_tuple
) )

View file

@ -45,7 +45,7 @@ class RoomMemberHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(RoomMemberHandler, self).__init__(hs) super(RoomMemberHandler, self).__init__(hs)
self.member_linearizer = Linearizer() self.member_linearizer = Linearizer(name="member")
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -89,7 +89,7 @@ class RoomMemberHandler(BaseHandler):
duplicate = yield msg_handler.deduplicate_state_event(event, context) duplicate = yield msg_handler.deduplicate_state_event(event, context)
if duplicate is not None: if duplicate is not None:
# Discard the new event since this membership change is a no-op. # Discard the new event since this membership change is a no-op.
return defer.returnValue(duplicate)
yield msg_handler.handle_new_client_event( yield msg_handler.handle_new_client_event(
requester, requester,
@ -120,6 +120,8 @@ class RoomMemberHandler(BaseHandler):
if prev_member_event.membership == Membership.JOIN: if prev_member_event.membership == Membership.JOIN:
user_left_room(self.distributor, target, room_id) user_left_room(self.distributor, target, room_id)
defer.returnValue(event)
@defer.inlineCallbacks @defer.inlineCallbacks
def remote_join(self, remote_room_hosts, room_id, user, content): def remote_join(self, remote_room_hosts, room_id, user, content):
if len(remote_room_hosts) == 0: if len(remote_room_hosts) == 0:
@ -187,6 +189,7 @@ class RoomMemberHandler(BaseHandler):
ratelimit=True, ratelimit=True,
content=None, content=None,
): ):
content_specified = bool(content)
if content is None: if content is None:
content = {} content = {}
@ -229,6 +232,13 @@ class RoomMemberHandler(BaseHandler):
errcode=Codes.BAD_STATE errcode=Codes.BAD_STATE
) )
if old_state:
same_content = content == old_state.content
same_membership = old_membership == effective_membership_state
same_sender = requester.user.to_string() == old_state.sender
if same_sender and same_membership and same_content:
defer.returnValue(old_state)
is_host_in_room = yield self._is_host_in_room(current_state_ids) is_host_in_room = yield self._is_host_in_room(current_state_ids)
if effective_membership_state == Membership.JOIN: if effective_membership_state == Membership.JOIN:
@ -247,6 +257,7 @@ class RoomMemberHandler(BaseHandler):
content["membership"] = Membership.JOIN content["membership"] = Membership.JOIN
profile = self.hs.get_handlers().profile_handler profile = self.hs.get_handlers().profile_handler
if not content_specified:
content["displayname"] = yield profile.get_displayname(target) content["displayname"] = yield profile.get_displayname(target)
content["avatar_url"] = yield profile.get_avatar_url(target) content["avatar_url"] = yield profile.get_avatar_url(target)
@ -290,7 +301,7 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue({}) defer.returnValue({})
yield self._local_membership_update( res = yield self._local_membership_update(
requester=requester, requester=requester,
target=target, target=target,
room_id=room_id, room_id=room_id,
@ -300,6 +311,7 @@ class RoomMemberHandler(BaseHandler):
prev_event_ids=latest_event_ids, prev_event_ids=latest_event_ids,
content=content, content=content,
) )
defer.returnValue(res)
@defer.inlineCallbacks @defer.inlineCallbacks
def send_membership_event( def send_membership_event(

View file

@ -16,7 +16,7 @@
from synapse.api.constants import Membership, EventTypes from synapse.api.constants import Membership, EventTypes
from synapse.util.async import concurrently_execute from synapse.util.async import concurrently_execute
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
from synapse.util.metrics import Measure from synapse.util.metrics import Measure, measure_func
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from synapse.push.clientformat import format_push_rules_for_user from synapse.push.clientformat import format_push_rules_for_user
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
@ -115,6 +115,7 @@ class SyncResult(collections.namedtuple("SyncResult", [
"invited", # InvitedSyncResult for each invited room. "invited", # InvitedSyncResult for each invited room.
"archived", # ArchivedSyncResult for each archived room. "archived", # ArchivedSyncResult for each archived room.
"to_device", # List of direct messages for the device. "to_device", # List of direct messages for the device.
"device_lists", # List of user_ids whose devices have chanegd
])): ])):
__slots__ = [] __slots__ = []
@ -129,7 +130,8 @@ class SyncResult(collections.namedtuple("SyncResult", [
self.invited or self.invited or
self.archived or self.archived or
self.account_data or self.account_data or
self.to_device self.to_device or
self.device_lists
) )
@ -544,6 +546,10 @@ class SyncHandler(object):
yield self._generate_sync_entry_for_to_device(sync_result_builder) yield self._generate_sync_entry_for_to_device(sync_result_builder)
device_lists = yield self._generate_sync_entry_for_device_list(
sync_result_builder
)
defer.returnValue(SyncResult( defer.returnValue(SyncResult(
presence=sync_result_builder.presence, presence=sync_result_builder.presence,
account_data=sync_result_builder.account_data, account_data=sync_result_builder.account_data,
@ -551,9 +557,33 @@ class SyncHandler(object):
invited=sync_result_builder.invited, invited=sync_result_builder.invited,
archived=sync_result_builder.archived, archived=sync_result_builder.archived,
to_device=sync_result_builder.to_device, to_device=sync_result_builder.to_device,
device_lists=device_lists,
next_batch=sync_result_builder.now_token, next_batch=sync_result_builder.now_token,
)) ))
@measure_func("_generate_sync_entry_for_device_list")
@defer.inlineCallbacks
def _generate_sync_entry_for_device_list(self, sync_result_builder):
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
if since_token and since_token.device_list_key:
rooms = yield self.store.get_rooms_for_user(user_id)
room_ids = set(r.room_id for r in rooms)
user_ids_changed = set()
changed = yield self.store.get_user_whose_devices_changed(
since_token.device_list_key
)
for other_user_id in changed:
other_rooms = yield self.store.get_rooms_for_user(other_user_id)
if room_ids.intersection(e.room_id for e in other_rooms):
user_ids_changed.add(other_user_id)
defer.returnValue(user_ids_changed)
else:
defer.returnValue([])
@defer.inlineCallbacks @defer.inlineCallbacks
def _generate_sync_entry_for_to_device(self, sync_result_builder): def _generate_sync_entry_for_to_device(self, sync_result_builder):
"""Generates the portion of the sync response. Populates """Generates the portion of the sync response. Populates

View file

@ -25,7 +25,7 @@ from synapse.http.endpoint import SpiderEndpoint
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from twisted.internet import defer, reactor, ssl, protocol, task from twisted.internet import defer, reactor, ssl, protocol, task
from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.web.client import ( from twisted.web.client import (
BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent, BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent,
readBody, PartialDownloadError, readBody, PartialDownloadError,
@ -386,26 +386,23 @@ class SpiderEndpointFactory(object):
def endpointForURI(self, uri): def endpointForURI(self, uri):
logger.info("Getting endpoint for %s", uri.toBytes()) logger.info("Getting endpoint for %s", uri.toBytes())
if uri.scheme == "http": if uri.scheme == "http":
return SpiderEndpoint( endpoint_factory = HostnameEndpoint
reactor, uri.host, uri.port, self.blacklist, self.whitelist,
endpoint=TCP4ClientEndpoint,
endpoint_kw_args={
'timeout': 15
},
)
elif uri.scheme == "https": elif uri.scheme == "https":
tlsPolicy = self.policyForHTTPS.creatorForNetloc(uri.host, uri.port) tlsCreator = self.policyForHTTPS.creatorForNetloc(uri.host, uri.port)
return SpiderEndpoint(
reactor, uri.host, uri.port, self.blacklist, self.whitelist, def endpoint_factory(reactor, host, port, **kw):
endpoint=SSL4ClientEndpoint, return wrapClientTLS(
endpoint_kw_args={ tlsCreator,
'sslContextFactory': tlsPolicy, HostnameEndpoint(reactor, host, port, **kw))
'timeout': 15
},
)
else: else:
logger.warn("Can't get endpoint for unrecognised scheme %s", uri.scheme) logger.warn("Can't get endpoint for unrecognised scheme %s", uri.scheme)
return None
return SpiderEndpoint(
reactor, uri.host, uri.port, self.blacklist, self.whitelist,
endpoint=endpoint_factory, endpoint_kw_args=dict(timeout=15),
)
class SpiderHttpClient(SimpleHttpClient): class SpiderHttpClient(SimpleHttpClient):

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.internet.error import ConnectError from twisted.internet.error import ConnectError
from twisted.names import client, dns from twisted.names import client, dns
@ -58,11 +58,13 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
endpoint_kw_args.update(timeout=timeout) endpoint_kw_args.update(timeout=timeout)
if ssl_context_factory is None: if ssl_context_factory is None:
transport_endpoint = TCP4ClientEndpoint transport_endpoint = HostnameEndpoint
default_port = 8008 default_port = 8008
else: else:
transport_endpoint = SSL4ClientEndpoint def transport_endpoint(reactor, host, port, timeout):
endpoint_kw_args.update(sslContextFactory=ssl_context_factory) return wrapClientTLS(
ssl_context_factory,
HostnameEndpoint(reactor, host, port, timeout=timeout))
default_port = 8448 default_port = 8448
if port is None: if port is None:
@ -142,7 +144,7 @@ class SpiderEndpoint(object):
Implements twisted.internet.interfaces.IStreamClientEndpoint. Implements twisted.internet.interfaces.IStreamClientEndpoint.
""" """
def __init__(self, reactor, host, port, blacklist, whitelist, def __init__(self, reactor, host, port, blacklist, whitelist,
endpoint=TCP4ClientEndpoint, endpoint_kw_args={}): endpoint=HostnameEndpoint, endpoint_kw_args={}):
self.reactor = reactor self.reactor = reactor
self.host = host self.host = host
self.port = port self.port = port
@ -180,7 +182,7 @@ class SRVClientEndpoint(object):
""" """
def __init__(self, reactor, service, domain, protocol="tcp", def __init__(self, reactor, service, domain, protocol="tcp",
default_port=None, endpoint=TCP4ClientEndpoint, default_port=None, endpoint=HostnameEndpoint,
endpoint_kw_args={}): endpoint_kw_args={}):
self.reactor = reactor self.reactor = reactor
self.service_name = "_%s._%s.%s" % (service, protocol, domain) self.service_name = "_%s._%s.%s" % (service, protocol, domain)

View file

@ -378,6 +378,7 @@ class Notifier(object):
limit=limit, limit=limit,
is_guest=is_peeking, is_guest=is_peeking,
room_ids=room_ids, room_ids=room_ids,
explicit_room_id=explicit_room_id,
) )
if name == "room": if name == "room":

View file

@ -81,7 +81,7 @@ class Mailer(object):
def __init__(self, hs, app_name): def __init__(self, hs, app_name):
self.hs = hs self.hs = hs
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.auth_handler = self.hs.get_auth_handler() self.macaroon_gen = self.hs.get_macaroon_generator()
self.state_handler = self.hs.get_state_handler() self.state_handler = self.hs.get_state_handler()
loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir) loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir)
self.app_name = app_name self.app_name = app_name
@ -439,15 +439,23 @@ class Mailer(object):
}) })
def make_room_link(self, room_id): def make_room_link(self, room_id):
if self.hs.config.email_riot_base_url:
base_url = self.hs.config.email_riot_base_url
elif self.app_name == "Vector":
# need /beta for Universal Links to work on iOS # need /beta for Universal Links to work on iOS
if self.app_name == "Vector": base_url = "https://vector.im/beta/#/room"
return "https://vector.im/beta/#/room/%s" % (room_id,)
else: else:
return "https://matrix.to/#/%s" % (room_id,) base_url = "https://matrix.to/#"
return "%s/%s" % (base_url, room_id)
def make_notif_link(self, notif): def make_notif_link(self, notif):
if self.hs.config.email_riot_base_url:
return "%s/#/room/%s/%s" % (
self.hs.config.email_riot_base_url,
notif['room_id'], notif['event_id']
)
elif self.app_name == "Vector":
# need /beta for Universal Links to work on iOS # need /beta for Universal Links to work on iOS
if self.app_name == "Vector":
return "https://vector.im/beta/#/room/%s/%s" % ( return "https://vector.im/beta/#/room/%s/%s" % (
notif['room_id'], notif['event_id'] notif['room_id'], notif['event_id']
) )
@ -458,7 +466,7 @@ class Mailer(object):
def make_unsubscribe_link(self, user_id, app_id, email_address): def make_unsubscribe_link(self, user_id, app_id, email_address):
params = { params = {
"access_token": self.auth_handler.generate_delete_pusher_token(user_id), "access_token": self.macaroon_gen.generate_delete_pusher_token(user_id),
"app_id": app_id, "app_id": app_id,
"pushkey": email_address, "pushkey": email_address,
} }

View file

@ -52,7 +52,7 @@ def get_badge_count(store, user_id):
def get_context_for_event(store, state_handler, ev, user_id): def get_context_for_event(store, state_handler, ev, user_id):
ctx = {} ctx = {}
room_state_ids = yield state_handler.get_current_state_ids(ev.room_id) room_state_ids = yield store.get_state_ids_for_event(ev.event_id)
# we no longer bother setting room_alias, and make room_name the # we no longer bother setting room_alias, and make room_name the
# human-readable name instead, be that m.room.name, an alias or # human-readable name instead, be that m.room.name, an alias or

View file

@ -24,7 +24,7 @@ REQUIREMENTS = {
"signedjson>=1.0.0": ["signedjson>=1.0.0"], "signedjson>=1.0.0": ["signedjson>=1.0.0"],
"pynacl==0.3.0": ["nacl==0.3.0", "nacl.bindings"], "pynacl==0.3.0": ["nacl==0.3.0", "nacl.bindings"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"], "service_identity>=1.0.0": ["service_identity>=1.0.0"],
"Twisted>=15.1.0": ["twisted>=15.1.0"], "Twisted>=16.0.0": ["twisted>=16.0.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"], "pyopenssl>=0.14": ["OpenSSL>=0.14"],
"pyyaml": ["yaml"], "pyyaml": ["yaml"],
"pyasn1": ["pyasn1"], "pyasn1": ["pyasn1"],

View file

@ -46,6 +46,7 @@ STREAM_NAMES = (
("to_device",), ("to_device",),
("public_rooms",), ("public_rooms",),
("federation",), ("federation",),
("device_lists",),
) )
@ -140,6 +141,7 @@ class ReplicationResource(Resource):
caches_token = self.store.get_cache_stream_token() caches_token = self.store.get_cache_stream_token()
public_rooms_token = self.store.get_current_public_room_stream_id() public_rooms_token = self.store.get_current_public_room_stream_id()
federation_token = self.federation_sender.get_current_token() federation_token = self.federation_sender.get_current_token()
device_list_token = self.store.get_device_stream_token()
defer.returnValue(_ReplicationToken( defer.returnValue(_ReplicationToken(
room_stream_token, room_stream_token,
@ -155,6 +157,7 @@ class ReplicationResource(Resource):
int(stream_token.to_device_key), int(stream_token.to_device_key),
int(public_rooms_token), int(public_rooms_token),
int(federation_token), int(federation_token),
int(device_list_token),
)) ))
@request_handler() @request_handler()
@ -214,6 +217,7 @@ class ReplicationResource(Resource):
yield self.caches(writer, current_token, limit, request_streams) yield self.caches(writer, current_token, limit, request_streams)
yield self.to_device(writer, current_token, limit, request_streams) yield self.to_device(writer, current_token, limit, request_streams)
yield self.public_rooms(writer, current_token, limit, request_streams) yield self.public_rooms(writer, current_token, limit, request_streams)
yield self.device_lists(writer, current_token, limit, request_streams)
self.federation(writer, current_token, limit, request_streams, federation_ack) self.federation(writer, current_token, limit, request_streams, federation_ack)
self.streams(writer, current_token, request_streams) self.streams(writer, current_token, request_streams)
@ -295,9 +299,6 @@ class ReplicationResource(Resource):
"backward_ex_outliers", res.backward_ex_outliers, "backward_ex_outliers", res.backward_ex_outliers,
("position", "event_id", "state_group"), ("position", "event_id", "state_group"),
) )
writer.write_header_and_rows(
"state_resets", res.state_resets, ("position",),
)
@defer.inlineCallbacks @defer.inlineCallbacks
def presence(self, writer, current_token, request_streams): def presence(self, writer, current_token, request_streams):
@ -495,6 +496,20 @@ class ReplicationResource(Resource):
"position", "type", "content", "position", "type", "content",
), position=upto_token) ), position=upto_token)
@defer.inlineCallbacks
def device_lists(self, writer, current_token, limit, request_streams):
current_position = current_token.device_lists
device_lists = request_streams.get("device_lists")
if device_lists is not None and device_lists != current_position:
changes = yield self.store.get_all_device_list_changes_for_remotes(
device_lists,
)
writer.write_header_and_rows("device_lists", changes, (
"position", "user_id", "destination",
), position=current_position)
class _Writer(object): class _Writer(object):
"""Writes the streams as a JSON object as the response to the request""" """Writes the streams as a JSON object as the response to the request"""
@ -527,7 +542,7 @@ class _Writer(object):
class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
"events", "presence", "typing", "receipts", "account_data", "backfill", "events", "presence", "typing", "receipts", "account_data", "backfill",
"push_rules", "pushers", "state", "caches", "to_device", "public_rooms", "push_rules", "pushers", "state", "caches", "to_device", "public_rooms",
"federation", "federation", "device_lists",
))): ))):
__slots__ = [] __slots__ = []

View file

@ -0,0 +1,72 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedDeviceStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedDeviceStore, self).__init__(db_conn, hs)
self.hs = hs
self._device_list_id_gen = SlavedIdTracker(
db_conn, "device_lists_stream", "stream_id",
)
device_list_max = self._device_list_id_gen.get_current_token()
self._device_list_stream_cache = StreamChangeCache(
"DeviceListStreamChangeCache", device_list_max,
)
self._device_list_federation_stream_cache = StreamChangeCache(
"DeviceListFederationStreamChangeCache", device_list_max,
)
get_device_stream_token = DataStore.get_device_stream_token.__func__
get_user_whose_devices_changed = DataStore.get_user_whose_devices_changed.__func__
get_devices_by_remote = DataStore.get_devices_by_remote.__func__
_get_devices_by_remote_txn = DataStore._get_devices_by_remote_txn.__func__
_get_e2e_device_keys_txn = DataStore._get_e2e_device_keys_txn.__func__
mark_as_sent_devices_by_remote = DataStore.mark_as_sent_devices_by_remote.__func__
_mark_as_sent_devices_by_remote_txn = (
DataStore._mark_as_sent_devices_by_remote_txn.__func__
)
def stream_positions(self):
result = super(SlavedDeviceStore, self).stream_positions()
result["device_lists"] = self._device_list_id_gen.get_current_token()
return result
def process_replication(self, result):
stream = result.get("device_lists")
if stream:
self._device_list_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
stream_id = row[0]
user_id = row[1]
destination = row[2]
self._device_list_stream_cache.entity_has_changed(
user_id, stream_id
)
if destination:
self._device_list_federation_stream_cache.entity_has_changed(
destination, stream_id
)
return super(SlavedDeviceStore, self).process_replication(result)

View file

@ -73,12 +73,12 @@ class SlavedEventStore(BaseSlavedStore):
# to reach inside the __dict__ to extract them. # to reach inside the __dict__ to extract them.
get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"] get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"]
get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"] get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"]
get_users_who_share_room_with_user = (
RoomMemberStore.__dict__["get_users_who_share_room_with_user"]
)
get_latest_event_ids_in_room = EventFederationStore.__dict__[ get_latest_event_ids_in_room = EventFederationStore.__dict__[
"get_latest_event_ids_in_room" "get_latest_event_ids_in_room"
] ]
_get_current_state_for_key = StateStore.__dict__[
"_get_current_state_for_key"
]
get_invited_rooms_for_user = RoomMemberStore.__dict__[ get_invited_rooms_for_user = RoomMemberStore.__dict__[
"get_invited_rooms_for_user" "get_invited_rooms_for_user"
] ]
@ -115,8 +115,6 @@ class SlavedEventStore(BaseSlavedStore):
) )
get_event = DataStore.get_event.__func__ get_event = DataStore.get_event.__func__
get_events = DataStore.get_events.__func__ get_events = DataStore.get_events.__func__
get_current_state = DataStore.get_current_state.__func__
get_current_state_for_key = DataStore.get_current_state_for_key.__func__
get_rooms_for_user_where_membership_is = ( get_rooms_for_user_where_membership_is = (
DataStore.get_rooms_for_user_where_membership_is.__func__ DataStore.get_rooms_for_user_where_membership_is.__func__
) )
@ -197,10 +195,6 @@ class SlavedEventStore(BaseSlavedStore):
return result return result
def process_replication(self, result): def process_replication(self, result):
state_resets = set(
r[0] for r in result.get("state_resets", {"rows": []})["rows"]
)
stream = result.get("events") stream = result.get("events")
if stream: if stream:
self._stream_id_gen.advance(int(stream["position"])) self._stream_id_gen.advance(int(stream["position"]))
@ -210,7 +204,7 @@ class SlavedEventStore(BaseSlavedStore):
for row in stream["rows"]: for row in stream["rows"]:
self._process_replication_row( self._process_replication_row(
row, backfilled=False, state_resets=state_resets row, backfilled=False,
) )
stream = result.get("backfill") stream = result.get("backfill")
@ -218,7 +212,7 @@ class SlavedEventStore(BaseSlavedStore):
self._backfill_id_gen.advance(-int(stream["position"])) self._backfill_id_gen.advance(-int(stream["position"]))
for row in stream["rows"]: for row in stream["rows"]:
self._process_replication_row( self._process_replication_row(
row, backfilled=True, state_resets=state_resets row, backfilled=True,
) )
stream = result.get("forward_ex_outliers") stream = result.get("forward_ex_outliers")
@ -237,21 +231,15 @@ class SlavedEventStore(BaseSlavedStore):
return super(SlavedEventStore, self).process_replication(result) return super(SlavedEventStore, self).process_replication(result)
def _process_replication_row(self, row, backfilled, state_resets): def _process_replication_row(self, row, backfilled):
position = row[0]
internal = json.loads(row[1]) internal = json.loads(row[1])
event_json = json.loads(row[2]) event_json = json.loads(row[2])
event = FrozenEvent(event_json, internal_metadata_dict=internal) event = FrozenEvent(event_json, internal_metadata_dict=internal)
self.invalidate_caches_for_event( self.invalidate_caches_for_event(
event, backfilled, reset_state=position in state_resets event, backfilled,
) )
def invalidate_caches_for_event(self, event, backfilled, reset_state): def invalidate_caches_for_event(self, event, backfilled):
if reset_state:
self._get_current_state_for_key.invalidate_all()
self.get_rooms_for_user.invalidate_all()
self.get_users_in_room.invalidate((event.room_id,))
self._invalidate_get_event_cache(event.event_id) self._invalidate_get_event_cache(event.event_id)
self.get_latest_event_ids_in_room.invalidate((event.room_id,)) self.get_latest_event_ids_in_room.invalidate((event.room_id,))
@ -273,8 +261,6 @@ class SlavedEventStore(BaseSlavedStore):
self._invalidate_get_event_cache(event.redacts) self._invalidate_get_event_cache(event.redacts)
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
self.get_rooms_for_user.invalidate((event.state_key,))
self.get_users_in_room.invalidate((event.room_id,))
self._membership_stream_cache.entity_has_changed( self._membership_stream_cache.entity_has_changed(
event.state_key, event.internal_metadata.stream_ordering event.state_key, event.internal_metadata.stream_ordering
) )
@ -289,7 +275,3 @@ class SlavedEventStore(BaseSlavedStore):
if (not event.internal_metadata.is_invite_from_remote() if (not event.internal_metadata.is_invite_from_remote()
and event.internal_metadata.is_outlier()): and event.internal_metadata.is_outlier()):
return return
self._get_current_state_for_key.invalidate((
event.room_id, event.type, event.state_key
))

View file

@ -86,7 +86,11 @@ class HttpTransactionCache(object):
pass # execute the function instead. pass # execute the function instead.
deferred = fn(*args, **kwargs) deferred = fn(*args, **kwargs)
observable = ObservableDeferred(deferred)
# We don't add an errback to the raw deferred, so we ask ObservableDeferred
# to swallow the error. This is fine as the error will still be reported
# to the observers.
observable = ObservableDeferred(deferred, consumeErrors=True)
self.transactions[txn_key] = (observable, self.clock.time_msec()) self.transactions[txn_key] = (observable, self.clock.time_msec())
return observable.observe() return observable.observe()

View file

@ -118,8 +118,14 @@ class LoginRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def do_password_login(self, login_submission): def do_password_login(self, login_submission):
if 'medium' in login_submission and 'address' in login_submission: if 'medium' in login_submission and 'address' in login_submission:
address = login_submission['address']
if login_submission['medium'] == 'email':
# For emails, transform the address to lowercase.
# We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py)
address = address.lower()
user_id = yield self.hs.get_datastore().get_user_id_by_threepid( user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
login_submission['medium'], login_submission['address'] login_submission['medium'], address
) )
if not user_id: if not user_id:
raise LoginError(403, "", errcode=Codes.FORBIDDEN) raise LoginError(403, "", errcode=Codes.FORBIDDEN)
@ -324,6 +330,7 @@ class CasTicketServlet(ClientV1RestServlet):
self.cas_required_attributes = hs.config.cas_required_attributes self.cas_required_attributes = hs.config.cas_required_attributes
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.macaroon_gen = hs.get_macaroon_generator()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
@ -362,7 +369,9 @@ class CasTicketServlet(ClientV1RestServlet):
yield self.handlers.registration_handler.register(localpart=user) yield self.handlers.registration_handler.register(localpart=user)
) )
login_token = auth_handler.generate_short_term_login_token(registered_user_id) login_token = self.macaroon_gen.generate_short_term_login_token(
registered_user_id
)
redirect_url = self.add_login_token_to_redirect_url(client_redirect_url, redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
login_token) login_token)
request.redirect(redirect_url) request.redirect(redirect_url)

View file

@ -152,6 +152,16 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
if state_key is not None: if state_key is not None:
event_dict["state_key"] = state_key event_dict["state_key"] = state_key
if event_type == EventTypes.Member:
membership = content.get("membership", None)
event = yield self.handlers.room_member_handler.update_membership(
requester,
target=UserID.from_string(state_key),
room_id=room_id,
action=membership,
content=content,
)
else:
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
event, context = yield msg_handler.create_event( event, context = yield msg_handler.create_event(
event_dict, event_dict,
@ -159,16 +169,12 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
txn_id=txn_id, txn_id=txn_id,
) )
if event_type == EventTypes.Member:
yield self.handlers.room_member_handler.send_membership_event(
requester,
event,
context,
)
else:
yield msg_handler.send_nonmember_event(requester, event, context) yield msg_handler.send_nonmember_event(requester, event, context)
defer.returnValue((200, {"event_id": event.event_id})) ret = {}
if event:
ret = {"event_id": event.event_id}
defer.returnValue((200, ret))
# TODO: Needs unit testing for generic events + feedback # TODO: Needs unit testing for generic events + feedback

View file

@ -32,10 +32,11 @@ class VoipRestServlet(ClientV1RestServlet):
turnUris = self.hs.config.turn_uris turnUris = self.hs.config.turn_uris
turnSecret = self.hs.config.turn_shared_secret turnSecret = self.hs.config.turn_shared_secret
turnUsername = self.hs.config.turn_username
turnPassword = self.hs.config.turn_password
userLifetime = self.hs.config.turn_user_lifetime userLifetime = self.hs.config.turn_user_lifetime
if not turnUris or not turnSecret or not userLifetime:
defer.returnValue((200, {}))
if turnUris and turnSecret and userLifetime:
expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000 expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000
username = "%d:%s" % (expiry, requester.user.to_string()) username = "%d:%s" % (expiry, requester.user.to_string())
@ -45,6 +46,13 @@ class VoipRestServlet(ClientV1RestServlet):
# same result as the TURN server. # same result as the TURN server.
password = base64.b64encode(mac.digest()) password = base64.b64encode(mac.digest())
elif turnUris and turnUsername and turnPassword and userLifetime:
username = turnUsername
password = turnPassword
else:
defer.returnValue((200, {}))
defer.returnValue((200, { defer.returnValue((200, {
'username': username, 'username': username,
'password': password, 'password': password,

View file

@ -96,6 +96,11 @@ class PasswordRestServlet(RestServlet):
threepid = result[LoginType.EMAIL_IDENTITY] threepid = result[LoginType.EMAIL_IDENTITY]
if 'medium' not in threepid or 'address' not in threepid: if 'medium' not in threepid or 'address' not in threepid:
raise SynapseError(500, "Malformed threepid") raise SynapseError(500, "Malformed threepid")
if threepid['medium'] == 'email':
# For emails, transform the address to lowercase.
# We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py)
threepid['address'] = threepid['address'].lower()
# if using email, we must know about the email they're authing with! # if using email, we must know about the email they're authing with!
threepid_user_id = yield self.hs.get_datastore().get_user_id_by_threepid( threepid_user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
threepid['medium'], threepid['address'] threepid['medium'], threepid['address']
@ -241,7 +246,7 @@ class ThreepidRestServlet(RestServlet):
for reqd in ['medium', 'address', 'validated_at']: for reqd in ['medium', 'address', 'validated_at']:
if reqd not in threepid: if reqd not in threepid:
logger.warn("Couldn't add 3pid: invalid response from ID sevrer") logger.warn("Couldn't add 3pid: invalid response from ID server")
raise SynapseError(500, "Invalid response from ID Server") raise SynapseError(500, "Invalid response from ID Server")
yield self.auth_handler.add_threepid( yield self.auth_handler.add_threepid(
@ -263,9 +268,43 @@ class ThreepidRestServlet(RestServlet):
defer.returnValue((200, {})) defer.returnValue((200, {}))
class ThreepidDeleteRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid/delete$", releases=())
def __init__(self, hs):
super(ThreepidDeleteRestServlet, self).__init__()
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
@defer.inlineCallbacks
def on_POST(self, request):
yield run_on_reactor()
body = parse_json_object_from_request(request)
required = ['medium', 'address']
absent = []
for k in required:
if k not in body:
absent.append(k)
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
yield self.auth_handler.delete_threepid(
user_id, body['medium'], body['address']
)
defer.returnValue((200, {}))
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
PasswordRequestTokenRestServlet(hs).register(http_server) PasswordRequestTokenRestServlet(hs).register(http_server)
PasswordRestServlet(hs).register(http_server) PasswordRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server) DeactivateAccountRestServlet(hs).register(http_server)
ThreepidRequestTokenRestServlet(hs).register(http_server) ThreepidRequestTokenRestServlet(hs).register(http_server)
ThreepidRestServlet(hs).register(http_server) ThreepidRestServlet(hs).register(http_server)
ThreepidDeleteRestServlet(hs).register(http_server)

View file

@ -21,6 +21,8 @@ from synapse.api.errors import SynapseError
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, parse_json_object_from_request, parse_integer RestServlet, parse_json_object_from_request, parse_integer
) )
from synapse.http.servlet import parse_string
from synapse.types import StreamToken
from ._base import client_v2_patterns from ._base import client_v2_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -149,6 +151,52 @@ class KeyQueryServlet(RestServlet):
defer.returnValue((200, result)) defer.returnValue((200, result))
class KeyChangesServlet(RestServlet):
"""Returns the list of changes of keys between two stream tokens (may return
spurious extra results, since we currently ignore the `to` param).
GET /keys/changes?from=...&to=...
200 OK
{ "changed": ["@foo:example.com"] }
"""
PATTERNS = client_v2_patterns(
"/keys/changes$",
releases=()
)
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
super(KeyChangesServlet, self).__init__()
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
from_token_string = parse_string(request, "from")
# We want to enforce they do pass us one, but we ignore it and return
# changes after the "to" as well as before.
parse_string(request, "to")
from_token = StreamToken.from_string(from_token_string)
user_id = requester.user.to_string()
changed = yield self.device_handler.get_user_ids_changed(
user_id, from_token,
)
defer.returnValue((200, {
"changed": list(changed),
}))
class OneTimeKeyServlet(RestServlet): class OneTimeKeyServlet(RestServlet):
""" """
POST /keys/claim HTTP/1.1 POST /keys/claim HTTP/1.1
@ -192,4 +240,5 @@ class OneTimeKeyServlet(RestServlet):
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
KeyUploadServlet(hs).register(http_server) KeyUploadServlet(hs).register(http_server)
KeyQueryServlet(hs).register(http_server) KeyQueryServlet(hs).register(http_server)
KeyChangesServlet(hs).register(http_server)
OneTimeKeyServlet(hs).register(http_server) OneTimeKeyServlet(hs).register(http_server)

View file

@ -96,6 +96,7 @@ class RegisterRestServlet(RestServlet):
self.registration_handler = hs.get_handlers().registration_handler self.registration_handler = hs.get_handlers().registration_handler
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
@ -436,7 +437,7 @@ class RegisterRestServlet(RestServlet):
user_id, device_id, initial_display_name user_id, device_id, initial_display_name
) )
access_token = self.auth_handler.generate_access_token( access_token = self.macaroon_gen.generate_access_token(
user_id, ["guest = true"] user_id, ["guest = true"]
) )
defer.returnValue((200, { defer.returnValue((200, {

View file

@ -170,12 +170,16 @@ class SyncRestServlet(RestServlet):
) )
archived = self.encode_archived( archived = self.encode_archived(
sync_result.archived, time_now, requester.access_token_id, filter.event_fields sync_result.archived, time_now, requester.access_token_id,
filter.event_fields,
) )
response_content = { response_content = {
"account_data": {"events": sync_result.account_data}, "account_data": {"events": sync_result.account_data},
"to_device": {"events": sync_result.to_device}, "to_device": {"events": sync_result.to_device},
"device_lists": {
"changed": list(sync_result.device_lists),
},
"presence": self.encode_presence( "presence": self.encode_presence(
sync_result.presence, time_now sync_result.presence, time_now
), ),

View file

@ -61,7 +61,7 @@ class MediaRepository(object):
self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.thumbnail_requirements = hs.config.thumbnail_requirements self.thumbnail_requirements = hs.config.thumbnail_requirements
self.remote_media_linearizer = Linearizer() self.remote_media_linearizer = Linearizer(name="media_remote")
self.recently_accessed_remotes = set() self.recently_accessed_remotes = set()
@ -98,6 +98,8 @@ class MediaRepository(object):
with open(fname, "wb") as f: with open(fname, "wb") as f:
f.write(content) f.write(content)
logger.info("Stored local media in file %r", fname)
yield self.store.store_local_media( yield self.store.store_local_media(
media_id=media_id, media_id=media_id,
media_type=media_type, media_type=media_type,
@ -190,6 +192,8 @@ class MediaRepository(object):
else: else:
upload_name = None upload_name = None
logger.info("Stored remote media in file %r", fname)
yield self.store.store_cached_remote_media( yield self.store.store_cached_remote_media(
origin=server_name, origin=server_name,
media_id=media_id, media_id=media_id,

View file

@ -16,6 +16,10 @@
import PIL.Image as Image import PIL.Image as Image
from io import BytesIO from io import BytesIO
import logging
logger = logging.getLogger(__name__)
class Thumbnailer(object): class Thumbnailer(object):
@ -86,4 +90,5 @@ class Thumbnailer(object):
output_bytes = output_bytes_io.getvalue() output_bytes = output_bytes_io.getvalue()
with open(output_path, "wb") as output_file: with open(output_path, "wb") as output_file:
output_file.write(output_bytes) output_file.write(output_bytes)
logger.info("Stored thumbnail in file %r", output_path)
return len(output_bytes) return len(output_bytes)

View file

@ -97,6 +97,8 @@ class UploadResource(Resource):
content_length, requester.user content_length, requester.user
) )
logger.info("Uploaded content with URI %r", content_uri)
respond_with_json( respond_with_json(
request, 200, {"content_uri": content_uri}, send_cors=True request, 200, {"content_uri": content_uri}, send_cors=True
) )

View file

@ -37,7 +37,7 @@ from synapse.federation.transport.client import TransportLayerClient
from synapse.federation.transaction_queue import TransactionQueue from synapse.federation.transaction_queue import TransactionQueue
from synapse.handlers import Handlers from synapse.handlers import Handlers
from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.handlers.auth import AuthHandler from synapse.handlers.auth import AuthHandler, MacaroonGeneartor
from synapse.handlers.devicemessage import DeviceMessageHandler from synapse.handlers.devicemessage import DeviceMessageHandler
from synapse.handlers.device import DeviceHandler from synapse.handlers.device import DeviceHandler
from synapse.handlers.e2e_keys import E2eKeysHandler from synapse.handlers.e2e_keys import E2eKeysHandler
@ -131,6 +131,7 @@ class HomeServer(object):
'federation_transport_client', 'federation_transport_client',
'federation_sender', 'federation_sender',
'receipts_handler', 'receipts_handler',
'macaroon_generator',
] ]
def __init__(self, hostname, **kwargs): def __init__(self, hostname, **kwargs):
@ -213,6 +214,9 @@ class HomeServer(object):
def build_auth_handler(self): def build_auth_handler(self):
return AuthHandler(self) return AuthHandler(self)
def build_macaroon_generator(self):
return MacaroonGeneartor(self)
def build_device_handler(self): def build_device_handler(self):
return DeviceHandler(self) return DeviceHandler(self)

View file

@ -16,12 +16,12 @@
from twisted.internet import defer from twisted.internet import defer
from synapse import event_auth
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.api.auth import AuthEventTypes
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
@ -41,12 +41,14 @@ KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
SIZE_OF_CACHE = int(1000 * CACHE_SIZE_FACTOR) SIZE_OF_CACHE = int(100000 * CACHE_SIZE_FACTOR)
EVICTION_TIMEOUT_SECONDS = 60 * 60 EVICTION_TIMEOUT_SECONDS = 60 * 60
_NEXT_STATE_ID = 1 _NEXT_STATE_ID = 1
POWER_KEY = (EventTypes.PowerLevels, "")
def _gen_state_id(): def _gen_state_id():
global _NEXT_STATE_ID global _NEXT_STATE_ID
@ -77,6 +79,9 @@ class _StateCacheEntry(object):
else: else:
self.state_id = _gen_state_id() self.state_id = _gen_state_id()
def __len__(self):
return len(self.state)
class StateHandler(object): class StateHandler(object):
""" Responsible for doing state conflict resolution. """ Responsible for doing state conflict resolution.
@ -89,7 +94,7 @@ class StateHandler(object):
# dict of set of event_ids -> _StateCacheEntry. # dict of set of event_ids -> _StateCacheEntry.
self._state_cache = None self._state_cache = None
self.resolve_linearizer = Linearizer() self.resolve_linearizer = Linearizer(name="state_resolve_lock")
def start_caching(self): def start_caching(self):
logger.debug("start_caching") logger.debug("start_caching")
@ -99,6 +104,7 @@ class StateHandler(object):
clock=self.clock, clock=self.clock,
max_len=SIZE_OF_CACHE, max_len=SIZE_OF_CACHE,
expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000, expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
iterable=True,
reset_expiry_on_get=True, reset_expiry_on_get=True,
) )
@ -123,7 +129,7 @@ class StateHandler(object):
if not latest_event_ids: if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.info("calling resolve_state_groups from get_current_state") logger.debug("calling resolve_state_groups from get_current_state")
ret = yield self.resolve_state_groups(room_id, latest_event_ids) ret = yield self.resolve_state_groups(room_id, latest_event_ids)
state = ret.state state = ret.state
@ -148,7 +154,7 @@ class StateHandler(object):
if not latest_event_ids: if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.info("calling resolve_state_groups from get_current_state_ids") logger.debug("calling resolve_state_groups from get_current_state_ids")
ret = yield self.resolve_state_groups(room_id, latest_event_ids) ret = yield self.resolve_state_groups(room_id, latest_event_ids)
state = ret.state state = ret.state
@ -162,7 +168,7 @@ class StateHandler(object):
def get_current_user_in_room(self, room_id, latest_event_ids=None): def get_current_user_in_room(self, room_id, latest_event_ids=None):
if not latest_event_ids: if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.info("calling resolve_state_groups from get_current_user_in_room") logger.debug("calling resolve_state_groups from get_current_user_in_room")
entry = yield self.resolve_state_groups(room_id, latest_event_ids) entry = yield self.resolve_state_groups(room_id, latest_event_ids)
joined_users = yield self.store.get_joined_users_from_state( joined_users = yield self.store.get_joined_users_from_state(
room_id, entry.state_id, entry.state room_id, entry.state_id, entry.state
@ -226,7 +232,7 @@ class StateHandler(object):
context.prev_state_events = [] context.prev_state_events = []
defer.returnValue(context) defer.returnValue(context)
logger.info("calling resolve_state_groups from compute_event_context") logger.debug("calling resolve_state_groups from compute_event_context")
if event.is_state(): if event.is_state():
entry = yield self.resolve_state_groups( entry = yield self.resolve_state_groups(
event.room_id, [e for e, _ in event.prev_events], event.room_id, [e for e, _ in event.prev_events],
@ -327,20 +333,13 @@ class StateHandler(object):
if conflicted_state: if conflicted_state:
logger.info("Resolving conflicted state for %r", room_id) logger.info("Resolving conflicted state for %r", room_id)
state_map = yield self.store.get_events( with Measure(self.clock, "state._resolve_events"):
[e_id for st in state_groups_ids.values() for e_id in st.values()], new_state = yield resolve_events(
get_prev_content=False state_groups_ids.values(),
state_map_factory=lambda ev_ids: self.store.get_events(
ev_ids, get_prev_content=False, check_redacted=False,
),
) )
state_sets = [
[state_map[e_id] for key, e_id in st.items() if e_id in state_map]
for st in state_groups_ids.values()
]
new_state, _ = self._resolve_events(
state_sets, event_type, state_key
)
new_state = {
key: e.event_id for key, e in new_state.items()
}
else: else:
new_state = { new_state = {
key: e_ids.pop() for key, e_ids in state.items() key: e_ids.pop() for key, e_ids in state.items()
@ -388,68 +387,180 @@ class StateHandler(object):
logger.info( logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets) "Resolving state for %s with %d groups", event.room_id, len(state_sets)
) )
if event.is_state(): state_set_ids = [{
return self._resolve_events( (ev.type, ev.state_key): ev.event_id
state_sets, event.type, event.state_key for ev in st
) } for st in state_sets]
else:
return self._resolve_events(state_sets) state_map = {
ev.event_id: ev
for st in state_sets
for ev in st
}
def _resolve_events(self, state_sets, event_type=None, state_key=""):
"""
Returns
(dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple
(new_state, prev_states). new_state is a map from (type, state_key)
to event. prev_states is a list of event_ids.
"""
with Measure(self.clock, "state._resolve_events"): with Measure(self.clock, "state._resolve_events"):
state = {} new_state = resolve_events(state_set_ids, state_map)
for st in state_sets:
for e in st:
state.setdefault(
(e.type, e.state_key),
{}
)[e.event_id] = e
unconflicted_state = { new_state = {
k: v.values()[0] for k, v in state.items() key: state_map[ev_id] for key, ev_id in new_state.items()
if len(v.values()) == 1
} }
conflicted_state = { return new_state
k: v.values()
for k, v in state.items()
if len(v.values()) > 1
}
if event_type:
prev_states_events = conflicted_state.get( def _ordered_events(events):
(event_type, state_key), [] def key_func(e):
return -int(e.depth), hashlib.sha1(e.event_id).hexdigest()
return sorted(events, key=key_func)
def resolve_events(state_sets, state_map_factory):
"""
Args:
state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
state_map_factory(dict|callable): If callable, then will be called
with a list of event_ids that are needed, and should return with
a Deferred of dict of event_id to event. Otherwise, should be
a dict from event_id to event of all events in state_sets.
Returns
dict[(str, str), synapse.events.FrozenEvent] is a map from
(type, state_key) to event.
"""
if len(state_sets) == 1:
return state_sets[0]
unconflicted_state, conflicted_state = _seperate(
state_sets,
) )
prev_states = [s.event_id for s in prev_states_events]
if callable(state_map_factory):
return _resolve_with_state_fac(
unconflicted_state, conflicted_state, state_map_factory
)
state_map = state_map_factory
auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map
)
return _resolve_with_state(
unconflicted_state, conflicted_state, auth_events, state_map
)
def _seperate(state_sets):
"""Takes the state_sets and figures out which keys are conflicted and
which aren't. i.e., which have multiple different event_ids associated
with them in different state sets.
"""
unconflicted_state = dict(state_sets[0])
conflicted_state = {}
for state_set in state_sets[1:]:
for key, value in state_set.iteritems():
# Check if there is an unconflicted entry for the state key.
unconflicted_value = unconflicted_state.get(key)
if unconflicted_value is None:
# There isn't an unconflicted entry so check if there is a
# conflicted entry.
ls = conflicted_state.get(key)
if ls is None:
# There wasn't a conflicted entry so haven't seen this key before.
# Therefore it isn't conflicted yet.
unconflicted_state[key] = value
else: else:
prev_states = [] # This key is already conflicted, add our value to the conflict set.
ls.add(value)
elif unconflicted_value != value:
# If the unconflicted value is not the same as our value then we
# have a new conflict. So move the key from the unconflicted_state
# to the conflicted state.
conflicted_state[key] = {value, unconflicted_value}
unconflicted_state.pop(key, None)
return unconflicted_state, conflicted_state
@defer.inlineCallbacks
def _resolve_with_state_fac(unconflicted_state, conflicted_state,
state_map_factory):
needed_events = set(
event_id
for event_ids in conflicted_state.itervalues()
for event_id in event_ids
)
logger.info("Asking for %d conflicted events", len(needed_events))
state_map = yield state_map_factory(needed_events)
auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map
)
new_needed_events = set(auth_events.itervalues())
new_needed_events -= needed_events
logger.info("Asking for %d auth events", len(new_needed_events))
state_map_new = yield state_map_factory(new_needed_events)
state_map.update(state_map_new)
defer.returnValue(_resolve_with_state(
unconflicted_state, conflicted_state, auth_events, state_map
))
def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
auth_events = {}
for event_ids in conflicted_state.itervalues():
for event_id in event_ids:
if event_id in state_map:
keys = event_auth.auth_types_for_event(state_map[event_id])
for key in keys:
if key not in auth_events:
event_id = unconflicted_state.get(key, None)
if event_id:
auth_events[key] = event_id
return auth_events
def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_ids,
state_map):
conflicted_state = {}
for key, event_ids in conflicted_state_ds.iteritems():
events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map]
if len(events) > 1:
conflicted_state[key] = events
elif len(events) == 1:
unconflicted_state_ids[key] = events[0].event_id
auth_events = { auth_events = {
k: e for k, e in unconflicted_state.items() key: state_map[ev_id]
if k[0] in AuthEventTypes for key, ev_id in auth_event_ids.items()
if ev_id in state_map
} }
try: try:
resolved_state = self._resolve_state_events( resolved_state = _resolve_state_events(
conflicted_state, auth_events conflicted_state, auth_events
) )
except: except:
logger.exception("Failed to resolve state") logger.exception("Failed to resolve state")
raise raise
new_state = unconflicted_state new_state = unconflicted_state_ids
new_state.update(resolved_state) for key, event in resolved_state.iteritems():
new_state[key] = event.event_id
return new_state, prev_states return new_state
@log_function
def _resolve_state_events(self, conflicted_state, auth_events): def _resolve_state_events(conflicted_state, auth_events):
""" This is where we actually decide which of the conflicted state to """ This is where we actually decide which of the conflicted state to
use. use.
@ -460,11 +571,10 @@ class StateHandler(object):
4. other events. 4. other events.
""" """
resolved_state = {} resolved_state = {}
power_key = (EventTypes.PowerLevels, "") if POWER_KEY in conflicted_state:
if power_key in conflicted_state: events = conflicted_state[POWER_KEY]
events = conflicted_state[power_key]
logger.debug("Resolving conflicted power levels %r", events) logger.debug("Resolving conflicted power levels %r", events)
resolved_state[power_key] = self._resolve_auth_events( resolved_state[POWER_KEY] = _resolve_auth_events(
events, auth_events) events, auth_events)
auth_events.update(resolved_state) auth_events.update(resolved_state)
@ -472,7 +582,7 @@ class StateHandler(object):
for key, events in conflicted_state.items(): for key, events in conflicted_state.items():
if key[0] == EventTypes.JoinRules: if key[0] == EventTypes.JoinRules:
logger.debug("Resolving conflicted join rules %r", events) logger.debug("Resolving conflicted join rules %r", events)
resolved_state[key] = self._resolve_auth_events( resolved_state[key] = _resolve_auth_events(
events, events,
auth_events auth_events
) )
@ -482,7 +592,7 @@ class StateHandler(object):
for key, events in conflicted_state.items(): for key, events in conflicted_state.items():
if key[0] == EventTypes.Member: if key[0] == EventTypes.Member:
logger.debug("Resolving conflicted member lists %r", events) logger.debug("Resolving conflicted member lists %r", events)
resolved_state[key] = self._resolve_auth_events( resolved_state[key] = _resolve_auth_events(
events, events,
auth_events auth_events
) )
@ -492,38 +602,48 @@ class StateHandler(object):
for key, events in conflicted_state.items(): for key, events in conflicted_state.items():
if key not in resolved_state: if key not in resolved_state:
logger.debug("Resolving conflicted state %r:%r", key, events) logger.debug("Resolving conflicted state %r:%r", key, events)
resolved_state[key] = self._resolve_normal_events( resolved_state[key] = _resolve_normal_events(
events, auth_events events, auth_events
) )
return resolved_state return resolved_state
def _resolve_auth_events(self, events, auth_events):
reverse = [i for i in reversed(self._ordered_events(events))]
auth_events = dict(auth_events) def _resolve_auth_events(events, auth_events):
reverse = [i for i in reversed(_ordered_events(events))]
auth_keys = set(
key
for event in events
for key in event_auth.auth_types_for_event(event)
)
new_auth_events = {}
for key in auth_keys:
auth_event = auth_events.get(key, None)
if auth_event:
new_auth_events[key] = auth_event
auth_events = new_auth_events
prev_event = reverse[0] prev_event = reverse[0]
for event in reverse[1:]: for event in reverse[1:]:
auth_events[(prev_event.type, prev_event.state_key)] = prev_event auth_events[(prev_event.type, prev_event.state_key)] = prev_event
try: try:
# FIXME: hs.get_auth() is bad style, but we need to do it to
# get around circular deps.
# The signatures have already been checked at this point # The signatures have already been checked at this point
self.hs.get_auth().check(event, auth_events, do_sig_check=False) event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
prev_event = event prev_event = event
except AuthError: except AuthError:
return prev_event return prev_event
return event return event
def _resolve_normal_events(self, events, auth_events):
for event in self._ordered_events(events): def _resolve_normal_events(events, auth_events):
for event in _ordered_events(events):
try: try:
# FIXME: hs.get_auth() is bad style, but we need to do it to
# get around circular deps.
# The signatures have already been checked at this point # The signatures have already been checked at this point
self.hs.get_auth().check(event, auth_events, do_sig_check=False) event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
return event return event
except AuthError: except AuthError:
pass pass
@ -531,9 +651,3 @@ class StateHandler(object):
# Use the last event (the one with the least depth) if they all fail # Use the last event (the one with the least depth) if they all fail
# the auth check. # the auth check.
return event return event
def _ordered_events(self, events):
def key_func(e):
return -int(e.depth), hashlib.sha1(e.event_id).hexdigest()
return sorted(events, key=key_func)

View file

@ -116,6 +116,9 @@ class DataStore(RoomMemberStore, RoomStore,
self._public_room_id_gen = StreamIdGenerator( self._public_room_id_gen = StreamIdGenerator(
db_conn, "public_room_list_stream", "stream_id" db_conn, "public_room_list_stream", "stream_id"
) )
self._device_list_id_gen = StreamIdGenerator(
db_conn, "device_lists_stream", "stream_id",
)
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id") self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
@ -189,7 +192,8 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "device_inbox", db_conn, "device_inbox",
entity_column="user_id", entity_column="user_id",
stream_column="stream_id", stream_column="stream_id",
max_value=max_device_inbox_id max_value=max_device_inbox_id,
limit=1000,
) )
self._device_inbox_stream_cache = StreamChangeCache( self._device_inbox_stream_cache = StreamChangeCache(
"DeviceInboxStreamChangeCache", min_device_inbox_id, "DeviceInboxStreamChangeCache", min_device_inbox_id,
@ -202,12 +206,21 @@ class DataStore(RoomMemberStore, RoomStore,
entity_column="destination", entity_column="destination",
stream_column="stream_id", stream_column="stream_id",
max_value=max_device_inbox_id, max_value=max_device_inbox_id,
limit=1000,
) )
self._device_federation_outbox_stream_cache = StreamChangeCache( self._device_federation_outbox_stream_cache = StreamChangeCache(
"DeviceFederationOutboxStreamChangeCache", min_device_outbox_id, "DeviceFederationOutboxStreamChangeCache", min_device_outbox_id,
prefilled_cache=device_outbox_prefill, prefilled_cache=device_outbox_prefill,
) )
device_list_max = self._device_list_id_gen.get_current_token()
self._device_list_stream_cache = StreamChangeCache(
"DeviceListStreamChangeCache", device_list_max,
)
self._device_list_federation_stream_cache = StreamChangeCache(
"DeviceListFederationStreamChangeCache", device_list_max,
)
cur = LoggingTransaction( cur = LoggingTransaction(
db_conn.cursor(), db_conn.cursor(),
name="_find_stream_orderings_for_times_txn", name="_find_stream_orderings_for_times_txn",

View file

@ -169,7 +169,7 @@ class SQLBaseStore(object):
max_entries=hs.config.event_cache_size) max_entries=hs.config.event_cache_size)
self._state_group_cache = DictionaryCache( self._state_group_cache = DictionaryCache(
"*stateGroupCache*", 2000 * CACHE_SIZE_FACTOR "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
) )
self._event_fetch_lock = threading.Condition() self._event_fetch_lock = threading.Condition()
@ -387,6 +387,10 @@ class SQLBaseStore(object):
Args: Args:
table : string giving the table name table : string giving the table name
values : dict of new column names and values for them values : dict of new column names and values for them
Returns:
bool: Whether the row was inserted or not. Only useful when
`or_ignore` is True
""" """
try: try:
yield self.runInteraction( yield self.runInteraction(
@ -398,6 +402,8 @@ class SQLBaseStore(object):
# a cursor after we receive an error from the db. # a cursor after we receive an error from the db.
if not or_ignore: if not or_ignore:
raise raise
defer.returnValue(False)
defer.returnValue(True)
@staticmethod @staticmethod
def _simple_insert_txn(txn, table, values): def _simple_insert_txn(txn, table, values):
@ -838,18 +844,19 @@ class SQLBaseStore(object):
return txn.execute(sql, keyvalues.values()) return txn.execute(sql, keyvalues.values())
def _get_cache_dict(self, db_conn, table, entity_column, stream_column, def _get_cache_dict(self, db_conn, table, entity_column, stream_column,
max_value): max_value, limit=100000):
# Fetch a mapping of room_id -> max stream position for "recent" rooms. # Fetch a mapping of room_id -> max stream position for "recent" rooms.
# It doesn't really matter how many we get, the StreamChangeCache will # It doesn't really matter how many we get, the StreamChangeCache will
# do the right thing to ensure it respects the max size of cache. # do the right thing to ensure it respects the max size of cache.
sql = ( sql = (
"SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s" "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
" WHERE %(stream)s > ? - 100000" " WHERE %(stream)s > ? - %(limit)s"
" GROUP BY %(entity)s" " GROUP BY %(entity)s"
) % { ) % {
"table": table, "table": table,
"entity": entity_column, "entity": entity_column,
"stream": stream_column, "stream": stream_column,
"limit": limit,
} }
sql = self.database_engine.convert_param_style(sql) sql = self.database_engine.convert_param_style(sql)

View file

@ -18,13 +18,29 @@ import ujson
from twisted.internet import defer from twisted.internet import defer
from ._base import SQLBaseStore from .background_updates import BackgroundUpdateStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DeviceInboxStore(SQLBaseStore): class DeviceInboxStore(BackgroundUpdateStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, hs):
super(DeviceInboxStore, self).__init__(hs)
self.register_background_index_update(
"device_inbox_stream_index",
index_name="device_inbox_stream_id_user_id",
table="device_inbox",
columns=["stream_id", "user_id"],
)
self.register_background_update_handler(
self.DEVICE_INBOX_STREAM_ID,
self._background_drop_index_device_inbox,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_messages_to_device_inbox(self, local_messages_by_user_then_device, def add_messages_to_device_inbox(self, local_messages_by_user_then_device,
@ -368,3 +384,18 @@ class DeviceInboxStore(SQLBaseStore):
"delete_device_msgs_for_remote", "delete_device_msgs_for_remote",
delete_messages_for_remote_destination_txn delete_messages_for_remote_destination_txn
) )
@defer.inlineCallbacks
def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn):
txn = conn.cursor()
txn.execute(
"DROP INDEX IF EXISTS device_inbox_stream_id"
)
txn.close()
yield self.runWithConnection(reindex_txn)
yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID)
defer.returnValue(1)

View file

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import ujson as json
from twisted.internet import defer from twisted.internet import defer
@ -23,27 +24,29 @@ logger = logging.getLogger(__name__)
class DeviceStore(SQLBaseStore): class DeviceStore(SQLBaseStore):
def __init__(self, hs):
super(DeviceStore, self).__init__(hs)
self._clock.looping_call(
self._prune_old_outbound_device_pokes, 60 * 60 * 1000
)
@defer.inlineCallbacks @defer.inlineCallbacks
def store_device(self, user_id, device_id, def store_device(self, user_id, device_id,
initial_device_display_name, initial_device_display_name):
ignore_if_known=True):
"""Ensure the given device is known; add it to the store if not """Ensure the given device is known; add it to the store if not
Args: Args:
user_id (str): id of user associated with the device user_id (str): id of user associated with the device
device_id (str): id of device device_id (str): id of device
initial_device_display_name (str): initial displayname of the initial_device_display_name (str): initial displayname of the
device device. Ignored if device exists.
ignore_if_known (bool): ignore integrity errors which mean the
device is already known
Returns: Returns:
defer.Deferred defer.Deferred: boolean whether the device was inserted or an
Raises: existing device existed with that ID.
StoreError: if ignore_if_known is False and the device was already
known
""" """
try: try:
yield self._simple_insert( inserted = yield self._simple_insert(
"devices", "devices",
values={ values={
"user_id": user_id, "user_id": user_id,
@ -51,8 +54,9 @@ class DeviceStore(SQLBaseStore):
"display_name": initial_device_display_name "display_name": initial_device_display_name
}, },
desc="store_device", desc="store_device",
or_ignore=ignore_if_known, or_ignore=True,
) )
defer.returnValue(inserted)
except Exception as e: except Exception as e:
logger.error("store_device with device_id=%s(%r) user_id=%s(%r)" logger.error("store_device with device_id=%s(%r) user_id=%s(%r)"
" display_name=%s(%r) failed: %s", " display_name=%s(%r) failed: %s",
@ -139,3 +143,451 @@ class DeviceStore(SQLBaseStore):
) )
defer.returnValue({d["device_id"]: d for d in devices}) defer.returnValue({d["device_id"]: d for d in devices})
def get_device_list_last_stream_id_for_remote(self, user_id):
"""Get the last stream_id we got for a user. May be None if we haven't
got any information for them.
"""
return self._simple_select_one_onecol(
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
retcol="stream_id",
desc="get_device_list_remote_extremity",
allow_none=True,
)
def mark_remote_user_device_list_as_unsubscribed(self, user_id):
"""Mark that we no longer track device lists for remote user.
"""
return self._simple_delete(
table="device_lists_remote_extremeties",
keyvalues={
"user_id": user_id,
},
desc="mark_remote_user_device_list_as_unsubscribed",
)
def update_remote_device_list_cache_entry(self, user_id, device_id, content,
stream_id):
"""Updates a single user's device in the cache.
"""
return self.runInteraction(
"update_remote_device_list_cache_entry",
self._update_remote_device_list_cache_entry_txn,
user_id, device_id, content, stream_id,
)
def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id,
content, stream_id):
self._simple_upsert_txn(
txn,
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
values={
"content": json.dumps(content),
}
)
self._simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={
"user_id": user_id,
},
values={
"stream_id": stream_id,
}
)
def update_remote_device_list_cache(self, user_id, devices, stream_id):
"""Replace the cache of the remote user's devices.
"""
return self.runInteraction(
"update_remote_device_list_cache",
self._update_remote_device_list_cache_txn,
user_id, devices, stream_id,
)
def _update_remote_device_list_cache_txn(self, txn, user_id, devices,
stream_id):
self._simple_delete_txn(
txn,
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
},
)
self._simple_insert_many_txn(
txn,
table="device_lists_remote_cache",
values=[
{
"user_id": user_id,
"device_id": content["device_id"],
"content": json.dumps(content),
}
for content in devices
]
)
self._simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={
"user_id": user_id,
},
values={
"stream_id": stream_id,
}
)
def get_devices_by_remote(self, destination, from_stream_id):
"""Get stream of updates to send to remote servers
Returns:
(now_stream_id, [ { updates }, .. ])
"""
now_stream_id = self._device_list_id_gen.get_current_token()
has_changed = self._device_list_federation_stream_cache.has_entity_changed(
destination, int(from_stream_id)
)
if not has_changed:
return (now_stream_id, [])
return self.runInteraction(
"get_devices_by_remote", self._get_devices_by_remote_txn,
destination, from_stream_id, now_stream_id,
)
def _get_devices_by_remote_txn(self, txn, destination, from_stream_id,
now_stream_id):
sql = """
SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
GROUP BY user_id, device_id
"""
txn.execute(
sql, (destination, from_stream_id, now_stream_id, False)
)
rows = txn.fetchall()
if not rows:
return (now_stream_id, [])
# maps (user_id, device_id) -> stream_id
query_map = {(r[0], r[1]): r[2] for r in rows}
devices = self._get_e2e_device_keys_txn(
txn, query_map.keys(), include_all_devices=True
)
prev_sent_id_sql = """
SELECT coalesce(max(stream_id), 0) as stream_id
FROM device_lists_outbound_pokes
WHERE destination = ? AND user_id = ? AND stream_id <= ?
"""
results = []
for user_id, user_devices in devices.iteritems():
# The prev_id for the first row is always the last row before
# `from_stream_id`
txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
rows = txn.fetchall()
prev_id = rows[0][0]
for device_id, device in user_devices.iteritems():
stream_id = query_map[(user_id, device_id)]
result = {
"user_id": user_id,
"device_id": device_id,
"prev_id": [prev_id] if prev_id else [],
"stream_id": stream_id,
}
prev_id = stream_id
key_json = device.get("key_json", None)
if key_json:
result["keys"] = json.loads(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
results.append(result)
return (now_stream_id, results)
def get_user_devices_from_cache(self, query_list):
"""Get the devices (and keys if any) for remote users from the cache.
Args:
query_list(list): List of (user_id, device_ids), if device_ids is
falsey then return all device ids for that user.
Returns:
(user_ids_not_in_cache, results_map), where user_ids_not_in_cache is
a set of user_ids and results_map is a mapping of
user_id -> device_id -> device_info
"""
return self.runInteraction(
"get_user_devices_from_cache", self._get_user_devices_from_cache_txn,
query_list,
)
def _get_user_devices_from_cache_txn(self, txn, query_list):
user_ids = {user_id for user_id, _ in query_list}
user_ids_in_cache = set()
for user_id in user_ids:
stream_ids = self._simple_select_onecol_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={
"user_id": user_id,
},
retcol="stream_id",
)
if stream_ids:
user_ids_in_cache.add(user_id)
user_ids_not_in_cache = user_ids - user_ids_in_cache
results = {}
for user_id, device_id in query_list:
if user_id not in user_ids_in_cache:
continue
if device_id:
content = self._simple_select_one_onecol_txn(
txn,
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
retcol="content",
)
results.setdefault(user_id, {})[device_id] = json.loads(content)
else:
devices = self._simple_select_list_txn(
txn,
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
},
retcols=("device_id", "content"),
)
results[user_id] = {
device["device_id"]: json.loads(device["content"])
for device in devices
}
user_ids_in_cache.discard(user_id)
return user_ids_not_in_cache, results
def get_devices_with_keys_by_user(self, user_id):
"""Get all devices (with any device keys) for a user
Returns:
(stream_id, devices)
"""
return self.runInteraction(
"get_devices_with_keys_by_user",
self._get_devices_with_keys_by_user_txn, user_id,
)
def _get_devices_with_keys_by_user_txn(self, txn, user_id):
now_stream_id = self._device_list_id_gen.get_current_token()
devices = self._get_e2e_device_keys_txn(
txn, [(user_id, None)], include_all_devices=True
)
if devices:
user_devices = devices[user_id]
results = []
for device_id, device in user_devices.iteritems():
result = {
"device_id": device_id,
}
key_json = device.get("key_json", None)
if key_json:
result["keys"] = json.loads(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
results.append(result)
return now_stream_id, results
return now_stream_id, []
def mark_as_sent_devices_by_remote(self, destination, stream_id):
"""Mark that updates have successfully been sent to the destination.
"""
return self.runInteraction(
"mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
destination, stream_id,
)
def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
# First we DELETE all rows such that only the latest row for each
# (destination, user_id is left. We do this by selecting first and
# deleting.
sql = """
SELECT user_id, coalesce(max(stream_id), 0) FROM device_lists_outbound_pokes
WHERE destination = ? AND stream_id <= ?
GROUP BY user_id
HAVING count(*) > 1
"""
txn.execute(sql, (destination, stream_id,))
rows = txn.fetchall()
sql = """
DELETE FROM device_lists_outbound_pokes
WHERE destination = ? AND user_id = ? AND stream_id < ?
"""
txn.executemany(
sql, ((destination, row[0], row[1],) for row in rows)
)
# Mark everything that is left as sent
sql = """
UPDATE device_lists_outbound_pokes SET sent = ?
WHERE destination = ? AND stream_id <= ?
"""
txn.execute(sql, (True, destination, stream_id,))
@defer.inlineCallbacks
def get_user_whose_devices_changed(self, from_key):
"""Get set of users whose devices have changed since `from_key`.
"""
from_key = int(from_key)
changed = self._device_list_stream_cache.get_all_entities_changed(from_key)
if changed is not None:
defer.returnValue(set(changed))
sql = """
SELECT user_id FROM device_lists_stream WHERE stream_id > ?
"""
rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
defer.returnValue(set(row[0] for row in rows))
def get_all_device_list_changes_for_remotes(self, from_key):
"""Return a list of `(stream_id, user_id, destination)` which is the
combined list of changes to devices, and which destinations need to be
poked. `destination` may be None if no destinations need to be poked.
"""
sql = """
SELECT stream_id, user_id, destination FROM device_lists_stream
LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
WHERE stream_id > ?
"""
return self._execute(
"get_users_and_hosts_device_list", None,
sql, from_key,
)
@defer.inlineCallbacks
def add_device_change_to_streams(self, user_id, device_ids, hosts):
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
"""
with self._device_list_id_gen.get_next() as stream_id:
yield self.runInteraction(
"add_device_change_to_streams", self._add_device_change_txn,
user_id, device_ids, hosts, stream_id,
)
defer.returnValue(stream_id)
def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id):
now = self._clock.time_msec()
txn.call_after(
self._device_list_stream_cache.entity_has_changed,
user_id, stream_id,
)
for host in hosts:
txn.call_after(
self._device_list_federation_stream_cache.entity_has_changed,
host, stream_id,
)
self._simple_insert_many_txn(
txn,
table="device_lists_stream",
values=[
{
"stream_id": stream_id,
"user_id": user_id,
"device_id": device_id,
}
for device_id in device_ids
]
)
self._simple_insert_many_txn(
txn,
table="device_lists_outbound_pokes",
values=[
{
"destination": destination,
"stream_id": stream_id,
"user_id": user_id,
"device_id": device_id,
"sent": False,
"ts": now,
}
for destination in hosts
for device_id in device_ids
]
)
def get_device_stream_token(self):
return self._device_list_id_gen.get_current_token()
def _prune_old_outbound_device_pokes(self):
"""Delete old entries out of the device_lists_outbound_pokes to ensure
that we don't fill up due to dead servers. We keep one entry per
(destination, user_id) tuple to ensure that the prev_ids remain correct
if the server does come back.
"""
yesterday = self._clock.time_msec() - 24 * 60 * 60 * 1000
def _prune_txn(txn):
select_sql = """
SELECT destination, user_id, max(stream_id) as stream_id
FROM device_lists_outbound_pokes
GROUP BY destination, user_id
HAVING min(ts) < ? AND count(*) > 1
"""
txn.execute(select_sql, (yesterday,))
rows = txn.fetchall()
if not rows:
return
delete_sql = """
DELETE FROM device_lists_outbound_pokes
WHERE ts < ? AND destination = ? AND user_id = ? AND stream_id < ?
"""
txn.executemany(
delete_sql,
(
(yesterday, row[0], row[1], row[2])
for row in rows
)
)
logger.info("Pruned %d device list outbound pokes", txn.rowcount)
return self.runInteraction(
"_prune_old_outbound_device_pokes", _prune_txn
)

View file

@ -12,16 +12,37 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections from twisted.internet import defer
import twisted.internet.defer from canonicaljson import encode_canonical_json
import ujson as json
from ._base import SQLBaseStore from ._base import SQLBaseStore
class EndToEndKeyStore(SQLBaseStore): class EndToEndKeyStore(SQLBaseStore):
def set_e2e_device_keys(self, user_id, device_id, time_now, json_bytes): def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
return self._simple_upsert( """Stores device keys for a device. Returns whether there was a change
or the keys were already in the database.
"""
def _set_e2e_device_keys_txn(txn):
old_key_json = self._simple_select_one_onecol_txn(
txn,
table="e2e_device_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
retcol="key_json",
allow_none=True,
)
new_key_json = encode_canonical_json(device_keys)
if old_key_json == new_key_json:
return False
self._simple_upsert_txn(
txn,
table="e2e_device_keys_json", table="e2e_device_keys_json",
keyvalues={ keyvalues={
"user_id": user_id, "user_id": user_id,
@ -29,57 +50,73 @@ class EndToEndKeyStore(SQLBaseStore):
}, },
values={ values={
"ts_added_ms": time_now, "ts_added_ms": time_now,
"key_json": json_bytes, "key_json": new_key_json,
} }
) )
def get_e2e_device_keys(self, query_list): return True
return self.runInteraction(
"set_e2e_device_keys", _set_e2e_device_keys_txn
)
@defer.inlineCallbacks
def get_e2e_device_keys(self, query_list, include_all_devices=False):
"""Fetch a list of device keys. """Fetch a list of device keys.
Args: Args:
query_list(list): List of pairs of user_ids and device_ids. query_list(list): List of pairs of user_ids and device_ids.
include_all_devices (bool): whether to include entries for devices
that don't have device keys
Returns: Returns:
Dict mapping from user-id to dict mapping from device_id to Dict mapping from user-id to dict mapping from device_id to
dict containing "key_json", "device_display_name". dict containing "key_json", "device_display_name".
""" """
if not query_list: if not query_list:
return {} defer.returnValue({})
return self.runInteraction( results = yield self.runInteraction(
"get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list "get_e2e_device_keys", self._get_e2e_device_keys_txn,
query_list, include_all_devices,
) )
def _get_e2e_device_keys_txn(self, txn, query_list): for user_id, device_keys in results.iteritems():
for device_id, device_info in device_keys.iteritems():
device_info["keys"] = json.loads(device_info.pop("key_json"))
defer.returnValue(results)
def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices):
query_clauses = [] query_clauses = []
query_params = [] query_params = []
for (user_id, device_id) in query_list: for (user_id, device_id) in query_list:
query_clause = "k.user_id = ?" query_clause = "user_id = ?"
query_params.append(user_id) query_params.append(user_id)
if device_id: if device_id:
query_clause += " AND k.device_id = ?" query_clause += " AND device_id = ?"
query_params.append(device_id) query_params.append(device_id)
query_clauses.append(query_clause) query_clauses.append(query_clause)
sql = ( sql = (
"SELECT k.user_id, k.device_id, " "SELECT user_id, device_id, "
" d.display_name AS device_display_name, " " d.display_name AS device_display_name, "
" k.key_json" " k.key_json"
" FROM e2e_device_keys_json k" " FROM devices d"
" LEFT JOIN devices d ON d.user_id = k.user_id" " %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
" AND d.device_id = k.device_id"
" WHERE %s" " WHERE %s"
) % ( ) % (
"LEFT" if include_all_devices else "INNER",
" OR ".join("(" + q + ")" for q in query_clauses) " OR ".join("(" + q + ")" for q in query_clauses)
) )
txn.execute(sql, query_params) txn.execute(sql, query_params)
rows = self.cursor_to_dict(txn) rows = self.cursor_to_dict(txn)
result = collections.defaultdict(dict) result = {}
for row in rows: for row in rows:
result[row["user_id"]][row["device_id"]] = row result.setdefault(row["user_id"], {})[row["device_id"]] = row
return result return result
@ -152,7 +189,7 @@ class EndToEndKeyStore(SQLBaseStore):
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
) )
@twisted.internet.defer.inlineCallbacks @defer.inlineCallbacks
def delete_e2e_keys_by_device(self, user_id, device_id): def delete_e2e_keys_by_device(self, user_id, device_id):
yield self._simple_delete( yield self._simple_delete(
table="e2e_device_keys_json", table="e2e_device_keys_json",

View file

@ -129,7 +129,7 @@ class EventFederationStore(SQLBaseStore):
room_id, room_id,
) )
@cached() @cached(max_entries=5000, iterable=True)
def get_latest_event_ids_in_room(self, room_id): def get_latest_event_ids_in_room(self, room_id):
return self._simple_select_onecol( return self._simple_select_onecol(
table="event_forward_extremities", table="event_forward_extremities",
@ -235,80 +235,21 @@ class EventFederationStore(SQLBaseStore):
], ],
) )
self._update_extremeties(txn, events) self._update_backward_extremeties(txn, events)
def _update_extremeties(self, txn, events): def _update_backward_extremeties(self, txn, events):
"""Updates the event_*_extremities tables based on the new/updated """Updates the event_backward_extremities tables based on the new/updated
events being persisted. events being persisted.
This is called for new events *and* for events that were outliers, but This is called for new events *and* for events that were outliers, but
are are now being persisted as non-outliers. are now being persisted as non-outliers.
Forward extremities are handled when we first start persisting the events.
""" """
events_by_room = {} events_by_room = {}
for ev in events: for ev in events:
events_by_room.setdefault(ev.room_id, []).append(ev) events_by_room.setdefault(ev.room_id, []).append(ev)
for room_id, room_events in events_by_room.items():
prevs = [
e_id for ev in room_events for e_id, _ in ev.prev_events
if not ev.internal_metadata.is_outlier()
]
if prevs:
txn.execute(
"DELETE FROM event_forward_extremities"
" WHERE room_id = ?"
" AND event_id in (%s)" % (
",".join(["?"] * len(prevs)),
),
[room_id] + prevs,
)
query = (
"INSERT INTO event_forward_extremities (event_id, room_id)"
" SELECT ?, ? WHERE NOT EXISTS ("
" SELECT 1 FROM event_edges WHERE prev_event_id = ?"
" )"
)
txn.executemany(
query,
[
(ev.event_id, ev.room_id, ev.event_id) for ev in events
if not ev.internal_metadata.is_outlier()
]
)
# We now insert into stream_ordering_to_exterm a mapping from room_id,
# new stream_ordering to new forward extremeties in the room.
# This allows us to later efficiently look up the forward extremeties
# for a room before a given stream_ordering
max_stream_ord = max(
ev.internal_metadata.stream_ordering for ev in events
)
new_extrem = {}
for room_id in events_by_room:
event_ids = self._simple_select_onecol_txn(
txn,
table="event_forward_extremities",
keyvalues={"room_id": room_id},
retcol="event_id",
)
new_extrem[room_id] = event_ids
self._simple_insert_many_txn(
txn,
table="stream_ordering_to_exterm",
values=[
{
"room_id": room_id,
"event_id": event_id,
"stream_ordering": max_stream_ord,
}
for room_id, extrem_evs in new_extrem.items()
for event_id in extrem_evs
]
)
query = ( query = (
"INSERT INTO event_backward_extremities (event_id, room_id)" "INSERT INTO event_backward_extremities (event_id, room_id)"
" SELECT ?, ? WHERE NOT EXISTS (" " SELECT ?, ? WHERE NOT EXISTS ("
@ -339,11 +280,6 @@ class EventFederationStore(SQLBaseStore):
] ]
) )
for room_id in events_by_room:
txn.call_after(
self.get_latest_event_ids_in_room.invalidate, (room_id,)
)
def get_forward_extremeties_for_room(self, room_id, stream_ordering): def get_forward_extremeties_for_room(self, room_id, stream_ordering):
# We want to make the cache more effective, so we clamp to the last # We want to make the cache more effective, so we clamp to the last
# change before the given ordering. # change before the given ordering.

View file

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore, _RollbackButIsFineException from ._base import SQLBaseStore
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
@ -27,6 +27,8 @@ from synapse.util.logutils import log_function
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.state import resolve_events
from synapse.util.caches.descriptors import cached
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from collections import deque, namedtuple, OrderedDict from collections import deque, namedtuple, OrderedDict
@ -71,22 +73,19 @@ class _EventPeristenceQueue(object):
""" """
_EventPersistQueueItem = namedtuple("_EventPersistQueueItem", ( _EventPersistQueueItem = namedtuple("_EventPersistQueueItem", (
"events_and_contexts", "current_state", "backfilled", "deferred", "events_and_contexts", "backfilled", "deferred",
)) ))
def __init__(self): def __init__(self):
self._event_persist_queues = {} self._event_persist_queues = {}
self._currently_persisting_rooms = set() self._currently_persisting_rooms = set()
def add_to_queue(self, room_id, events_and_contexts, backfilled, current_state): def add_to_queue(self, room_id, events_and_contexts, backfilled):
"""Add events to the queue, with the given persist_event options. """Add events to the queue, with the given persist_event options.
""" """
queue = self._event_persist_queues.setdefault(room_id, deque()) queue = self._event_persist_queues.setdefault(room_id, deque())
if queue: if queue:
end_item = queue[-1] end_item = queue[-1]
if end_item.current_state or current_state:
# We perist events with current_state set to True one at a time
pass
if end_item.backfilled == backfilled: if end_item.backfilled == backfilled:
end_item.events_and_contexts.extend(events_and_contexts) end_item.events_and_contexts.extend(events_and_contexts)
return end_item.deferred.observe() return end_item.deferred.observe()
@ -96,7 +95,6 @@ class _EventPeristenceQueue(object):
queue.append(self._EventPersistQueueItem( queue.append(self._EventPersistQueueItem(
events_and_contexts=events_and_contexts, events_and_contexts=events_and_contexts,
backfilled=backfilled, backfilled=backfilled,
current_state=current_state,
deferred=deferred, deferred=deferred,
)) ))
@ -216,7 +214,6 @@ class EventsStore(SQLBaseStore):
d = preserve_fn(self._event_persist_queue.add_to_queue)( d = preserve_fn(self._event_persist_queue.add_to_queue)(
room_id, evs_ctxs, room_id, evs_ctxs,
backfilled=backfilled, backfilled=backfilled,
current_state=None,
) )
deferreds.append(d) deferreds.append(d)
@ -229,11 +226,10 @@ class EventsStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def persist_event(self, event, context, current_state=None, backfilled=False): def persist_event(self, event, context, backfilled=False):
deferred = self._event_persist_queue.add_to_queue( deferred = self._event_persist_queue.add_to_queue(
event.room_id, [(event, context)], event.room_id, [(event, context)],
backfilled=backfilled, backfilled=backfilled,
current_state=current_state,
) )
self._maybe_start_persisting(event.room_id) self._maybe_start_persisting(event.room_id)
@ -246,17 +242,6 @@ class EventsStore(SQLBaseStore):
def _maybe_start_persisting(self, room_id): def _maybe_start_persisting(self, room_id):
@defer.inlineCallbacks @defer.inlineCallbacks
def persisting_queue(item): def persisting_queue(item):
if item.current_state:
for event, context in item.events_and_contexts:
# There should only ever be one item in
# events_and_contexts when current_state is
# not None
yield self._persist_event(
event, context,
current_state=item.current_state,
backfilled=item.backfilled,
)
else:
yield self._persist_events( yield self._persist_events(
item.events_and_contexts, item.events_and_contexts,
backfilled=item.backfilled, backfilled=item.backfilled,
@ -294,35 +279,183 @@ class EventsStore(SQLBaseStore):
for chunk in chunks: for chunk in chunks:
# We can't easily parallelize these since different chunks # We can't easily parallelize these since different chunks
# might contain the same event. :( # might contain the same event. :(
# NB: Assumes that we are only persisting events for one room
# at a time.
new_forward_extremeties = {}
current_state_for_room = {}
if not backfilled:
with Measure(self._clock, "_calculate_state_and_extrem"):
# Work out the new "current state" for each room.
# We do this by working out what the new extremities are and then
# calculating the state from that.
events_by_room = {}
for event, context in chunk:
events_by_room.setdefault(event.room_id, []).append(
(event, context)
)
for room_id, ev_ctx_rm in events_by_room.items():
# Work out new extremities by recursively adding and removing
# the new events.
latest_event_ids = yield self.get_latest_event_ids_in_room(
room_id
)
new_latest_event_ids = yield self._calculate_new_extremeties(
room_id, [ev for ev, _ in ev_ctx_rm]
)
if new_latest_event_ids == set(latest_event_ids):
# No change in extremities, so no change in state
continue
new_forward_extremeties[room_id] = new_latest_event_ids
state = yield self._calculate_state_delta(
room_id, ev_ctx_rm, new_latest_event_ids
)
if state:
current_state_for_room[room_id] = state
yield self.runInteraction( yield self.runInteraction(
"persist_events", "persist_events",
self._persist_events_txn, self._persist_events_txn,
events_and_contexts=chunk, events_and_contexts=chunk,
backfilled=backfilled, backfilled=backfilled,
delete_existing=delete_existing, delete_existing=delete_existing,
current_state_for_room=current_state_for_room,
new_forward_extremeties=new_forward_extremeties,
) )
persist_event_counter.inc_by(len(chunk)) persist_event_counter.inc_by(len(chunk))
@_retry_on_integrity_error
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function def _calculate_new_extremeties(self, room_id, events):
def _persist_event(self, event, context, current_state=None, backfilled=False, """Calculates the new forward extremeties for a room given events to
delete_existing=False): persist.
try:
with self._stream_id_gen.get_next() as stream_ordering: Assumes that we are only persisting events for one room at a time.
event.internal_metadata.stream_ordering = stream_ordering """
yield self.runInteraction( latest_event_ids = yield self.get_latest_event_ids_in_room(
"persist_event", room_id
self._persist_event_txn,
event=event,
context=context,
current_state=current_state,
backfilled=backfilled,
delete_existing=delete_existing,
) )
persist_event_counter.inc() new_latest_event_ids = set(latest_event_ids)
except _RollbackButIsFineException: # First, add all the new events to the list
pass new_latest_event_ids.update(
event.event_id for event in events
if not event.internal_metadata.is_outlier()
)
# Now remove all events that are referenced by the to-be-added events
new_latest_event_ids.difference_update(
e_id
for event in events
for e_id, _ in event.prev_events
if not event.internal_metadata.is_outlier()
)
# And finally remove any events that are referenced by previously added
# events.
rows = yield self._simple_select_many_batch(
table="event_edges",
column="prev_event_id",
iterable=list(new_latest_event_ids),
retcols=["prev_event_id"],
keyvalues={
"room_id": room_id,
"is_state": False,
},
desc="_calculate_new_extremeties",
)
new_latest_event_ids.difference_update(
row["prev_event_id"] for row in rows
)
defer.returnValue(new_latest_event_ids)
@defer.inlineCallbacks
def _calculate_state_delta(self, room_id, events_context, new_latest_event_ids):
"""Calculate the new state deltas for a room.
Assumes that we are only persisting events for one room at a time.
Returns:
2-tuple (to_delete, to_insert) where both are state dicts, i.e.
(type, state_key) -> event_id. `to_delete` are the entries to
first be deleted from current_state_events, `to_insert` are entries
to insert.
May return None if there are no changes to be applied.
"""
# Now we need to work out the different state sets for
# each state extremities
state_sets = []
missing_event_ids = []
was_updated = False
for event_id in new_latest_event_ids:
# First search in the list of new events we're adding,
# and then use the current state from that
for ev, ctx in events_context:
if event_id == ev.event_id:
if ctx.current_state_ids is None:
raise Exception("Unknown current state")
state_sets.append(ctx.current_state_ids)
if ctx.delta_ids or hasattr(ev, "state_key"):
was_updated = True
break
else:
# If we couldn't find it, then we'll need to pull
# the state from the database
was_updated = True
missing_event_ids.append(event_id)
if missing_event_ids:
# Now pull out the state for any missing events from DB
event_to_groups = yield self._get_state_group_for_events(
missing_event_ids,
)
groups = set(event_to_groups.values())
group_to_state = yield self._get_state_for_groups(groups)
state_sets.extend(group_to_state.values())
if not new_latest_event_ids:
current_state = {}
elif was_updated:
current_state = yield resolve_events(
state_sets,
state_map_factory=lambda ev_ids: self.get_events(
ev_ids, get_prev_content=False, check_redacted=False,
),
)
else:
return
existing_state_rows = yield self._simple_select_list(
table="current_state_events",
keyvalues={"room_id": room_id},
retcols=["event_id", "type", "state_key"],
desc="_calculate_state_delta",
)
existing_events = set(row["event_id"] for row in existing_state_rows)
new_events = set(ev_id for ev_id in current_state.itervalues())
changed_events = existing_events ^ new_events
if not changed_events:
return
to_delete = {
(row["type"], row["state_key"]): row["event_id"]
for row in existing_state_rows
if row["event_id"] in changed_events
}
events_to_insert = (new_events - existing_events)
to_insert = {
key: ev_id for key, ev_id in current_state.iteritems()
if ev_id in events_to_insert
}
defer.returnValue((to_delete, to_insert))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_event(self, event_id, check_redacted=True, def get_event(self, event_id, check_redacted=True,
@ -380,53 +513,10 @@ class EventsStore(SQLBaseStore):
defer.returnValue({e.event_id: e for e in events}) defer.returnValue({e.event_id: e for e in events})
@log_function
def _persist_event_txn(self, txn, event, context, current_state, backfilled=False,
delete_existing=False):
# We purposefully do this first since if we include a `current_state`
# key, we *want* to update the `current_state_events` table
if current_state:
txn.call_after(self._get_current_state_for_key.invalidate_all)
txn.call_after(self.get_rooms_for_user.invalidate_all)
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
# Add an entry to the current_state_resets table to record the point
# where we clobbered the current state
stream_order = event.internal_metadata.stream_ordering
self._simple_insert_txn(
txn,
table="current_state_resets",
values={"event_stream_ordering": stream_order}
)
self._simple_delete_txn(
txn,
table="current_state_events",
keyvalues={"room_id": event.room_id},
)
for s in current_state:
self._simple_insert_txn(
txn,
"current_state_events",
{
"event_id": s.event_id,
"room_id": s.room_id,
"type": s.type,
"state_key": s.state_key,
}
)
return self._persist_events_txn(
txn,
[(event, context)],
backfilled=backfilled,
delete_existing=delete_existing,
)
@log_function @log_function
def _persist_events_txn(self, txn, events_and_contexts, backfilled, def _persist_events_txn(self, txn, events_and_contexts, backfilled,
delete_existing=False): delete_existing=False, current_state_for_room={},
new_forward_extremeties={}):
"""Insert some number of room events into the necessary database tables. """Insert some number of room events into the necessary database tables.
Rejected events are only inserted into the events table, the events_json table, Rejected events are only inserted into the events table, the events_json table,
@ -436,6 +526,93 @@ class EventsStore(SQLBaseStore):
If delete_existing is True then existing events will be purged from the If delete_existing is True then existing events will be purged from the
database before insertion. This is useful when retrying due to IntegrityError. database before insertion. This is useful when retrying due to IntegrityError.
""" """
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
for room_id, current_state_tuple in current_state_for_room.iteritems():
to_delete, to_insert = current_state_tuple
txn.executemany(
"DELETE FROM current_state_events WHERE event_id = ?",
[(ev_id,) for ev_id in to_delete.itervalues()],
)
self._simple_insert_many_txn(
txn,
table="current_state_events",
values=[
{
"event_id": ev_id,
"room_id": room_id,
"type": key[0],
"state_key": key[1],
}
for key, ev_id in to_insert.iteritems()
],
)
# Invalidate the various caches
# Figure out the changes of membership to invalidate the
# `get_rooms_for_user` cache.
# We find out which membership events we may have deleted
# and which we have added, then we invlidate the caches for all
# those users.
members_changed = set(
state_key for ev_type, state_key in to_delete.iterkeys()
if ev_type == EventTypes.Member
)
members_changed.update(
state_key for ev_type, state_key in to_insert.iterkeys()
if ev_type == EventTypes.Member
)
for member in members_changed:
self._invalidate_cache_and_stream(
txn, self.get_rooms_for_user, (member,)
)
self._invalidate_cache_and_stream(
txn, self.get_users_in_room, (room_id,)
)
for room_id, new_extrem in new_forward_extremeties.items():
self._simple_delete_txn(
txn,
table="event_forward_extremities",
keyvalues={"room_id": room_id},
)
txn.call_after(
self.get_latest_event_ids_in_room.invalidate, (room_id,)
)
self._simple_insert_many_txn(
txn,
table="event_forward_extremities",
values=[
{
"event_id": ev_id,
"room_id": room_id,
}
for room_id, new_extrem in new_forward_extremeties.items()
for ev_id in new_extrem
],
)
# We now insert into stream_ordering_to_exterm a mapping from room_id,
# new stream_ordering to new forward extremeties in the room.
# This allows us to later efficiently look up the forward extremeties
# for a room before a given stream_ordering
self._simple_insert_many_txn(
txn,
table="stream_ordering_to_exterm",
values=[
{
"room_id": room_id,
"event_id": event_id,
"stream_ordering": max_stream_order,
}
for room_id, new_extrem in new_forward_extremeties.items()
for event_id in new_extrem
]
)
# Ensure that we don't have the same event twice. # Ensure that we don't have the same event twice.
# Pick the earliest non-outlier if there is one, else the earliest one. # Pick the earliest non-outlier if there is one, else the earliest one.
new_events_and_contexts = OrderedDict() new_events_and_contexts = OrderedDict()
@ -550,7 +727,7 @@ class EventsStore(SQLBaseStore):
# Update the event_backward_extremities table now that this # Update the event_backward_extremities table now that this
# event isn't an outlier any more. # event isn't an outlier any more.
self._update_extremeties(txn, [event]) self._update_backward_extremeties(txn, [event])
events_and_contexts = [ events_and_contexts = [
ec for ec in events_and_contexts if ec[0] not in to_remove ec for ec in events_and_contexts if ec[0] not in to_remove
@ -798,29 +975,6 @@ class EventsStore(SQLBaseStore):
# to update the current state table # to update the current state table
return return
for event, _ in state_events_and_contexts:
if event.internal_metadata.is_outlier():
# Outlier events shouldn't clobber the current state.
continue
txn.call_after(
self._get_current_state_for_key.invalidate,
(event.room_id, event.type, event.state_key,)
)
self._simple_upsert_txn(
txn,
"current_state_events",
keyvalues={
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
},
values={
"event_id": event.event_id,
}
)
return return
def _add_to_cache(self, txn, events_and_contexts): def _add_to_cache(self, txn, events_and_contexts):
@ -1084,10 +1238,10 @@ class EventsStore(SQLBaseStore):
self._do_fetch self._do_fetch
) )
logger.info("Loading %d events", len(events)) logger.debug("Loading %d events", len(events))
with PreserveLoggingContext(): with PreserveLoggingContext():
rows = yield events_d rows = yield events_d
logger.info("Loaded %d events (%d rows)", len(events), len(rows)) logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
if not allow_rejected: if not allow_rejected:
rows[:] = [r for r in rows if not r["rejects"]] rows[:] = [r for r in rows if not r["rejects"]]
@ -1418,6 +1572,7 @@ class EventsStore(SQLBaseStore):
"""The current minimum token that backfilled events have reached""" """The current minimum token that backfilled events have reached"""
return -self._backfill_id_gen.get_current_token() return -self._backfill_id_gen.get_current_token()
@cached(num_args=5, max_entries=10)
def get_all_new_events(self, last_backfill_id, last_forward_id, def get_all_new_events(self, last_backfill_id, last_forward_id,
current_backfill_id, current_forward_id, limit): current_backfill_id, current_forward_id, limit):
"""Get all the new events that have arrived at the server either as """Get all the new events that have arrived at the server either as
@ -1449,15 +1604,6 @@ class EventsStore(SQLBaseStore):
else: else:
upper_bound = current_forward_id upper_bound = current_forward_id
sql = (
"SELECT event_stream_ordering FROM current_state_resets"
" WHERE ? < event_stream_ordering"
" AND event_stream_ordering <= ?"
" ORDER BY event_stream_ordering ASC"
)
txn.execute(sql, (last_forward_id, upper_bound))
state_resets = txn.fetchall()
sql = ( sql = (
"SELECT event_stream_ordering, event_id, state_group" "SELECT event_stream_ordering, event_id, state_group"
" FROM ex_outlier_stream" " FROM ex_outlier_stream"
@ -1469,7 +1615,6 @@ class EventsStore(SQLBaseStore):
forward_ex_outliers = txn.fetchall() forward_ex_outliers = txn.fetchall()
else: else:
new_forward_events = [] new_forward_events = []
state_resets = []
forward_ex_outliers = [] forward_ex_outliers = []
sql = ( sql = (
@ -1509,7 +1654,6 @@ class EventsStore(SQLBaseStore):
return AllNewEventsResult( return AllNewEventsResult(
new_forward_events, new_backfill_events, new_forward_events, new_backfill_events,
forward_ex_outliers, backward_ex_outliers, forward_ex_outliers, backward_ex_outliers,
state_resets,
) )
return self.runInteraction("get_all_new_events", get_all_new_events_txn) return self.runInteraction("get_all_new_events", get_all_new_events_txn)
@ -1735,5 +1879,4 @@ class EventsStore(SQLBaseStore):
AllNewEventsResult = namedtuple("AllNewEventsResult", [ AllNewEventsResult = namedtuple("AllNewEventsResult", [
"new_forward_events", "new_backfill_events", "new_forward_events", "new_backfill_events",
"forward_ex_outliers", "backward_ex_outliers", "forward_ex_outliers", "backward_ex_outliers",
"state_resets"
]) ])

View file

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 39 SCHEMA_VERSION = 40
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))

View file

@ -413,6 +413,17 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
desc="user_delete_threepids", desc="user_delete_threepids",
) )
def user_delete_threepid(self, user_id, medium, address):
return self._simple_delete(
"user_threepids",
keyvalues={
"user_id": user_id,
"medium": medium,
"address": address,
},
desc="user_delete_threepids",
)
@defer.inlineCallbacks @defer.inlineCallbacks
def count_all_users(self): def count_all_users(self):
"""Counts all users registered on the homeserver.""" """Counts all users registered on the homeserver."""

View file

@ -66,8 +66,6 @@ class RoomMemberStore(SQLBaseStore):
) )
for event in events: for event in events:
txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
txn.call_after( txn.call_after(
self._membership_stream_cache.entity_has_changed, self._membership_stream_cache.entity_has_changed,
event.state_key, event.internal_metadata.stream_ordering event.state_key, event.internal_metadata.stream_ordering
@ -131,7 +129,7 @@ class RoomMemberStore(SQLBaseStore):
with self._stream_id_gen.get_next() as stream_ordering: with self._stream_id_gen.get_next() as stream_ordering:
yield self.runInteraction("locally_reject_invite", f, stream_ordering) yield self.runInteraction("locally_reject_invite", f, stream_ordering)
@cached(max_entries=5000) @cached(max_entries=500000, iterable=True)
def get_users_in_room(self, room_id): def get_users_in_room(self, room_id):
def f(txn): def f(txn):
@ -220,7 +218,7 @@ class RoomMemberStore(SQLBaseStore):
" ON e.event_id = c.event_id" " ON e.event_id = c.event_id"
" AND m.room_id = c.room_id" " AND m.room_id = c.room_id"
" AND m.user_id = c.state_key" " AND m.user_id = c.state_key"
" WHERE %s" " WHERE c.type = 'm.room.member' AND %s"
) % (where_clause,) ) % (where_clause,)
txn.execute(sql, args) txn.execute(sql, args)
@ -266,7 +264,7 @@ class RoomMemberStore(SQLBaseStore):
" ON m.event_id = c.event_id " " ON m.event_id = c.event_id "
" AND m.room_id = c.room_id " " AND m.room_id = c.room_id "
" AND m.user_id = c.state_key" " AND m.user_id = c.state_key"
" WHERE %(where)s" " WHERE c.type = 'm.room.member' AND %(where)s"
) % { ) % {
"where": where_clause, "where": where_clause,
} }
@ -276,12 +274,29 @@ class RoomMemberStore(SQLBaseStore):
return rows return rows
@cached(max_entries=5000) @cached(max_entries=500000, iterable=True)
def get_rooms_for_user(self, user_id): def get_rooms_for_user(self, user_id):
return self.get_rooms_for_user_where_membership_is( return self.get_rooms_for_user_where_membership_is(
user_id, membership_list=[Membership.JOIN], user_id, membership_list=[Membership.JOIN],
) )
@cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
def get_users_who_share_room_with_user(self, user_id, cache_context):
"""Returns the set of users who share a room with `user_id`
"""
rooms = yield self.get_rooms_for_user(
user_id, on_invalidate=cache_context.invalidate,
)
user_who_share_room = set()
for room in rooms:
user_ids = yield self.get_users_in_room(
room.room_id, on_invalidate=cache_context.invalidate,
)
user_who_share_room.update(user_ids)
defer.returnValue(user_who_share_room)
def forget(self, user_id, room_id): def forget(self, user_id, room_id):
"""Indicate that user_id wishes to discard history for room_id.""" """Indicate that user_id wishes to discard history for room_id."""
def f(txn): def f(txn):
@ -390,7 +405,8 @@ class RoomMemberStore(SQLBaseStore):
room_id, state_group, state_ids, room_id, state_group, state_ids,
) )
@cachedInlineCallbacks(num_args=2, cache_context=True) @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True,
max_entries=100000)
def _get_joined_users_from_context(self, room_id, state_group, current_state_ids, def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
cache_context, event=None): cache_context, event=None):
# We don't use `state_group`, it's there so that we can cache based # We don't use `state_group`, it's there so that we can cache based

View file

@ -0,0 +1,17 @@
/* Copyright 2017 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
INSERT INTO background_updates (update_name, progress_json) VALUES
('current_state_members_idx', '{}');

View file

@ -0,0 +1,21 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- turn the pre-fill startup query into a index-only scan on postgresql.
INSERT into background_updates (update_name, progress_json)
VALUES ('device_inbox_stream_index', '{}');
INSERT into background_updates (update_name, progress_json, depends_on)
VALUES ('device_inbox_stream_drop', '{}', 'device_inbox_stream_index');

View file

@ -0,0 +1,59 @@
/* Copyright 2017 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Cache of remote devices.
CREATE TABLE device_lists_remote_cache (
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
content TEXT NOT NULL
);
CREATE INDEX device_lists_remote_cache_id ON device_lists_remote_cache(user_id, device_id);
-- The last update we got for a user. Empty if we're not receiving updates for
-- that user.
CREATE TABLE device_lists_remote_extremeties (
user_id TEXT NOT NULL,
stream_id TEXT NOT NULL
);
CREATE INDEX device_lists_remote_extremeties_id ON device_lists_remote_extremeties(user_id, stream_id);
-- Stream of device lists updates. Includes both local and remotes
CREATE TABLE device_lists_stream (
stream_id BIGINT NOT NULL,
user_id TEXT NOT NULL,
device_id TEXT NOT NULL
);
CREATE INDEX device_lists_stream_id ON device_lists_stream(stream_id, user_id);
-- The stream of updates to send to other servers. We keep at least one row
-- per user that was sent so that the prev_id for any new updates can be
-- calculated
CREATE TABLE device_lists_outbound_pokes (
destination TEXT NOT NULL,
stream_id BIGINT NOT NULL,
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
sent BOOLEAN NOT NULL,
ts BIGINT NOT NULL -- So that in future we can clear out pokes to dead servers
);
CREATE INDEX device_lists_outbound_pokes_id ON device_lists_outbound_pokes(destination, stream_id);
CREATE INDEX device_lists_outbound_pokes_user ON device_lists_outbound_pokes(destination, user_id);

View file

@ -49,6 +49,7 @@ class StateStore(SQLBaseStore):
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
def __init__(self, hs): def __init__(self, hs):
super(StateStore, self).__init__(hs) super(StateStore, self).__init__(hs)
@ -60,6 +61,13 @@ class StateStore(SQLBaseStore):
self.STATE_GROUP_INDEX_UPDATE_NAME, self.STATE_GROUP_INDEX_UPDATE_NAME,
self._background_index_state, self._background_index_state,
) )
self.register_background_index_update(
self.CURRENT_STATE_INDEX_UPDATE_NAME,
index_name="current_state_events_member_index",
table="current_state_events",
columns=["state_key"],
where_clause="type='m.room.member'",
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_groups_ids(self, room_id, event_ids): def get_state_groups_ids(self, room_id, event_ids):
@ -232,59 +240,7 @@ class StateStore(SQLBaseStore):
return count return count
@defer.inlineCallbacks @cached(num_args=2, max_entries=100000, iterable=True)
def get_current_state(self, room_id, event_type=None, state_key=""):
if event_type and state_key is not None:
result = yield self.get_current_state_for_key(
room_id, event_type, state_key
)
defer.returnValue(result)
def f(txn):
sql = (
"SELECT event_id FROM current_state_events"
" WHERE room_id = ? "
)
if event_type and state_key is not None:
sql += " AND type = ? AND state_key = ? "
args = (room_id, event_type, state_key)
elif event_type:
sql += " AND type = ?"
args = (room_id, event_type)
else:
args = (room_id, )
txn.execute(sql, args)
results = txn.fetchall()
return [r[0] for r in results]
event_ids = yield self.runInteraction("get_current_state", f)
events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events)
@defer.inlineCallbacks
def get_current_state_for_key(self, room_id, event_type, state_key):
event_ids = yield self._get_current_state_for_key(room_id, event_type, state_key)
events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events)
@cached(num_args=3)
def _get_current_state_for_key(self, room_id, event_type, state_key):
def f(txn):
sql = (
"SELECT event_id FROM current_state_events"
" WHERE room_id = ? AND type = ? AND state_key = ?"
)
args = (room_id, event_type, state_key)
txn.execute(sql, args)
results = txn.fetchall()
return [r[0] for r in results]
return self.runInteraction("get_current_state_for_key", f)
@cached(num_args=2, max_entries=1000)
def _get_state_group_from_group(self, group, types): def _get_state_group_from_group(self, group, types):
raise NotImplementedError() raise NotImplementedError()

View file

@ -244,6 +244,20 @@ class StreamStore(SQLBaseStore):
defer.returnValue(results) defer.returnValue(results)
def get_rooms_that_changed(self, room_ids, from_key):
"""Given a list of rooms and a token, return rooms where there may have
been changes.
Args:
room_ids (list)
from_key (str): The room_key portion of a StreamToken
"""
from_key = RoomStreamToken.parse_stream_token(from_key).stream
return set(
room_id for room_id in room_ids
if self._events_stream_cache.has_entity_changed(room_id, from_key)
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0, def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0,
order='DESC'): order='DESC'):

View file

@ -44,6 +44,7 @@ class EventSources(object):
def get_current_token(self): def get_current_token(self):
push_rules_key, _ = self.store.get_push_rules_stream_token() push_rules_key, _ = self.store.get_push_rules_stream_token()
to_device_key = self.store.get_to_device_stream_token() to_device_key = self.store.get_to_device_stream_token()
device_list_key = self.store.get_device_stream_token()
token = StreamToken( token = StreamToken(
room_key=( room_key=(
@ -63,6 +64,7 @@ class EventSources(object):
), ),
push_rules_key=push_rules_key, push_rules_key=push_rules_key,
to_device_key=to_device_key, to_device_key=to_device_key,
device_list_key=device_list_key,
) )
defer.returnValue(token) defer.returnValue(token)
@ -70,6 +72,7 @@ class EventSources(object):
def get_current_token_for_room(self, room_id): def get_current_token_for_room(self, room_id):
push_rules_key, _ = self.store.get_push_rules_stream_token() push_rules_key, _ = self.store.get_push_rules_stream_token()
to_device_key = self.store.get_to_device_stream_token() to_device_key = self.store.get_to_device_stream_token()
device_list_key = self.store.get_device_stream_token()
token = StreamToken( token = StreamToken(
room_key=( room_key=(
@ -89,5 +92,6 @@ class EventSources(object):
), ),
push_rules_key=push_rules_key, push_rules_key=push_rules_key,
to_device_key=to_device_key, to_device_key=to_device_key,
device_list_key=device_list_key,
) )
defer.returnValue(token) defer.returnValue(token)

View file

@ -158,6 +158,7 @@ class StreamToken(
"account_data_key", "account_data_key",
"push_rules_key", "push_rules_key",
"to_device_key", "to_device_key",
"device_list_key",
)) ))
): ):
_SEPARATOR = "_" _SEPARATOR = "_"
@ -195,6 +196,7 @@ class StreamToken(
or (int(other.account_data_key) < int(self.account_data_key)) or (int(other.account_data_key) < int(self.account_data_key))
or (int(other.push_rules_key) < int(self.push_rules_key)) or (int(other.push_rules_key) < int(self.push_rules_key))
or (int(other.to_device_key) < int(self.to_device_key)) or (int(other.to_device_key) < int(self.to_device_key))
or (int(other.device_list_key) < int(self.device_list_key))
) )
def copy_and_advance(self, key, new_value): def copy_and_advance(self, key, new_value):

View file

@ -192,8 +192,11 @@ class Linearizer(object):
logger.info( logger.info(
"Waiting to acquire linearizer lock %r for key %r", self.name, key "Waiting to acquire linearizer lock %r for key %r", self.name, key
) )
try:
with PreserveLoggingContext(): with PreserveLoggingContext():
yield current_defer yield current_defer
except:
logger.exception("Unexpected exception in Linearizer")
logger.info("Acquired linearizer lock %r for key %r", self.name, key) logger.info("Acquired linearizer lock %r for key %r", self.name, key)

View file

@ -40,8 +40,8 @@ def register_cache(name, cache):
) )
_string_cache = LruCache(int(5000 * CACHE_SIZE_FACTOR)) _string_cache = LruCache(int(100000 * CACHE_SIZE_FACTOR))
caches_by_name["string_cache"] = _string_cache _stirng_cache_metrics = register_cache("string_cache", _string_cache)
KNOWN_KEYS = { KNOWN_KEYS = {
@ -69,7 +69,12 @@ KNOWN_KEYS = {
def intern_string(string): def intern_string(string):
"""Takes a (potentially) unicode string and interns using custom cache """Takes a (potentially) unicode string and interns using custom cache
""" """
return _string_cache.setdefault(string, string) new_str = _string_cache.setdefault(string, string)
if new_str is string:
_stirng_cache_metrics.inc_hits()
else:
_stirng_cache_metrics.inc_misses()
return new_str
def intern_dict(dictionary): def intern_dict(dictionary):

View file

@ -17,7 +17,7 @@ import logging
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
from synapse.util.logcontext import ( from synapse.util.logcontext import (
PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
) )
@ -42,6 +42,25 @@ _CacheSentinel = object()
CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
class CacheEntry(object):
__slots__ = [
"deferred", "sequence", "callbacks", "invalidated"
]
def __init__(self, deferred, sequence, callbacks):
self.deferred = deferred
self.sequence = sequence
self.callbacks = set(callbacks)
self.invalidated = False
def invalidate(self):
if not self.invalidated:
self.invalidated = True
for callback in self.callbacks:
callback()
self.callbacks.clear()
class Cache(object): class Cache(object):
__slots__ = ( __slots__ = (
"cache", "cache",
@ -51,12 +70,16 @@ class Cache(object):
"sequence", "sequence",
"thread", "thread",
"metrics", "metrics",
"_pending_deferred_cache",
) )
def __init__(self, name, max_entries=1000, keylen=1, tree=False): def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False):
cache_type = TreeCache if tree else dict cache_type = TreeCache if tree else dict
self._pending_deferred_cache = cache_type()
self.cache = LruCache( self.cache = LruCache(
max_size=max_entries, keylen=keylen, cache_type=cache_type max_size=max_entries, keylen=keylen, cache_type=cache_type,
size_callback=(lambda d: len(d.result)) if iterable else None,
) )
self.name = name self.name = name
@ -76,7 +99,15 @@ class Cache(object):
) )
def get(self, key, default=_CacheSentinel, callback=None): def get(self, key, default=_CacheSentinel, callback=None):
val = self.cache.get(key, _CacheSentinel, callback=callback) callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _CacheSentinel)
if val is not _CacheSentinel:
if val.sequence == self.sequence:
val.callbacks.update(callbacks)
self.metrics.inc_hits()
return val.deferred
val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
if val is not _CacheSentinel: if val is not _CacheSentinel:
self.metrics.inc_hits() self.metrics.inc_hits()
return val return val
@ -88,15 +119,39 @@ class Cache(object):
else: else:
return default return default
def update(self, sequence, key, value, callback=None): def set(self, key, value, callback=None):
callbacks = [callback] if callback else []
self.check_thread() self.check_thread()
if self.sequence == sequence: entry = CacheEntry(
# Only update the cache if the caches sequence number matches the deferred=value,
# number that the cache had before the SELECT was started (SYN-369) sequence=self.sequence,
self.prefill(key, value, callback=callback) callbacks=callbacks,
)
entry.callbacks.update(callbacks)
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry:
existing_entry.invalidate()
self._pending_deferred_cache[key] = entry
def shuffle(result):
if self.sequence == entry.sequence:
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry is entry:
self.cache.set(key, entry.deferred, entry.callbacks)
else:
entry.invalidate()
else:
entry.invalidate()
return result
entry.deferred.addCallback(shuffle)
def prefill(self, key, value, callback=None): def prefill(self, key, value, callback=None):
self.cache.set(key, value, callback=callback) callbacks = [callback] if callback else []
self.cache.set(key, value, callbacks=callbacks)
def invalidate(self, key): def invalidate(self, key):
self.check_thread() self.check_thread()
@ -108,6 +163,10 @@ class Cache(object):
# Increment the sequence number so that any SELECT statements that # Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369) # raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1 self.sequence += 1
entry = self._pending_deferred_cache.pop(key, None)
if entry:
entry.invalidate()
self.cache.pop(key, None) self.cache.pop(key, None)
def invalidate_many(self, key): def invalidate_many(self, key):
@ -119,6 +178,11 @@ class Cache(object):
self.sequence += 1 self.sequence += 1
self.cache.del_multi(key) self.cache.del_multi(key)
entry_dict = self._pending_deferred_cache.pop(key, None)
if entry_dict is not None:
for entry in iterate_tree_cache_entry(entry_dict):
entry.invalidate()
def invalidate_all(self): def invalidate_all(self):
self.check_thread() self.check_thread()
self.sequence += 1 self.sequence += 1
@ -155,7 +219,7 @@ class CacheDescriptor(object):
""" """
def __init__(self, orig, max_entries=1000, num_args=1, tree=False, def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
inlineCallbacks=False, cache_context=False): inlineCallbacks=False, cache_context=False, iterable=False):
max_entries = int(max_entries * CACHE_SIZE_FACTOR) max_entries = int(max_entries * CACHE_SIZE_FACTOR)
self.orig = orig self.orig = orig
@ -169,6 +233,8 @@ class CacheDescriptor(object):
self.num_args = num_args self.num_args = num_args
self.tree = tree self.tree = tree
self.iterable = iterable
all_args = inspect.getargspec(orig) all_args = inspect.getargspec(orig)
self.arg_names = all_args.args[1:num_args + 1] self.arg_names = all_args.args[1:num_args + 1]
@ -203,6 +269,7 @@ class CacheDescriptor(object):
max_entries=self.max_entries, max_entries=self.max_entries,
keylen=self.num_args, keylen=self.num_args,
tree=self.tree, tree=self.tree,
iterable=self.iterable,
) )
@functools.wraps(self.orig) @functools.wraps(self.orig)
@ -243,11 +310,6 @@ class CacheDescriptor(object):
return preserve_context_over_deferred(observer) return preserve_context_over_deferred(observer)
except KeyError: except KeyError:
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369)
sequence = cache.sequence
ret = defer.maybeDeferred( ret = defer.maybeDeferred(
preserve_context_over_fn, preserve_context_over_fn,
self.function_to_call, self.function_to_call,
@ -261,7 +323,7 @@ class CacheDescriptor(object):
ret.addErrback(onErr) ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=True) ret = ObservableDeferred(ret, consumeErrors=True)
cache.update(sequence, cache_key, ret, callback=invalidate_callback) cache.set(cache_key, ret, callback=invalidate_callback)
return preserve_context_over_deferred(ret.observe()) return preserve_context_over_deferred(ret.observe())
@ -359,7 +421,6 @@ class CacheListDescriptor(object):
missing.append(arg) missing.append(arg)
if missing: if missing:
sequence = cache.sequence
args_to_call = dict(arg_dict) args_to_call = dict(arg_dict)
args_to_call[self.list_name] = missing args_to_call[self.list_name] = missing
@ -382,8 +443,8 @@ class CacheListDescriptor(object):
key = list(keyargs) key = list(keyargs)
key[self.list_pos] = arg key[self.list_pos] = arg
cache.update( cache.set(
sequence, tuple(key), observer, tuple(key), observer,
callback=invalidate_callback callback=invalidate_callback
) )
@ -417,21 +478,29 @@ class CacheListDescriptor(object):
class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))): class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
# We rely on _CacheContext implementing __eq__ and __hash__ sensibly,
# which namedtuple does for us (i.e. two _CacheContext are the same if
# their caches and keys match). This is important in particular to
# dedupe when we add callbacks to lru cache nodes, otherwise the number
# of callbacks would grow.
def invalidate(self): def invalidate(self):
self.cache.invalidate(self.key) self.cache.invalidate(self.key)
def cached(max_entries=1000, num_args=1, tree=False, cache_context=False): def cached(max_entries=1000, num_args=1, tree=False, cache_context=False,
iterable=False):
return lambda orig: CacheDescriptor( return lambda orig: CacheDescriptor(
orig, orig,
max_entries=max_entries, max_entries=max_entries,
num_args=num_args, num_args=num_args,
tree=tree, tree=tree,
cache_context=cache_context, cache_context=cache_context,
iterable=iterable,
) )
def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False): def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False,
iterable=False):
return lambda orig: CacheDescriptor( return lambda orig: CacheDescriptor(
orig, orig,
max_entries=max_entries, max_entries=max_entries,
@ -439,6 +508,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex
tree=tree, tree=tree,
inlineCallbacks=True, inlineCallbacks=True,
cache_context=cache_context, cache_context=cache_context,
iterable=iterable,
) )

View file

@ -23,7 +23,9 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DictionaryEntry = namedtuple("DictionaryEntry", ("full", "value")) class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "value"))):
def __len__(self):
return len(self.value)
class DictionaryCache(object): class DictionaryCache(object):
@ -32,7 +34,7 @@ class DictionaryCache(object):
""" """
def __init__(self, name, max_entries=1000): def __init__(self, name, max_entries=1000):
self.cache = LruCache(max_size=max_entries) self.cache = LruCache(max_size=max_entries, size_callback=len)
self.name = name self.name = name
self.sequence = 0 self.sequence = 0

View file

@ -15,6 +15,7 @@
from synapse.util.caches import register_cache from synapse.util.caches import register_cache
from collections import OrderedDict
import logging import logging
@ -23,7 +24,7 @@ logger = logging.getLogger(__name__)
class ExpiringCache(object): class ExpiringCache(object):
def __init__(self, cache_name, clock, max_len=0, expiry_ms=0, def __init__(self, cache_name, clock, max_len=0, expiry_ms=0,
reset_expiry_on_get=False): reset_expiry_on_get=False, iterable=False):
""" """
Args: Args:
cache_name (str): Name of this cache, used for logging. cache_name (str): Name of this cache, used for logging.
@ -36,6 +37,8 @@ class ExpiringCache(object):
evicted based on time. evicted based on time.
reset_expiry_on_get (bool): If true, will reset the expiry time for reset_expiry_on_get (bool): If true, will reset the expiry time for
an item on access. Defaults to False. an item on access. Defaults to False.
iterable (bool): If true, the size is calculated by summing the
sizes of all entries, rather than the number of entries.
""" """
self._cache_name = cache_name self._cache_name = cache_name
@ -47,9 +50,13 @@ class ExpiringCache(object):
self._reset_expiry_on_get = reset_expiry_on_get self._reset_expiry_on_get = reset_expiry_on_get
self._cache = {} self._cache = OrderedDict()
self.metrics = register_cache(cache_name, self._cache) self.metrics = register_cache(cache_name, self)
self.iterable = iterable
self._size_estimate = 0
def start(self): def start(self):
if not self._expiry_ms: if not self._expiry_ms:
@ -65,15 +72,14 @@ class ExpiringCache(object):
now = self._clock.time_msec() now = self._clock.time_msec()
self._cache[key] = _CacheEntry(now, value) self._cache[key] = _CacheEntry(now, value)
# Evict if there are now too many items if self.iterable:
if self._max_len and len(self._cache.keys()) > self._max_len: self._size_estimate += len(value)
sorted_entries = sorted(
self._cache.items(),
key=lambda item: item[1].time,
)
for k, _ in sorted_entries[self._max_len:]: # Evict if there are now too many items
self._cache.pop(k) while self._max_len and len(self) > self._max_len:
_key, value = self._cache.popitem(last=False)
if self.iterable:
self._size_estimate -= len(value.value)
def __getitem__(self, key): def __getitem__(self, key):
try: try:
@ -99,7 +105,7 @@ class ExpiringCache(object):
# zero expiry time means don't expire. This should never get called # zero expiry time means don't expire. This should never get called
# since we have this check in start too. # since we have this check in start too.
return return
begin_length = len(self._cache) begin_length = len(self)
now = self._clock.time_msec() now = self._clock.time_msec()
@ -110,14 +116,19 @@ class ExpiringCache(object):
keys_to_delete.add(key) keys_to_delete.add(key)
for k in keys_to_delete: for k in keys_to_delete:
self._cache.pop(k) value = self._cache.pop(k)
if self.iterable:
self._size_estimate -= len(value.value)
logger.debug( logger.debug(
"[%s] _prune_cache before: %d, after len: %d", "[%s] _prune_cache before: %d, after len: %d",
self._cache_name, begin_length, len(self._cache) self._cache_name, begin_length, len(self)
) )
def __len__(self): def __len__(self):
if self.iterable:
return self._size_estimate
else:
return len(self._cache) return len(self._cache)

View file

@ -49,7 +49,7 @@ class LruCache(object):
Can also set callbacks on objects when getting/setting which are fired Can also set callbacks on objects when getting/setting which are fired
when that key gets invalidated/evicted. when that key gets invalidated/evicted.
""" """
def __init__(self, max_size, keylen=1, cache_type=dict): def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None):
cache = cache_type() cache = cache_type()
self.cache = cache # Used for introspection. self.cache = cache # Used for introspection.
list_root = _Node(None, None, None, None) list_root = _Node(None, None, None, None)
@ -58,6 +58,12 @@ class LruCache(object):
lock = threading.Lock() lock = threading.Lock()
def evict():
while cache_len() > max_size:
todelete = list_root.prev_node
delete_node(todelete)
cache.pop(todelete.key, None)
def synchronized(f): def synchronized(f):
@wraps(f) @wraps(f)
def inner(*args, **kwargs): def inner(*args, **kwargs):
@ -66,6 +72,16 @@ class LruCache(object):
return inner return inner
cached_cache_len = [0]
if size_callback is not None:
def cache_len():
return cached_cache_len[0]
else:
def cache_len():
return len(cache)
self.len = synchronized(cache_len)
def add_node(key, value, callbacks=set()): def add_node(key, value, callbacks=set()):
prev_node = list_root prev_node = list_root
next_node = prev_node.next_node next_node = prev_node.next_node
@ -74,6 +90,9 @@ class LruCache(object):
next_node.prev_node = node next_node.prev_node = node
cache[key] = node cache[key] = node
if size_callback:
cached_cache_len[0] += size_callback(node.value)
def move_node_to_front(node): def move_node_to_front(node):
prev_node = node.prev_node prev_node = node.prev_node
next_node = node.next_node next_node = node.next_node
@ -92,23 +111,25 @@ class LruCache(object):
prev_node.next_node = next_node prev_node.next_node = next_node
next_node.prev_node = prev_node next_node.prev_node = prev_node
if size_callback:
cached_cache_len[0] -= size_callback(node.value)
for cb in node.callbacks: for cb in node.callbacks:
cb() cb()
node.callbacks.clear() node.callbacks.clear()
@synchronized @synchronized
def cache_get(key, default=None, callback=None): def cache_get(key, default=None, callbacks=[]):
node = cache.get(key, None) node = cache.get(key, None)
if node is not None: if node is not None:
move_node_to_front(node) move_node_to_front(node)
if callback: node.callbacks.update(callbacks)
node.callbacks.add(callback)
return node.value return node.value
else: else:
return default return default
@synchronized @synchronized
def cache_set(key, value, callback=None): def cache_set(key, value, callbacks=[]):
node = cache.get(key, None) node = cache.get(key, None)
if node is not None: if node is not None:
if value != node.value: if value != node.value:
@ -116,21 +137,18 @@ class LruCache(object):
cb() cb()
node.callbacks.clear() node.callbacks.clear()
if callback: if size_callback:
node.callbacks.add(callback) cached_cache_len[0] -= size_callback(node.value)
cached_cache_len[0] += size_callback(value)
node.callbacks.update(callbacks)
move_node_to_front(node) move_node_to_front(node)
node.value = value node.value = value
else: else:
if callback: add_node(key, value, set(callbacks))
callbacks = set([callback])
else: evict()
callbacks = set()
add_node(key, value, callbacks)
if len(cache) > max_size:
todelete = list_root.prev_node
delete_node(todelete)
cache.pop(todelete.key, None)
@synchronized @synchronized
def cache_set_default(key, value): def cache_set_default(key, value):
@ -139,10 +157,7 @@ class LruCache(object):
return node.value return node.value
else: else:
add_node(key, value) add_node(key, value)
if len(cache) > max_size: evict()
todelete = list_root.prev_node
delete_node(todelete)
cache.pop(todelete.key, None)
return value return value
@synchronized @synchronized
@ -174,10 +189,8 @@ class LruCache(object):
for cb in node.callbacks: for cb in node.callbacks:
cb() cb()
cache.clear() cache.clear()
if size_callback:
@synchronized cached_cache_len[0] = 0
def cache_len():
return len(cache)
@synchronized @synchronized
def cache_contains(key): def cache_contains(key):
@ -190,7 +203,7 @@ class LruCache(object):
self.pop = cache_pop self.pop = cache_pop
if cache_type is TreeCache: if cache_type is TreeCache:
self.del_multi = cache_del_multi self.del_multi = cache_del_multi
self.len = cache_len self.len = synchronized(cache_len)
self.contains = cache_contains self.contains = cache_contains
self.clear = cache_clear self.clear = cache_clear

View file

@ -65,12 +65,27 @@ class TreeCache(object):
return popped return popped
def values(self): def values(self):
return [e.value for e in self.root.values()] return list(iterate_tree_cache_entry(self.root))
def __len__(self): def __len__(self):
return self.size return self.size
def iterate_tree_cache_entry(d):
"""Helper function to iterate over the leaves of a tree, i.e. a dict of that
can contain dicts.
"""
if isinstance(d, dict):
for value_d in d.itervalues():
for value in iterate_tree_cache_entry(value_d):
yield value
else:
if isinstance(d, _Entry):
yield d.value
else:
yield d
class _Entry(object): class _Entry(object):
__slots__ = ["value"] __slots__ = ["value"]

View file

@ -1,71 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer, reactor
from functools import wraps
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
def debug_deferreds():
"""Cause all deferreds to wait for a reactor tick before running their
callbacks. This increases the chance of getting a stack trace out of
a defer.inlineCallback since the code waiting on the deferred will get
a chance to add an errback before the deferred runs."""
# Helper method for retrieving and restoring the current logging context
# around a callback.
def with_logging_context(fn):
context = LoggingContext.current_context()
def restore_context_callback(x):
with PreserveLoggingContext(context):
return fn(x)
return restore_context_callback
# We are going to modify the __init__ method of defer.Deferred so we
# need to get a copy of the old method so we can still call it.
old__init__ = defer.Deferred.__init__
# We need to create a deferred to bounce the callbacks through the reactor
# but we don't want to add a callback when we create that deferred so we
# we create a new type of deferred that uses the old __init__ method.
# This is safe as long as the old __init__ method doesn't invoke an
# __init__ using super.
class Bouncer(defer.Deferred):
__init__ = old__init__
# We'll add this as a callback to all Deferreds. Twisted will wait until
# the bouncer deferred resolves before calling the callbacks of the
# original deferred.
def bounce_callback(x):
bouncer = Bouncer()
reactor.callLater(0, with_logging_context(bouncer.callback), x)
return bouncer
# We'll add this as an errback to all Deferreds. Twisted will wait until
# the bouncer deferred resolves before calling the errbacks of the
# original deferred.
def bounce_errback(x):
bouncer = Bouncer()
reactor.callLater(0, with_logging_context(bouncer.errback), x)
return bouncer
@wraps(old__init__)
def new__init__(self, *args, **kargs):
old__init__(self, *args, **kargs)
self.addCallbacks(bounce_callback, bounce_errback)
defer.Deferred.__init__ = new__init__

View file

@ -88,7 +88,7 @@ class RetryDestinationLimiter(object):
def __init__(self, destination, clock, store, retry_interval, def __init__(self, destination, clock, store, retry_interval,
min_retry_interval=10 * 60 * 1000, min_retry_interval=10 * 60 * 1000,
max_retry_interval=24 * 60 * 60 * 1000, max_retry_interval=24 * 60 * 60 * 1000,
multiplier_retry_interval=5,): multiplier_retry_interval=5, backoff_on_404=False):
"""Marks the destination as "down" if an exception is thrown in the """Marks the destination as "down" if an exception is thrown in the
context, except for CodeMessageException with code < 500. context, except for CodeMessageException with code < 500.
@ -107,6 +107,7 @@ class RetryDestinationLimiter(object):
a failed request, in milliseconds. a failed request, in milliseconds.
multiplier_retry_interval (int): The multiplier to use to increase multiplier_retry_interval (int): The multiplier to use to increase
the retry interval after a failed request. the retry interval after a failed request.
backoff_on_404 (bool): Back off if we get a 404
""" """
self.clock = clock self.clock = clock
self.store = store self.store = store
@ -116,6 +117,7 @@ class RetryDestinationLimiter(object):
self.min_retry_interval = min_retry_interval self.min_retry_interval = min_retry_interval
self.max_retry_interval = max_retry_interval self.max_retry_interval = max_retry_interval
self.multiplier_retry_interval = multiplier_retry_interval self.multiplier_retry_interval = multiplier_retry_interval
self.backoff_on_404 = backoff_on_404
def __enter__(self): def __enter__(self):
pass pass
@ -123,7 +125,22 @@ class RetryDestinationLimiter(object):
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
valid_err_code = False valid_err_code = False
if exc_type is not None and issubclass(exc_type, CodeMessageException): if exc_type is not None and issubclass(exc_type, CodeMessageException):
valid_err_code = exc_val.code != 429 and 0 <= exc_val.code < 500 # Some error codes are perfectly fine for some APIs, whereas other
# APIs may expect to never received e.g. a 404. It's important to
# handle 404 as some remote servers will return a 404 when the HS
# has been decommissioned.
# If we get a 401, then we should probably back off since they
# won't accept our requests for at least a while.
# 429 is us being aggresively rate limited, so lets rate limit
# ourselves.
if exc_val.code == 404 and self.backoff_on_404:
valid_err_code = False
elif exc_val.code in (401, 429):
valid_err_code = False
elif exc_val.code < 500:
valid_err_code = True
else:
valid_err_code = False
if exc_type is None or valid_err_code: if exc_type is None or valid_err_code:
# We connected successfully. # We connected successfully.

View file

@ -25,10 +25,13 @@ from synapse.api.filtering import Filter
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
user_localpart = "test_user" user_localpart = "test_user"
# MockEvent = namedtuple("MockEvent", "sender type room_id")
def MockEvent(**kwargs): def MockEvent(**kwargs):
if "event_id" not in kwargs:
kwargs["event_id"] = "fake_event_id"
if "type" not in kwargs:
kwargs["type"] = "fake_type"
return FrozenEvent(kwargs) return FrozenEvent(kwargs)

View file

@ -21,6 +21,10 @@ from synapse.events.utils import prune_event, serialize_event
def MockEvent(**kwargs): def MockEvent(**kwargs):
if "event_id" not in kwargs:
kwargs["event_id"] = "fake_event_id"
if "type" not in kwargs:
kwargs["type"] = "fake_type"
return FrozenEvent(kwargs) return FrozenEvent(kwargs)
@ -35,9 +39,13 @@ class PruneEventTestCase(unittest.TestCase):
def test_minimal(self): def test_minimal(self):
self.run_test( self.run_test(
{'type': 'A'},
{ {
'type': 'A', 'type': 'A',
'event_id': '$test:domain',
},
{
'type': 'A',
'event_id': '$test:domain',
'content': {}, 'content': {},
'signatures': {}, 'signatures': {},
'unsigned': {}, 'unsigned': {},
@ -69,10 +77,12 @@ class PruneEventTestCase(unittest.TestCase):
self.run_test( self.run_test(
{ {
'type': 'B', 'type': 'B',
'event_id': '$test:domain',
'unsigned': {'age_ts': 20}, 'unsigned': {'age_ts': 20},
}, },
{ {
'type': 'B', 'type': 'B',
'event_id': '$test:domain',
'content': {}, 'content': {},
'signatures': {}, 'signatures': {},
'unsigned': {'age_ts': 20}, 'unsigned': {'age_ts': 20},
@ -82,10 +92,12 @@ class PruneEventTestCase(unittest.TestCase):
self.run_test( self.run_test(
{ {
'type': 'B', 'type': 'B',
'event_id': '$test:domain',
'unsigned': {'other_key': 'here'}, 'unsigned': {'other_key': 'here'},
}, },
{ {
'type': 'B', 'type': 'B',
'event_id': '$test:domain',
'content': {}, 'content': {},
'signatures': {}, 'signatures': {},
'unsigned': {}, 'unsigned': {},
@ -96,10 +108,12 @@ class PruneEventTestCase(unittest.TestCase):
self.run_test( self.run_test(
{ {
'type': 'C', 'type': 'C',
'event_id': '$test:domain',
'content': {'things': 'here'}, 'content': {'things': 'here'},
}, },
{ {
'type': 'C', 'type': 'C',
'event_id': '$test:domain',
'content': {}, 'content': {},
'signatures': {}, 'signatures': {},
'unsigned': {}, 'unsigned': {},
@ -109,10 +123,12 @@ class PruneEventTestCase(unittest.TestCase):
self.run_test( self.run_test(
{ {
'type': 'm.room.create', 'type': 'm.room.create',
'event_id': '$test:domain',
'content': {'creator': '@2:domain', 'other_field': 'here'}, 'content': {'creator': '@2:domain', 'other_field': 'here'},
}, },
{ {
'type': 'm.room.create', 'type': 'm.room.create',
'event_id': '$test:domain',
'content': {'creator': '@2:domain'}, 'content': {'creator': '@2:domain'},
'signatures': {}, 'signatures': {},
'unsigned': {}, 'unsigned': {},
@ -255,6 +271,8 @@ class SerializeEventTestCase(unittest.TestCase):
self.assertEquals( self.assertEquals(
self.serialize( self.serialize(
MockEvent( MockEvent(
type="foo",
event_id="test",
room_id="!foo:bar", room_id="!foo:bar",
content={ content={
"foo": "bar", "foo": "bar",
@ -263,6 +281,8 @@ class SerializeEventTestCase(unittest.TestCase):
[] []
), ),
{ {
"type": "foo",
"event_id": "test",
"room_id": "!foo:bar", "room_id": "!foo:bar",
"content": { "content": {
"foo": "bar", "foo": "bar",

View file

@ -34,11 +34,10 @@ class AuthTestCase(unittest.TestCase):
self.hs = yield setup_test_homeserver(handlers=None) self.hs = yield setup_test_homeserver(handlers=None)
self.hs.handlers = AuthHandlers(self.hs) self.hs.handlers = AuthHandlers(self.hs)
self.auth_handler = self.hs.handlers.auth_handler self.auth_handler = self.hs.handlers.auth_handler
self.macaroon_generator = self.hs.get_macaroon_generator()
def test_token_is_a_macaroon(self): def test_token_is_a_macaroon(self):
self.hs.config.macaroon_secret_key = "this key is a huge secret" token = self.macaroon_generator.generate_access_token("some_user")
token = self.auth_handler.generate_access_token("some_user")
# Check that we can parse the thing with pymacaroons # Check that we can parse the thing with pymacaroons
macaroon = pymacaroons.Macaroon.deserialize(token) macaroon = pymacaroons.Macaroon.deserialize(token)
# The most basic of sanity checks # The most basic of sanity checks
@ -46,10 +45,9 @@ class AuthTestCase(unittest.TestCase):
self.fail("some_user was not in %s" % macaroon.inspect()) self.fail("some_user was not in %s" % macaroon.inspect())
def test_macaroon_caveats(self): def test_macaroon_caveats(self):
self.hs.config.macaroon_secret_key = "this key is a massive secret"
self.hs.clock.now = 5000 self.hs.clock.now = 5000
token = self.auth_handler.generate_access_token("a_user") token = self.macaroon_generator.generate_access_token("a_user")
macaroon = pymacaroons.Macaroon.deserialize(token) macaroon = pymacaroons.Macaroon.deserialize(token)
def verify_gen(caveat): def verify_gen(caveat):
@ -74,7 +72,7 @@ class AuthTestCase(unittest.TestCase):
def test_short_term_login_token_gives_user_id(self): def test_short_term_login_token_gives_user_id(self):
self.hs.clock.now = 1000 self.hs.clock.now = 1000
token = self.auth_handler.generate_short_term_login_token( token = self.macaroon_generator.generate_short_term_login_token(
"a_user", 5000 "a_user", 5000
) )
@ -93,7 +91,7 @@ class AuthTestCase(unittest.TestCase):
) )
def test_short_term_login_token_cannot_replace_user_id(self): def test_short_term_login_token_cannot_replace_user_id(self):
token = self.auth_handler.generate_short_term_login_token( token = self.macaroon_generator.generate_short_term_login_token(
"a_user", 5000 "a_user", 5000
) )
macaroon = pymacaroons.Macaroon.deserialize(token) macaroon = pymacaroons.Macaroon.deserialize(token)

View file

@ -35,51 +35,51 @@ class DeviceTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
hs = yield utils.setup_test_homeserver(handlers=None) hs = yield utils.setup_test_homeserver()
self.handler = synapse.handlers.device.DeviceHandler(hs) self.handler = hs.get_device_handler()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.clock = hs.get_clock() self.clock = hs.get_clock()
@defer.inlineCallbacks @defer.inlineCallbacks
def test_device_is_created_if_doesnt_exist(self): def test_device_is_created_if_doesnt_exist(self):
res = yield self.handler.check_device_registered( res = yield self.handler.check_device_registered(
user_id="boris", user_id="@boris:foo",
device_id="fco", device_id="fco",
initial_device_display_name="display name" initial_device_display_name="display name"
) )
self.assertEqual(res, "fco") self.assertEqual(res, "fco")
dev = yield self.handler.store.get_device("boris", "fco") dev = yield self.handler.store.get_device("@boris:foo", "fco")
self.assertEqual(dev["display_name"], "display name") self.assertEqual(dev["display_name"], "display name")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_device_is_preserved_if_exists(self): def test_device_is_preserved_if_exists(self):
res1 = yield self.handler.check_device_registered( res1 = yield self.handler.check_device_registered(
user_id="boris", user_id="@boris:foo",
device_id="fco", device_id="fco",
initial_device_display_name="display name" initial_device_display_name="display name"
) )
self.assertEqual(res1, "fco") self.assertEqual(res1, "fco")
res2 = yield self.handler.check_device_registered( res2 = yield self.handler.check_device_registered(
user_id="boris", user_id="@boris:foo",
device_id="fco", device_id="fco",
initial_device_display_name="new display name" initial_device_display_name="new display name"
) )
self.assertEqual(res2, "fco") self.assertEqual(res2, "fco")
dev = yield self.handler.store.get_device("boris", "fco") dev = yield self.handler.store.get_device("@boris:foo", "fco")
self.assertEqual(dev["display_name"], "display name") self.assertEqual(dev["display_name"], "display name")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_device_id_is_made_up_if_unspecified(self): def test_device_id_is_made_up_if_unspecified(self):
device_id = yield self.handler.check_device_registered( device_id = yield self.handler.check_device_registered(
user_id="theresa", user_id="@theresa:foo",
device_id=None, device_id=None,
initial_device_display_name="display" initial_device_display_name="display"
) )
dev = yield self.handler.store.get_device("theresa", device_id) dev = yield self.handler.store.get_device("@theresa:foo", device_id)
self.assertEqual(dev["display_name"], "display") self.assertEqual(dev["display_name"], "display")
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -37,6 +37,7 @@ class DirectoryTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.mock_federation = Mock(spec=[ self.mock_federation = Mock(spec=[
"make_query", "make_query",
"register_edu_handler",
]) ])
self.query_handlers = {} self.query_handlers = {}

View file

@ -39,6 +39,7 @@ class ProfileTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.mock_federation = Mock(spec=[ self.mock_federation = Mock(spec=[
"make_query", "make_query",
"register_edu_handler",
]) ])
self.query_handlers = {} self.query_handlers = {}

View file

@ -41,15 +41,12 @@ class RegistrationTestCase(unittest.TestCase):
handlers=None, handlers=None,
http_client=None, http_client=None,
expire_access_token=True) expire_access_token=True)
self.auth_handler = Mock( self.macaroon_generator = Mock(
generate_access_token=Mock(return_value='secret')) generate_access_token=Mock(return_value='secret'))
self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator)
self.hs.handlers = RegistrationHandlers(self.hs) self.hs.handlers = RegistrationHandlers(self.hs)
self.handler = self.hs.get_handlers().registration_handler self.handler = self.hs.get_handlers().registration_handler
self.hs.get_handlers().profile_handler = Mock() self.hs.get_handlers().profile_handler = Mock()
self.mock_handler = Mock(spec=[
"generate_access_token",
])
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_user_is_created_and_logged_in_if_doesnt_exist(self): def test_user_is_created_and_logged_in_if_doesnt_exist(self):

View file

@ -75,6 +75,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
"get_received_txn_response", "get_received_txn_response",
"set_received_txn_response", "set_received_txn_response",
"get_destination_retry_timings", "get_destination_retry_timings",
"get_devices_by_remote",
]), ]),
state_handler=self.state_handler, state_handler=self.state_handler,
handlers=None, handlers=None,
@ -99,6 +100,8 @@ class TypingNotificationsTestCase(unittest.TestCase):
defer.succeed(retry_timings_res) defer.succeed(retry_timings_res)
) )
self.datastore.get_devices_by_remote.return_value = (0, [])
def get_received_txn_response(*args): def get_received_txn_response(*args):
return defer.succeed(None) return defer.succeed(None)
self.datastore.get_received_txn_response = get_received_txn_response self.datastore.get_received_txn_response = get_received_txn_response

View file

@ -58,53 +58,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
def tearDown(self): def tearDown(self):
[unpatch() for unpatch in self.unpatches] [unpatch() for unpatch in self.unpatches]
@defer.inlineCallbacks
def test_room_members(self):
create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.replicate()
yield self.check("get_rooms_for_user", (USER_ID,), [])
yield self.check("get_users_in_room", (ROOM_ID,), [])
# Join the room.
join = yield self.persist(type="m.room.member", key=USER_ID, membership="join")
yield self.replicate()
yield self.check("get_rooms_for_user", (USER_ID,), [RoomsForUser(
room_id=ROOM_ID,
sender=USER_ID,
membership="join",
event_id=join.event_id,
stream_ordering=join.internal_metadata.stream_ordering,
)])
yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID])
# Leave the room.
yield self.persist(type="m.room.member", key=USER_ID, membership="leave")
yield self.replicate()
yield self.check("get_rooms_for_user", (USER_ID,), [])
yield self.check("get_users_in_room", (ROOM_ID,), [])
# Add some other user to the room.
join = yield self.persist(type="m.room.member", key=USER_ID_2, membership="join")
yield self.replicate()
yield self.check("get_rooms_for_user", (USER_ID_2,), [RoomsForUser(
room_id=ROOM_ID,
sender=USER_ID,
membership="join",
event_id=join.event_id,
stream_ordering=join.internal_metadata.stream_ordering,
)])
yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID_2])
# Join the room clobbering the state.
# This should remove any evidence of the other user being in the room.
yield self.persist(
type="m.room.member", key=USER_ID, membership="join",
reset_state=[create]
)
yield self.replicate()
yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID])
yield self.check("get_rooms_for_user", (USER_ID_2,), [])
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_latest_event_ids_in_room(self): def test_get_latest_event_ids_in_room(self):
create = yield self.persist(type="m.room.create", key="", creator=USER_ID) create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
@ -122,51 +75,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
"get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id] "get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id]
) )
@defer.inlineCallbacks
def test_get_current_state(self):
# Create the room.
create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.replicate()
yield self.check(
"get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID), []
)
# Join the room.
join1 = yield self.persist(
type="m.room.member", key=USER_ID, membership="join",
)
yield self.replicate()
yield self.check(
"get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID),
[join1]
)
# Add some other user to the room.
join2 = yield self.persist(
type="m.room.member", key=USER_ID_2, membership="join",
)
yield self.replicate()
yield self.check(
"get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID_2),
[join2]
)
# Leave the room, then rejoin the room clobbering state.
yield self.persist(type="m.room.member", key=USER_ID, membership="leave")
join3 = yield self.persist(
type="m.room.member", key=USER_ID, membership="join",
reset_state=[create]
)
yield self.replicate()
yield self.check(
"get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID_2),
[]
)
yield self.check(
"get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID),
[join3]
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_redactions(self): def test_redactions(self):
yield self.persist(type="m.room.create", key="", creator=USER_ID) yield self.persist(type="m.room.create", key="", creator=USER_ID)
@ -283,6 +191,12 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
if depth is None: if depth is None:
depth = self.event_id depth = self.event_id
if not prev_events:
latest_event_ids = yield self.master_store.get_latest_event_ids_in_room(
room_id
)
prev_events = [(ev_id, {}) for ev_id in latest_event_ids]
event_dict = { event_dict = {
"sender": sender, "sender": sender,
"type": type, "type": type,
@ -309,12 +223,15 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
state_ids = { state_ids = {
key: e.event_id for key, e in state.items() key: e.event_id for key, e in state.items()
} }
else:
state_ids = None
context = EventContext() context = EventContext()
context.current_state_ids = state_ids context.current_state_ids = state_ids
context.prev_state_ids = state_ids context.prev_state_ids = state_ids
elif not backfill:
state_handler = self.hs.get_state_handler()
context = yield state_handler.compute_event_context(event)
else:
context = EventContext()
context.push_actions = push_actions context.push_actions = push_actions
ordering = None ordering = None
@ -324,7 +241,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
) )
else: else:
ordering, _ = yield self.master_store.persist_event( ordering, _ = yield self.master_store.persist_event(
event, context, current_state=reset_state event, context,
) )
if ordering: if ordering:

View file

@ -259,8 +259,8 @@ class RoomPermissionsTestCase(RestTestCase):
# set [invite/join/left] of self, set [invite/join/left] of other, # set [invite/join/left] of self, set [invite/join/left] of other,
# expect all 404s because room doesn't exist on any server # expect all 404s because room doesn't exist on any server
for usr in [self.user_id, self.rmcreator_id]: for usr in [self.user_id, self.rmcreator_id]:
yield self.join(room=room, user=usr, expect_code=403) yield self.join(room=room, user=usr, expect_code=404)
yield self.leave(room=room, user=usr, expect_code=403) yield self.leave(room=room, user=usr, expect_code=404)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_membership_private_room_perms(self): def test_membership_private_room_perms(self):
@ -1032,7 +1032,7 @@ class RoomMessageListTestCase(RestTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_topo_token_is_accepted(self): def test_topo_token_is_accepted(self):
token = "t1-0_0_0_0_0_0_0" token = "t1-0_0_0_0_0_0_0_0"
(code, response) = yield self.mock_resource.trigger_get( (code, response) = yield self.mock_resource.trigger_get(
"/rooms/%s/messages?access_token=x&from=%s" % "/rooms/%s/messages?access_token=x&from=%s" %
(self.room_id, token)) (self.room_id, token))
@ -1044,7 +1044,7 @@ class RoomMessageListTestCase(RestTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_stream_token_is_accepted_for_fwd_pagianation(self): def test_stream_token_is_accepted_for_fwd_pagianation(self):
token = "s0_0_0_0_0_0_0" token = "s0_0_0_0_0_0_0_0"
(code, response) = yield self.mock_resource.trigger_get( (code, response) = yield self.mock_resource.trigger_get(
"/rooms/%s/messages?access_token=x&from=%s" % "/rooms/%s/messages?access_token=x&from=%s" %
(self.room_id, token)) (self.room_id, token))

View file

@ -87,7 +87,10 @@ class RestTestCase(unittest.TestCase):
(code, response) = yield self.mock_resource.trigger( (code, response) = yield self.mock_resource.trigger(
"PUT", path, json.dumps(data) "PUT", path, json.dumps(data)
) )
self.assertEquals(expect_code, code, msg=str(response)) self.assertEquals(
expect_code, code,
msg="Expected: %d, got: %d, resp: %r" % (expect_code, code, response)
)
self.auth_user_id = temp_id self.auth_user_id = temp_id

View file

@ -241,7 +241,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
callcount2 = [0] callcount2 = [0]
class A(object): class A(object):
@cached(max_entries=2) @cached(max_entries=20) # HACK: This makes it 2 due to cache factor
def func(self, key): def func(self, key):
callcount[0] += 1 callcount[0] += 1
return key return key
@ -258,6 +258,10 @@ class CacheDecoratorTestCase(unittest.TestCase):
self.assertEquals(callcount[0], 2) self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2) self.assertEquals(callcount2[0], 2)
yield a.func2("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
yield a.func("foo3") yield a.func("foo3")
self.assertEquals(callcount[0], 3) self.assertEquals(callcount[0], 3)

View file

@ -39,7 +39,11 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
event_cache_size=1, event_cache_size=1,
password_providers=[], password_providers=[],
) )
hs = yield setup_test_homeserver(config=config, federation_sender=Mock()) hs = yield setup_test_homeserver(
config=config,
federation_sender=Mock(),
replication_layer=Mock(),
)
self.as_token = "token1" self.as_token = "token1"
self.as_url = "some_url" self.as_url = "some_url"
@ -112,7 +116,11 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
event_cache_size=1, event_cache_size=1,
password_providers=[], password_providers=[],
) )
hs = yield setup_test_homeserver(config=config, federation_sender=Mock()) hs = yield setup_test_homeserver(
config=config,
federation_sender=Mock(),
replication_layer=Mock(),
)
self.db_pool = hs.get_db_pool() self.db_pool = hs.get_db_pool()
self.as_list = [ self.as_list = [
@ -446,7 +454,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
config=config, config=config,
datastore=Mock(), datastore=Mock(),
federation_sender=Mock() federation_sender=Mock(),
replication_layer=Mock(),
) )
ApplicationServiceStore(hs) ApplicationServiceStore(hs)
@ -463,7 +472,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
config=config, config=config,
datastore=Mock(), datastore=Mock(),
federation_sender=Mock() federation_sender=Mock(),
replication_layer=Mock(),
) )
with self.assertRaises(ConfigError) as cm: with self.assertRaises(ConfigError) as cm:
@ -486,7 +496,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
config=config, config=config,
datastore=Mock(), datastore=Mock(),
federation_sender=Mock() federation_sender=Mock(),
replication_layer=Mock(),
) )
with self.assertRaises(ConfigError) as cm: with self.assertRaises(ConfigError) as cm:

View file

@ -33,7 +33,11 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_key_without_device_name(self): def test_key_without_device_name(self):
now = 1470174257070 now = 1470174257070
json = '{ "key": "value" }' json = {"key": "value"}
yield self.store.store_device(
"user", "device", None
)
yield self.store.set_e2e_device_keys( yield self.store.set_e2e_device_keys(
"user", "device", now, json) "user", "device", now, json)
@ -43,14 +47,14 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
self.assertIn("device", res["user"]) self.assertIn("device", res["user"])
dev = res["user"]["device"] dev = res["user"]["device"]
self.assertDictContainsSubset({ self.assertDictContainsSubset({
"key_json": json, "keys": json,
"device_display_name": None, "device_display_name": None,
}, dev) }, dev)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_key_with_device_name(self): def test_get_key_with_device_name(self):
now = 1470174257070 now = 1470174257070
json = '{ "key": "value" }' json = {"key": "value"}
yield self.store.set_e2e_device_keys( yield self.store.set_e2e_device_keys(
"user", "device", now, json) "user", "device", now, json)
@ -63,7 +67,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
self.assertIn("device", res["user"]) self.assertIn("device", res["user"])
dev = res["user"]["device"] dev = res["user"]["device"]
self.assertDictContainsSubset({ self.assertDictContainsSubset({
"key_json": json, "keys": json,
"device_display_name": "display_name", "device_display_name": "display_name",
}, dev) }, dev)
@ -71,6 +75,19 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
def test_multiple_devices(self): def test_multiple_devices(self):
now = 1470174257070 now = 1470174257070
yield self.store.store_device(
"user1", "device1", None
)
yield self.store.store_device(
"user1", "device2", None
)
yield self.store.store_device(
"user2", "device1", None
)
yield self.store.store_device(
"user2", "device2", None
)
yield self.store.set_e2e_device_keys( yield self.store.set_e2e_device_keys(
"user1", "device1", now, 'json11') "user1", "device1", now, 'json11')
yield self.store.set_e2e_device_keys( yield self.store.set_e2e_device_keys(

Some files were not shown because too many files have changed in this diff Show more