0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 21:03:51 +01:00

Merge branch 'release-v0.19.0' of github.com:matrix-org/synapse

This commit is contained in:
Erik Johnston 2017-02-04 08:28:56 +00:00
commit fad3a84335
102 changed files with 3614 additions and 1783 deletions

View file

@ -1,3 +1,78 @@
Changes in synapse v0.19.0 (2017-02-04)
=======================================
No changes since RC 4.
Changes in synapse v0.19.0-rc4 (2017-02-02)
===========================================
* Bump cache sizes for common membership queries (PR #1879)
Changes in synapse v0.19.0-rc3 (2017-02-02)
===========================================
* Fix email push in pusher worker (PR #1875)
* Make presence.get_new_events a bit faster (PR #1876)
* Make /keys/changes a bit more performant (PR #1877)
Changes in synapse v0.19.0-rc2 (2017-02-02)
===========================================
* Include newly joined users in /keys/changes API (PR #1872)
Changes in synapse v0.19.0-rc1 (2017-02-02)
===========================================
Features:
* Add support for specifying multiple bind addresses (PR #1709, #1712, #1795,
#1835). Thanks to @kyrias!
* Add /account/3pid/delete endpoint (PR #1714)
* Add config option to configure the Riot URL used in notification emails (PR
#1811). Thanks to @aperezdc!
* Add username and password config options for turn server (PR #1832). Thanks
to @xsteadfastx!
* Implement device lists updates over federation (PR #1857, #1861, #1864)
* Implement /keys/changes (PR #1869, #1872)
Changes:
* Improve IPv6 support (PR #1696). Thanks to @kyrias and @glyph!
* Log which files we saved attachments to in the media_repository (PR #1791)
* Linearize updates to membership via PUT /state/ to better handle multiple
joins (PR #1787)
* Limit number of entries to prefill from cache on startup (PR #1792)
* Remove full_twisted_stacktraces option (PR #1802)
* Measure size of some caches by sum of the size of cached values (PR #1815)
* Measure metrics of string_cache (PR #1821)
* Reduce logging verbosity (PR #1822, #1823, #1824)
* Don't clobber a displayname or avatar_url if provided by an m.room.member
event (PR #1852)
* Better handle 401/404 response for federation /send/ (PR #1866, #1871)
Fixes:
* Fix ability to change password to a non-ascii one (PR #1711)
* Fix push getting stuck due to looking at the wrong view of state (PR #1820)
* Fix email address comparison to be case insensitive (PR #1827)
* Fix occasional inconsistencies of room membership (PR #1836, #1840)
Performance:
* Don't block messages sending on bumping presence (PR #1789)
* Change device_inbox stream index to include user (PR #1793)
* Optimise state resolution (PR #1818)
* Use DB cache of joined users for presence (PR #1862)
* Add an index to make membership queries faster (PR #1867)
Changes in synapse v0.18.7 (2017-01-09)
=======================================
@ -30,6 +105,7 @@ Changes in synapse v0.18.6 (2017-01-06)
Bug fixes:
* Fix bug when checking if a guest user is allowed to join a room (PR #1772)
Thanks to Patrik Oldsberg for diagnosing and the fix!
Changes in synapse v0.18.6-rc3 (2017-01-05)

View file

@ -138,6 +138,7 @@ Installing prerequisites on openSUSE::
python-devel libffi-devel libopenssl-devel libjpeg62-devel
Installing prerequisites on OpenBSD::
doas pkg_add python libffi py-pip py-setuptools sqlite3 py-virtualenv \
libxslt
@ -658,7 +659,7 @@ configuration might look like::
}
}
You will also want to set ``bind_address: 127.0.0.1`` and ``x_forwarded: true``
You will also want to set ``bind_addresses: ['127.0.0.1']`` and ``x_forwarded: true``
for port 8008 in ``homeserver.yaml`` to ensure that client IP addresses are
recorded correctly.

View file

@ -2,15 +2,13 @@ Purge Remote Media API
======================
The purge remote media API allows server admins to purge old cached remote
media.
media.
The API is::
POST /_matrix/client/r0/admin/purge_media_cache
POST /_matrix/client/r0/admin/purge_media_cache?before_ts=<unix_timestamp_in_ms>&access_token=<access_token>
{
"before_ts": <unix_timestamp_in_ms>
}
{}
Which will remove all cached media that was last accessed before
``<unix_timestamp_in_ms>``.

View file

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server.
"""
__version__ = "0.18.7"
__version__ = "0.19.0"

View file

@ -16,18 +16,14 @@
import logging
import pymacaroons
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json, SignatureVerifyException
from twisted.internet import defer
from unpaddedbase64 import decode_base64
import synapse.types
from synapse import event_auth
from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import UserID, get_domain_from_id
from synapse.api.errors import AuthError, Codes
from synapse.types import UserID
from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
@ -78,147 +74,7 @@ class Auth(object):
True if the auth checks pass.
"""
with Measure(self.clock, "auth.check"):
self.check_size_limits(event)
if not hasattr(event, "room_id"):
raise AuthError(500, "Event has no room_id: %s" % event)
if do_sig_check:
sender_domain = get_domain_from_id(event.sender)
event_id_domain = get_domain_from_id(event.event_id)
is_invite_via_3pid = (
event.type == EventTypes.Member
and event.membership == Membership.INVITE
and "third_party_invite" in event.content
)
# Check the sender's domain has signed the event
if not event.signatures.get(sender_domain):
# We allow invites via 3pid to have a sender from a different
# HS, as the sender must match the sender of the original
# 3pid invite. This is checked further down with the
# other dedicated membership checks.
if not is_invite_via_3pid:
raise AuthError(403, "Event not signed by sender's server")
# Check the event_id's domain has signed the event
if not event.signatures.get(event_id_domain):
raise AuthError(403, "Event not signed by sending server")
if auth_events is None:
# Oh, we don't know what the state of the room was, so we
# are trusting that this is allowed (at least for now)
logger.warn("Trusting event: %s", event.event_id)
return True
if event.type == EventTypes.Create:
room_id_domain = get_domain_from_id(event.room_id)
if room_id_domain != sender_domain:
raise AuthError(
403,
"Creation event's room_id domain does not match sender's"
)
# FIXME
return True
creation_event = auth_events.get((EventTypes.Create, ""), None)
if not creation_event:
raise SynapseError(
403,
"Room %r does not exist" % (event.room_id,)
)
creating_domain = get_domain_from_id(event.room_id)
originating_domain = get_domain_from_id(event.sender)
if creating_domain != originating_domain:
if not self.can_federate(event, auth_events):
raise AuthError(
403,
"This room has been marked as unfederatable."
)
# FIXME: Temp hack
if event.type == EventTypes.Aliases:
if not event.is_state():
raise AuthError(
403,
"Alias event must be a state event",
)
if not event.state_key:
raise AuthError(
403,
"Alias event must have non-empty state_key"
)
sender_domain = get_domain_from_id(event.sender)
if event.state_key != sender_domain:
raise AuthError(
403,
"Alias event's state_key does not match sender's domain"
)
return True
logger.debug(
"Auth events: %s",
[a.event_id for a in auth_events.values()]
)
if event.type == EventTypes.Member:
allowed = self.is_membership_change_allowed(
event, auth_events
)
if allowed:
logger.debug("Allowing! %s", event)
else:
logger.debug("Denying! %s", event)
return allowed
self.check_event_sender_in_room(event, auth_events)
# Special case to allow m.room.third_party_invite events wherever
# a user is allowed to issue invites. Fixes
# https://github.com/vector-im/vector-web/issues/1208 hopefully
if event.type == EventTypes.ThirdPartyInvite:
user_level = self._get_user_power_level(event.user_id, auth_events)
invite_level = self._get_named_level(auth_events, "invite", 0)
if user_level < invite_level:
raise AuthError(
403, (
"You cannot issue a third party invite for %s." %
(event.content.display_name,)
)
)
else:
return True
self._can_send_event(event, auth_events)
if event.type == EventTypes.PowerLevels:
self._check_power_levels(event, auth_events)
if event.type == EventTypes.Redaction:
self.check_redaction(event, auth_events)
logger.debug("Allowing! %s", event)
def check_size_limits(self, event):
def too_big(field):
raise EventSizeError("%s too large" % (field,))
if len(event.user_id) > 255:
too_big("user_id")
if len(event.room_id) > 255:
too_big("room_id")
if event.is_state() and len(event.state_key) > 255:
too_big("state_key")
if len(event.type) > 255:
too_big("type")
if len(event.event_id) > 255:
too_big("event_id")
if len(encode_canonical_json(event.get_pdu_json())) > 65536:
too_big("event")
event_auth.check(event, auth_events, do_sig_check=do_sig_check)
@defer.inlineCallbacks
def check_joined_room(self, room_id, user_id, current_state=None):
@ -290,7 +146,7 @@ class Auth(object):
with Measure(self.clock, "check_host_in_room"):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.info("calling resolve_state_groups from check_host_in_room")
logger.debug("calling resolve_state_groups from check_host_in_room")
entry = yield self.state.resolve_state_groups(
room_id, latest_event_ids
)
@ -300,16 +156,6 @@ class Auth(object):
)
defer.returnValue(ret)
def check_event_sender_in_room(self, event, auth_events):
key = (EventTypes.Member, event.user_id, )
member_event = auth_events.get(key)
return self._check_joined_room(
member_event,
event.user_id,
event.room_id
)
def _check_joined_room(self, member, user_id, room_id):
if not member or member.membership != Membership.JOIN:
raise AuthError(403, "User %s not in room %s (%s)" % (
@ -321,267 +167,8 @@ class Auth(object):
return creation_event.content.get("m.federate", True) is True
@log_function
def is_membership_change_allowed(self, event, auth_events):
membership = event.content["membership"]
# Check if this is the room creator joining:
if len(event.prev_events) == 1 and Membership.JOIN == membership:
# Get room creation event:
key = (EventTypes.Create, "", )
create = auth_events.get(key)
if create and event.prev_events[0][0] == create.event_id:
if create.content["creator"] == event.state_key:
return True
target_user_id = event.state_key
creating_domain = get_domain_from_id(event.room_id)
target_domain = get_domain_from_id(target_user_id)
if creating_domain != target_domain:
if not self.can_federate(event, auth_events):
raise AuthError(
403,
"This room has been marked as unfederatable."
)
# get info about the caller
key = (EventTypes.Member, event.user_id, )
caller = auth_events.get(key)
caller_in_room = caller and caller.membership == Membership.JOIN
caller_invited = caller and caller.membership == Membership.INVITE
# get info about the target
key = (EventTypes.Member, target_user_id, )
target = auth_events.get(key)
target_in_room = target and target.membership == Membership.JOIN
target_banned = target and target.membership == Membership.BAN
key = (EventTypes.JoinRules, "", )
join_rule_event = auth_events.get(key)
if join_rule_event:
join_rule = join_rule_event.content.get(
"join_rule", JoinRules.INVITE
)
else:
join_rule = JoinRules.INVITE
user_level = self._get_user_power_level(event.user_id, auth_events)
target_level = self._get_user_power_level(
target_user_id, auth_events
)
# FIXME (erikj): What should we do here as the default?
ban_level = self._get_named_level(auth_events, "ban", 50)
logger.debug(
"is_membership_change_allowed: %s",
{
"caller_in_room": caller_in_room,
"caller_invited": caller_invited,
"target_banned": target_banned,
"target_in_room": target_in_room,
"membership": membership,
"join_rule": join_rule,
"target_user_id": target_user_id,
"event.user_id": event.user_id,
}
)
if Membership.INVITE == membership and "third_party_invite" in event.content:
if not self._verify_third_party_invite(event, auth_events):
raise AuthError(403, "You are not invited to this room.")
if target_banned:
raise AuthError(
403, "%s is banned from the room" % (target_user_id,)
)
return True
if Membership.JOIN != membership:
if (caller_invited
and Membership.LEAVE == membership
and target_user_id == event.user_id):
return True
if not caller_in_room: # caller isn't joined
raise AuthError(
403,
"%s not in room %s." % (event.user_id, event.room_id,)
)
if Membership.INVITE == membership:
# TODO (erikj): We should probably handle this more intelligently
# PRIVATE join rules.
# Invites are valid iff caller is in the room and target isn't.
if target_banned:
raise AuthError(
403, "%s is banned from the room" % (target_user_id,)
)
elif target_in_room: # the target is already in the room.
raise AuthError(403, "%s is already in the room." %
target_user_id)
else:
invite_level = self._get_named_level(auth_events, "invite", 0)
if user_level < invite_level:
raise AuthError(
403, "You cannot invite user %s." % target_user_id
)
elif Membership.JOIN == membership:
# Joins are valid iff caller == target and they were:
# invited: They are accepting the invitation
# joined: It's a NOOP
if event.user_id != target_user_id:
raise AuthError(403, "Cannot force another user to join.")
elif target_banned:
raise AuthError(403, "You are banned from this room")
elif join_rule == JoinRules.PUBLIC:
pass
elif join_rule == JoinRules.INVITE:
if not caller_in_room and not caller_invited:
raise AuthError(403, "You are not invited to this room.")
else:
# TODO (erikj): may_join list
# TODO (erikj): private rooms
raise AuthError(403, "You are not allowed to join this room")
elif Membership.LEAVE == membership:
# TODO (erikj): Implement kicks.
if target_banned and user_level < ban_level:
raise AuthError(
403, "You cannot unban user &s." % (target_user_id,)
)
elif target_user_id != event.user_id:
kick_level = self._get_named_level(auth_events, "kick", 50)
if user_level < kick_level or user_level <= target_level:
raise AuthError(
403, "You cannot kick user %s." % target_user_id
)
elif Membership.BAN == membership:
if user_level < ban_level or user_level <= target_level:
raise AuthError(403, "You don't have permission to ban")
else:
raise AuthError(500, "Unknown membership %s" % membership)
return True
def _verify_third_party_invite(self, event, auth_events):
"""
Validates that the invite event is authorized by a previous third-party invite.
Checks that the public key, and keyserver, match those in the third party invite,
and that the invite event has a signature issued using that public key.
Args:
event: The m.room.member join event being validated.
auth_events: All relevant previous context events which may be used
for authorization decisions.
Return:
True if the event fulfills the expectations of a previous third party
invite event.
"""
if "third_party_invite" not in event.content:
return False
if "signed" not in event.content["third_party_invite"]:
return False
signed = event.content["third_party_invite"]["signed"]
for key in {"mxid", "token"}:
if key not in signed:
return False
token = signed["token"]
invite_event = auth_events.get(
(EventTypes.ThirdPartyInvite, token,)
)
if not invite_event:
return False
if invite_event.sender != event.sender:
return False
if event.user_id != invite_event.user_id:
return False
if signed["mxid"] != event.state_key:
return False
if signed["token"] != token:
return False
for public_key_object in self.get_public_keys(invite_event):
public_key = public_key_object["public_key"]
try:
for server, signature_block in signed["signatures"].items():
for key_name, encoded_signature in signature_block.items():
if not key_name.startswith("ed25519:"):
continue
verify_key = decode_verify_key_bytes(
key_name,
decode_base64(public_key)
)
verify_signed_json(signed, server, verify_key)
# We got the public key from the invite, so we know that the
# correct server signed the signed bundle.
# The caller is responsible for checking that the signing
# server has not revoked that public key.
return True
except (KeyError, SignatureVerifyException,):
continue
return False
def get_public_keys(self, invite_event):
public_keys = []
if "public_key" in invite_event.content:
o = {
"public_key": invite_event.content["public_key"],
}
if "key_validity_url" in invite_event.content:
o["key_validity_url"] = invite_event.content["key_validity_url"]
public_keys.append(o)
public_keys.extend(invite_event.content.get("public_keys", []))
return public_keys
def _get_power_level_event(self, auth_events):
key = (EventTypes.PowerLevels, "", )
return auth_events.get(key)
def _get_user_power_level(self, user_id, auth_events):
power_level_event = self._get_power_level_event(auth_events)
if power_level_event:
level = power_level_event.content.get("users", {}).get(user_id)
if not level:
level = power_level_event.content.get("users_default", 0)
if level is None:
return 0
else:
return int(level)
else:
key = (EventTypes.Create, "", )
create_event = auth_events.get(key)
if (create_event is not None and
create_event.content["creator"] == user_id):
return 100
else:
return 0
def _get_named_level(self, auth_events, name, default):
power_level_event = self._get_power_level_event(auth_events)
if not power_level_event:
return default
level = power_level_event.content.get(name, None)
if level is not None:
return int(level)
else:
return default
return event_auth.get_public_keys(invite_event)
@defer.inlineCallbacks
def get_user_by_req(self, request, allow_guest=False, rights="access"):
@ -974,56 +561,6 @@ class Auth(object):
defer.returnValue(auth_ids)
def _get_send_level(self, etype, state_key, auth_events):
key = (EventTypes.PowerLevels, "", )
send_level_event = auth_events.get(key)
send_level = None
if send_level_event:
send_level = send_level_event.content.get("events", {}).get(
etype
)
if send_level is None:
if state_key is not None:
send_level = send_level_event.content.get(
"state_default", 50
)
else:
send_level = send_level_event.content.get(
"events_default", 0
)
if send_level:
send_level = int(send_level)
else:
send_level = 0
return send_level
@log_function
def _can_send_event(self, event, auth_events):
send_level = self._get_send_level(
event.type, event.get("state_key", None), auth_events
)
user_level = self._get_user_power_level(event.user_id, auth_events)
if user_level < send_level:
raise AuthError(
403,
"You don't have permission to post that to the room. " +
"user_level (%d) < send_level (%d)" % (user_level, send_level)
)
# Check state_key
if hasattr(event, "state_key"):
if event.state_key.startswith("@"):
if event.state_key != event.user_id:
raise AuthError(
403,
"You are not allowed to set others state"
)
return True
def check_redaction(self, event, auth_events):
"""Check whether the event sender is allowed to redact the target event.
@ -1037,107 +574,7 @@ class Auth(object):
AuthError if the event sender is definitely not allowed to redact
the target event.
"""
user_level = self._get_user_power_level(event.user_id, auth_events)
redact_level = self._get_named_level(auth_events, "redact", 50)
if user_level >= redact_level:
return False
redacter_domain = get_domain_from_id(event.event_id)
redactee_domain = get_domain_from_id(event.redacts)
if redacter_domain == redactee_domain:
return True
raise AuthError(
403,
"You don't have permission to redact events"
)
def _check_power_levels(self, event, auth_events):
user_list = event.content.get("users", {})
# Validate users
for k, v in user_list.items():
try:
UserID.from_string(k)
except:
raise SynapseError(400, "Not a valid user_id: %s" % (k,))
try:
int(v)
except:
raise SynapseError(400, "Not a valid power level: %s" % (v,))
key = (event.type, event.state_key, )
current_state = auth_events.get(key)
if not current_state:
return
user_level = self._get_user_power_level(event.user_id, auth_events)
# Check other levels:
levels_to_check = [
("users_default", None),
("events_default", None),
("state_default", None),
("ban", None),
("redact", None),
("kick", None),
("invite", None),
]
old_list = current_state.content.get("users")
for user in set(old_list.keys() + user_list.keys()):
levels_to_check.append(
(user, "users")
)
old_list = current_state.content.get("events")
new_list = event.content.get("events")
for ev_id in set(old_list.keys() + new_list.keys()):
levels_to_check.append(
(ev_id, "events")
)
old_state = current_state.content
new_state = event.content
for level_to_check, dir in levels_to_check:
old_loc = old_state
new_loc = new_state
if dir:
old_loc = old_loc.get(dir, {})
new_loc = new_loc.get(dir, {})
if level_to_check in old_loc:
old_level = int(old_loc[level_to_check])
else:
old_level = None
if level_to_check in new_loc:
new_level = int(new_loc[level_to_check])
else:
new_level = None
if new_level is not None and old_level is not None:
if new_level == old_level:
continue
if dir == "users" and level_to_check != event.user_id:
if old_level == user_level:
raise AuthError(
403,
"You don't have permission to remove ops level equal "
"to your own"
)
if old_level > user_level or new_level > user_level:
raise AuthError(
403,
"You don't have permission to add ops level greater "
"than your own"
)
return event_auth.check_redaction(event, auth_events)
@defer.inlineCallbacks
def check_can_change_room_list(self, room_id, user):
@ -1167,10 +604,10 @@ class Auth(object):
if power_level_event:
auth_events[(EventTypes.PowerLevels, "")] = power_level_event
send_level = self._get_send_level(
send_level = event_auth.get_send_level(
EventTypes.Aliases, "", auth_events
)
user_level = self._get_user_power_level(user_id, auth_events)
user_level = event_auth.get_user_power_level(user_id, auth_events)
if user_level < send_level:
raise AuthError(

View file

@ -76,7 +76,7 @@ class AppserviceServer(HomeServer):
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
bind_addresses = listener_config["bind_addresses"]
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
@ -85,16 +85,19 @@ class AppserviceServer(HomeServer):
resources[METRICS_PREFIX] = MetricsResource(self)
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
for address in bind_addresses:
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=address
)
logger.info("Synapse appservice now listening on port %d", port)
def start_listening(self, listeners):
@ -102,15 +105,18 @@ class AppserviceServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=address
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])

View file

@ -90,7 +90,7 @@ class ClientReaderServer(HomeServer):
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
bind_addresses = listener_config["bind_addresses"]
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
@ -108,16 +108,19 @@ class ClientReaderServer(HomeServer):
})
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
for address in bind_addresses:
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=address
)
logger.info("Synapse client reader now listening on port %d", port)
def start_listening(self, listeners):
@ -125,15 +128,18 @@ class ClientReaderServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=address
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])

View file

@ -86,7 +86,7 @@ class FederationReaderServer(HomeServer):
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
bind_addresses = listener_config["bind_addresses"]
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
@ -99,16 +99,19 @@ class FederationReaderServer(HomeServer):
})
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
for address in bind_addresses:
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=address
)
logger.info("Synapse federation reader now listening on port %d", port)
def start_listening(self, listeners):
@ -116,15 +119,18 @@ class FederationReaderServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=address
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])

View file

@ -30,6 +30,7 @@ from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.storage.engines import create_engine
from synapse.storage.presence import UserPresenceState
from synapse.util.async import sleep
@ -56,7 +57,7 @@ logger = logging.getLogger("synapse.app.appservice")
class FederationSenderSlaveStore(
SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore,
SlavedRegistrationStore,
SlavedRegistrationStore, SlavedDeviceStore,
):
pass
@ -82,7 +83,7 @@ class FederationSenderServer(HomeServer):
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
bind_addresses = listener_config["bind_addresses"]
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
@ -91,16 +92,19 @@ class FederationSenderServer(HomeServer):
resources[METRICS_PREFIX] = MetricsResource(self)
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
for address in bind_addresses:
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=address
)
logger.info("Synapse federation_sender now listening on port %d", port)
def start_listening(self, listeners):
@ -108,15 +112,18 @@ class FederationSenderServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=address
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])

View file

@ -107,7 +107,7 @@ def build_resource_for_web_client(hs):
class SynapseHomeServer(HomeServer):
def _listener_http(self, config, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
bind_addresses = listener_config["bind_addresses"]
tls = listener_config.get("tls", False)
site_tag = listener_config.get("tag", port)
@ -173,29 +173,32 @@ class SynapseHomeServer(HomeServer):
root_resource = Resource()
root_resource = create_resource_tree(resources, root_resource)
if tls:
reactor.listenSSL(
port,
SynapseSite(
"synapse.access.https.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
self.tls_server_context_factory,
interface=bind_address
)
for address in bind_addresses:
reactor.listenSSL(
port,
SynapseSite(
"synapse.access.https.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
self.tls_server_context_factory,
interface=address
)
else:
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
for address in bind_addresses:
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=address
)
logger.info("Synapse now listening on port %d", port)
def start_listening(self):
@ -205,15 +208,18 @@ class SynapseHomeServer(HomeServer):
if listener["type"] == "http":
self._listener_http(config, listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=address
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])

View file

@ -87,7 +87,7 @@ class MediaRepositoryServer(HomeServer):
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
bind_addresses = listener_config["bind_addresses"]
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
@ -105,16 +105,19 @@ class MediaRepositoryServer(HomeServer):
})
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
for address in bind_addresses:
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=address
)
logger.info("Synapse media repository now listening on port %d", port)
def start_listening(self, listeners):
@ -122,15 +125,18 @@ class MediaRepositoryServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=address
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])

View file

@ -121,7 +121,7 @@ class PusherServer(HomeServer):
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
bind_addresses = listener_config["bind_addresses"]
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
@ -130,16 +130,19 @@ class PusherServer(HomeServer):
resources[METRICS_PREFIX] = MetricsResource(self)
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
for address in bind_addresses:
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=address
)
logger.info("Synapse pusher now listening on port %d", port)
def start_listening(self, listeners):
@ -147,15 +150,18 @@ class PusherServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=address
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])

View file

@ -39,6 +39,7 @@ from synapse.replication.slave.storage.filtering import SlavedFilteringStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.presence import SlavedPresenceStore
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore
@ -77,6 +78,7 @@ class SynchrotronSlavedStore(
SlavedFilteringStore,
SlavedPresenceStore,
SlavedDeviceInboxStore,
SlavedDeviceStore,
RoomStore,
BaseSlavedStore,
ClientIpStore, # After BaseSlavedStore because the constructor is different
@ -289,7 +291,7 @@ class SynchrotronServer(HomeServer):
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
bind_addresses = listener_config["bind_addresses"]
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
@ -310,16 +312,19 @@ class SynchrotronServer(HomeServer):
})
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
for address in bind_addresses:
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=address
)
logger.info("Synapse synchrotron now listening on port %d", port)
def start_listening(self, listeners):
@ -327,15 +332,18 @@ class SynchrotronServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=address
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@ -374,6 +382,27 @@ class SynchrotronServer(HomeServer):
stream_key, position, users=users, rooms=rooms
)
@defer.inlineCallbacks
def notify_device_list_update(result):
stream = result.get("device_lists")
if not stream:
return
position_index = stream["field_names"].index("position")
user_index = stream["field_names"].index("user_id")
for row in stream["rows"]:
position = row[position_index]
user_id = row[user_index]
rooms = yield store.get_rooms_for_user(user_id)
room_ids = [r.room_id for r in rooms]
notifier.on_new_event(
"device_list_key", position, rooms=room_ids,
)
@defer.inlineCallbacks
def notify(result):
stream = result.get("events")
if stream:
@ -411,6 +440,7 @@ class SynchrotronServer(HomeServer):
notify_from_stream(
result, "to_device", "to_device_key", user="user_id"
)
yield notify_device_list_update(result)
while True:
try:
@ -421,7 +451,7 @@ class SynchrotronServer(HomeServer):
yield store.process_replication(result)
typing_handler.process_replication(result)
yield presence_handler.process_replication(result)
notify(result)
yield notify(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(5)

View file

@ -68,6 +68,9 @@ class EmailConfig(Config):
self.email_notif_for_new_users = email_config.get(
"notif_for_new_users", True
)
self.email_riot_base_url = email_config.get(
"riot_base_url", None
)
if "app_name" in email_config:
self.email_app_name = email_config["app_name"]
else:
@ -85,6 +88,9 @@ class EmailConfig(Config):
def default_config(self, config_dir_path, server_name, **kwargs):
return """
# Enable sending emails for notification events
# Defining a custom URL for Riot is only needed if email notifications
# should contain links to a self-hosted installation of Riot; when set
# the "app_name" setting is ignored.
#email:
# enable_notifs: false
# smtp_host: "localhost"
@ -95,4 +101,5 @@ class EmailConfig(Config):
# notif_template_html: notif_mail.html
# notif_template_text: notif_mail.txt
# notif_for_new_users: True
# riot_base_url: "http://localhost/riot"
"""

View file

@ -22,7 +22,6 @@ import yaml
from string import Template
import os
import signal
from synapse.util.debug import debug_deferreds
DEFAULT_LOG_CONFIG = Template("""
@ -71,8 +70,6 @@ class LoggingConfig(Config):
self.verbosity = config.get("verbose", 0)
self.log_config = self.abspath(config.get("log_config"))
self.log_file = self.abspath(config.get("log_file"))
if config.get("full_twisted_stacktraces"):
debug_deferreds()
def default_config(self, config_dir_path, server_name, **kwargs):
log_file = self.abspath("homeserver.log")
@ -88,11 +85,6 @@ class LoggingConfig(Config):
# A yaml python logging config file
log_config: "%(log_config)s"
# Stop twisted from discarding the stack traces of exceptions in
# deferreds by waiting a reactor tick before running a deferred's
# callbacks.
# full_twisted_stacktraces: true
""" % locals()
def read_arguments(self, args):

View file

@ -42,6 +42,15 @@ class ServerConfig(Config):
self.listeners = config.get("listeners", [])
for listener in self.listeners:
bind_address = listener.pop("bind_address", None)
bind_addresses = listener.setdefault("bind_addresses", [])
if bind_address:
bind_addresses.append(bind_address)
elif not bind_addresses:
bind_addresses.append('')
self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None))
bind_port = config.get("bind_port")
@ -54,7 +63,7 @@ class ServerConfig(Config):
self.listeners.append({
"port": bind_port,
"bind_address": bind_host,
"bind_addresses": [bind_host],
"tls": True,
"type": "http",
"resources": [
@ -73,7 +82,7 @@ class ServerConfig(Config):
if unsecure_port:
self.listeners.append({
"port": unsecure_port,
"bind_address": bind_host,
"bind_addresses": [bind_host],
"tls": False,
"type": "http",
"resources": [
@ -92,7 +101,7 @@ class ServerConfig(Config):
if manhole:
self.listeners.append({
"port": manhole,
"bind_address": "127.0.0.1",
"bind_addresses": ["127.0.0.1"],
"type": "manhole",
})
@ -100,7 +109,7 @@ class ServerConfig(Config):
if metrics_port:
self.listeners.append({
"port": metrics_port,
"bind_address": config.get("metrics_bind_host", "127.0.0.1"),
"bind_addresses": [config.get("metrics_bind_host", "127.0.0.1")],
"tls": False,
"type": "http",
"resources": [
@ -155,9 +164,14 @@ class ServerConfig(Config):
# The port to listen for HTTPS requests on.
port: %(bind_port)s
# Local interface to listen on.
# The empty string will cause synapse to listen on all interfaces.
bind_address: ''
# Local addresses to listen on.
# This will listen on all IPv4 addresses by default.
bind_addresses:
- '0.0.0.0'
# Uncomment to listen on all IPv6 interfaces
# N.B: On at least Linux this will also listen on all IPv4
# addresses, so you will need to comment out the line above.
# - '::'
# This is a 'http' listener, allows us to specify 'resources'.
type: http
@ -188,7 +202,7 @@ class ServerConfig(Config):
# For when matrix traffic passes through loadbalancer that unwraps TLS.
- port: %(unsecure_port)s
tls: false
bind_address: ''
bind_addresses: ['0.0.0.0']
type: http
x_forwarded: false

View file

@ -19,7 +19,9 @@ class VoipConfig(Config):
def read_config(self, config):
self.turn_uris = config.get("turn_uris", [])
self.turn_shared_secret = config["turn_shared_secret"]
self.turn_shared_secret = config.get("turn_shared_secret")
self.turn_username = config.get("turn_username")
self.turn_password = config.get("turn_password")
self.turn_user_lifetime = self.parse_duration(config["turn_user_lifetime"])
def default_config(self, **kwargs):
@ -32,6 +34,11 @@ class VoipConfig(Config):
# The shared secret used to compute passwords for the TURN server
turn_shared_secret: "YOUR_SHARED_SECRET"
# The Username and password if the TURN server needs them and
# does not use a token
#turn_username: "TURNSERVER_USERNAME"
#turn_password: "TURNSERVER_PASSWORD"
# How long generated TURN credentials last
turn_user_lifetime: "1h"
"""

View file

@ -29,3 +29,13 @@ class WorkerConfig(Config):
self.worker_log_file = config.get("worker_log_file")
self.worker_log_config = config.get("worker_log_config")
self.worker_replication_url = config.get("worker_replication_url")
if self.worker_listeners:
for listener in self.worker_listeners:
bind_address = listener.pop("bind_address", None)
bind_addresses = listener.setdefault("bind_addresses", [])
if bind_address:
bind_addresses.append(bind_address)
elif not bind_addresses:
bind_addresses.append('')

678
synapse/event_auth.py Normal file
View file

@ -0,0 +1,678 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json, SignatureVerifyException
from unpaddedbase64 import decode_base64
from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, SynapseError, EventSizeError
from synapse.types import UserID, get_domain_from_id
logger = logging.getLogger(__name__)
def check(event, auth_events, do_sig_check=True, do_size_check=True):
""" Checks if this event is correctly authed.
Args:
event: the event being checked.
auth_events (dict: event-key -> event): the existing room state.
Returns:
True if the auth checks pass.
"""
if do_size_check:
_check_size_limits(event)
if not hasattr(event, "room_id"):
raise AuthError(500, "Event has no room_id: %s" % event)
if do_sig_check:
sender_domain = get_domain_from_id(event.sender)
event_id_domain = get_domain_from_id(event.event_id)
is_invite_via_3pid = (
event.type == EventTypes.Member
and event.membership == Membership.INVITE
and "third_party_invite" in event.content
)
# Check the sender's domain has signed the event
if not event.signatures.get(sender_domain):
# We allow invites via 3pid to have a sender from a different
# HS, as the sender must match the sender of the original
# 3pid invite. This is checked further down with the
# other dedicated membership checks.
if not is_invite_via_3pid:
raise AuthError(403, "Event not signed by sender's server")
# Check the event_id's domain has signed the event
if not event.signatures.get(event_id_domain):
raise AuthError(403, "Event not signed by sending server")
if auth_events is None:
# Oh, we don't know what the state of the room was, so we
# are trusting that this is allowed (at least for now)
logger.warn("Trusting event: %s", event.event_id)
return True
if event.type == EventTypes.Create:
room_id_domain = get_domain_from_id(event.room_id)
if room_id_domain != sender_domain:
raise AuthError(
403,
"Creation event's room_id domain does not match sender's"
)
# FIXME
return True
creation_event = auth_events.get((EventTypes.Create, ""), None)
if not creation_event:
raise SynapseError(
403,
"Room %r does not exist" % (event.room_id,)
)
creating_domain = get_domain_from_id(event.room_id)
originating_domain = get_domain_from_id(event.sender)
if creating_domain != originating_domain:
if not _can_federate(event, auth_events):
raise AuthError(
403,
"This room has been marked as unfederatable."
)
# FIXME: Temp hack
if event.type == EventTypes.Aliases:
if not event.is_state():
raise AuthError(
403,
"Alias event must be a state event",
)
if not event.state_key:
raise AuthError(
403,
"Alias event must have non-empty state_key"
)
sender_domain = get_domain_from_id(event.sender)
if event.state_key != sender_domain:
raise AuthError(
403,
"Alias event's state_key does not match sender's domain"
)
return True
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"Auth events: %s",
[a.event_id for a in auth_events.values()]
)
if event.type == EventTypes.Member:
allowed = _is_membership_change_allowed(
event, auth_events
)
if allowed:
logger.debug("Allowing! %s", event)
else:
logger.debug("Denying! %s", event)
return allowed
_check_event_sender_in_room(event, auth_events)
# Special case to allow m.room.third_party_invite events wherever
# a user is allowed to issue invites. Fixes
# https://github.com/vector-im/vector-web/issues/1208 hopefully
if event.type == EventTypes.ThirdPartyInvite:
user_level = get_user_power_level(event.user_id, auth_events)
invite_level = _get_named_level(auth_events, "invite", 0)
if user_level < invite_level:
raise AuthError(
403, (
"You cannot issue a third party invite for %s." %
(event.content.display_name,)
)
)
else:
return True
_can_send_event(event, auth_events)
if event.type == EventTypes.PowerLevels:
_check_power_levels(event, auth_events)
if event.type == EventTypes.Redaction:
check_redaction(event, auth_events)
logger.debug("Allowing! %s", event)
def _check_size_limits(event):
def too_big(field):
raise EventSizeError("%s too large" % (field,))
if len(event.user_id) > 255:
too_big("user_id")
if len(event.room_id) > 255:
too_big("room_id")
if event.is_state() and len(event.state_key) > 255:
too_big("state_key")
if len(event.type) > 255:
too_big("type")
if len(event.event_id) > 255:
too_big("event_id")
if len(encode_canonical_json(event.get_pdu_json())) > 65536:
too_big("event")
def _can_federate(event, auth_events):
creation_event = auth_events.get((EventTypes.Create, ""))
return creation_event.content.get("m.federate", True) is True
def _is_membership_change_allowed(event, auth_events):
membership = event.content["membership"]
# Check if this is the room creator joining:
if len(event.prev_events) == 1 and Membership.JOIN == membership:
# Get room creation event:
key = (EventTypes.Create, "", )
create = auth_events.get(key)
if create and event.prev_events[0][0] == create.event_id:
if create.content["creator"] == event.state_key:
return True
target_user_id = event.state_key
creating_domain = get_domain_from_id(event.room_id)
target_domain = get_domain_from_id(target_user_id)
if creating_domain != target_domain:
if not _can_federate(event, auth_events):
raise AuthError(
403,
"This room has been marked as unfederatable."
)
# get info about the caller
key = (EventTypes.Member, event.user_id, )
caller = auth_events.get(key)
caller_in_room = caller and caller.membership == Membership.JOIN
caller_invited = caller and caller.membership == Membership.INVITE
# get info about the target
key = (EventTypes.Member, target_user_id, )
target = auth_events.get(key)
target_in_room = target and target.membership == Membership.JOIN
target_banned = target and target.membership == Membership.BAN
key = (EventTypes.JoinRules, "", )
join_rule_event = auth_events.get(key)
if join_rule_event:
join_rule = join_rule_event.content.get(
"join_rule", JoinRules.INVITE
)
else:
join_rule = JoinRules.INVITE
user_level = get_user_power_level(event.user_id, auth_events)
target_level = get_user_power_level(
target_user_id, auth_events
)
# FIXME (erikj): What should we do here as the default?
ban_level = _get_named_level(auth_events, "ban", 50)
logger.debug(
"_is_membership_change_allowed: %s",
{
"caller_in_room": caller_in_room,
"caller_invited": caller_invited,
"target_banned": target_banned,
"target_in_room": target_in_room,
"membership": membership,
"join_rule": join_rule,
"target_user_id": target_user_id,
"event.user_id": event.user_id,
}
)
if Membership.INVITE == membership and "third_party_invite" in event.content:
if not _verify_third_party_invite(event, auth_events):
raise AuthError(403, "You are not invited to this room.")
if target_banned:
raise AuthError(
403, "%s is banned from the room" % (target_user_id,)
)
return True
if Membership.JOIN != membership:
if (caller_invited
and Membership.LEAVE == membership
and target_user_id == event.user_id):
return True
if not caller_in_room: # caller isn't joined
raise AuthError(
403,
"%s not in room %s." % (event.user_id, event.room_id,)
)
if Membership.INVITE == membership:
# TODO (erikj): We should probably handle this more intelligently
# PRIVATE join rules.
# Invites are valid iff caller is in the room and target isn't.
if target_banned:
raise AuthError(
403, "%s is banned from the room" % (target_user_id,)
)
elif target_in_room: # the target is already in the room.
raise AuthError(403, "%s is already in the room." %
target_user_id)
else:
invite_level = _get_named_level(auth_events, "invite", 0)
if user_level < invite_level:
raise AuthError(
403, "You cannot invite user %s." % target_user_id
)
elif Membership.JOIN == membership:
# Joins are valid iff caller == target and they were:
# invited: They are accepting the invitation
# joined: It's a NOOP
if event.user_id != target_user_id:
raise AuthError(403, "Cannot force another user to join.")
elif target_banned:
raise AuthError(403, "You are banned from this room")
elif join_rule == JoinRules.PUBLIC:
pass
elif join_rule == JoinRules.INVITE:
if not caller_in_room and not caller_invited:
raise AuthError(403, "You are not invited to this room.")
else:
# TODO (erikj): may_join list
# TODO (erikj): private rooms
raise AuthError(403, "You are not allowed to join this room")
elif Membership.LEAVE == membership:
# TODO (erikj): Implement kicks.
if target_banned and user_level < ban_level:
raise AuthError(
403, "You cannot unban user &s." % (target_user_id,)
)
elif target_user_id != event.user_id:
kick_level = _get_named_level(auth_events, "kick", 50)
if user_level < kick_level or user_level <= target_level:
raise AuthError(
403, "You cannot kick user %s." % target_user_id
)
elif Membership.BAN == membership:
if user_level < ban_level or user_level <= target_level:
raise AuthError(403, "You don't have permission to ban")
else:
raise AuthError(500, "Unknown membership %s" % membership)
return True
def _check_event_sender_in_room(event, auth_events):
key = (EventTypes.Member, event.user_id, )
member_event = auth_events.get(key)
return _check_joined_room(
member_event,
event.user_id,
event.room_id
)
def _check_joined_room(member, user_id, room_id):
if not member or member.membership != Membership.JOIN:
raise AuthError(403, "User %s not in room %s (%s)" % (
user_id, room_id, repr(member)
))
def get_send_level(etype, state_key, auth_events):
key = (EventTypes.PowerLevels, "", )
send_level_event = auth_events.get(key)
send_level = None
if send_level_event:
send_level = send_level_event.content.get("events", {}).get(
etype
)
if send_level is None:
if state_key is not None:
send_level = send_level_event.content.get(
"state_default", 50
)
else:
send_level = send_level_event.content.get(
"events_default", 0
)
if send_level:
send_level = int(send_level)
else:
send_level = 0
return send_level
def _can_send_event(event, auth_events):
send_level = get_send_level(
event.type, event.get("state_key", None), auth_events
)
user_level = get_user_power_level(event.user_id, auth_events)
if user_level < send_level:
raise AuthError(
403,
"You don't have permission to post that to the room. " +
"user_level (%d) < send_level (%d)" % (user_level, send_level)
)
# Check state_key
if hasattr(event, "state_key"):
if event.state_key.startswith("@"):
if event.state_key != event.user_id:
raise AuthError(
403,
"You are not allowed to set others state"
)
return True
def check_redaction(event, auth_events):
"""Check whether the event sender is allowed to redact the target event.
Returns:
True if the the sender is allowed to redact the target event if the
target event was created by them.
False if the sender is allowed to redact the target event with no
further checks.
Raises:
AuthError if the event sender is definitely not allowed to redact
the target event.
"""
user_level = get_user_power_level(event.user_id, auth_events)
redact_level = _get_named_level(auth_events, "redact", 50)
if user_level >= redact_level:
return False
redacter_domain = get_domain_from_id(event.event_id)
redactee_domain = get_domain_from_id(event.redacts)
if redacter_domain == redactee_domain:
return True
raise AuthError(
403,
"You don't have permission to redact events"
)
def _check_power_levels(event, auth_events):
user_list = event.content.get("users", {})
# Validate users
for k, v in user_list.items():
try:
UserID.from_string(k)
except:
raise SynapseError(400, "Not a valid user_id: %s" % (k,))
try:
int(v)
except:
raise SynapseError(400, "Not a valid power level: %s" % (v,))
key = (event.type, event.state_key, )
current_state = auth_events.get(key)
if not current_state:
return
user_level = get_user_power_level(event.user_id, auth_events)
# Check other levels:
levels_to_check = [
("users_default", None),
("events_default", None),
("state_default", None),
("ban", None),
("redact", None),
("kick", None),
("invite", None),
]
old_list = current_state.content.get("users")
for user in set(old_list.keys() + user_list.keys()):
levels_to_check.append(
(user, "users")
)
old_list = current_state.content.get("events")
new_list = event.content.get("events")
for ev_id in set(old_list.keys() + new_list.keys()):
levels_to_check.append(
(ev_id, "events")
)
old_state = current_state.content
new_state = event.content
for level_to_check, dir in levels_to_check:
old_loc = old_state
new_loc = new_state
if dir:
old_loc = old_loc.get(dir, {})
new_loc = new_loc.get(dir, {})
if level_to_check in old_loc:
old_level = int(old_loc[level_to_check])
else:
old_level = None
if level_to_check in new_loc:
new_level = int(new_loc[level_to_check])
else:
new_level = None
if new_level is not None and old_level is not None:
if new_level == old_level:
continue
if dir == "users" and level_to_check != event.user_id:
if old_level == user_level:
raise AuthError(
403,
"You don't have permission to remove ops level equal "
"to your own"
)
if old_level > user_level or new_level > user_level:
raise AuthError(
403,
"You don't have permission to add ops level greater "
"than your own"
)
def _get_power_level_event(auth_events):
key = (EventTypes.PowerLevels, "", )
return auth_events.get(key)
def get_user_power_level(user_id, auth_events):
power_level_event = _get_power_level_event(auth_events)
if power_level_event:
level = power_level_event.content.get("users", {}).get(user_id)
if not level:
level = power_level_event.content.get("users_default", 0)
if level is None:
return 0
else:
return int(level)
else:
key = (EventTypes.Create, "", )
create_event = auth_events.get(key)
if (create_event is not None and
create_event.content["creator"] == user_id):
return 100
else:
return 0
def _get_named_level(auth_events, name, default):
power_level_event = _get_power_level_event(auth_events)
if not power_level_event:
return default
level = power_level_event.content.get(name, None)
if level is not None:
return int(level)
else:
return default
def _verify_third_party_invite(event, auth_events):
"""
Validates that the invite event is authorized by a previous third-party invite.
Checks that the public key, and keyserver, match those in the third party invite,
and that the invite event has a signature issued using that public key.
Args:
event: The m.room.member join event being validated.
auth_events: All relevant previous context events which may be used
for authorization decisions.
Return:
True if the event fulfills the expectations of a previous third party
invite event.
"""
if "third_party_invite" not in event.content:
return False
if "signed" not in event.content["third_party_invite"]:
return False
signed = event.content["third_party_invite"]["signed"]
for key in {"mxid", "token"}:
if key not in signed:
return False
token = signed["token"]
invite_event = auth_events.get(
(EventTypes.ThirdPartyInvite, token,)
)
if not invite_event:
return False
if invite_event.sender != event.sender:
return False
if event.user_id != invite_event.user_id:
return False
if signed["mxid"] != event.state_key:
return False
if signed["token"] != token:
return False
for public_key_object in get_public_keys(invite_event):
public_key = public_key_object["public_key"]
try:
for server, signature_block in signed["signatures"].items():
for key_name, encoded_signature in signature_block.items():
if not key_name.startswith("ed25519:"):
continue
verify_key = decode_verify_key_bytes(
key_name,
decode_base64(public_key)
)
verify_signed_json(signed, server, verify_key)
# We got the public key from the invite, so we know that the
# correct server signed the signed bundle.
# The caller is responsible for checking that the signing
# server has not revoked that public key.
return True
except (KeyError, SignatureVerifyException,):
continue
return False
def get_public_keys(invite_event):
public_keys = []
if "public_key" in invite_event.content:
o = {
"public_key": invite_event.content["public_key"],
}
if "key_validity_url" in invite_event.content:
o["key_validity_url"] = invite_event.content["key_validity_url"]
public_keys.append(o)
public_keys.extend(invite_event.content.get("public_keys", []))
return public_keys
def auth_types_for_event(event):
"""Given an event, return a list of (EventType, StateKey) that may be
needed to auth the event. The returned list may be a superset of what
would actually be required depending on the full state of the room.
Used to limit the number of events to fetch from the database to
actually auth the event.
"""
if event.type == EventTypes.Create:
return []
auth_types = []
auth_types.append((EventTypes.PowerLevels, "", ))
auth_types.append((EventTypes.Member, event.user_id, ))
auth_types.append((EventTypes.Create, "", ))
if event.type == EventTypes.Member:
membership = event.content["membership"]
if membership in [Membership.JOIN, Membership.INVITE]:
auth_types.append((EventTypes.JoinRules, "", ))
auth_types.append((EventTypes.Member, event.state_key, ))
if membership == Membership.INVITE:
if "third_party_invite" in event.content:
key = (
EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"]
)
auth_types.append(key)
return auth_types

View file

@ -79,7 +79,6 @@ class EventBase(object):
auth_events = _event_dict_property("auth_events")
depth = _event_dict_property("depth")
content = _event_dict_property("content")
event_id = _event_dict_property("event_id")
hashes = _event_dict_property("hashes")
origin = _event_dict_property("origin")
origin_server_ts = _event_dict_property("origin_server_ts")
@ -88,8 +87,6 @@ class EventBase(object):
redacts = _event_dict_property("redacts")
room_id = _event_dict_property("room_id")
sender = _event_dict_property("sender")
state_key = _event_dict_property("state_key")
type = _event_dict_property("type")
user_id = _event_dict_property("sender")
@property
@ -162,6 +159,11 @@ class FrozenEvent(EventBase):
else:
frozen_dict = event_dict
self.event_id = event_dict["event_id"]
self.type = event_dict["type"]
if "state_key" in event_dict:
self.state_key = event_dict["state_key"]
super(FrozenEvent, self).__init__(
frozen_dict,
signatures=signatures,

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from . import EventBase, FrozenEvent
from . import EventBase, FrozenEvent, _event_dict_property
from synapse.types import EventID
@ -34,6 +34,10 @@ class EventBuilder(EventBase):
internal_metadata_dict=internal_metadata_dict,
)
event_id = _event_dict_property("event_id")
state_key = _event_dict_property("state_key")
type = _event_dict_property("type")
def build(self):
return FrozenEvent.from_event(self)

View file

@ -26,7 +26,7 @@ from synapse.util import unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.events import FrozenEvent
from synapse.events import FrozenEvent, builder
import synapse.metrics
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
@ -126,6 +126,16 @@ class FederationClient(FederationBase):
destination, content, timeout
)
@log_function
def query_user_devices(self, destination, user_id, timeout=30000):
"""Query the device keys for a list of user ids hosted on a remote
server.
"""
sent_queries_counter.inc("user_devices")
return self.transport_layer.query_user_devices(
destination, user_id, timeout
)
@log_function
def claim_client_keys(self, destination, content, timeout):
"""Claims one-time keys for a device hosted on a remote server.
@ -499,8 +509,10 @@ class FederationClient(FederationBase):
if "prev_state" not in pdu_dict:
pdu_dict["prev_state"] = []
ev = builder.EventBuilder(pdu_dict)
defer.returnValue(
(destination, self.event_from_pdu_json(pdu_dict))
(destination, ev)
)
break
except CodeMessageException as e:

View file

@ -52,8 +52,8 @@ class FederationServer(FederationBase):
self.auth = hs.get_auth()
self._room_pdu_linearizer = Linearizer()
self._server_linearizer = Linearizer()
self._room_pdu_linearizer = Linearizer("fed_room_pdu")
self._server_linearizer = Linearizer("fed_server")
# We cache responses to state queries, as they take a while and often
# come in waves.
@ -416,6 +416,9 @@ class FederationServer(FederationBase):
def on_query_client_keys(self, origin, content):
return self.on_query_request("client_keys", content)
def on_query_user_devices(self, origin, user_id):
return self.on_query_request("user_devices", user_id)
@defer.inlineCallbacks
@log_function
def on_claim_client_keys(self, origin, content):

View file

@ -100,6 +100,7 @@ class TransactionQueue(object):
self.pending_failures_by_dest = {}
self.last_device_stream_id_by_dest = {}
self.last_device_list_stream_id_by_dest = {}
# HACK to get unique tx id
self._next_txn_id = int(self.clock.time_msec())
@ -305,64 +306,77 @@ class TransactionQueue(object):
yield run_on_reactor()
while True:
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_presence = self.pending_presence_by_dest.pop(destination, {})
pending_failures = self.pending_failures_by_dest.pop(destination, [])
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_presence = self.pending_presence_by_dest.pop(destination, {})
pending_failures = self.pending_failures_by_dest.pop(destination, [])
pending_edus.extend(
self.pending_edus_keyed_by_dest.pop(destination, {}).values()
pending_edus.extend(
self.pending_edus_keyed_by_dest.pop(destination, {}).values()
)
limiter = yield get_retry_limiter(
destination,
self.clock,
self.store,
backoff_on_404=True, # If we get a 404 the other side has gone
)
device_message_edus, device_stream_id, dev_list_id = (
yield self._get_new_device_messages(destination)
)
pending_edus.extend(device_message_edus)
if pending_presence:
pending_edus.append(
Edu(
origin=self.server_name,
destination=destination,
edu_type="m.presence",
content={
"push": [
format_user_presence_state(
presence, self.clock.time_msec()
)
for presence in pending_presence.values()
]
},
)
)
limiter = yield get_retry_limiter(
destination,
self.clock,
self.store,
)
if pending_pdus:
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus))
device_message_edus, device_stream_id = (
yield self._get_new_device_messages(destination)
if not pending_pdus and not pending_edus and not pending_failures:
logger.debug("TX [%s] Nothing to send", destination)
self.last_device_stream_id_by_dest[destination] = (
device_stream_id
)
return
pending_edus.extend(device_message_edus)
if pending_presence:
pending_edus.append(
Edu(
origin=self.server_name,
destination=destination,
edu_type="m.presence",
content={
"push": [
format_user_presence_state(
presence, self.clock.time_msec()
)
for presence in pending_presence.values()
]
},
)
success = yield self._send_new_transaction(
destination, pending_pdus, pending_edus, pending_failures,
limiter=limiter,
)
if success:
# Remove the acknowledged device messages from the database
# Only bother if we actually sent some device messages
if device_message_edus:
yield self.store.delete_device_msgs_for_remote(
destination, device_stream_id
)
logger.info("Marking as sent %r %r", destination, dev_list_id)
yield self.store.mark_as_sent_devices_by_remote(
destination, dev_list_id
)
if pending_pdus:
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus))
if not pending_pdus and not pending_edus and not pending_failures:
logger.debug("TX [%s] Nothing to send", destination)
self.last_device_stream_id_by_dest[destination] = (
device_stream_id
)
return
success = yield self._send_new_transaction(
destination, pending_pdus, pending_edus, pending_failures,
device_stream_id,
should_delete_from_device_stream=bool(device_message_edus),
limiter=limiter,
)
if not success:
break
self.last_device_stream_id_by_dest[destination] = device_stream_id
self.last_device_list_stream_id_by_dest[destination] = dev_list_id
else:
break
except NotRetryingDestination:
logger.info(
logger.debug(
"TX [%s] not ready for retry yet - "
"dropping transaction for now",
destination,
@ -387,13 +401,26 @@ class TransactionQueue(object):
)
for content in contents
]
defer.returnValue((edus, stream_id))
last_device_list = self.last_device_list_stream_id_by_dest.get(destination, 0)
now_stream_id, results = yield self.store.get_devices_by_remote(
destination, last_device_list
)
edus.extend(
Edu(
origin=self.server_name,
destination=destination,
edu_type="m.device_list_update",
content=content,
)
for content in results
)
defer.returnValue((edus, stream_id, now_stream_id))
@measure_func("_send_new_transaction")
@defer.inlineCallbacks
def _send_new_transaction(self, destination, pending_pdus, pending_edus,
pending_failures, device_stream_id,
should_delete_from_device_stream, limiter):
pending_failures, limiter):
# Sort based on the order field
pending_pdus.sort(key=lambda t: t[1])
@ -477,7 +504,7 @@ class TransactionQueue(object):
code = e.code
response = e.response
if e.code == 429 or 500 <= e.code:
if e.code in (401, 404, 429) or 500 <= e.code:
logger.info(
"TX [%s] {%s} got %d response",
destination, txn_id, code
@ -504,13 +531,6 @@ class TransactionQueue(object):
"Failed to send event %s to %s", p.event_id, destination
)
success = False
else:
# Remove the acknowledged device messages from the database
if should_delete_from_device_stream:
yield self.store.delete_device_msgs_for_remote(
destination, device_stream_id
)
self.last_device_stream_id_by_dest[destination] = device_stream_id
except RuntimeError as e:
# We capture this here as there as nothing actually listens
# for this finishing functions deferred.

View file

@ -346,6 +346,32 @@ class TransportLayerClient(object):
)
defer.returnValue(content)
@defer.inlineCallbacks
@log_function
def query_user_devices(self, destination, user_id, timeout):
"""Query the devices for a user id hosted on a remote server.
Response:
{
"stream_id": "...",
"devices": [ { ... } ]
}
Args:
destination(str): The server to query.
query_content(dict): The user ids to query.
Returns:
A dict containg the device keys.
"""
path = PREFIX + "/user/devices/" + user_id
content = yield self.client.get_json(
destination=destination,
path=path,
timeout=timeout,
)
defer.returnValue(content)
@defer.inlineCallbacks
@log_function
def claim_client_keys(self, destination, query_content, timeout):

View file

@ -409,6 +409,13 @@ class FederationClientKeysQueryServlet(BaseFederationServlet):
return self.handler.on_query_client_keys(origin, content)
class FederationUserDevicesQueryServlet(BaseFederationServlet):
PATH = "/user/devices/(?P<user_id>[^/]*)"
def on_GET(self, origin, content, query, user_id):
return self.handler.on_query_user_devices(origin, user_id)
class FederationClientKeysClaimServlet(BaseFederationServlet):
PATH = "/user/keys/claim"
@ -613,6 +620,7 @@ SERVLET_CLASSES = (
FederationGetMissingEventsServlet,
FederationEventAuthServlet,
FederationClientKeysQueryServlet,
FederationUserDevicesQueryServlet,
FederationClientKeysClaimServlet,
FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet,

View file

@ -88,9 +88,13 @@ class BaseHandler(object):
current_state = yield self.store.get_events(
context.current_state_ids.values()
)
current_state = current_state.values()
else:
current_state = yield self.store.get_current_state(event.room_id)
current_state = yield self.state_handler.get_current_state(
event.room_id
)
current_state = current_state.values()
logger.info("maybe_kick_guest_users %r", current_state)
yield self.kick_guest_users(current_state)

View file

@ -65,6 +65,7 @@ class AuthHandler(BaseHandler):
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator()
@defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip):
@ -529,37 +530,11 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks
def issue_access_token(self, user_id, device_id=None):
access_token = self.generate_access_token(user_id)
access_token = self.macaroon_gen.generate_access_token(user_id)
yield self.store.add_access_token_to_user(user_id, access_token,
device_id)
defer.returnValue(access_token)
def generate_access_token(self, user_id, extra_caveats=None):
extra_caveats = extra_caveats or []
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = access")
# Include a nonce, to make sure that each login gets a different
# access token.
macaroon.add_first_party_caveat("nonce = %s" % (
stringutils.random_string_with_symbols(16),
))
for caveat in extra_caveats:
macaroon.add_first_party_caveat(caveat)
return macaroon.serialize()
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login")
now = self.hs.get_clock().time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
def generate_delete_pusher_token(self, user_id):
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = delete_pusher")
return macaroon.serialize()
def validate_short_term_login_token_and_get_user_id(self, login_token):
auth_api = self.hs.get_auth()
try:
@ -570,15 +545,6 @@ class AuthHandler(BaseHandler):
except Exception:
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
def _generate_base_macaroon(self, user_id):
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
key=self.hs.config.macaroon_secret_key)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon
@defer.inlineCallbacks
def set_password(self, user_id, newpassword, requester=None):
password_hash = self.hash(newpassword)
@ -607,7 +573,7 @@ class AuthHandler(BaseHandler):
# types (mediums) of threepid. For now, we still use the existing
# infrastructure, but this is the start of synapse gaining knowledge
# of specific types of threepid (and fixes the fact that checking
# for the presenc eof an email address during password reset was
# for the presence of an email address during password reset was
# case sensitive).
if medium == 'email':
address = address.lower()
@ -617,6 +583,17 @@ class AuthHandler(BaseHandler):
self.hs.get_clock().time_msec()
)
@defer.inlineCallbacks
def delete_threepid(self, user_id, medium, address):
# 'Canonicalise' email addresses as per above
if medium == 'email':
address = address.lower()
ret = yield self.store.user_delete_threepid(
user_id, medium, address,
)
defer.returnValue(ret)
def _save_session(self, session):
# TODO: Persistent storage
logger.debug("Saving session %s", session)
@ -656,12 +633,54 @@ class AuthHandler(BaseHandler):
Whether self.hash(password) == stored_hash (bool).
"""
if stored_hash:
return bcrypt.hashpw(password + self.hs.config.password_pepper,
stored_hash.encode('utf-8')) == stored_hash
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
stored_hash.encode('utf8')) == stored_hash
else:
return False
class MacaroonGeneartor(object):
def __init__(self, hs):
self.clock = hs.get_clock()
self.server_name = hs.config.server_name
self.macaroon_secret_key = hs.config.macaroon_secret_key
def generate_access_token(self, user_id, extra_caveats=None):
extra_caveats = extra_caveats or []
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = access")
# Include a nonce, to make sure that each login gets a different
# access token.
macaroon.add_first_party_caveat("nonce = %s" % (
stringutils.random_string_with_symbols(16),
))
for caveat in extra_caveats:
macaroon.add_first_party_caveat(caveat)
return macaroon.serialize()
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login")
now = self.clock.time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
def generate_delete_pusher_token(self, user_id):
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = delete_pusher")
return macaroon.serialize()
def _generate_base_macaroon(self, user_id):
macaroon = pymacaroons.Macaroon(
location=self.server_name,
identifier="key",
key=self.macaroon_secret_key)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon
class _AccountHandler(object):
"""A proxy object that gets passed to password auth providers so they
can register new users etc if necessary.

View file

@ -14,7 +14,11 @@
# limitations under the License.
from synapse.api import errors
from synapse.api.constants import EventTypes
from synapse.util import stringutils
from synapse.util.async import Linearizer
from synapse.util.metrics import measure_func
from synapse.types import get_domain_from_id, RoomStreamToken
from twisted.internet import defer
from ._base import BaseHandler
@ -27,6 +31,21 @@ class DeviceHandler(BaseHandler):
def __init__(self, hs):
super(DeviceHandler, self).__init__(hs)
self.hs = hs
self.state = hs.get_state_handler()
self.federation_sender = hs.get_federation_sender()
self.federation = hs.get_replication_layer()
self._remote_edue_linearizer = Linearizer(name="remote_device_list")
self.federation.register_edu_handler(
"m.device_list_update", self._incoming_device_list_update,
)
self.federation.register_query_handler(
"user_devices", self.on_federation_query_user_devices,
)
hs.get_distributor().observe("user_left_room", self.user_left_room)
@defer.inlineCallbacks
def check_device_registered(self, user_id, device_id,
initial_device_display_name=None):
@ -45,29 +64,29 @@ class DeviceHandler(BaseHandler):
str: device id (generated if none was supplied)
"""
if device_id is not None:
yield self.store.store_device(
new_device = yield self.store.store_device(
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
ignore_if_known=True,
)
if new_device:
yield self.notify_device_update(user_id, [device_id])
defer.returnValue(device_id)
# if the device id is not specified, we'll autogen one, but loop a few
# times in case of a clash.
attempts = 0
while attempts < 5:
try:
device_id = stringutils.random_string(10).upper()
yield self.store.store_device(
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
ignore_if_known=False,
)
device_id = stringutils.random_string(10).upper()
new_device = yield self.store.store_device(
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
)
if new_device:
yield self.notify_device_update(user_id, [device_id])
defer.returnValue(device_id)
except errors.StoreError:
attempts += 1
attempts += 1
raise errors.StoreError(500, "Couldn't generate a device ID.")
@ -147,6 +166,8 @@ class DeviceHandler(BaseHandler):
user_id=user_id, device_id=device_id
)
yield self.notify_device_update(user_id, [device_id])
@defer.inlineCallbacks
def update_device(self, user_id, device_id, content):
""" Update the given device
@ -166,12 +187,168 @@ class DeviceHandler(BaseHandler):
device_id,
new_display_name=content.get("display_name")
)
yield self.notify_device_update(user_id, [device_id])
except errors.StoreError, e:
if e.code == 404:
raise errors.NotFoundError()
else:
raise
@measure_func("notify_device_update")
@defer.inlineCallbacks
def notify_device_update(self, user_id, device_ids):
"""Notify that a user's device(s) has changed. Pokes the notifier, and
remote servers if the user is local.
"""
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
user_id
)
hosts = set()
if self.hs.is_mine_id(user_id):
hosts.update(get_domain_from_id(u) for u in users_who_share_room)
hosts.discard(self.server_name)
position = yield self.store.add_device_change_to_streams(
user_id, device_ids, list(hosts)
)
rooms = yield self.store.get_rooms_for_user(user_id)
room_ids = [r.room_id for r in rooms]
yield self.notifier.on_new_event(
"device_list_key", position, rooms=room_ids,
)
if hosts:
logger.info("Sending device list update notif to: %r", hosts)
for host in hosts:
self.federation_sender.send_device_messages(host)
@measure_func("device.get_user_ids_changed")
@defer.inlineCallbacks
def get_user_ids_changed(self, user_id, from_token):
"""Get list of users that have had the devices updated, or have newly
joined a room, that `user_id` may be interested in.
Args:
user_id (str)
from_token (StreamToken)
"""
rooms = yield self.store.get_rooms_for_user(user_id)
room_ids = set(r.room_id for r in rooms)
# First we check if any devices have changed
changed = yield self.store.get_user_whose_devices_changed(
from_token.device_list_key
)
# Then work out if any users have since joined
rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key)
possibly_changed = set(changed)
for room_id in rooms_changed:
# Fetch the current state at the time.
stream_ordering = RoomStreamToken.parse_stream_token(from_token.room_key)
try:
event_ids = yield self.store.get_forward_extremeties_for_room(
room_id, stream_ordering=stream_ordering
)
prev_state_ids = yield self.store.get_state_ids_for_events(event_ids)
except:
prev_state_ids = {}
current_state_ids = yield self.state.get_current_state_ids(room_id)
# If there has been any change in membership, include them in the
# possibly changed list. We'll check if they are joined below,
# and we're not toooo worried about spuriously adding users.
for key, event_id in current_state_ids.iteritems():
etype, state_key = key
if etype == EventTypes.Member:
prev_event_id = prev_state_ids.get(key, None)
if not prev_event_id or prev_event_id != event_id:
possibly_changed.add(state_key)
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
user_id
)
# Take the intersection of the users whose devices may have changed
# and those that actually still share a room with the user
defer.returnValue(users_who_share_room & possibly_changed)
@measure_func("_incoming_device_list_update")
@defer.inlineCallbacks
def _incoming_device_list_update(self, origin, edu_content):
user_id = edu_content["user_id"]
device_id = edu_content["device_id"]
stream_id = edu_content["stream_id"]
prev_ids = edu_content.get("prev_id", [])
if get_domain_from_id(user_id) != origin:
# TODO: Raise?
logger.warning("Got device list update edu for %r from %r", user_id, origin)
return
rooms = yield self.store.get_rooms_for_user(user_id)
if not rooms:
# We don't share any rooms with this user. Ignore update, as we
# probably won't get any further updates.
return
with (yield self._remote_edue_linearizer.queue(user_id)):
# If the prev id matches whats in our cache table, then we don't need
# to resync the users device list, otherwise we do.
resync = True
if len(prev_ids) == 1:
extremity = yield self.store.get_device_list_last_stream_id_for_remote(
user_id
)
logger.info("Extrem: %r, prev_ids: %r", extremity, prev_ids)
if str(extremity) == str(prev_ids[0]):
resync = False
if resync:
# Fetch all devices for the user.
result = yield self.federation.query_user_devices(origin, user_id)
stream_id = result["stream_id"]
devices = result["devices"]
yield self.store.update_remote_device_list_cache(
user_id, devices, stream_id,
)
device_ids = [device["device_id"] for device in devices]
yield self.notify_device_update(user_id, device_ids)
else:
# Simply update the single device, since we know that is the only
# change (becuase of the single prev_id matching the current cache)
content = dict(edu_content)
for key in ("user_id", "device_id", "stream_id", "prev_ids"):
content.pop(key, None)
yield self.store.update_remote_device_list_cache_entry(
user_id, device_id, content, stream_id,
)
yield self.notify_device_update(user_id, [device_id])
@defer.inlineCallbacks
def on_federation_query_user_devices(self, user_id):
stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
defer.returnValue({
"user_id": user_id,
"stream_id": stream_id,
"devices": devices,
})
@defer.inlineCallbacks
def user_left_room(self, user, room_id):
user_id = user.to_string()
rooms = yield self.store.get_rooms_for_user(user_id)
if not rooms:
# We no longer share rooms with this user, so we'll no longer
# receive device updates. Mark this in DB.
yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
def _update_device_from_client_ips(device, client_ips):
ip = client_ips.get((device["user_id"], device["device_id"]), {})

View file

@ -73,10 +73,9 @@ class E2eKeysHandler(object):
if self.is_mine_id(user_id):
local_query[user_id] = device_ids
else:
domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = device_ids
remote_queries[user_id] = device_ids
# do the queries
# Firt get local devices.
failures = {}
results = {}
if local_query:
@ -85,9 +84,42 @@ class E2eKeysHandler(object):
if user_id in local_query:
results[user_id] = keys
# Now attempt to get any remote devices from our local cache.
remote_queries_not_in_cache = {}
if remote_queries:
query_list = []
for user_id, device_ids in remote_queries.iteritems():
if device_ids:
query_list.extend((user_id, device_id) for device_id in device_ids)
else:
query_list.append((user_id, None))
user_ids_not_in_cache, remote_results = (
yield self.store.get_user_devices_from_cache(
query_list
)
)
for user_id, devices in remote_results.iteritems():
user_devices = results.setdefault(user_id, {})
for device_id, device in devices.iteritems():
keys = device.get("keys", None)
device_display_name = device.get("device_display_name", None)
if keys:
result = dict(keys)
unsigned = result.setdefault("unsigned", {})
if device_display_name:
unsigned["device_display_name"] = device_display_name
user_devices[device_id] = result
for user_id in user_ids_not_in_cache:
domain = get_domain_from_id(user_id)
r = remote_queries_not_in_cache.setdefault(domain, {})
r[user_id] = remote_queries[user_id]
# Now fetch any devices that we don't have in our cache
@defer.inlineCallbacks
def do_remote_query(destination):
destination_query = remote_queries[destination]
destination_query = remote_queries_not_in_cache[destination]
try:
limiter = yield get_retry_limiter(
destination, self.clock, self.store
@ -119,7 +151,7 @@ class E2eKeysHandler(object):
yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(do_remote_query)(destination)
for destination in remote_queries
for destination in remote_queries_not_in_cache
]))
defer.returnValue({
@ -162,7 +194,7 @@ class E2eKeysHandler(object):
# "unsigned" section
for user_id, device_keys in results.items():
for device_id, device_info in device_keys.items():
r = json.loads(device_info["key_json"])
r = dict(device_info["keys"])
r["unsigned"] = {}
display_name = device_info["device_display_name"]
if display_name is not None:
@ -255,10 +287,12 @@ class E2eKeysHandler(object):
device_id, user_id, time_now
)
# TODO: Sign the JSON with the server key
yield self.store.set_e2e_device_keys(
user_id, device_id, time_now,
encode_canonical_json(device_keys)
changed = yield self.store.set_e2e_device_keys(
user_id, device_id, time_now, device_keys,
)
if changed:
# Only notify about device updates *if* the keys actually changed
yield self.device_handler.notify_device_update(user_id, [device_id])
one_time_keys = keys.get("one_time_keys", None)
if one_time_keys:

View file

@ -591,12 +591,12 @@ class FederationHandler(BaseHandler):
event_ids = list(extremities.keys())
logger.info("calling resolve_state_groups in _maybe_backfill")
logger.debug("calling resolve_state_groups in _maybe_backfill")
states = yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e])
for e in event_ids
]))
states = dict(zip(event_ids, [s[1] for s in states]))
states = dict(zip(event_ids, [s.state for s in states]))
state_map = yield self.store.get_events(
[e_id for ids in states.values() for e_id in ids],
@ -1319,7 +1319,6 @@ class FederationHandler(BaseHandler):
event_stream_id, max_stream_id = yield self.store.persist_event(
event, new_event_context,
current_state=state,
)
defer.returnValue((event_stream_id, max_stream_id))
@ -1530,7 +1529,7 @@ class FederationHandler(BaseHandler):
(d.type, d.state_key): d for d in different_events if d
})
new_state, prev_state = self.state_handler.resolve_events(
new_state = self.state_handler.resolve_events(
[local_view.values(), remote_view.values()],
event
)

View file

@ -208,8 +208,10 @@ class MessageHandler(BaseHandler):
content = builder.content
try:
content["displayname"] = yield profile.get_displayname(target)
content["avatar_url"] = yield profile.get_avatar_url(target)
if "displayname" not in content:
content["displayname"] = yield profile.get_displayname(target)
if "avatar_url" not in content:
content["avatar_url"] = yield profile.get_avatar_url(target)
except Exception as e:
logger.info(
"Failed to get profile information for %r: %s",
@ -279,7 +281,9 @@ class MessageHandler(BaseHandler):
if event.type == EventTypes.Message:
presence = self.hs.get_presence_handler()
yield presence.bump_presence_active_time(user)
# We don't want to block sending messages on any presence code. This
# matters as sometimes presence code can take a while.
preserve_fn(presence.bump_presence_active_time)(user)
@defer.inlineCallbacks
def deduplicate_state_event(self, event, context):

View file

@ -574,7 +574,7 @@ class PresenceHandler(object):
if not local_states:
continue
users = yield self.state.get_current_user_in_room(room_id)
users = yield self.store.get_users_in_room(room_id)
hosts = set(get_domain_from_id(u) for u in users)
for host in hosts:
@ -766,7 +766,7 @@ class PresenceHandler(object):
# don't need to send to local clients here, as that is done as part
# of the event stream/sync.
# TODO: Only send to servers not already in the room.
user_ids = yield self.state.get_current_user_in_room(room_id)
user_ids = yield self.store.get_users_in_room(room_id)
if self.is_mine(user):
state = yield self.current_state_for_user(user.to_string())
@ -1011,7 +1011,7 @@ class PresenceEventSource(object):
@defer.inlineCallbacks
@log_function
def get_new_events(self, user, from_key, room_ids=None, include_offline=True,
**kwargs):
explicit_room_id=None, **kwargs):
# The process for getting presence events are:
# 1. Get the rooms the user is in.
# 2. Get the list of user in the rooms.
@ -1028,22 +1028,24 @@ class PresenceEventSource(object):
user_id = user.to_string()
if from_key is not None:
from_key = int(from_key)
room_ids = room_ids or []
presence = self.get_presence_handler()
stream_change_cache = self.store.presence_stream_cache
if not room_ids:
rooms = yield self.store.get_rooms_for_user(user_id)
room_ids = set(e.room_id for e in rooms)
else:
room_ids = set(room_ids)
max_token = self.store.get_current_presence_token()
plist = yield self.store.get_presence_list_accepted(user.localpart)
friends = set(row["observed_user_id"] for row in plist)
friends.add(user_id) # So that we receive our own presence
users_interested_in = set(row["observed_user_id"] for row in plist)
users_interested_in.add(user_id) # So that we receive our own presence
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
user_id
)
users_interested_in.update(users_who_share_room)
if explicit_room_id:
user_ids = yield self.store.get_users_in_room(explicit_room_id)
users_interested_in.update(user_ids)
user_ids_changed = set()
changed = None
@ -1055,35 +1057,19 @@ class PresenceEventSource(object):
# work out if we share a room or they're in our presence list
get_updates_counter.inc("stream")
for other_user_id in changed:
if other_user_id in friends:
if other_user_id in users_interested_in:
user_ids_changed.add(other_user_id)
continue
other_rooms = yield self.store.get_rooms_for_user(other_user_id)
if room_ids.intersection(e.room_id for e in other_rooms):
user_ids_changed.add(other_user_id)
continue
else:
# Too many possible updates. Find all users we can see and check
# if any of them have changed.
get_updates_counter.inc("full")
user_ids_to_check = set()
for room_id in room_ids:
users = yield self.state.get_current_user_in_room(room_id)
user_ids_to_check.update(users)
user_ids_to_check.update(friends)
# Always include yourself. Only really matters for when the user is
# not in any rooms, but still.
user_ids_to_check.add(user_id)
if from_key:
user_ids_changed = stream_change_cache.get_entities_changed(
user_ids_to_check, from_key,
users_interested_in, from_key,
)
else:
user_ids_changed = user_ids_to_check
user_ids_changed = users_interested_in
updates = yield presence.current_state_for_users(user_ids_changed)

View file

@ -40,6 +40,8 @@ class RegistrationHandler(BaseHandler):
self._next_generated_user_id = None
self.macaroon_gen = hs.get_macaroon_generator()
@defer.inlineCallbacks
def check_username(self, localpart, guest_access_token=None,
assigned_user_id=None):
@ -143,7 +145,7 @@ class RegistrationHandler(BaseHandler):
token = None
if generate_token:
token = self.auth_handler().generate_access_token(user_id)
token = self.macaroon_gen.generate_access_token(user_id)
yield self.store.register(
user_id=user_id,
token=token,
@ -167,7 +169,7 @@ class RegistrationHandler(BaseHandler):
user_id = user.to_string()
yield self.check_user_id_not_appservice_exclusive(user_id)
if generate_token:
token = self.auth_handler().generate_access_token(user_id)
token = self.macaroon_gen.generate_access_token(user_id)
try:
yield self.store.register(
user_id=user_id,
@ -254,7 +256,7 @@ class RegistrationHandler(BaseHandler):
user_id = user.to_string()
yield self.check_user_id_not_appservice_exclusive(user_id)
token = self.auth_handler().generate_access_token(user_id)
token = self.macaroon_gen.generate_access_token(user_id)
try:
yield self.store.register(
user_id=user_id,
@ -399,7 +401,7 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
token = self.auth_handler().generate_access_token(user_id)
token = self.macaroon_gen.generate_access_token(user_id)
if need_register:
yield self.store.register(

View file

@ -437,6 +437,7 @@ class RoomEventSource(object):
limit,
room_ids,
is_guest,
explicit_room_id=None,
):
# We just ignore the key for now.

View file

@ -62,17 +62,18 @@ class RoomListHandler(BaseHandler):
appservice and network id to use an appservice specific one.
Setting to None returns all public rooms across all lists.
"""
if search_filter or (network_tuple and network_tuple.appservice_id is not None):
if search_filter:
# We explicitly don't bother caching searches or requests for
# appservice specific lists.
return self._get_public_room_list(
limit, since_token, search_filter, network_tuple=network_tuple,
)
result = self.response_cache.get((limit, since_token))
key = (limit, since_token, network_tuple)
result = self.response_cache.get(key)
if not result:
result = self.response_cache.set(
(limit, since_token),
key,
self._get_public_room_list(
limit, since_token, network_tuple=network_tuple
)

View file

@ -45,7 +45,7 @@ class RoomMemberHandler(BaseHandler):
def __init__(self, hs):
super(RoomMemberHandler, self).__init__(hs)
self.member_linearizer = Linearizer()
self.member_linearizer = Linearizer(name="member")
self.clock = hs.get_clock()
@ -89,7 +89,7 @@ class RoomMemberHandler(BaseHandler):
duplicate = yield msg_handler.deduplicate_state_event(event, context)
if duplicate is not None:
# Discard the new event since this membership change is a no-op.
return
defer.returnValue(duplicate)
yield msg_handler.handle_new_client_event(
requester,
@ -120,6 +120,8 @@ class RoomMemberHandler(BaseHandler):
if prev_member_event.membership == Membership.JOIN:
user_left_room(self.distributor, target, room_id)
defer.returnValue(event)
@defer.inlineCallbacks
def remote_join(self, remote_room_hosts, room_id, user, content):
if len(remote_room_hosts) == 0:
@ -187,6 +189,7 @@ class RoomMemberHandler(BaseHandler):
ratelimit=True,
content=None,
):
content_specified = bool(content)
if content is None:
content = {}
@ -229,6 +232,13 @@ class RoomMemberHandler(BaseHandler):
errcode=Codes.BAD_STATE
)
if old_state:
same_content = content == old_state.content
same_membership = old_membership == effective_membership_state
same_sender = requester.user.to_string() == old_state.sender
if same_sender and same_membership and same_content:
defer.returnValue(old_state)
is_host_in_room = yield self._is_host_in_room(current_state_ids)
if effective_membership_state == Membership.JOIN:
@ -247,8 +257,9 @@ class RoomMemberHandler(BaseHandler):
content["membership"] = Membership.JOIN
profile = self.hs.get_handlers().profile_handler
content["displayname"] = yield profile.get_displayname(target)
content["avatar_url"] = yield profile.get_avatar_url(target)
if not content_specified:
content["displayname"] = yield profile.get_displayname(target)
content["avatar_url"] = yield profile.get_avatar_url(target)
if requester.is_guest:
content["kind"] = "guest"
@ -290,7 +301,7 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue({})
yield self._local_membership_update(
res = yield self._local_membership_update(
requester=requester,
target=target,
room_id=room_id,
@ -300,6 +311,7 @@ class RoomMemberHandler(BaseHandler):
prev_event_ids=latest_event_ids,
content=content,
)
defer.returnValue(res)
@defer.inlineCallbacks
def send_membership_event(

View file

@ -16,7 +16,7 @@
from synapse.api.constants import Membership, EventTypes
from synapse.util.async import concurrently_execute
from synapse.util.logcontext import LoggingContext
from synapse.util.metrics import Measure
from synapse.util.metrics import Measure, measure_func
from synapse.util.caches.response_cache import ResponseCache
from synapse.push.clientformat import format_push_rules_for_user
from synapse.visibility import filter_events_for_client
@ -115,6 +115,7 @@ class SyncResult(collections.namedtuple("SyncResult", [
"invited", # InvitedSyncResult for each invited room.
"archived", # ArchivedSyncResult for each archived room.
"to_device", # List of direct messages for the device.
"device_lists", # List of user_ids whose devices have chanegd
])):
__slots__ = []
@ -129,7 +130,8 @@ class SyncResult(collections.namedtuple("SyncResult", [
self.invited or
self.archived or
self.account_data or
self.to_device
self.to_device or
self.device_lists
)
@ -544,6 +546,10 @@ class SyncHandler(object):
yield self._generate_sync_entry_for_to_device(sync_result_builder)
device_lists = yield self._generate_sync_entry_for_device_list(
sync_result_builder
)
defer.returnValue(SyncResult(
presence=sync_result_builder.presence,
account_data=sync_result_builder.account_data,
@ -551,9 +557,33 @@ class SyncHandler(object):
invited=sync_result_builder.invited,
archived=sync_result_builder.archived,
to_device=sync_result_builder.to_device,
device_lists=device_lists,
next_batch=sync_result_builder.now_token,
))
@measure_func("_generate_sync_entry_for_device_list")
@defer.inlineCallbacks
def _generate_sync_entry_for_device_list(self, sync_result_builder):
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
if since_token and since_token.device_list_key:
rooms = yield self.store.get_rooms_for_user(user_id)
room_ids = set(r.room_id for r in rooms)
user_ids_changed = set()
changed = yield self.store.get_user_whose_devices_changed(
since_token.device_list_key
)
for other_user_id in changed:
other_rooms = yield self.store.get_rooms_for_user(other_user_id)
if room_ids.intersection(e.room_id for e in other_rooms):
user_ids_changed.add(other_user_id)
defer.returnValue(user_ids_changed)
else:
defer.returnValue([])
@defer.inlineCallbacks
def _generate_sync_entry_for_to_device(self, sync_result_builder):
"""Generates the portion of the sync response. Populates

View file

@ -25,7 +25,7 @@ from synapse.http.endpoint import SpiderEndpoint
from canonicaljson import encode_canonical_json
from twisted.internet import defer, reactor, ssl, protocol, task
from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.web.client import (
BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent,
readBody, PartialDownloadError,
@ -386,26 +386,23 @@ class SpiderEndpointFactory(object):
def endpointForURI(self, uri):
logger.info("Getting endpoint for %s", uri.toBytes())
if uri.scheme == "http":
return SpiderEndpoint(
reactor, uri.host, uri.port, self.blacklist, self.whitelist,
endpoint=TCP4ClientEndpoint,
endpoint_kw_args={
'timeout': 15
},
)
endpoint_factory = HostnameEndpoint
elif uri.scheme == "https":
tlsPolicy = self.policyForHTTPS.creatorForNetloc(uri.host, uri.port)
return SpiderEndpoint(
reactor, uri.host, uri.port, self.blacklist, self.whitelist,
endpoint=SSL4ClientEndpoint,
endpoint_kw_args={
'sslContextFactory': tlsPolicy,
'timeout': 15
},
)
tlsCreator = self.policyForHTTPS.creatorForNetloc(uri.host, uri.port)
def endpoint_factory(reactor, host, port, **kw):
return wrapClientTLS(
tlsCreator,
HostnameEndpoint(reactor, host, port, **kw))
else:
logger.warn("Can't get endpoint for unrecognised scheme %s", uri.scheme)
return None
return SpiderEndpoint(
reactor, uri.host, uri.port, self.blacklist, self.whitelist,
endpoint=endpoint_factory, endpoint_kw_args=dict(timeout=15),
)
class SpiderHttpClient(SimpleHttpClient):

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet import defer, reactor
from twisted.internet.error import ConnectError
from twisted.names import client, dns
@ -58,11 +58,13 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
endpoint_kw_args.update(timeout=timeout)
if ssl_context_factory is None:
transport_endpoint = TCP4ClientEndpoint
transport_endpoint = HostnameEndpoint
default_port = 8008
else:
transport_endpoint = SSL4ClientEndpoint
endpoint_kw_args.update(sslContextFactory=ssl_context_factory)
def transport_endpoint(reactor, host, port, timeout):
return wrapClientTLS(
ssl_context_factory,
HostnameEndpoint(reactor, host, port, timeout=timeout))
default_port = 8448
if port is None:
@ -142,7 +144,7 @@ class SpiderEndpoint(object):
Implements twisted.internet.interfaces.IStreamClientEndpoint.
"""
def __init__(self, reactor, host, port, blacklist, whitelist,
endpoint=TCP4ClientEndpoint, endpoint_kw_args={}):
endpoint=HostnameEndpoint, endpoint_kw_args={}):
self.reactor = reactor
self.host = host
self.port = port
@ -180,7 +182,7 @@ class SRVClientEndpoint(object):
"""
def __init__(self, reactor, service, domain, protocol="tcp",
default_port=None, endpoint=TCP4ClientEndpoint,
default_port=None, endpoint=HostnameEndpoint,
endpoint_kw_args={}):
self.reactor = reactor
self.service_name = "_%s._%s.%s" % (service, protocol, domain)

View file

@ -378,6 +378,7 @@ class Notifier(object):
limit=limit,
is_guest=is_peeking,
room_ids=room_ids,
explicit_room_id=explicit_room_id,
)
if name == "room":

View file

@ -81,7 +81,7 @@ class Mailer(object):
def __init__(self, hs, app_name):
self.hs = hs
self.store = self.hs.get_datastore()
self.auth_handler = self.hs.get_auth_handler()
self.macaroon_gen = self.hs.get_macaroon_generator()
self.state_handler = self.hs.get_state_handler()
loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir)
self.app_name = app_name
@ -439,15 +439,23 @@ class Mailer(object):
})
def make_room_link(self, room_id):
# need /beta for Universal Links to work on iOS
if self.app_name == "Vector":
return "https://vector.im/beta/#/room/%s" % (room_id,)
if self.hs.config.email_riot_base_url:
base_url = self.hs.config.email_riot_base_url
elif self.app_name == "Vector":
# need /beta for Universal Links to work on iOS
base_url = "https://vector.im/beta/#/room"
else:
return "https://matrix.to/#/%s" % (room_id,)
base_url = "https://matrix.to/#"
return "%s/%s" % (base_url, room_id)
def make_notif_link(self, notif):
# need /beta for Universal Links to work on iOS
if self.app_name == "Vector":
if self.hs.config.email_riot_base_url:
return "%s/#/room/%s/%s" % (
self.hs.config.email_riot_base_url,
notif['room_id'], notif['event_id']
)
elif self.app_name == "Vector":
# need /beta for Universal Links to work on iOS
return "https://vector.im/beta/#/room/%s/%s" % (
notif['room_id'], notif['event_id']
)
@ -458,7 +466,7 @@ class Mailer(object):
def make_unsubscribe_link(self, user_id, app_id, email_address):
params = {
"access_token": self.auth_handler.generate_delete_pusher_token(user_id),
"access_token": self.macaroon_gen.generate_delete_pusher_token(user_id),
"app_id": app_id,
"pushkey": email_address,
}

View file

@ -52,7 +52,7 @@ def get_badge_count(store, user_id):
def get_context_for_event(store, state_handler, ev, user_id):
ctx = {}
room_state_ids = yield state_handler.get_current_state_ids(ev.room_id)
room_state_ids = yield store.get_state_ids_for_event(ev.event_id)
# we no longer bother setting room_alias, and make room_name the
# human-readable name instead, be that m.room.name, an alias or

View file

@ -24,7 +24,7 @@ REQUIREMENTS = {
"signedjson>=1.0.0": ["signedjson>=1.0.0"],
"pynacl==0.3.0": ["nacl==0.3.0", "nacl.bindings"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
"Twisted>=15.1.0": ["twisted>=15.1.0"],
"Twisted>=16.0.0": ["twisted>=16.0.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
"pyyaml": ["yaml"],
"pyasn1": ["pyasn1"],

View file

@ -46,6 +46,7 @@ STREAM_NAMES = (
("to_device",),
("public_rooms",),
("federation",),
("device_lists",),
)
@ -140,6 +141,7 @@ class ReplicationResource(Resource):
caches_token = self.store.get_cache_stream_token()
public_rooms_token = self.store.get_current_public_room_stream_id()
federation_token = self.federation_sender.get_current_token()
device_list_token = self.store.get_device_stream_token()
defer.returnValue(_ReplicationToken(
room_stream_token,
@ -155,6 +157,7 @@ class ReplicationResource(Resource):
int(stream_token.to_device_key),
int(public_rooms_token),
int(federation_token),
int(device_list_token),
))
@request_handler()
@ -214,6 +217,7 @@ class ReplicationResource(Resource):
yield self.caches(writer, current_token, limit, request_streams)
yield self.to_device(writer, current_token, limit, request_streams)
yield self.public_rooms(writer, current_token, limit, request_streams)
yield self.device_lists(writer, current_token, limit, request_streams)
self.federation(writer, current_token, limit, request_streams, federation_ack)
self.streams(writer, current_token, request_streams)
@ -295,9 +299,6 @@ class ReplicationResource(Resource):
"backward_ex_outliers", res.backward_ex_outliers,
("position", "event_id", "state_group"),
)
writer.write_header_and_rows(
"state_resets", res.state_resets, ("position",),
)
@defer.inlineCallbacks
def presence(self, writer, current_token, request_streams):
@ -495,6 +496,20 @@ class ReplicationResource(Resource):
"position", "type", "content",
), position=upto_token)
@defer.inlineCallbacks
def device_lists(self, writer, current_token, limit, request_streams):
current_position = current_token.device_lists
device_lists = request_streams.get("device_lists")
if device_lists is not None and device_lists != current_position:
changes = yield self.store.get_all_device_list_changes_for_remotes(
device_lists,
)
writer.write_header_and_rows("device_lists", changes, (
"position", "user_id", "destination",
), position=current_position)
class _Writer(object):
"""Writes the streams as a JSON object as the response to the request"""
@ -527,7 +542,7 @@ class _Writer(object):
class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
"events", "presence", "typing", "receipts", "account_data", "backfill",
"push_rules", "pushers", "state", "caches", "to_device", "public_rooms",
"federation",
"federation", "device_lists",
))):
__slots__ = []

View file

@ -0,0 +1,72 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedDeviceStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedDeviceStore, self).__init__(db_conn, hs)
self.hs = hs
self._device_list_id_gen = SlavedIdTracker(
db_conn, "device_lists_stream", "stream_id",
)
device_list_max = self._device_list_id_gen.get_current_token()
self._device_list_stream_cache = StreamChangeCache(
"DeviceListStreamChangeCache", device_list_max,
)
self._device_list_federation_stream_cache = StreamChangeCache(
"DeviceListFederationStreamChangeCache", device_list_max,
)
get_device_stream_token = DataStore.get_device_stream_token.__func__
get_user_whose_devices_changed = DataStore.get_user_whose_devices_changed.__func__
get_devices_by_remote = DataStore.get_devices_by_remote.__func__
_get_devices_by_remote_txn = DataStore._get_devices_by_remote_txn.__func__
_get_e2e_device_keys_txn = DataStore._get_e2e_device_keys_txn.__func__
mark_as_sent_devices_by_remote = DataStore.mark_as_sent_devices_by_remote.__func__
_mark_as_sent_devices_by_remote_txn = (
DataStore._mark_as_sent_devices_by_remote_txn.__func__
)
def stream_positions(self):
result = super(SlavedDeviceStore, self).stream_positions()
result["device_lists"] = self._device_list_id_gen.get_current_token()
return result
def process_replication(self, result):
stream = result.get("device_lists")
if stream:
self._device_list_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
stream_id = row[0]
user_id = row[1]
destination = row[2]
self._device_list_stream_cache.entity_has_changed(
user_id, stream_id
)
if destination:
self._device_list_federation_stream_cache.entity_has_changed(
destination, stream_id
)
return super(SlavedDeviceStore, self).process_replication(result)

View file

@ -73,12 +73,12 @@ class SlavedEventStore(BaseSlavedStore):
# to reach inside the __dict__ to extract them.
get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"]
get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"]
get_users_who_share_room_with_user = (
RoomMemberStore.__dict__["get_users_who_share_room_with_user"]
)
get_latest_event_ids_in_room = EventFederationStore.__dict__[
"get_latest_event_ids_in_room"
]
_get_current_state_for_key = StateStore.__dict__[
"_get_current_state_for_key"
]
get_invited_rooms_for_user = RoomMemberStore.__dict__[
"get_invited_rooms_for_user"
]
@ -115,8 +115,6 @@ class SlavedEventStore(BaseSlavedStore):
)
get_event = DataStore.get_event.__func__
get_events = DataStore.get_events.__func__
get_current_state = DataStore.get_current_state.__func__
get_current_state_for_key = DataStore.get_current_state_for_key.__func__
get_rooms_for_user_where_membership_is = (
DataStore.get_rooms_for_user_where_membership_is.__func__
)
@ -197,10 +195,6 @@ class SlavedEventStore(BaseSlavedStore):
return result
def process_replication(self, result):
state_resets = set(
r[0] for r in result.get("state_resets", {"rows": []})["rows"]
)
stream = result.get("events")
if stream:
self._stream_id_gen.advance(int(stream["position"]))
@ -210,7 +204,7 @@ class SlavedEventStore(BaseSlavedStore):
for row in stream["rows"]:
self._process_replication_row(
row, backfilled=False, state_resets=state_resets
row, backfilled=False,
)
stream = result.get("backfill")
@ -218,7 +212,7 @@ class SlavedEventStore(BaseSlavedStore):
self._backfill_id_gen.advance(-int(stream["position"]))
for row in stream["rows"]:
self._process_replication_row(
row, backfilled=True, state_resets=state_resets
row, backfilled=True,
)
stream = result.get("forward_ex_outliers")
@ -237,21 +231,15 @@ class SlavedEventStore(BaseSlavedStore):
return super(SlavedEventStore, self).process_replication(result)
def _process_replication_row(self, row, backfilled, state_resets):
position = row[0]
def _process_replication_row(self, row, backfilled):
internal = json.loads(row[1])
event_json = json.loads(row[2])
event = FrozenEvent(event_json, internal_metadata_dict=internal)
self.invalidate_caches_for_event(
event, backfilled, reset_state=position in state_resets
event, backfilled,
)
def invalidate_caches_for_event(self, event, backfilled, reset_state):
if reset_state:
self._get_current_state_for_key.invalidate_all()
self.get_rooms_for_user.invalidate_all()
self.get_users_in_room.invalidate((event.room_id,))
def invalidate_caches_for_event(self, event, backfilled):
self._invalidate_get_event_cache(event.event_id)
self.get_latest_event_ids_in_room.invalidate((event.room_id,))
@ -273,8 +261,6 @@ class SlavedEventStore(BaseSlavedStore):
self._invalidate_get_event_cache(event.redacts)
if event.type == EventTypes.Member:
self.get_rooms_for_user.invalidate((event.state_key,))
self.get_users_in_room.invalidate((event.room_id,))
self._membership_stream_cache.entity_has_changed(
event.state_key, event.internal_metadata.stream_ordering
)
@ -289,7 +275,3 @@ class SlavedEventStore(BaseSlavedStore):
if (not event.internal_metadata.is_invite_from_remote()
and event.internal_metadata.is_outlier()):
return
self._get_current_state_for_key.invalidate((
event.room_id, event.type, event.state_key
))

View file

@ -86,7 +86,11 @@ class HttpTransactionCache(object):
pass # execute the function instead.
deferred = fn(*args, **kwargs)
observable = ObservableDeferred(deferred)
# We don't add an errback to the raw deferred, so we ask ObservableDeferred
# to swallow the error. This is fine as the error will still be reported
# to the observers.
observable = ObservableDeferred(deferred, consumeErrors=True)
self.transactions[txn_key] = (observable, self.clock.time_msec())
return observable.observe()

View file

@ -118,8 +118,14 @@ class LoginRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def do_password_login(self, login_submission):
if 'medium' in login_submission and 'address' in login_submission:
address = login_submission['address']
if login_submission['medium'] == 'email':
# For emails, transform the address to lowercase.
# We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py)
address = address.lower()
user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
login_submission['medium'], login_submission['address']
login_submission['medium'], address
)
if not user_id:
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
@ -324,6 +330,7 @@ class CasTicketServlet(ClientV1RestServlet):
self.cas_required_attributes = hs.config.cas_required_attributes
self.auth_handler = hs.get_auth_handler()
self.handlers = hs.get_handlers()
self.macaroon_gen = hs.get_macaroon_generator()
@defer.inlineCallbacks
def on_GET(self, request):
@ -362,7 +369,9 @@ class CasTicketServlet(ClientV1RestServlet):
yield self.handlers.registration_handler.register(localpart=user)
)
login_token = auth_handler.generate_short_term_login_token(registered_user_id)
login_token = self.macaroon_gen.generate_short_term_login_token(
registered_user_id
)
redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
login_token)
request.redirect(redirect_url)

View file

@ -152,23 +152,29 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
if state_key is not None:
event_dict["state_key"] = state_key
msg_handler = self.handlers.message_handler
event, context = yield msg_handler.create_event(
event_dict,
token_id=requester.access_token_id,
txn_id=txn_id,
)
if event_type == EventTypes.Member:
yield self.handlers.room_member_handler.send_membership_event(
membership = content.get("membership", None)
event = yield self.handlers.room_member_handler.update_membership(
requester,
event,
context,
target=UserID.from_string(state_key),
room_id=room_id,
action=membership,
content=content,
)
else:
msg_handler = self.handlers.message_handler
event, context = yield msg_handler.create_event(
event_dict,
token_id=requester.access_token_id,
txn_id=txn_id,
)
yield msg_handler.send_nonmember_event(requester, event, context)
defer.returnValue((200, {"event_id": event.event_id}))
ret = {}
if event:
ret = {"event_id": event.event_id}
defer.returnValue((200, ret))
# TODO: Needs unit testing for generic events + feedback

View file

@ -32,19 +32,27 @@ class VoipRestServlet(ClientV1RestServlet):
turnUris = self.hs.config.turn_uris
turnSecret = self.hs.config.turn_shared_secret
turnUsername = self.hs.config.turn_username
turnPassword = self.hs.config.turn_password
userLifetime = self.hs.config.turn_user_lifetime
if not turnUris or not turnSecret or not userLifetime:
if turnUris and turnSecret and userLifetime:
expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000
username = "%d:%s" % (expiry, requester.user.to_string())
mac = hmac.new(turnSecret, msg=username, digestmod=hashlib.sha1)
# We need to use standard padded base64 encoding here
# encode_base64 because we need to add the standard padding to get the
# same result as the TURN server.
password = base64.b64encode(mac.digest())
elif turnUris and turnUsername and turnPassword and userLifetime:
username = turnUsername
password = turnPassword
else:
defer.returnValue((200, {}))
expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000
username = "%d:%s" % (expiry, requester.user.to_string())
mac = hmac.new(turnSecret, msg=username, digestmod=hashlib.sha1)
# We need to use standard padded base64 encoding here
# encode_base64 because we need to add the standard padding to get the
# same result as the TURN server.
password = base64.b64encode(mac.digest())
defer.returnValue((200, {
'username': username,
'password': password,

View file

@ -96,6 +96,11 @@ class PasswordRestServlet(RestServlet):
threepid = result[LoginType.EMAIL_IDENTITY]
if 'medium' not in threepid or 'address' not in threepid:
raise SynapseError(500, "Malformed threepid")
if threepid['medium'] == 'email':
# For emails, transform the address to lowercase.
# We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py)
threepid['address'] = threepid['address'].lower()
# if using email, we must know about the email they're authing with!
threepid_user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
threepid['medium'], threepid['address']
@ -241,7 +246,7 @@ class ThreepidRestServlet(RestServlet):
for reqd in ['medium', 'address', 'validated_at']:
if reqd not in threepid:
logger.warn("Couldn't add 3pid: invalid response from ID sevrer")
logger.warn("Couldn't add 3pid: invalid response from ID server")
raise SynapseError(500, "Invalid response from ID Server")
yield self.auth_handler.add_threepid(
@ -263,9 +268,43 @@ class ThreepidRestServlet(RestServlet):
defer.returnValue((200, {}))
class ThreepidDeleteRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid/delete$", releases=())
def __init__(self, hs):
super(ThreepidDeleteRestServlet, self).__init__()
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
@defer.inlineCallbacks
def on_POST(self, request):
yield run_on_reactor()
body = parse_json_object_from_request(request)
required = ['medium', 'address']
absent = []
for k in required:
if k not in body:
absent.append(k)
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
yield self.auth_handler.delete_threepid(
user_id, body['medium'], body['address']
)
defer.returnValue((200, {}))
def register_servlets(hs, http_server):
PasswordRequestTokenRestServlet(hs).register(http_server)
PasswordRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server)
ThreepidRequestTokenRestServlet(hs).register(http_server)
ThreepidRestServlet(hs).register(http_server)
ThreepidDeleteRestServlet(hs).register(http_server)

View file

@ -21,6 +21,8 @@ from synapse.api.errors import SynapseError
from synapse.http.servlet import (
RestServlet, parse_json_object_from_request, parse_integer
)
from synapse.http.servlet import parse_string
from synapse.types import StreamToken
from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
@ -149,6 +151,52 @@ class KeyQueryServlet(RestServlet):
defer.returnValue((200, result))
class KeyChangesServlet(RestServlet):
"""Returns the list of changes of keys between two stream tokens (may return
spurious extra results, since we currently ignore the `to` param).
GET /keys/changes?from=...&to=...
200 OK
{ "changed": ["@foo:example.com"] }
"""
PATTERNS = client_v2_patterns(
"/keys/changes$",
releases=()
)
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
super(KeyChangesServlet, self).__init__()
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
from_token_string = parse_string(request, "from")
# We want to enforce they do pass us one, but we ignore it and return
# changes after the "to" as well as before.
parse_string(request, "to")
from_token = StreamToken.from_string(from_token_string)
user_id = requester.user.to_string()
changed = yield self.device_handler.get_user_ids_changed(
user_id, from_token,
)
defer.returnValue((200, {
"changed": list(changed),
}))
class OneTimeKeyServlet(RestServlet):
"""
POST /keys/claim HTTP/1.1
@ -192,4 +240,5 @@ class OneTimeKeyServlet(RestServlet):
def register_servlets(hs, http_server):
KeyUploadServlet(hs).register(http_server)
KeyQueryServlet(hs).register(http_server)
KeyChangesServlet(hs).register(http_server)
OneTimeKeyServlet(hs).register(http_server)

View file

@ -96,6 +96,7 @@ class RegisterRestServlet(RestServlet):
self.registration_handler = hs.get_handlers().registration_handler
self.identity_handler = hs.get_handlers().identity_handler
self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator()
@defer.inlineCallbacks
def on_POST(self, request):
@ -436,7 +437,7 @@ class RegisterRestServlet(RestServlet):
user_id, device_id, initial_display_name
)
access_token = self.auth_handler.generate_access_token(
access_token = self.macaroon_gen.generate_access_token(
user_id, ["guest = true"]
)
defer.returnValue((200, {

View file

@ -170,12 +170,16 @@ class SyncRestServlet(RestServlet):
)
archived = self.encode_archived(
sync_result.archived, time_now, requester.access_token_id, filter.event_fields
sync_result.archived, time_now, requester.access_token_id,
filter.event_fields,
)
response_content = {
"account_data": {"events": sync_result.account_data},
"to_device": {"events": sync_result.to_device},
"device_lists": {
"changed": list(sync_result.device_lists),
},
"presence": self.encode_presence(
sync_result.presence, time_now
),

View file

@ -61,7 +61,7 @@ class MediaRepository(object):
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.thumbnail_requirements = hs.config.thumbnail_requirements
self.remote_media_linearizer = Linearizer()
self.remote_media_linearizer = Linearizer(name="media_remote")
self.recently_accessed_remotes = set()
@ -98,6 +98,8 @@ class MediaRepository(object):
with open(fname, "wb") as f:
f.write(content)
logger.info("Stored local media in file %r", fname)
yield self.store.store_local_media(
media_id=media_id,
media_type=media_type,
@ -190,6 +192,8 @@ class MediaRepository(object):
else:
upload_name = None
logger.info("Stored remote media in file %r", fname)
yield self.store.store_cached_remote_media(
origin=server_name,
media_id=media_id,

View file

@ -16,6 +16,10 @@
import PIL.Image as Image
from io import BytesIO
import logging
logger = logging.getLogger(__name__)
class Thumbnailer(object):
@ -86,4 +90,5 @@ class Thumbnailer(object):
output_bytes = output_bytes_io.getvalue()
with open(output_path, "wb") as output_file:
output_file.write(output_bytes)
logger.info("Stored thumbnail in file %r", output_path)
return len(output_bytes)

View file

@ -97,6 +97,8 @@ class UploadResource(Resource):
content_length, requester.user
)
logger.info("Uploaded content with URI %r", content_uri)
respond_with_json(
request, 200, {"content_uri": content_uri}, send_cors=True
)

View file

@ -37,7 +37,7 @@ from synapse.federation.transport.client import TransportLayerClient
from synapse.federation.transaction_queue import TransactionQueue
from synapse.handlers import Handlers
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.handlers.auth import AuthHandler
from synapse.handlers.auth import AuthHandler, MacaroonGeneartor
from synapse.handlers.devicemessage import DeviceMessageHandler
from synapse.handlers.device import DeviceHandler
from synapse.handlers.e2e_keys import E2eKeysHandler
@ -131,6 +131,7 @@ class HomeServer(object):
'federation_transport_client',
'federation_sender',
'receipts_handler',
'macaroon_generator',
]
def __init__(self, hostname, **kwargs):
@ -213,6 +214,9 @@ class HomeServer(object):
def build_auth_handler(self):
return AuthHandler(self)
def build_macaroon_generator(self):
return MacaroonGeneartor(self)
def build_device_handler(self):
return DeviceHandler(self)

View file

@ -16,12 +16,12 @@
from twisted.internet import defer
from synapse import event_auth
from synapse.util.logutils import log_function
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import Measure
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.api.auth import AuthEventTypes
from synapse.events.snapshot import EventContext
from synapse.util.async import Linearizer
@ -41,12 +41,14 @@ KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
SIZE_OF_CACHE = int(1000 * CACHE_SIZE_FACTOR)
SIZE_OF_CACHE = int(100000 * CACHE_SIZE_FACTOR)
EVICTION_TIMEOUT_SECONDS = 60 * 60
_NEXT_STATE_ID = 1
POWER_KEY = (EventTypes.PowerLevels, "")
def _gen_state_id():
global _NEXT_STATE_ID
@ -77,6 +79,9 @@ class _StateCacheEntry(object):
else:
self.state_id = _gen_state_id()
def __len__(self):
return len(self.state)
class StateHandler(object):
""" Responsible for doing state conflict resolution.
@ -89,7 +94,7 @@ class StateHandler(object):
# dict of set of event_ids -> _StateCacheEntry.
self._state_cache = None
self.resolve_linearizer = Linearizer()
self.resolve_linearizer = Linearizer(name="state_resolve_lock")
def start_caching(self):
logger.debug("start_caching")
@ -99,6 +104,7 @@ class StateHandler(object):
clock=self.clock,
max_len=SIZE_OF_CACHE,
expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
iterable=True,
reset_expiry_on_get=True,
)
@ -123,7 +129,7 @@ class StateHandler(object):
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.info("calling resolve_state_groups from get_current_state")
logger.debug("calling resolve_state_groups from get_current_state")
ret = yield self.resolve_state_groups(room_id, latest_event_ids)
state = ret.state
@ -148,7 +154,7 @@ class StateHandler(object):
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.info("calling resolve_state_groups from get_current_state_ids")
logger.debug("calling resolve_state_groups from get_current_state_ids")
ret = yield self.resolve_state_groups(room_id, latest_event_ids)
state = ret.state
@ -162,7 +168,7 @@ class StateHandler(object):
def get_current_user_in_room(self, room_id, latest_event_ids=None):
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.info("calling resolve_state_groups from get_current_user_in_room")
logger.debug("calling resolve_state_groups from get_current_user_in_room")
entry = yield self.resolve_state_groups(room_id, latest_event_ids)
joined_users = yield self.store.get_joined_users_from_state(
room_id, entry.state_id, entry.state
@ -226,7 +232,7 @@ class StateHandler(object):
context.prev_state_events = []
defer.returnValue(context)
logger.info("calling resolve_state_groups from compute_event_context")
logger.debug("calling resolve_state_groups from compute_event_context")
if event.is_state():
entry = yield self.resolve_state_groups(
event.room_id, [e for e, _ in event.prev_events],
@ -327,20 +333,13 @@ class StateHandler(object):
if conflicted_state:
logger.info("Resolving conflicted state for %r", room_id)
state_map = yield self.store.get_events(
[e_id for st in state_groups_ids.values() for e_id in st.values()],
get_prev_content=False
)
state_sets = [
[state_map[e_id] for key, e_id in st.items() if e_id in state_map]
for st in state_groups_ids.values()
]
new_state, _ = self._resolve_events(
state_sets, event_type, state_key
)
new_state = {
key: e.event_id for key, e in new_state.items()
}
with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events(
state_groups_ids.values(),
state_map_factory=lambda ev_ids: self.store.get_events(
ev_ids, get_prev_content=False, check_redacted=False,
),
)
else:
new_state = {
key: e_ids.pop() for key, e_ids in state.items()
@ -388,152 +387,267 @@ class StateHandler(object):
logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
)
if event.is_state():
return self._resolve_events(
state_sets, event.type, event.state_key
)
else:
return self._resolve_events(state_sets)
state_set_ids = [{
(ev.type, ev.state_key): ev.event_id
for ev in st
} for st in state_sets]
state_map = {
ev.event_id: ev
for st in state_sets
for ev in st
}
def _resolve_events(self, state_sets, event_type=None, state_key=""):
"""
Returns
(dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple
(new_state, prev_states). new_state is a map from (type, state_key)
to event. prev_states is a list of event_ids.
"""
with Measure(self.clock, "state._resolve_events"):
state = {}
for st in state_sets:
for e in st:
state.setdefault(
(e.type, e.state_key),
{}
)[e.event_id] = e
new_state = resolve_events(state_set_ids, state_map)
unconflicted_state = {
k: v.values()[0] for k, v in state.items()
if len(v.values()) == 1
}
new_state = {
key: state_map[ev_id] for key, ev_id in new_state.items()
}
conflicted_state = {
k: v.values()
for k, v in state.items()
if len(v.values()) > 1
}
return new_state
if event_type:
prev_states_events = conflicted_state.get(
(event_type, state_key), []
)
prev_states = [s.event_id for s in prev_states_events]
else:
prev_states = []
auth_events = {
k: e for k, e in unconflicted_state.items()
if k[0] in AuthEventTypes
}
def _ordered_events(events):
def key_func(e):
return -int(e.depth), hashlib.sha1(e.event_id).hexdigest()
try:
resolved_state = self._resolve_state_events(
conflicted_state, auth_events
)
except:
logger.exception("Failed to resolve state")
raise
return sorted(events, key=key_func)
new_state = unconflicted_state
new_state.update(resolved_state)
return new_state, prev_states
def resolve_events(state_sets, state_map_factory):
"""
Args:
state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
state_map_factory(dict|callable): If callable, then will be called
with a list of event_ids that are needed, and should return with
a Deferred of dict of event_id to event. Otherwise, should be
a dict from event_id to event of all events in state_sets.
@log_function
def _resolve_state_events(self, conflicted_state, auth_events):
""" This is where we actually decide which of the conflicted state to
use.
Returns
dict[(str, str), synapse.events.FrozenEvent] is a map from
(type, state_key) to event.
"""
if len(state_sets) == 1:
return state_sets[0]
We resolve conflicts in the following order:
1. power levels
2. join rules
3. memberships
4. other events.
"""
resolved_state = {}
power_key = (EventTypes.PowerLevels, "")
if power_key in conflicted_state:
events = conflicted_state[power_key]
logger.debug("Resolving conflicted power levels %r", events)
resolved_state[power_key] = self._resolve_auth_events(
events, auth_events)
unconflicted_state, conflicted_state = _seperate(
state_sets,
)
auth_events.update(resolved_state)
if callable(state_map_factory):
return _resolve_with_state_fac(
unconflicted_state, conflicted_state, state_map_factory
)
for key, events in conflicted_state.items():
if key[0] == EventTypes.JoinRules:
logger.debug("Resolving conflicted join rules %r", events)
resolved_state[key] = self._resolve_auth_events(
events,
auth_events
)
state_map = state_map_factory
auth_events.update(resolved_state)
auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map
)
for key, events in conflicted_state.items():
if key[0] == EventTypes.Member:
logger.debug("Resolving conflicted member lists %r", events)
resolved_state[key] = self._resolve_auth_events(
events,
auth_events
)
return _resolve_with_state(
unconflicted_state, conflicted_state, auth_events, state_map
)
auth_events.update(resolved_state)
for key, events in conflicted_state.items():
if key not in resolved_state:
logger.debug("Resolving conflicted state %r:%r", key, events)
resolved_state[key] = self._resolve_normal_events(
events, auth_events
)
def _seperate(state_sets):
"""Takes the state_sets and figures out which keys are conflicted and
which aren't. i.e., which have multiple different event_ids associated
with them in different state sets.
"""
unconflicted_state = dict(state_sets[0])
conflicted_state = {}
return resolved_state
for state_set in state_sets[1:]:
for key, value in state_set.iteritems():
# Check if there is an unconflicted entry for the state key.
unconflicted_value = unconflicted_state.get(key)
if unconflicted_value is None:
# There isn't an unconflicted entry so check if there is a
# conflicted entry.
ls = conflicted_state.get(key)
if ls is None:
# There wasn't a conflicted entry so haven't seen this key before.
# Therefore it isn't conflicted yet.
unconflicted_state[key] = value
else:
# This key is already conflicted, add our value to the conflict set.
ls.add(value)
elif unconflicted_value != value:
# If the unconflicted value is not the same as our value then we
# have a new conflict. So move the key from the unconflicted_state
# to the conflicted state.
conflicted_state[key] = {value, unconflicted_value}
unconflicted_state.pop(key, None)
def _resolve_auth_events(self, events, auth_events):
reverse = [i for i in reversed(self._ordered_events(events))]
return unconflicted_state, conflicted_state
auth_events = dict(auth_events)
prev_event = reverse[0]
for event in reverse[1:]:
auth_events[(prev_event.type, prev_event.state_key)] = prev_event
try:
# FIXME: hs.get_auth() is bad style, but we need to do it to
# get around circular deps.
# The signatures have already been checked at this point
self.hs.get_auth().check(event, auth_events, do_sig_check=False)
prev_event = event
except AuthError:
return prev_event
@defer.inlineCallbacks
def _resolve_with_state_fac(unconflicted_state, conflicted_state,
state_map_factory):
needed_events = set(
event_id
for event_ids in conflicted_state.itervalues()
for event_id in event_ids
)
return event
logger.info("Asking for %d conflicted events", len(needed_events))
def _resolve_normal_events(self, events, auth_events):
for event in self._ordered_events(events):
try:
# FIXME: hs.get_auth() is bad style, but we need to do it to
# get around circular deps.
# The signatures have already been checked at this point
self.hs.get_auth().check(event, auth_events, do_sig_check=False)
return event
except AuthError:
pass
state_map = yield state_map_factory(needed_events)
# Use the last event (the one with the least depth) if they all fail
# the auth check.
return event
auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map
)
def _ordered_events(self, events):
def key_func(e):
return -int(e.depth), hashlib.sha1(e.event_id).hexdigest()
new_needed_events = set(auth_events.itervalues())
new_needed_events -= needed_events
return sorted(events, key=key_func)
logger.info("Asking for %d auth events", len(new_needed_events))
state_map_new = yield state_map_factory(new_needed_events)
state_map.update(state_map_new)
defer.returnValue(_resolve_with_state(
unconflicted_state, conflicted_state, auth_events, state_map
))
def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
auth_events = {}
for event_ids in conflicted_state.itervalues():
for event_id in event_ids:
if event_id in state_map:
keys = event_auth.auth_types_for_event(state_map[event_id])
for key in keys:
if key not in auth_events:
event_id = unconflicted_state.get(key, None)
if event_id:
auth_events[key] = event_id
return auth_events
def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_ids,
state_map):
conflicted_state = {}
for key, event_ids in conflicted_state_ds.iteritems():
events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map]
if len(events) > 1:
conflicted_state[key] = events
elif len(events) == 1:
unconflicted_state_ids[key] = events[0].event_id
auth_events = {
key: state_map[ev_id]
for key, ev_id in auth_event_ids.items()
if ev_id in state_map
}
try:
resolved_state = _resolve_state_events(
conflicted_state, auth_events
)
except:
logger.exception("Failed to resolve state")
raise
new_state = unconflicted_state_ids
for key, event in resolved_state.iteritems():
new_state[key] = event.event_id
return new_state
def _resolve_state_events(conflicted_state, auth_events):
""" This is where we actually decide which of the conflicted state to
use.
We resolve conflicts in the following order:
1. power levels
2. join rules
3. memberships
4. other events.
"""
resolved_state = {}
if POWER_KEY in conflicted_state:
events = conflicted_state[POWER_KEY]
logger.debug("Resolving conflicted power levels %r", events)
resolved_state[POWER_KEY] = _resolve_auth_events(
events, auth_events)
auth_events.update(resolved_state)
for key, events in conflicted_state.items():
if key[0] == EventTypes.JoinRules:
logger.debug("Resolving conflicted join rules %r", events)
resolved_state[key] = _resolve_auth_events(
events,
auth_events
)
auth_events.update(resolved_state)
for key, events in conflicted_state.items():
if key[0] == EventTypes.Member:
logger.debug("Resolving conflicted member lists %r", events)
resolved_state[key] = _resolve_auth_events(
events,
auth_events
)
auth_events.update(resolved_state)
for key, events in conflicted_state.items():
if key not in resolved_state:
logger.debug("Resolving conflicted state %r:%r", key, events)
resolved_state[key] = _resolve_normal_events(
events, auth_events
)
return resolved_state
def _resolve_auth_events(events, auth_events):
reverse = [i for i in reversed(_ordered_events(events))]
auth_keys = set(
key
for event in events
for key in event_auth.auth_types_for_event(event)
)
new_auth_events = {}
for key in auth_keys:
auth_event = auth_events.get(key, None)
if auth_event:
new_auth_events[key] = auth_event
auth_events = new_auth_events
prev_event = reverse[0]
for event in reverse[1:]:
auth_events[(prev_event.type, prev_event.state_key)] = prev_event
try:
# The signatures have already been checked at this point
event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
prev_event = event
except AuthError:
return prev_event
return event
def _resolve_normal_events(events, auth_events):
for event in _ordered_events(events):
try:
# The signatures have already been checked at this point
event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
return event
except AuthError:
pass
# Use the last event (the one with the least depth) if they all fail
# the auth check.
return event

View file

@ -116,6 +116,9 @@ class DataStore(RoomMemberStore, RoomStore,
self._public_room_id_gen = StreamIdGenerator(
db_conn, "public_room_list_stream", "stream_id"
)
self._device_list_id_gen = StreamIdGenerator(
db_conn, "device_lists_stream", "stream_id",
)
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
@ -189,7 +192,8 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "device_inbox",
entity_column="user_id",
stream_column="stream_id",
max_value=max_device_inbox_id
max_value=max_device_inbox_id,
limit=1000,
)
self._device_inbox_stream_cache = StreamChangeCache(
"DeviceInboxStreamChangeCache", min_device_inbox_id,
@ -202,12 +206,21 @@ class DataStore(RoomMemberStore, RoomStore,
entity_column="destination",
stream_column="stream_id",
max_value=max_device_inbox_id,
limit=1000,
)
self._device_federation_outbox_stream_cache = StreamChangeCache(
"DeviceFederationOutboxStreamChangeCache", min_device_outbox_id,
prefilled_cache=device_outbox_prefill,
)
device_list_max = self._device_list_id_gen.get_current_token()
self._device_list_stream_cache = StreamChangeCache(
"DeviceListStreamChangeCache", device_list_max,
)
self._device_list_federation_stream_cache = StreamChangeCache(
"DeviceListFederationStreamChangeCache", device_list_max,
)
cur = LoggingTransaction(
db_conn.cursor(),
name="_find_stream_orderings_for_times_txn",

View file

@ -169,7 +169,7 @@ class SQLBaseStore(object):
max_entries=hs.config.event_cache_size)
self._state_group_cache = DictionaryCache(
"*stateGroupCache*", 2000 * CACHE_SIZE_FACTOR
"*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
)
self._event_fetch_lock = threading.Condition()
@ -387,6 +387,10 @@ class SQLBaseStore(object):
Args:
table : string giving the table name
values : dict of new column names and values for them
Returns:
bool: Whether the row was inserted or not. Only useful when
`or_ignore` is True
"""
try:
yield self.runInteraction(
@ -398,6 +402,8 @@ class SQLBaseStore(object):
# a cursor after we receive an error from the db.
if not or_ignore:
raise
defer.returnValue(False)
defer.returnValue(True)
@staticmethod
def _simple_insert_txn(txn, table, values):
@ -838,18 +844,19 @@ class SQLBaseStore(object):
return txn.execute(sql, keyvalues.values())
def _get_cache_dict(self, db_conn, table, entity_column, stream_column,
max_value):
max_value, limit=100000):
# Fetch a mapping of room_id -> max stream position for "recent" rooms.
# It doesn't really matter how many we get, the StreamChangeCache will
# do the right thing to ensure it respects the max size of cache.
sql = (
"SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
" WHERE %(stream)s > ? - 100000"
" WHERE %(stream)s > ? - %(limit)s"
" GROUP BY %(entity)s"
) % {
"table": table,
"entity": entity_column,
"stream": stream_column,
"limit": limit,
}
sql = self.database_engine.convert_param_style(sql)

View file

@ -18,13 +18,29 @@ import ujson
from twisted.internet import defer
from ._base import SQLBaseStore
from .background_updates import BackgroundUpdateStore
logger = logging.getLogger(__name__)
class DeviceInboxStore(SQLBaseStore):
class DeviceInboxStore(BackgroundUpdateStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, hs):
super(DeviceInboxStore, self).__init__(hs)
self.register_background_index_update(
"device_inbox_stream_index",
index_name="device_inbox_stream_id_user_id",
table="device_inbox",
columns=["stream_id", "user_id"],
)
self.register_background_update_handler(
self.DEVICE_INBOX_STREAM_ID,
self._background_drop_index_device_inbox,
)
@defer.inlineCallbacks
def add_messages_to_device_inbox(self, local_messages_by_user_then_device,
@ -368,3 +384,18 @@ class DeviceInboxStore(SQLBaseStore):
"delete_device_msgs_for_remote",
delete_messages_for_remote_destination_txn
)
@defer.inlineCallbacks
def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn):
txn = conn.cursor()
txn.execute(
"DROP INDEX IF EXISTS device_inbox_stream_id"
)
txn.close()
yield self.runWithConnection(reindex_txn)
yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID)
defer.returnValue(1)

View file

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import ujson as json
from twisted.internet import defer
@ -23,27 +24,29 @@ logger = logging.getLogger(__name__)
class DeviceStore(SQLBaseStore):
def __init__(self, hs):
super(DeviceStore, self).__init__(hs)
self._clock.looping_call(
self._prune_old_outbound_device_pokes, 60 * 60 * 1000
)
@defer.inlineCallbacks
def store_device(self, user_id, device_id,
initial_device_display_name,
ignore_if_known=True):
initial_device_display_name):
"""Ensure the given device is known; add it to the store if not
Args:
user_id (str): id of user associated with the device
device_id (str): id of device
initial_device_display_name (str): initial displayname of the
device
ignore_if_known (bool): ignore integrity errors which mean the
device is already known
device. Ignored if device exists.
Returns:
defer.Deferred
Raises:
StoreError: if ignore_if_known is False and the device was already
known
defer.Deferred: boolean whether the device was inserted or an
existing device existed with that ID.
"""
try:
yield self._simple_insert(
inserted = yield self._simple_insert(
"devices",
values={
"user_id": user_id,
@ -51,8 +54,9 @@ class DeviceStore(SQLBaseStore):
"display_name": initial_device_display_name
},
desc="store_device",
or_ignore=ignore_if_known,
or_ignore=True,
)
defer.returnValue(inserted)
except Exception as e:
logger.error("store_device with device_id=%s(%r) user_id=%s(%r)"
" display_name=%s(%r) failed: %s",
@ -139,3 +143,451 @@ class DeviceStore(SQLBaseStore):
)
defer.returnValue({d["device_id"]: d for d in devices})
def get_device_list_last_stream_id_for_remote(self, user_id):
"""Get the last stream_id we got for a user. May be None if we haven't
got any information for them.
"""
return self._simple_select_one_onecol(
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
retcol="stream_id",
desc="get_device_list_remote_extremity",
allow_none=True,
)
def mark_remote_user_device_list_as_unsubscribed(self, user_id):
"""Mark that we no longer track device lists for remote user.
"""
return self._simple_delete(
table="device_lists_remote_extremeties",
keyvalues={
"user_id": user_id,
},
desc="mark_remote_user_device_list_as_unsubscribed",
)
def update_remote_device_list_cache_entry(self, user_id, device_id, content,
stream_id):
"""Updates a single user's device in the cache.
"""
return self.runInteraction(
"update_remote_device_list_cache_entry",
self._update_remote_device_list_cache_entry_txn,
user_id, device_id, content, stream_id,
)
def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id,
content, stream_id):
self._simple_upsert_txn(
txn,
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
values={
"content": json.dumps(content),
}
)
self._simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={
"user_id": user_id,
},
values={
"stream_id": stream_id,
}
)
def update_remote_device_list_cache(self, user_id, devices, stream_id):
"""Replace the cache of the remote user's devices.
"""
return self.runInteraction(
"update_remote_device_list_cache",
self._update_remote_device_list_cache_txn,
user_id, devices, stream_id,
)
def _update_remote_device_list_cache_txn(self, txn, user_id, devices,
stream_id):
self._simple_delete_txn(
txn,
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
},
)
self._simple_insert_many_txn(
txn,
table="device_lists_remote_cache",
values=[
{
"user_id": user_id,
"device_id": content["device_id"],
"content": json.dumps(content),
}
for content in devices
]
)
self._simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={
"user_id": user_id,
},
values={
"stream_id": stream_id,
}
)
def get_devices_by_remote(self, destination, from_stream_id):
"""Get stream of updates to send to remote servers
Returns:
(now_stream_id, [ { updates }, .. ])
"""
now_stream_id = self._device_list_id_gen.get_current_token()
has_changed = self._device_list_federation_stream_cache.has_entity_changed(
destination, int(from_stream_id)
)
if not has_changed:
return (now_stream_id, [])
return self.runInteraction(
"get_devices_by_remote", self._get_devices_by_remote_txn,
destination, from_stream_id, now_stream_id,
)
def _get_devices_by_remote_txn(self, txn, destination, from_stream_id,
now_stream_id):
sql = """
SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
GROUP BY user_id, device_id
"""
txn.execute(
sql, (destination, from_stream_id, now_stream_id, False)
)
rows = txn.fetchall()
if not rows:
return (now_stream_id, [])
# maps (user_id, device_id) -> stream_id
query_map = {(r[0], r[1]): r[2] for r in rows}
devices = self._get_e2e_device_keys_txn(
txn, query_map.keys(), include_all_devices=True
)
prev_sent_id_sql = """
SELECT coalesce(max(stream_id), 0) as stream_id
FROM device_lists_outbound_pokes
WHERE destination = ? AND user_id = ? AND stream_id <= ?
"""
results = []
for user_id, user_devices in devices.iteritems():
# The prev_id for the first row is always the last row before
# `from_stream_id`
txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
rows = txn.fetchall()
prev_id = rows[0][0]
for device_id, device in user_devices.iteritems():
stream_id = query_map[(user_id, device_id)]
result = {
"user_id": user_id,
"device_id": device_id,
"prev_id": [prev_id] if prev_id else [],
"stream_id": stream_id,
}
prev_id = stream_id
key_json = device.get("key_json", None)
if key_json:
result["keys"] = json.loads(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
results.append(result)
return (now_stream_id, results)
def get_user_devices_from_cache(self, query_list):
"""Get the devices (and keys if any) for remote users from the cache.
Args:
query_list(list): List of (user_id, device_ids), if device_ids is
falsey then return all device ids for that user.
Returns:
(user_ids_not_in_cache, results_map), where user_ids_not_in_cache is
a set of user_ids and results_map is a mapping of
user_id -> device_id -> device_info
"""
return self.runInteraction(
"get_user_devices_from_cache", self._get_user_devices_from_cache_txn,
query_list,
)
def _get_user_devices_from_cache_txn(self, txn, query_list):
user_ids = {user_id for user_id, _ in query_list}
user_ids_in_cache = set()
for user_id in user_ids:
stream_ids = self._simple_select_onecol_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={
"user_id": user_id,
},
retcol="stream_id",
)
if stream_ids:
user_ids_in_cache.add(user_id)
user_ids_not_in_cache = user_ids - user_ids_in_cache
results = {}
for user_id, device_id in query_list:
if user_id not in user_ids_in_cache:
continue
if device_id:
content = self._simple_select_one_onecol_txn(
txn,
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
retcol="content",
)
results.setdefault(user_id, {})[device_id] = json.loads(content)
else:
devices = self._simple_select_list_txn(
txn,
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
},
retcols=("device_id", "content"),
)
results[user_id] = {
device["device_id"]: json.loads(device["content"])
for device in devices
}
user_ids_in_cache.discard(user_id)
return user_ids_not_in_cache, results
def get_devices_with_keys_by_user(self, user_id):
"""Get all devices (with any device keys) for a user
Returns:
(stream_id, devices)
"""
return self.runInteraction(
"get_devices_with_keys_by_user",
self._get_devices_with_keys_by_user_txn, user_id,
)
def _get_devices_with_keys_by_user_txn(self, txn, user_id):
now_stream_id = self._device_list_id_gen.get_current_token()
devices = self._get_e2e_device_keys_txn(
txn, [(user_id, None)], include_all_devices=True
)
if devices:
user_devices = devices[user_id]
results = []
for device_id, device in user_devices.iteritems():
result = {
"device_id": device_id,
}
key_json = device.get("key_json", None)
if key_json:
result["keys"] = json.loads(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
results.append(result)
return now_stream_id, results
return now_stream_id, []
def mark_as_sent_devices_by_remote(self, destination, stream_id):
"""Mark that updates have successfully been sent to the destination.
"""
return self.runInteraction(
"mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
destination, stream_id,
)
def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
# First we DELETE all rows such that only the latest row for each
# (destination, user_id is left. We do this by selecting first and
# deleting.
sql = """
SELECT user_id, coalesce(max(stream_id), 0) FROM device_lists_outbound_pokes
WHERE destination = ? AND stream_id <= ?
GROUP BY user_id
HAVING count(*) > 1
"""
txn.execute(sql, (destination, stream_id,))
rows = txn.fetchall()
sql = """
DELETE FROM device_lists_outbound_pokes
WHERE destination = ? AND user_id = ? AND stream_id < ?
"""
txn.executemany(
sql, ((destination, row[0], row[1],) for row in rows)
)
# Mark everything that is left as sent
sql = """
UPDATE device_lists_outbound_pokes SET sent = ?
WHERE destination = ? AND stream_id <= ?
"""
txn.execute(sql, (True, destination, stream_id,))
@defer.inlineCallbacks
def get_user_whose_devices_changed(self, from_key):
"""Get set of users whose devices have changed since `from_key`.
"""
from_key = int(from_key)
changed = self._device_list_stream_cache.get_all_entities_changed(from_key)
if changed is not None:
defer.returnValue(set(changed))
sql = """
SELECT user_id FROM device_lists_stream WHERE stream_id > ?
"""
rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
defer.returnValue(set(row[0] for row in rows))
def get_all_device_list_changes_for_remotes(self, from_key):
"""Return a list of `(stream_id, user_id, destination)` which is the
combined list of changes to devices, and which destinations need to be
poked. `destination` may be None if no destinations need to be poked.
"""
sql = """
SELECT stream_id, user_id, destination FROM device_lists_stream
LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
WHERE stream_id > ?
"""
return self._execute(
"get_users_and_hosts_device_list", None,
sql, from_key,
)
@defer.inlineCallbacks
def add_device_change_to_streams(self, user_id, device_ids, hosts):
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
"""
with self._device_list_id_gen.get_next() as stream_id:
yield self.runInteraction(
"add_device_change_to_streams", self._add_device_change_txn,
user_id, device_ids, hosts, stream_id,
)
defer.returnValue(stream_id)
def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id):
now = self._clock.time_msec()
txn.call_after(
self._device_list_stream_cache.entity_has_changed,
user_id, stream_id,
)
for host in hosts:
txn.call_after(
self._device_list_federation_stream_cache.entity_has_changed,
host, stream_id,
)
self._simple_insert_many_txn(
txn,
table="device_lists_stream",
values=[
{
"stream_id": stream_id,
"user_id": user_id,
"device_id": device_id,
}
for device_id in device_ids
]
)
self._simple_insert_many_txn(
txn,
table="device_lists_outbound_pokes",
values=[
{
"destination": destination,
"stream_id": stream_id,
"user_id": user_id,
"device_id": device_id,
"sent": False,
"ts": now,
}
for destination in hosts
for device_id in device_ids
]
)
def get_device_stream_token(self):
return self._device_list_id_gen.get_current_token()
def _prune_old_outbound_device_pokes(self):
"""Delete old entries out of the device_lists_outbound_pokes to ensure
that we don't fill up due to dead servers. We keep one entry per
(destination, user_id) tuple to ensure that the prev_ids remain correct
if the server does come back.
"""
yesterday = self._clock.time_msec() - 24 * 60 * 60 * 1000
def _prune_txn(txn):
select_sql = """
SELECT destination, user_id, max(stream_id) as stream_id
FROM device_lists_outbound_pokes
GROUP BY destination, user_id
HAVING min(ts) < ? AND count(*) > 1
"""
txn.execute(select_sql, (yesterday,))
rows = txn.fetchall()
if not rows:
return
delete_sql = """
DELETE FROM device_lists_outbound_pokes
WHERE ts < ? AND destination = ? AND user_id = ? AND stream_id < ?
"""
txn.executemany(
delete_sql,
(
(yesterday, row[0], row[1], row[2])
for row in rows
)
)
logger.info("Pruned %d device list outbound pokes", txn.rowcount)
return self.runInteraction(
"_prune_old_outbound_device_pokes", _prune_txn
)

View file

@ -12,74 +12,111 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from twisted.internet import defer
import twisted.internet.defer
from canonicaljson import encode_canonical_json
import ujson as json
from ._base import SQLBaseStore
class EndToEndKeyStore(SQLBaseStore):
def set_e2e_device_keys(self, user_id, device_id, time_now, json_bytes):
return self._simple_upsert(
table="e2e_device_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
values={
"ts_added_ms": time_now,
"key_json": json_bytes,
}
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
"""Stores device keys for a device. Returns whether there was a change
or the keys were already in the database.
"""
def _set_e2e_device_keys_txn(txn):
old_key_json = self._simple_select_one_onecol_txn(
txn,
table="e2e_device_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
retcol="key_json",
allow_none=True,
)
new_key_json = encode_canonical_json(device_keys)
if old_key_json == new_key_json:
return False
self._simple_upsert_txn(
txn,
table="e2e_device_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
values={
"ts_added_ms": time_now,
"key_json": new_key_json,
}
)
return True
return self.runInteraction(
"set_e2e_device_keys", _set_e2e_device_keys_txn
)
def get_e2e_device_keys(self, query_list):
@defer.inlineCallbacks
def get_e2e_device_keys(self, query_list, include_all_devices=False):
"""Fetch a list of device keys.
Args:
query_list(list): List of pairs of user_ids and device_ids.
include_all_devices (bool): whether to include entries for devices
that don't have device keys
Returns:
Dict mapping from user-id to dict mapping from device_id to
dict containing "key_json", "device_display_name".
"""
if not query_list:
return {}
defer.returnValue({})
return self.runInteraction(
"get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list
results = yield self.runInteraction(
"get_e2e_device_keys", self._get_e2e_device_keys_txn,
query_list, include_all_devices,
)
def _get_e2e_device_keys_txn(self, txn, query_list):
for user_id, device_keys in results.iteritems():
for device_id, device_info in device_keys.iteritems():
device_info["keys"] = json.loads(device_info.pop("key_json"))
defer.returnValue(results)
def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices):
query_clauses = []
query_params = []
for (user_id, device_id) in query_list:
query_clause = "k.user_id = ?"
query_clause = "user_id = ?"
query_params.append(user_id)
if device_id:
query_clause += " AND k.device_id = ?"
query_clause += " AND device_id = ?"
query_params.append(device_id)
query_clauses.append(query_clause)
sql = (
"SELECT k.user_id, k.device_id, "
"SELECT user_id, device_id, "
" d.display_name AS device_display_name, "
" k.key_json"
" FROM e2e_device_keys_json k"
" LEFT JOIN devices d ON d.user_id = k.user_id"
" AND d.device_id = k.device_id"
" FROM devices d"
" %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
" WHERE %s"
) % (
"LEFT" if include_all_devices else "INNER",
" OR ".join("(" + q + ")" for q in query_clauses)
)
txn.execute(sql, query_params)
rows = self.cursor_to_dict(txn)
result = collections.defaultdict(dict)
result = {}
for row in rows:
result[row["user_id"]][row["device_id"]] = row
result.setdefault(row["user_id"], {})[row["device_id"]] = row
return result
@ -152,7 +189,7 @@ class EndToEndKeyStore(SQLBaseStore):
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
)
@twisted.internet.defer.inlineCallbacks
@defer.inlineCallbacks
def delete_e2e_keys_by_device(self, user_id, device_id):
yield self._simple_delete(
table="e2e_device_keys_json",

View file

@ -129,7 +129,7 @@ class EventFederationStore(SQLBaseStore):
room_id,
)
@cached()
@cached(max_entries=5000, iterable=True)
def get_latest_event_ids_in_room(self, room_id):
return self._simple_select_onecol(
table="event_forward_extremities",
@ -235,80 +235,21 @@ class EventFederationStore(SQLBaseStore):
],
)
self._update_extremeties(txn, events)
self._update_backward_extremeties(txn, events)
def _update_extremeties(self, txn, events):
"""Updates the event_*_extremities tables based on the new/updated
def _update_backward_extremeties(self, txn, events):
"""Updates the event_backward_extremities tables based on the new/updated
events being persisted.
This is called for new events *and* for events that were outliers, but
are are now being persisted as non-outliers.
are now being persisted as non-outliers.
Forward extremities are handled when we first start persisting the events.
"""
events_by_room = {}
for ev in events:
events_by_room.setdefault(ev.room_id, []).append(ev)
for room_id, room_events in events_by_room.items():
prevs = [
e_id for ev in room_events for e_id, _ in ev.prev_events
if not ev.internal_metadata.is_outlier()
]
if prevs:
txn.execute(
"DELETE FROM event_forward_extremities"
" WHERE room_id = ?"
" AND event_id in (%s)" % (
",".join(["?"] * len(prevs)),
),
[room_id] + prevs,
)
query = (
"INSERT INTO event_forward_extremities (event_id, room_id)"
" SELECT ?, ? WHERE NOT EXISTS ("
" SELECT 1 FROM event_edges WHERE prev_event_id = ?"
" )"
)
txn.executemany(
query,
[
(ev.event_id, ev.room_id, ev.event_id) for ev in events
if not ev.internal_metadata.is_outlier()
]
)
# We now insert into stream_ordering_to_exterm a mapping from room_id,
# new stream_ordering to new forward extremeties in the room.
# This allows us to later efficiently look up the forward extremeties
# for a room before a given stream_ordering
max_stream_ord = max(
ev.internal_metadata.stream_ordering for ev in events
)
new_extrem = {}
for room_id in events_by_room:
event_ids = self._simple_select_onecol_txn(
txn,
table="event_forward_extremities",
keyvalues={"room_id": room_id},
retcol="event_id",
)
new_extrem[room_id] = event_ids
self._simple_insert_many_txn(
txn,
table="stream_ordering_to_exterm",
values=[
{
"room_id": room_id,
"event_id": event_id,
"stream_ordering": max_stream_ord,
}
for room_id, extrem_evs in new_extrem.items()
for event_id in extrem_evs
]
)
query = (
"INSERT INTO event_backward_extremities (event_id, room_id)"
" SELECT ?, ? WHERE NOT EXISTS ("
@ -339,11 +280,6 @@ class EventFederationStore(SQLBaseStore):
]
)
for room_id in events_by_room:
txn.call_after(
self.get_latest_event_ids_in_room.invalidate, (room_id,)
)
def get_forward_extremeties_for_room(self, room_id, stream_ordering):
# We want to make the cache more effective, so we clamp to the last
# change before the given ordering.

View file

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore, _RollbackButIsFineException
from ._base import SQLBaseStore
from twisted.internet import defer, reactor
@ -27,6 +27,8 @@ from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
from synapse.state import resolve_events
from synapse.util.caches.descriptors import cached
from canonicaljson import encode_canonical_json
from collections import deque, namedtuple, OrderedDict
@ -71,22 +73,19 @@ class _EventPeristenceQueue(object):
"""
_EventPersistQueueItem = namedtuple("_EventPersistQueueItem", (
"events_and_contexts", "current_state", "backfilled", "deferred",
"events_and_contexts", "backfilled", "deferred",
))
def __init__(self):
self._event_persist_queues = {}
self._currently_persisting_rooms = set()
def add_to_queue(self, room_id, events_and_contexts, backfilled, current_state):
def add_to_queue(self, room_id, events_and_contexts, backfilled):
"""Add events to the queue, with the given persist_event options.
"""
queue = self._event_persist_queues.setdefault(room_id, deque())
if queue:
end_item = queue[-1]
if end_item.current_state or current_state:
# We perist events with current_state set to True one at a time
pass
if end_item.backfilled == backfilled:
end_item.events_and_contexts.extend(events_and_contexts)
return end_item.deferred.observe()
@ -96,7 +95,6 @@ class _EventPeristenceQueue(object):
queue.append(self._EventPersistQueueItem(
events_and_contexts=events_and_contexts,
backfilled=backfilled,
current_state=current_state,
deferred=deferred,
))
@ -216,7 +214,6 @@ class EventsStore(SQLBaseStore):
d = preserve_fn(self._event_persist_queue.add_to_queue)(
room_id, evs_ctxs,
backfilled=backfilled,
current_state=None,
)
deferreds.append(d)
@ -229,11 +226,10 @@ class EventsStore(SQLBaseStore):
@defer.inlineCallbacks
@log_function
def persist_event(self, event, context, current_state=None, backfilled=False):
def persist_event(self, event, context, backfilled=False):
deferred = self._event_persist_queue.add_to_queue(
event.room_id, [(event, context)],
backfilled=backfilled,
current_state=current_state,
)
self._maybe_start_persisting(event.room_id)
@ -246,21 +242,10 @@ class EventsStore(SQLBaseStore):
def _maybe_start_persisting(self, room_id):
@defer.inlineCallbacks
def persisting_queue(item):
if item.current_state:
for event, context in item.events_and_contexts:
# There should only ever be one item in
# events_and_contexts when current_state is
# not None
yield self._persist_event(
event, context,
current_state=item.current_state,
backfilled=item.backfilled,
)
else:
yield self._persist_events(
item.events_and_contexts,
backfilled=item.backfilled,
)
yield self._persist_events(
item.events_and_contexts,
backfilled=item.backfilled,
)
self._event_persist_queue.handle_queue(room_id, persisting_queue)
@ -294,35 +279,183 @@ class EventsStore(SQLBaseStore):
for chunk in chunks:
# We can't easily parallelize these since different chunks
# might contain the same event. :(
# NB: Assumes that we are only persisting events for one room
# at a time.
new_forward_extremeties = {}
current_state_for_room = {}
if not backfilled:
with Measure(self._clock, "_calculate_state_and_extrem"):
# Work out the new "current state" for each room.
# We do this by working out what the new extremities are and then
# calculating the state from that.
events_by_room = {}
for event, context in chunk:
events_by_room.setdefault(event.room_id, []).append(
(event, context)
)
for room_id, ev_ctx_rm in events_by_room.items():
# Work out new extremities by recursively adding and removing
# the new events.
latest_event_ids = yield self.get_latest_event_ids_in_room(
room_id
)
new_latest_event_ids = yield self._calculate_new_extremeties(
room_id, [ev for ev, _ in ev_ctx_rm]
)
if new_latest_event_ids == set(latest_event_ids):
# No change in extremities, so no change in state
continue
new_forward_extremeties[room_id] = new_latest_event_ids
state = yield self._calculate_state_delta(
room_id, ev_ctx_rm, new_latest_event_ids
)
if state:
current_state_for_room[room_id] = state
yield self.runInteraction(
"persist_events",
self._persist_events_txn,
events_and_contexts=chunk,
backfilled=backfilled,
delete_existing=delete_existing,
current_state_for_room=current_state_for_room,
new_forward_extremeties=new_forward_extremeties,
)
persist_event_counter.inc_by(len(chunk))
@_retry_on_integrity_error
@defer.inlineCallbacks
@log_function
def _persist_event(self, event, context, current_state=None, backfilled=False,
delete_existing=False):
try:
with self._stream_id_gen.get_next() as stream_ordering:
event.internal_metadata.stream_ordering = stream_ordering
yield self.runInteraction(
"persist_event",
self._persist_event_txn,
event=event,
context=context,
current_state=current_state,
backfilled=backfilled,
delete_existing=delete_existing,
)
persist_event_counter.inc()
except _RollbackButIsFineException:
pass
def _calculate_new_extremeties(self, room_id, events):
"""Calculates the new forward extremeties for a room given events to
persist.
Assumes that we are only persisting events for one room at a time.
"""
latest_event_ids = yield self.get_latest_event_ids_in_room(
room_id
)
new_latest_event_ids = set(latest_event_ids)
# First, add all the new events to the list
new_latest_event_ids.update(
event.event_id for event in events
if not event.internal_metadata.is_outlier()
)
# Now remove all events that are referenced by the to-be-added events
new_latest_event_ids.difference_update(
e_id
for event in events
for e_id, _ in event.prev_events
if not event.internal_metadata.is_outlier()
)
# And finally remove any events that are referenced by previously added
# events.
rows = yield self._simple_select_many_batch(
table="event_edges",
column="prev_event_id",
iterable=list(new_latest_event_ids),
retcols=["prev_event_id"],
keyvalues={
"room_id": room_id,
"is_state": False,
},
desc="_calculate_new_extremeties",
)
new_latest_event_ids.difference_update(
row["prev_event_id"] for row in rows
)
defer.returnValue(new_latest_event_ids)
@defer.inlineCallbacks
def _calculate_state_delta(self, room_id, events_context, new_latest_event_ids):
"""Calculate the new state deltas for a room.
Assumes that we are only persisting events for one room at a time.
Returns:
2-tuple (to_delete, to_insert) where both are state dicts, i.e.
(type, state_key) -> event_id. `to_delete` are the entries to
first be deleted from current_state_events, `to_insert` are entries
to insert.
May return None if there are no changes to be applied.
"""
# Now we need to work out the different state sets for
# each state extremities
state_sets = []
missing_event_ids = []
was_updated = False
for event_id in new_latest_event_ids:
# First search in the list of new events we're adding,
# and then use the current state from that
for ev, ctx in events_context:
if event_id == ev.event_id:
if ctx.current_state_ids is None:
raise Exception("Unknown current state")
state_sets.append(ctx.current_state_ids)
if ctx.delta_ids or hasattr(ev, "state_key"):
was_updated = True
break
else:
# If we couldn't find it, then we'll need to pull
# the state from the database
was_updated = True
missing_event_ids.append(event_id)
if missing_event_ids:
# Now pull out the state for any missing events from DB
event_to_groups = yield self._get_state_group_for_events(
missing_event_ids,
)
groups = set(event_to_groups.values())
group_to_state = yield self._get_state_for_groups(groups)
state_sets.extend(group_to_state.values())
if not new_latest_event_ids:
current_state = {}
elif was_updated:
current_state = yield resolve_events(
state_sets,
state_map_factory=lambda ev_ids: self.get_events(
ev_ids, get_prev_content=False, check_redacted=False,
),
)
else:
return
existing_state_rows = yield self._simple_select_list(
table="current_state_events",
keyvalues={"room_id": room_id},
retcols=["event_id", "type", "state_key"],
desc="_calculate_state_delta",
)
existing_events = set(row["event_id"] for row in existing_state_rows)
new_events = set(ev_id for ev_id in current_state.itervalues())
changed_events = existing_events ^ new_events
if not changed_events:
return
to_delete = {
(row["type"], row["state_key"]): row["event_id"]
for row in existing_state_rows
if row["event_id"] in changed_events
}
events_to_insert = (new_events - existing_events)
to_insert = {
key: ev_id for key, ev_id in current_state.iteritems()
if ev_id in events_to_insert
}
defer.returnValue((to_delete, to_insert))
@defer.inlineCallbacks
def get_event(self, event_id, check_redacted=True,
@ -380,53 +513,10 @@ class EventsStore(SQLBaseStore):
defer.returnValue({e.event_id: e for e in events})
@log_function
def _persist_event_txn(self, txn, event, context, current_state, backfilled=False,
delete_existing=False):
# We purposefully do this first since if we include a `current_state`
# key, we *want* to update the `current_state_events` table
if current_state:
txn.call_after(self._get_current_state_for_key.invalidate_all)
txn.call_after(self.get_rooms_for_user.invalidate_all)
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
# Add an entry to the current_state_resets table to record the point
# where we clobbered the current state
stream_order = event.internal_metadata.stream_ordering
self._simple_insert_txn(
txn,
table="current_state_resets",
values={"event_stream_ordering": stream_order}
)
self._simple_delete_txn(
txn,
table="current_state_events",
keyvalues={"room_id": event.room_id},
)
for s in current_state:
self._simple_insert_txn(
txn,
"current_state_events",
{
"event_id": s.event_id,
"room_id": s.room_id,
"type": s.type,
"state_key": s.state_key,
}
)
return self._persist_events_txn(
txn,
[(event, context)],
backfilled=backfilled,
delete_existing=delete_existing,
)
@log_function
def _persist_events_txn(self, txn, events_and_contexts, backfilled,
delete_existing=False):
delete_existing=False, current_state_for_room={},
new_forward_extremeties={}):
"""Insert some number of room events into the necessary database tables.
Rejected events are only inserted into the events table, the events_json table,
@ -436,6 +526,93 @@ class EventsStore(SQLBaseStore):
If delete_existing is True then existing events will be purged from the
database before insertion. This is useful when retrying due to IntegrityError.
"""
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
for room_id, current_state_tuple in current_state_for_room.iteritems():
to_delete, to_insert = current_state_tuple
txn.executemany(
"DELETE FROM current_state_events WHERE event_id = ?",
[(ev_id,) for ev_id in to_delete.itervalues()],
)
self._simple_insert_many_txn(
txn,
table="current_state_events",
values=[
{
"event_id": ev_id,
"room_id": room_id,
"type": key[0],
"state_key": key[1],
}
for key, ev_id in to_insert.iteritems()
],
)
# Invalidate the various caches
# Figure out the changes of membership to invalidate the
# `get_rooms_for_user` cache.
# We find out which membership events we may have deleted
# and which we have added, then we invlidate the caches for all
# those users.
members_changed = set(
state_key for ev_type, state_key in to_delete.iterkeys()
if ev_type == EventTypes.Member
)
members_changed.update(
state_key for ev_type, state_key in to_insert.iterkeys()
if ev_type == EventTypes.Member
)
for member in members_changed:
self._invalidate_cache_and_stream(
txn, self.get_rooms_for_user, (member,)
)
self._invalidate_cache_and_stream(
txn, self.get_users_in_room, (room_id,)
)
for room_id, new_extrem in new_forward_extremeties.items():
self._simple_delete_txn(
txn,
table="event_forward_extremities",
keyvalues={"room_id": room_id},
)
txn.call_after(
self.get_latest_event_ids_in_room.invalidate, (room_id,)
)
self._simple_insert_many_txn(
txn,
table="event_forward_extremities",
values=[
{
"event_id": ev_id,
"room_id": room_id,
}
for room_id, new_extrem in new_forward_extremeties.items()
for ev_id in new_extrem
],
)
# We now insert into stream_ordering_to_exterm a mapping from room_id,
# new stream_ordering to new forward extremeties in the room.
# This allows us to later efficiently look up the forward extremeties
# for a room before a given stream_ordering
self._simple_insert_many_txn(
txn,
table="stream_ordering_to_exterm",
values=[
{
"room_id": room_id,
"event_id": event_id,
"stream_ordering": max_stream_order,
}
for room_id, new_extrem in new_forward_extremeties.items()
for event_id in new_extrem
]
)
# Ensure that we don't have the same event twice.
# Pick the earliest non-outlier if there is one, else the earliest one.
new_events_and_contexts = OrderedDict()
@ -550,7 +727,7 @@ class EventsStore(SQLBaseStore):
# Update the event_backward_extremities table now that this
# event isn't an outlier any more.
self._update_extremeties(txn, [event])
self._update_backward_extremeties(txn, [event])
events_and_contexts = [
ec for ec in events_and_contexts if ec[0] not in to_remove
@ -798,29 +975,6 @@ class EventsStore(SQLBaseStore):
# to update the current state table
return
for event, _ in state_events_and_contexts:
if event.internal_metadata.is_outlier():
# Outlier events shouldn't clobber the current state.
continue
txn.call_after(
self._get_current_state_for_key.invalidate,
(event.room_id, event.type, event.state_key,)
)
self._simple_upsert_txn(
txn,
"current_state_events",
keyvalues={
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
},
values={
"event_id": event.event_id,
}
)
return
def _add_to_cache(self, txn, events_and_contexts):
@ -1084,10 +1238,10 @@ class EventsStore(SQLBaseStore):
self._do_fetch
)
logger.info("Loading %d events", len(events))
logger.debug("Loading %d events", len(events))
with PreserveLoggingContext():
rows = yield events_d
logger.info("Loaded %d events (%d rows)", len(events), len(rows))
logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
if not allow_rejected:
rows[:] = [r for r in rows if not r["rejects"]]
@ -1418,6 +1572,7 @@ class EventsStore(SQLBaseStore):
"""The current minimum token that backfilled events have reached"""
return -self._backfill_id_gen.get_current_token()
@cached(num_args=5, max_entries=10)
def get_all_new_events(self, last_backfill_id, last_forward_id,
current_backfill_id, current_forward_id, limit):
"""Get all the new events that have arrived at the server either as
@ -1449,15 +1604,6 @@ class EventsStore(SQLBaseStore):
else:
upper_bound = current_forward_id
sql = (
"SELECT event_stream_ordering FROM current_state_resets"
" WHERE ? < event_stream_ordering"
" AND event_stream_ordering <= ?"
" ORDER BY event_stream_ordering ASC"
)
txn.execute(sql, (last_forward_id, upper_bound))
state_resets = txn.fetchall()
sql = (
"SELECT event_stream_ordering, event_id, state_group"
" FROM ex_outlier_stream"
@ -1469,7 +1615,6 @@ class EventsStore(SQLBaseStore):
forward_ex_outliers = txn.fetchall()
else:
new_forward_events = []
state_resets = []
forward_ex_outliers = []
sql = (
@ -1509,7 +1654,6 @@ class EventsStore(SQLBaseStore):
return AllNewEventsResult(
new_forward_events, new_backfill_events,
forward_ex_outliers, backward_ex_outliers,
state_resets,
)
return self.runInteraction("get_all_new_events", get_all_new_events_txn)
@ -1735,5 +1879,4 @@ class EventsStore(SQLBaseStore):
AllNewEventsResult = namedtuple("AllNewEventsResult", [
"new_forward_events", "new_backfill_events",
"forward_ex_outliers", "backward_ex_outliers",
"state_resets"
])

View file

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 39
SCHEMA_VERSION = 40
dir_path = os.path.abspath(os.path.dirname(__file__))

View file

@ -413,6 +413,17 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
desc="user_delete_threepids",
)
def user_delete_threepid(self, user_id, medium, address):
return self._simple_delete(
"user_threepids",
keyvalues={
"user_id": user_id,
"medium": medium,
"address": address,
},
desc="user_delete_threepids",
)
@defer.inlineCallbacks
def count_all_users(self):
"""Counts all users registered on the homeserver."""

View file

@ -66,8 +66,6 @@ class RoomMemberStore(SQLBaseStore):
)
for event in events:
txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
txn.call_after(
self._membership_stream_cache.entity_has_changed,
event.state_key, event.internal_metadata.stream_ordering
@ -131,7 +129,7 @@ class RoomMemberStore(SQLBaseStore):
with self._stream_id_gen.get_next() as stream_ordering:
yield self.runInteraction("locally_reject_invite", f, stream_ordering)
@cached(max_entries=5000)
@cached(max_entries=500000, iterable=True)
def get_users_in_room(self, room_id):
def f(txn):
@ -220,7 +218,7 @@ class RoomMemberStore(SQLBaseStore):
" ON e.event_id = c.event_id"
" AND m.room_id = c.room_id"
" AND m.user_id = c.state_key"
" WHERE %s"
" WHERE c.type = 'm.room.member' AND %s"
) % (where_clause,)
txn.execute(sql, args)
@ -266,7 +264,7 @@ class RoomMemberStore(SQLBaseStore):
" ON m.event_id = c.event_id "
" AND m.room_id = c.room_id "
" AND m.user_id = c.state_key"
" WHERE %(where)s"
" WHERE c.type = 'm.room.member' AND %(where)s"
) % {
"where": where_clause,
}
@ -276,12 +274,29 @@ class RoomMemberStore(SQLBaseStore):
return rows
@cached(max_entries=5000)
@cached(max_entries=500000, iterable=True)
def get_rooms_for_user(self, user_id):
return self.get_rooms_for_user_where_membership_is(
user_id, membership_list=[Membership.JOIN],
)
@cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
def get_users_who_share_room_with_user(self, user_id, cache_context):
"""Returns the set of users who share a room with `user_id`
"""
rooms = yield self.get_rooms_for_user(
user_id, on_invalidate=cache_context.invalidate,
)
user_who_share_room = set()
for room in rooms:
user_ids = yield self.get_users_in_room(
room.room_id, on_invalidate=cache_context.invalidate,
)
user_who_share_room.update(user_ids)
defer.returnValue(user_who_share_room)
def forget(self, user_id, room_id):
"""Indicate that user_id wishes to discard history for room_id."""
def f(txn):
@ -390,7 +405,8 @@ class RoomMemberStore(SQLBaseStore):
room_id, state_group, state_ids,
)
@cachedInlineCallbacks(num_args=2, cache_context=True)
@cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True,
max_entries=100000)
def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
cache_context, event=None):
# We don't use `state_group`, it's there so that we can cache based

View file

@ -0,0 +1,17 @@
/* Copyright 2017 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
INSERT INTO background_updates (update_name, progress_json) VALUES
('current_state_members_idx', '{}');

View file

@ -0,0 +1,21 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- turn the pre-fill startup query into a index-only scan on postgresql.
INSERT into background_updates (update_name, progress_json)
VALUES ('device_inbox_stream_index', '{}');
INSERT into background_updates (update_name, progress_json, depends_on)
VALUES ('device_inbox_stream_drop', '{}', 'device_inbox_stream_index');

View file

@ -0,0 +1,59 @@
/* Copyright 2017 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Cache of remote devices.
CREATE TABLE device_lists_remote_cache (
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
content TEXT NOT NULL
);
CREATE INDEX device_lists_remote_cache_id ON device_lists_remote_cache(user_id, device_id);
-- The last update we got for a user. Empty if we're not receiving updates for
-- that user.
CREATE TABLE device_lists_remote_extremeties (
user_id TEXT NOT NULL,
stream_id TEXT NOT NULL
);
CREATE INDEX device_lists_remote_extremeties_id ON device_lists_remote_extremeties(user_id, stream_id);
-- Stream of device lists updates. Includes both local and remotes
CREATE TABLE device_lists_stream (
stream_id BIGINT NOT NULL,
user_id TEXT NOT NULL,
device_id TEXT NOT NULL
);
CREATE INDEX device_lists_stream_id ON device_lists_stream(stream_id, user_id);
-- The stream of updates to send to other servers. We keep at least one row
-- per user that was sent so that the prev_id for any new updates can be
-- calculated
CREATE TABLE device_lists_outbound_pokes (
destination TEXT NOT NULL,
stream_id BIGINT NOT NULL,
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
sent BOOLEAN NOT NULL,
ts BIGINT NOT NULL -- So that in future we can clear out pokes to dead servers
);
CREATE INDEX device_lists_outbound_pokes_id ON device_lists_outbound_pokes(destination, stream_id);
CREATE INDEX device_lists_outbound_pokes_user ON device_lists_outbound_pokes(destination, user_id);

View file

@ -49,6 +49,7 @@ class StateStore(SQLBaseStore):
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
def __init__(self, hs):
super(StateStore, self).__init__(hs)
@ -60,6 +61,13 @@ class StateStore(SQLBaseStore):
self.STATE_GROUP_INDEX_UPDATE_NAME,
self._background_index_state,
)
self.register_background_index_update(
self.CURRENT_STATE_INDEX_UPDATE_NAME,
index_name="current_state_events_member_index",
table="current_state_events",
columns=["state_key"],
where_clause="type='m.room.member'",
)
@defer.inlineCallbacks
def get_state_groups_ids(self, room_id, event_ids):
@ -232,59 +240,7 @@ class StateStore(SQLBaseStore):
return count
@defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key=""):
if event_type and state_key is not None:
result = yield self.get_current_state_for_key(
room_id, event_type, state_key
)
defer.returnValue(result)
def f(txn):
sql = (
"SELECT event_id FROM current_state_events"
" WHERE room_id = ? "
)
if event_type and state_key is not None:
sql += " AND type = ? AND state_key = ? "
args = (room_id, event_type, state_key)
elif event_type:
sql += " AND type = ?"
args = (room_id, event_type)
else:
args = (room_id, )
txn.execute(sql, args)
results = txn.fetchall()
return [r[0] for r in results]
event_ids = yield self.runInteraction("get_current_state", f)
events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events)
@defer.inlineCallbacks
def get_current_state_for_key(self, room_id, event_type, state_key):
event_ids = yield self._get_current_state_for_key(room_id, event_type, state_key)
events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events)
@cached(num_args=3)
def _get_current_state_for_key(self, room_id, event_type, state_key):
def f(txn):
sql = (
"SELECT event_id FROM current_state_events"
" WHERE room_id = ? AND type = ? AND state_key = ?"
)
args = (room_id, event_type, state_key)
txn.execute(sql, args)
results = txn.fetchall()
return [r[0] for r in results]
return self.runInteraction("get_current_state_for_key", f)
@cached(num_args=2, max_entries=1000)
@cached(num_args=2, max_entries=100000, iterable=True)
def _get_state_group_from_group(self, group, types):
raise NotImplementedError()

View file

@ -244,6 +244,20 @@ class StreamStore(SQLBaseStore):
defer.returnValue(results)
def get_rooms_that_changed(self, room_ids, from_key):
"""Given a list of rooms and a token, return rooms where there may have
been changes.
Args:
room_ids (list)
from_key (str): The room_key portion of a StreamToken
"""
from_key = RoomStreamToken.parse_stream_token(from_key).stream
return set(
room_id for room_id in room_ids
if self._events_stream_cache.has_entity_changed(room_id, from_key)
)
@defer.inlineCallbacks
def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0,
order='DESC'):

View file

@ -44,6 +44,7 @@ class EventSources(object):
def get_current_token(self):
push_rules_key, _ = self.store.get_push_rules_stream_token()
to_device_key = self.store.get_to_device_stream_token()
device_list_key = self.store.get_device_stream_token()
token = StreamToken(
room_key=(
@ -63,6 +64,7 @@ class EventSources(object):
),
push_rules_key=push_rules_key,
to_device_key=to_device_key,
device_list_key=device_list_key,
)
defer.returnValue(token)
@ -70,6 +72,7 @@ class EventSources(object):
def get_current_token_for_room(self, room_id):
push_rules_key, _ = self.store.get_push_rules_stream_token()
to_device_key = self.store.get_to_device_stream_token()
device_list_key = self.store.get_device_stream_token()
token = StreamToken(
room_key=(
@ -89,5 +92,6 @@ class EventSources(object):
),
push_rules_key=push_rules_key,
to_device_key=to_device_key,
device_list_key=device_list_key,
)
defer.returnValue(token)

View file

@ -158,6 +158,7 @@ class StreamToken(
"account_data_key",
"push_rules_key",
"to_device_key",
"device_list_key",
))
):
_SEPARATOR = "_"
@ -195,6 +196,7 @@ class StreamToken(
or (int(other.account_data_key) < int(self.account_data_key))
or (int(other.push_rules_key) < int(self.push_rules_key))
or (int(other.to_device_key) < int(self.to_device_key))
or (int(other.device_list_key) < int(self.device_list_key))
)
def copy_and_advance(self, key, new_value):

View file

@ -192,8 +192,11 @@ class Linearizer(object):
logger.info(
"Waiting to acquire linearizer lock %r for key %r", self.name, key
)
with PreserveLoggingContext():
yield current_defer
try:
with PreserveLoggingContext():
yield current_defer
except:
logger.exception("Unexpected exception in Linearizer")
logger.info("Acquired linearizer lock %r for key %r", self.name, key)

View file

@ -40,8 +40,8 @@ def register_cache(name, cache):
)
_string_cache = LruCache(int(5000 * CACHE_SIZE_FACTOR))
caches_by_name["string_cache"] = _string_cache
_string_cache = LruCache(int(100000 * CACHE_SIZE_FACTOR))
_stirng_cache_metrics = register_cache("string_cache", _string_cache)
KNOWN_KEYS = {
@ -69,7 +69,12 @@ KNOWN_KEYS = {
def intern_string(string):
"""Takes a (potentially) unicode string and interns using custom cache
"""
return _string_cache.setdefault(string, string)
new_str = _string_cache.setdefault(string, string)
if new_str is string:
_stirng_cache_metrics.inc_hits()
else:
_stirng_cache_metrics.inc_misses()
return new_str
def intern_dict(dictionary):

View file

@ -17,7 +17,7 @@ import logging
from synapse.util.async import ObservableDeferred
from synapse.util import unwrapFirstError
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
from synapse.util.logcontext import (
PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
)
@ -42,6 +42,25 @@ _CacheSentinel = object()
CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
class CacheEntry(object):
__slots__ = [
"deferred", "sequence", "callbacks", "invalidated"
]
def __init__(self, deferred, sequence, callbacks):
self.deferred = deferred
self.sequence = sequence
self.callbacks = set(callbacks)
self.invalidated = False
def invalidate(self):
if not self.invalidated:
self.invalidated = True
for callback in self.callbacks:
callback()
self.callbacks.clear()
class Cache(object):
__slots__ = (
"cache",
@ -51,12 +70,16 @@ class Cache(object):
"sequence",
"thread",
"metrics",
"_pending_deferred_cache",
)
def __init__(self, name, max_entries=1000, keylen=1, tree=False):
def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False):
cache_type = TreeCache if tree else dict
self._pending_deferred_cache = cache_type()
self.cache = LruCache(
max_size=max_entries, keylen=keylen, cache_type=cache_type
max_size=max_entries, keylen=keylen, cache_type=cache_type,
size_callback=(lambda d: len(d.result)) if iterable else None,
)
self.name = name
@ -76,7 +99,15 @@ class Cache(object):
)
def get(self, key, default=_CacheSentinel, callback=None):
val = self.cache.get(key, _CacheSentinel, callback=callback)
callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _CacheSentinel)
if val is not _CacheSentinel:
if val.sequence == self.sequence:
val.callbacks.update(callbacks)
self.metrics.inc_hits()
return val.deferred
val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
if val is not _CacheSentinel:
self.metrics.inc_hits()
return val
@ -88,15 +119,39 @@ class Cache(object):
else:
return default
def update(self, sequence, key, value, callback=None):
def set(self, key, value, callback=None):
callbacks = [callback] if callback else []
self.check_thread()
if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
self.prefill(key, value, callback=callback)
entry = CacheEntry(
deferred=value,
sequence=self.sequence,
callbacks=callbacks,
)
entry.callbacks.update(callbacks)
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry:
existing_entry.invalidate()
self._pending_deferred_cache[key] = entry
def shuffle(result):
if self.sequence == entry.sequence:
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry is entry:
self.cache.set(key, entry.deferred, entry.callbacks)
else:
entry.invalidate()
else:
entry.invalidate()
return result
entry.deferred.addCallback(shuffle)
def prefill(self, key, value, callback=None):
self.cache.set(key, value, callback=callback)
callbacks = [callback] if callback else []
self.cache.set(key, value, callbacks=callbacks)
def invalidate(self, key):
self.check_thread()
@ -108,6 +163,10 @@ class Cache(object):
# Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1
entry = self._pending_deferred_cache.pop(key, None)
if entry:
entry.invalidate()
self.cache.pop(key, None)
def invalidate_many(self, key):
@ -119,6 +178,11 @@ class Cache(object):
self.sequence += 1
self.cache.del_multi(key)
entry_dict = self._pending_deferred_cache.pop(key, None)
if entry_dict is not None:
for entry in iterate_tree_cache_entry(entry_dict):
entry.invalidate()
def invalidate_all(self):
self.check_thread()
self.sequence += 1
@ -155,7 +219,7 @@ class CacheDescriptor(object):
"""
def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
inlineCallbacks=False, cache_context=False):
inlineCallbacks=False, cache_context=False, iterable=False):
max_entries = int(max_entries * CACHE_SIZE_FACTOR)
self.orig = orig
@ -169,6 +233,8 @@ class CacheDescriptor(object):
self.num_args = num_args
self.tree = tree
self.iterable = iterable
all_args = inspect.getargspec(orig)
self.arg_names = all_args.args[1:num_args + 1]
@ -203,6 +269,7 @@ class CacheDescriptor(object):
max_entries=self.max_entries,
keylen=self.num_args,
tree=self.tree,
iterable=self.iterable,
)
@functools.wraps(self.orig)
@ -243,11 +310,6 @@ class CacheDescriptor(object):
return preserve_context_over_deferred(observer)
except KeyError:
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369)
sequence = cache.sequence
ret = defer.maybeDeferred(
preserve_context_over_fn,
self.function_to_call,
@ -261,7 +323,7 @@ class CacheDescriptor(object):
ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=True)
cache.update(sequence, cache_key, ret, callback=invalidate_callback)
cache.set(cache_key, ret, callback=invalidate_callback)
return preserve_context_over_deferred(ret.observe())
@ -359,7 +421,6 @@ class CacheListDescriptor(object):
missing.append(arg)
if missing:
sequence = cache.sequence
args_to_call = dict(arg_dict)
args_to_call[self.list_name] = missing
@ -382,8 +443,8 @@ class CacheListDescriptor(object):
key = list(keyargs)
key[self.list_pos] = arg
cache.update(
sequence, tuple(key), observer,
cache.set(
tuple(key), observer,
callback=invalidate_callback
)
@ -417,21 +478,29 @@ class CacheListDescriptor(object):
class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
# We rely on _CacheContext implementing __eq__ and __hash__ sensibly,
# which namedtuple does for us (i.e. two _CacheContext are the same if
# their caches and keys match). This is important in particular to
# dedupe when we add callbacks to lru cache nodes, otherwise the number
# of callbacks would grow.
def invalidate(self):
self.cache.invalidate(self.key)
def cached(max_entries=1000, num_args=1, tree=False, cache_context=False):
def cached(max_entries=1000, num_args=1, tree=False, cache_context=False,
iterable=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
tree=tree,
cache_context=cache_context,
iterable=iterable,
)
def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False):
def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False,
iterable=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
@ -439,6 +508,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex
tree=tree,
inlineCallbacks=True,
cache_context=cache_context,
iterable=iterable,
)

View file

@ -23,7 +23,9 @@ import logging
logger = logging.getLogger(__name__)
DictionaryEntry = namedtuple("DictionaryEntry", ("full", "value"))
class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "value"))):
def __len__(self):
return len(self.value)
class DictionaryCache(object):
@ -32,7 +34,7 @@ class DictionaryCache(object):
"""
def __init__(self, name, max_entries=1000):
self.cache = LruCache(max_size=max_entries)
self.cache = LruCache(max_size=max_entries, size_callback=len)
self.name = name
self.sequence = 0

View file

@ -15,6 +15,7 @@
from synapse.util.caches import register_cache
from collections import OrderedDict
import logging
@ -23,7 +24,7 @@ logger = logging.getLogger(__name__)
class ExpiringCache(object):
def __init__(self, cache_name, clock, max_len=0, expiry_ms=0,
reset_expiry_on_get=False):
reset_expiry_on_get=False, iterable=False):
"""
Args:
cache_name (str): Name of this cache, used for logging.
@ -36,6 +37,8 @@ class ExpiringCache(object):
evicted based on time.
reset_expiry_on_get (bool): If true, will reset the expiry time for
an item on access. Defaults to False.
iterable (bool): If true, the size is calculated by summing the
sizes of all entries, rather than the number of entries.
"""
self._cache_name = cache_name
@ -47,9 +50,13 @@ class ExpiringCache(object):
self._reset_expiry_on_get = reset_expiry_on_get
self._cache = {}
self._cache = OrderedDict()
self.metrics = register_cache(cache_name, self._cache)
self.metrics = register_cache(cache_name, self)
self.iterable = iterable
self._size_estimate = 0
def start(self):
if not self._expiry_ms:
@ -65,15 +72,14 @@ class ExpiringCache(object):
now = self._clock.time_msec()
self._cache[key] = _CacheEntry(now, value)
# Evict if there are now too many items
if self._max_len and len(self._cache.keys()) > self._max_len:
sorted_entries = sorted(
self._cache.items(),
key=lambda item: item[1].time,
)
if self.iterable:
self._size_estimate += len(value)
for k, _ in sorted_entries[self._max_len:]:
self._cache.pop(k)
# Evict if there are now too many items
while self._max_len and len(self) > self._max_len:
_key, value = self._cache.popitem(last=False)
if self.iterable:
self._size_estimate -= len(value.value)
def __getitem__(self, key):
try:
@ -99,7 +105,7 @@ class ExpiringCache(object):
# zero expiry time means don't expire. This should never get called
# since we have this check in start too.
return
begin_length = len(self._cache)
begin_length = len(self)
now = self._clock.time_msec()
@ -110,15 +116,20 @@ class ExpiringCache(object):
keys_to_delete.add(key)
for k in keys_to_delete:
self._cache.pop(k)
value = self._cache.pop(k)
if self.iterable:
self._size_estimate -= len(value.value)
logger.debug(
"[%s] _prune_cache before: %d, after len: %d",
self._cache_name, begin_length, len(self._cache)
self._cache_name, begin_length, len(self)
)
def __len__(self):
return len(self._cache)
if self.iterable:
return self._size_estimate
else:
return len(self._cache)
class _CacheEntry(object):

View file

@ -49,7 +49,7 @@ class LruCache(object):
Can also set callbacks on objects when getting/setting which are fired
when that key gets invalidated/evicted.
"""
def __init__(self, max_size, keylen=1, cache_type=dict):
def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None):
cache = cache_type()
self.cache = cache # Used for introspection.
list_root = _Node(None, None, None, None)
@ -58,6 +58,12 @@ class LruCache(object):
lock = threading.Lock()
def evict():
while cache_len() > max_size:
todelete = list_root.prev_node
delete_node(todelete)
cache.pop(todelete.key, None)
def synchronized(f):
@wraps(f)
def inner(*args, **kwargs):
@ -66,6 +72,16 @@ class LruCache(object):
return inner
cached_cache_len = [0]
if size_callback is not None:
def cache_len():
return cached_cache_len[0]
else:
def cache_len():
return len(cache)
self.len = synchronized(cache_len)
def add_node(key, value, callbacks=set()):
prev_node = list_root
next_node = prev_node.next_node
@ -74,6 +90,9 @@ class LruCache(object):
next_node.prev_node = node
cache[key] = node
if size_callback:
cached_cache_len[0] += size_callback(node.value)
def move_node_to_front(node):
prev_node = node.prev_node
next_node = node.next_node
@ -92,23 +111,25 @@ class LruCache(object):
prev_node.next_node = next_node
next_node.prev_node = prev_node
if size_callback:
cached_cache_len[0] -= size_callback(node.value)
for cb in node.callbacks:
cb()
node.callbacks.clear()
@synchronized
def cache_get(key, default=None, callback=None):
def cache_get(key, default=None, callbacks=[]):
node = cache.get(key, None)
if node is not None:
move_node_to_front(node)
if callback:
node.callbacks.add(callback)
node.callbacks.update(callbacks)
return node.value
else:
return default
@synchronized
def cache_set(key, value, callback=None):
def cache_set(key, value, callbacks=[]):
node = cache.get(key, None)
if node is not None:
if value != node.value:
@ -116,21 +137,18 @@ class LruCache(object):
cb()
node.callbacks.clear()
if callback:
node.callbacks.add(callback)
if size_callback:
cached_cache_len[0] -= size_callback(node.value)
cached_cache_len[0] += size_callback(value)
node.callbacks.update(callbacks)
move_node_to_front(node)
node.value = value
else:
if callback:
callbacks = set([callback])
else:
callbacks = set()
add_node(key, value, callbacks)
if len(cache) > max_size:
todelete = list_root.prev_node
delete_node(todelete)
cache.pop(todelete.key, None)
add_node(key, value, set(callbacks))
evict()
@synchronized
def cache_set_default(key, value):
@ -139,10 +157,7 @@ class LruCache(object):
return node.value
else:
add_node(key, value)
if len(cache) > max_size:
todelete = list_root.prev_node
delete_node(todelete)
cache.pop(todelete.key, None)
evict()
return value
@synchronized
@ -174,10 +189,8 @@ class LruCache(object):
for cb in node.callbacks:
cb()
cache.clear()
@synchronized
def cache_len():
return len(cache)
if size_callback:
cached_cache_len[0] = 0
@synchronized
def cache_contains(key):
@ -190,7 +203,7 @@ class LruCache(object):
self.pop = cache_pop
if cache_type is TreeCache:
self.del_multi = cache_del_multi
self.len = cache_len
self.len = synchronized(cache_len)
self.contains = cache_contains
self.clear = cache_clear

View file

@ -65,12 +65,27 @@ class TreeCache(object):
return popped
def values(self):
return [e.value for e in self.root.values()]
return list(iterate_tree_cache_entry(self.root))
def __len__(self):
return self.size
def iterate_tree_cache_entry(d):
"""Helper function to iterate over the leaves of a tree, i.e. a dict of that
can contain dicts.
"""
if isinstance(d, dict):
for value_d in d.itervalues():
for value in iterate_tree_cache_entry(value_d):
yield value
else:
if isinstance(d, _Entry):
yield d.value
else:
yield d
class _Entry(object):
__slots__ = ["value"]

View file

@ -1,71 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer, reactor
from functools import wraps
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
def debug_deferreds():
"""Cause all deferreds to wait for a reactor tick before running their
callbacks. This increases the chance of getting a stack trace out of
a defer.inlineCallback since the code waiting on the deferred will get
a chance to add an errback before the deferred runs."""
# Helper method for retrieving and restoring the current logging context
# around a callback.
def with_logging_context(fn):
context = LoggingContext.current_context()
def restore_context_callback(x):
with PreserveLoggingContext(context):
return fn(x)
return restore_context_callback
# We are going to modify the __init__ method of defer.Deferred so we
# need to get a copy of the old method so we can still call it.
old__init__ = defer.Deferred.__init__
# We need to create a deferred to bounce the callbacks through the reactor
# but we don't want to add a callback when we create that deferred so we
# we create a new type of deferred that uses the old __init__ method.
# This is safe as long as the old __init__ method doesn't invoke an
# __init__ using super.
class Bouncer(defer.Deferred):
__init__ = old__init__
# We'll add this as a callback to all Deferreds. Twisted will wait until
# the bouncer deferred resolves before calling the callbacks of the
# original deferred.
def bounce_callback(x):
bouncer = Bouncer()
reactor.callLater(0, with_logging_context(bouncer.callback), x)
return bouncer
# We'll add this as an errback to all Deferreds. Twisted will wait until
# the bouncer deferred resolves before calling the errbacks of the
# original deferred.
def bounce_errback(x):
bouncer = Bouncer()
reactor.callLater(0, with_logging_context(bouncer.errback), x)
return bouncer
@wraps(old__init__)
def new__init__(self, *args, **kargs):
old__init__(self, *args, **kargs)
self.addCallbacks(bounce_callback, bounce_errback)
defer.Deferred.__init__ = new__init__

View file

@ -88,7 +88,7 @@ class RetryDestinationLimiter(object):
def __init__(self, destination, clock, store, retry_interval,
min_retry_interval=10 * 60 * 1000,
max_retry_interval=24 * 60 * 60 * 1000,
multiplier_retry_interval=5,):
multiplier_retry_interval=5, backoff_on_404=False):
"""Marks the destination as "down" if an exception is thrown in the
context, except for CodeMessageException with code < 500.
@ -107,6 +107,7 @@ class RetryDestinationLimiter(object):
a failed request, in milliseconds.
multiplier_retry_interval (int): The multiplier to use to increase
the retry interval after a failed request.
backoff_on_404 (bool): Back off if we get a 404
"""
self.clock = clock
self.store = store
@ -116,6 +117,7 @@ class RetryDestinationLimiter(object):
self.min_retry_interval = min_retry_interval
self.max_retry_interval = max_retry_interval
self.multiplier_retry_interval = multiplier_retry_interval
self.backoff_on_404 = backoff_on_404
def __enter__(self):
pass
@ -123,7 +125,22 @@ class RetryDestinationLimiter(object):
def __exit__(self, exc_type, exc_val, exc_tb):
valid_err_code = False
if exc_type is not None and issubclass(exc_type, CodeMessageException):
valid_err_code = exc_val.code != 429 and 0 <= exc_val.code < 500
# Some error codes are perfectly fine for some APIs, whereas other
# APIs may expect to never received e.g. a 404. It's important to
# handle 404 as some remote servers will return a 404 when the HS
# has been decommissioned.
# If we get a 401, then we should probably back off since they
# won't accept our requests for at least a while.
# 429 is us being aggresively rate limited, so lets rate limit
# ourselves.
if exc_val.code == 404 and self.backoff_on_404:
valid_err_code = False
elif exc_val.code in (401, 429):
valid_err_code = False
elif exc_val.code < 500:
valid_err_code = True
else:
valid_err_code = False
if exc_type is None or valid_err_code:
# We connected successfully.

View file

@ -25,10 +25,13 @@ from synapse.api.filtering import Filter
from synapse.events import FrozenEvent
user_localpart = "test_user"
# MockEvent = namedtuple("MockEvent", "sender type room_id")
def MockEvent(**kwargs):
if "event_id" not in kwargs:
kwargs["event_id"] = "fake_event_id"
if "type" not in kwargs:
kwargs["type"] = "fake_type"
return FrozenEvent(kwargs)

View file

@ -21,6 +21,10 @@ from synapse.events.utils import prune_event, serialize_event
def MockEvent(**kwargs):
if "event_id" not in kwargs:
kwargs["event_id"] = "fake_event_id"
if "type" not in kwargs:
kwargs["type"] = "fake_type"
return FrozenEvent(kwargs)
@ -35,9 +39,13 @@ class PruneEventTestCase(unittest.TestCase):
def test_minimal(self):
self.run_test(
{'type': 'A'},
{
'type': 'A',
'event_id': '$test:domain',
},
{
'type': 'A',
'event_id': '$test:domain',
'content': {},
'signatures': {},
'unsigned': {},
@ -69,10 +77,12 @@ class PruneEventTestCase(unittest.TestCase):
self.run_test(
{
'type': 'B',
'event_id': '$test:domain',
'unsigned': {'age_ts': 20},
},
{
'type': 'B',
'event_id': '$test:domain',
'content': {},
'signatures': {},
'unsigned': {'age_ts': 20},
@ -82,10 +92,12 @@ class PruneEventTestCase(unittest.TestCase):
self.run_test(
{
'type': 'B',
'event_id': '$test:domain',
'unsigned': {'other_key': 'here'},
},
{
'type': 'B',
'event_id': '$test:domain',
'content': {},
'signatures': {},
'unsigned': {},
@ -96,10 +108,12 @@ class PruneEventTestCase(unittest.TestCase):
self.run_test(
{
'type': 'C',
'event_id': '$test:domain',
'content': {'things': 'here'},
},
{
'type': 'C',
'event_id': '$test:domain',
'content': {},
'signatures': {},
'unsigned': {},
@ -109,10 +123,12 @@ class PruneEventTestCase(unittest.TestCase):
self.run_test(
{
'type': 'm.room.create',
'event_id': '$test:domain',
'content': {'creator': '@2:domain', 'other_field': 'here'},
},
{
'type': 'm.room.create',
'event_id': '$test:domain',
'content': {'creator': '@2:domain'},
'signatures': {},
'unsigned': {},
@ -255,6 +271,8 @@ class SerializeEventTestCase(unittest.TestCase):
self.assertEquals(
self.serialize(
MockEvent(
type="foo",
event_id="test",
room_id="!foo:bar",
content={
"foo": "bar",
@ -263,6 +281,8 @@ class SerializeEventTestCase(unittest.TestCase):
[]
),
{
"type": "foo",
"event_id": "test",
"room_id": "!foo:bar",
"content": {
"foo": "bar",

View file

@ -34,11 +34,10 @@ class AuthTestCase(unittest.TestCase):
self.hs = yield setup_test_homeserver(handlers=None)
self.hs.handlers = AuthHandlers(self.hs)
self.auth_handler = self.hs.handlers.auth_handler
self.macaroon_generator = self.hs.get_macaroon_generator()
def test_token_is_a_macaroon(self):
self.hs.config.macaroon_secret_key = "this key is a huge secret"
token = self.auth_handler.generate_access_token("some_user")
token = self.macaroon_generator.generate_access_token("some_user")
# Check that we can parse the thing with pymacaroons
macaroon = pymacaroons.Macaroon.deserialize(token)
# The most basic of sanity checks
@ -46,10 +45,9 @@ class AuthTestCase(unittest.TestCase):
self.fail("some_user was not in %s" % macaroon.inspect())
def test_macaroon_caveats(self):
self.hs.config.macaroon_secret_key = "this key is a massive secret"
self.hs.clock.now = 5000
token = self.auth_handler.generate_access_token("a_user")
token = self.macaroon_generator.generate_access_token("a_user")
macaroon = pymacaroons.Macaroon.deserialize(token)
def verify_gen(caveat):
@ -74,7 +72,7 @@ class AuthTestCase(unittest.TestCase):
def test_short_term_login_token_gives_user_id(self):
self.hs.clock.now = 1000
token = self.auth_handler.generate_short_term_login_token(
token = self.macaroon_generator.generate_short_term_login_token(
"a_user", 5000
)
@ -93,7 +91,7 @@ class AuthTestCase(unittest.TestCase):
)
def test_short_term_login_token_cannot_replace_user_id(self):
token = self.auth_handler.generate_short_term_login_token(
token = self.macaroon_generator.generate_short_term_login_token(
"a_user", 5000
)
macaroon = pymacaroons.Macaroon.deserialize(token)

View file

@ -35,51 +35,51 @@ class DeviceTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
hs = yield utils.setup_test_homeserver(handlers=None)
self.handler = synapse.handlers.device.DeviceHandler(hs)
hs = yield utils.setup_test_homeserver()
self.handler = hs.get_device_handler()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
@defer.inlineCallbacks
def test_device_is_created_if_doesnt_exist(self):
res = yield self.handler.check_device_registered(
user_id="boris",
user_id="@boris:foo",
device_id="fco",
initial_device_display_name="display name"
)
self.assertEqual(res, "fco")
dev = yield self.handler.store.get_device("boris", "fco")
dev = yield self.handler.store.get_device("@boris:foo", "fco")
self.assertEqual(dev["display_name"], "display name")
@defer.inlineCallbacks
def test_device_is_preserved_if_exists(self):
res1 = yield self.handler.check_device_registered(
user_id="boris",
user_id="@boris:foo",
device_id="fco",
initial_device_display_name="display name"
)
self.assertEqual(res1, "fco")
res2 = yield self.handler.check_device_registered(
user_id="boris",
user_id="@boris:foo",
device_id="fco",
initial_device_display_name="new display name"
)
self.assertEqual(res2, "fco")
dev = yield self.handler.store.get_device("boris", "fco")
dev = yield self.handler.store.get_device("@boris:foo", "fco")
self.assertEqual(dev["display_name"], "display name")
@defer.inlineCallbacks
def test_device_id_is_made_up_if_unspecified(self):
device_id = yield self.handler.check_device_registered(
user_id="theresa",
user_id="@theresa:foo",
device_id=None,
initial_device_display_name="display"
)
dev = yield self.handler.store.get_device("theresa", device_id)
dev = yield self.handler.store.get_device("@theresa:foo", device_id)
self.assertEqual(dev["display_name"], "display")
@defer.inlineCallbacks

View file

@ -37,6 +37,7 @@ class DirectoryTestCase(unittest.TestCase):
def setUp(self):
self.mock_federation = Mock(spec=[
"make_query",
"register_edu_handler",
])
self.query_handlers = {}

View file

@ -39,6 +39,7 @@ class ProfileTestCase(unittest.TestCase):
def setUp(self):
self.mock_federation = Mock(spec=[
"make_query",
"register_edu_handler",
])
self.query_handlers = {}

View file

@ -41,15 +41,12 @@ class RegistrationTestCase(unittest.TestCase):
handlers=None,
http_client=None,
expire_access_token=True)
self.auth_handler = Mock(
self.macaroon_generator = Mock(
generate_access_token=Mock(return_value='secret'))
self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator)
self.hs.handlers = RegistrationHandlers(self.hs)
self.handler = self.hs.get_handlers().registration_handler
self.hs.get_handlers().profile_handler = Mock()
self.mock_handler = Mock(spec=[
"generate_access_token",
])
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
@defer.inlineCallbacks
def test_user_is_created_and_logged_in_if_doesnt_exist(self):

View file

@ -75,6 +75,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
"get_received_txn_response",
"set_received_txn_response",
"get_destination_retry_timings",
"get_devices_by_remote",
]),
state_handler=self.state_handler,
handlers=None,
@ -99,6 +100,8 @@ class TypingNotificationsTestCase(unittest.TestCase):
defer.succeed(retry_timings_res)
)
self.datastore.get_devices_by_remote.return_value = (0, [])
def get_received_txn_response(*args):
return defer.succeed(None)
self.datastore.get_received_txn_response = get_received_txn_response

View file

@ -58,53 +58,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
def tearDown(self):
[unpatch() for unpatch in self.unpatches]
@defer.inlineCallbacks
def test_room_members(self):
create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.replicate()
yield self.check("get_rooms_for_user", (USER_ID,), [])
yield self.check("get_users_in_room", (ROOM_ID,), [])
# Join the room.
join = yield self.persist(type="m.room.member", key=USER_ID, membership="join")
yield self.replicate()
yield self.check("get_rooms_for_user", (USER_ID,), [RoomsForUser(
room_id=ROOM_ID,
sender=USER_ID,
membership="join",
event_id=join.event_id,
stream_ordering=join.internal_metadata.stream_ordering,
)])
yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID])
# Leave the room.
yield self.persist(type="m.room.member", key=USER_ID, membership="leave")
yield self.replicate()
yield self.check("get_rooms_for_user", (USER_ID,), [])
yield self.check("get_users_in_room", (ROOM_ID,), [])
# Add some other user to the room.
join = yield self.persist(type="m.room.member", key=USER_ID_2, membership="join")
yield self.replicate()
yield self.check("get_rooms_for_user", (USER_ID_2,), [RoomsForUser(
room_id=ROOM_ID,
sender=USER_ID,
membership="join",
event_id=join.event_id,
stream_ordering=join.internal_metadata.stream_ordering,
)])
yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID_2])
# Join the room clobbering the state.
# This should remove any evidence of the other user being in the room.
yield self.persist(
type="m.room.member", key=USER_ID, membership="join",
reset_state=[create]
)
yield self.replicate()
yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID])
yield self.check("get_rooms_for_user", (USER_ID_2,), [])
@defer.inlineCallbacks
def test_get_latest_event_ids_in_room(self):
create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
@ -122,51 +75,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
"get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id]
)
@defer.inlineCallbacks
def test_get_current_state(self):
# Create the room.
create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.replicate()
yield self.check(
"get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID), []
)
# Join the room.
join1 = yield self.persist(
type="m.room.member", key=USER_ID, membership="join",
)
yield self.replicate()
yield self.check(
"get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID),
[join1]
)
# Add some other user to the room.
join2 = yield self.persist(
type="m.room.member", key=USER_ID_2, membership="join",
)
yield self.replicate()
yield self.check(
"get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID_2),
[join2]
)
# Leave the room, then rejoin the room clobbering state.
yield self.persist(type="m.room.member", key=USER_ID, membership="leave")
join3 = yield self.persist(
type="m.room.member", key=USER_ID, membership="join",
reset_state=[create]
)
yield self.replicate()
yield self.check(
"get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID_2),
[]
)
yield self.check(
"get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID),
[join3]
)
@defer.inlineCallbacks
def test_redactions(self):
yield self.persist(type="m.room.create", key="", creator=USER_ID)
@ -283,6 +191,12 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
if depth is None:
depth = self.event_id
if not prev_events:
latest_event_ids = yield self.master_store.get_latest_event_ids_in_room(
room_id
)
prev_events = [(ev_id, {}) for ev_id in latest_event_ids]
event_dict = {
"sender": sender,
"type": type,
@ -309,12 +223,15 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
state_ids = {
key: e.event_id for key, e in state.items()
}
context = EventContext()
context.current_state_ids = state_ids
context.prev_state_ids = state_ids
elif not backfill:
state_handler = self.hs.get_state_handler()
context = yield state_handler.compute_event_context(event)
else:
state_ids = None
context = EventContext()
context = EventContext()
context.current_state_ids = state_ids
context.prev_state_ids = state_ids
context.push_actions = push_actions
ordering = None
@ -324,7 +241,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
)
else:
ordering, _ = yield self.master_store.persist_event(
event, context, current_state=reset_state
event, context,
)
if ordering:

View file

@ -259,8 +259,8 @@ class RoomPermissionsTestCase(RestTestCase):
# set [invite/join/left] of self, set [invite/join/left] of other,
# expect all 404s because room doesn't exist on any server
for usr in [self.user_id, self.rmcreator_id]:
yield self.join(room=room, user=usr, expect_code=403)
yield self.leave(room=room, user=usr, expect_code=403)
yield self.join(room=room, user=usr, expect_code=404)
yield self.leave(room=room, user=usr, expect_code=404)
@defer.inlineCallbacks
def test_membership_private_room_perms(self):
@ -1032,7 +1032,7 @@ class RoomMessageListTestCase(RestTestCase):
@defer.inlineCallbacks
def test_topo_token_is_accepted(self):
token = "t1-0_0_0_0_0_0_0"
token = "t1-0_0_0_0_0_0_0_0"
(code, response) = yield self.mock_resource.trigger_get(
"/rooms/%s/messages?access_token=x&from=%s" %
(self.room_id, token))
@ -1044,7 +1044,7 @@ class RoomMessageListTestCase(RestTestCase):
@defer.inlineCallbacks
def test_stream_token_is_accepted_for_fwd_pagianation(self):
token = "s0_0_0_0_0_0_0"
token = "s0_0_0_0_0_0_0_0"
(code, response) = yield self.mock_resource.trigger_get(
"/rooms/%s/messages?access_token=x&from=%s" %
(self.room_id, token))

View file

@ -87,7 +87,10 @@ class RestTestCase(unittest.TestCase):
(code, response) = yield self.mock_resource.trigger(
"PUT", path, json.dumps(data)
)
self.assertEquals(expect_code, code, msg=str(response))
self.assertEquals(
expect_code, code,
msg="Expected: %d, got: %d, resp: %r" % (expect_code, code, response)
)
self.auth_user_id = temp_id

View file

@ -241,7 +241,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
callcount2 = [0]
class A(object):
@cached(max_entries=2)
@cached(max_entries=20) # HACK: This makes it 2 due to cache factor
def func(self, key):
callcount[0] += 1
return key
@ -258,6 +258,10 @@ class CacheDecoratorTestCase(unittest.TestCase):
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
yield a.func2("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
yield a.func("foo3")
self.assertEquals(callcount[0], 3)

View file

@ -39,7 +39,11 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
event_cache_size=1,
password_providers=[],
)
hs = yield setup_test_homeserver(config=config, federation_sender=Mock())
hs = yield setup_test_homeserver(
config=config,
federation_sender=Mock(),
replication_layer=Mock(),
)
self.as_token = "token1"
self.as_url = "some_url"
@ -112,7 +116,11 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
event_cache_size=1,
password_providers=[],
)
hs = yield setup_test_homeserver(config=config, federation_sender=Mock())
hs = yield setup_test_homeserver(
config=config,
federation_sender=Mock(),
replication_layer=Mock(),
)
self.db_pool = hs.get_db_pool()
self.as_list = [
@ -446,7 +454,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs = yield setup_test_homeserver(
config=config,
datastore=Mock(),
federation_sender=Mock()
federation_sender=Mock(),
replication_layer=Mock(),
)
ApplicationServiceStore(hs)
@ -463,7 +472,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs = yield setup_test_homeserver(
config=config,
datastore=Mock(),
federation_sender=Mock()
federation_sender=Mock(),
replication_layer=Mock(),
)
with self.assertRaises(ConfigError) as cm:
@ -486,7 +496,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs = yield setup_test_homeserver(
config=config,
datastore=Mock(),
federation_sender=Mock()
federation_sender=Mock(),
replication_layer=Mock(),
)
with self.assertRaises(ConfigError) as cm:

View file

@ -33,7 +33,11 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def test_key_without_device_name(self):
now = 1470174257070
json = '{ "key": "value" }'
json = {"key": "value"}
yield self.store.store_device(
"user", "device", None
)
yield self.store.set_e2e_device_keys(
"user", "device", now, json)
@ -43,14 +47,14 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
self.assertIn("device", res["user"])
dev = res["user"]["device"]
self.assertDictContainsSubset({
"key_json": json,
"keys": json,
"device_display_name": None,
}, dev)
@defer.inlineCallbacks
def test_get_key_with_device_name(self):
now = 1470174257070
json = '{ "key": "value" }'
json = {"key": "value"}
yield self.store.set_e2e_device_keys(
"user", "device", now, json)
@ -63,7 +67,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
self.assertIn("device", res["user"])
dev = res["user"]["device"]
self.assertDictContainsSubset({
"key_json": json,
"keys": json,
"device_display_name": "display_name",
}, dev)
@ -71,6 +75,19 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
def test_multiple_devices(self):
now = 1470174257070
yield self.store.store_device(
"user1", "device1", None
)
yield self.store.store_device(
"user1", "device2", None
)
yield self.store.store_device(
"user2", "device1", None
)
yield self.store.store_device(
"user2", "device2", None
)
yield self.store.set_e2e_device_keys(
"user1", "device1", now, 'json11')
yield self.store.set_e2e_device_keys(

Some files were not shown because too many files have changed in this diff Show more