0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-15 10:23:53 +01:00

Merge branch 'release-v0.17.2' of github.com:matrix-org/synapse

This commit is contained in:
Erik Johnston 2016-09-08 15:26:26 +01:00
commit 5834c6178c
60 changed files with 1842 additions and 679 deletions

View file

@ -1,3 +1,40 @@
Changes in synapse v0.17.2 (2016-09-08)
=======================================
This release contains security bug fixes. Please upgrade.
No changes since v0.17.2
Changes in synapse v0.17.2-rc1 (2016-09-05)
===========================================
Features:
* Start adding store-and-forward direct-to-device messaging (PR #1046, #1050,
#1062, #1066)
Changes:
* Avoid pulling the full state of a room out so often (PR #1047, #1049, #1063,
#1068)
* Don't notify for online to online presence transitions. (PR #1054)
* Occasionally persist unpersisted presence updates (PR #1055)
* Allow application services to have an optional 'url' (PR #1056)
* Clean up old sent transactions from DB (PR #1059)
Bug fixes:
* Fix None check in backfill (PR #1043)
* Fix membership changes to be idempotent (PR #1067)
* Fix bug in get_pdu where it would sometimes return events with incorrect
signature
Changes in synapse v0.17.1 (2016-08-24) Changes in synapse v0.17.1 (2016-08-24)
======================================= =======================================

View file

@ -134,6 +134,12 @@ Installing prerequisites on Raspbian::
sudo pip install --upgrade ndg-httpsclient sudo pip install --upgrade ndg-httpsclient
sudo pip install --upgrade virtualenv sudo pip install --upgrade virtualenv
Installing prerequisites on openSUSE::
sudo zypper in -t pattern devel_basis
sudo zypper in python-pip python-setuptools sqlite3 python-virtualenv \
python-devel libffi-devel libopenssl-devel libjpeg62-devel
To install the synapse homeserver run:: To install the synapse homeserver run::
virtualenv -p python2.7 ~/.synapse virtualenv -p python2.7 ~/.synapse
@ -199,6 +205,21 @@ run (e.g. ``~/.synapse``), and::
source ./bin/activate source ./bin/activate
synctl start synctl start
Security Note
=============
Matrix serves raw user generated data in some APIs - specifically the content
repository endpoints: http://matrix.org/docs/spec/client_server/r0.2.0.html#get-matrix-media-r0-download-servername-mediaid
Whilst we have tried to mitigate against possible XSS attacks (e.g.
https://github.com/matrix-org/synapse/pull/1021) we recommend running
matrix homeservers on a dedicated domain name, to limit any malicious user generated
content served to web browsers a matrix API from being able to attack webapps hosted
on the same domain. This is particularly true of sharing a matrix webclient and
server on the same domain.
See https://github.com/vector-im/vector-web/issues/1977 and
https://developer.github.com/changes/2014-04-25-user-content-security for more details.
Using PostgreSQL Using PostgreSQL
================ ================
@ -215,9 +236,6 @@ The advantages of Postgres include:
pointing at the same DB master, as well as enabling DB replication in pointing at the same DB master, as well as enabling DB replication in
synapse itself. synapse itself.
The only disadvantage is that the code is relatively new as of April 2015 and
may have a few regressions relative to SQLite.
For information on how to install and use PostgreSQL, please see For information on how to install and use PostgreSQL, please see
`docs/postgres.rst <docs/postgres.rst>`_. `docs/postgres.rst <docs/postgres.rst>`_.

View file

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.17.1" __version__ = "0.17.2"

View file

@ -52,7 +52,7 @@ class Auth(object):
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
# Docs for these currently lives at # Docs for these currently lives at
# https://github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst # github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst
# In addition, we have type == delete_pusher which grants access only to # In addition, we have type == delete_pusher which grants access only to
# delete pushers. # delete pushers.
self._KNOWN_CAVEAT_PREFIXES = set([ self._KNOWN_CAVEAT_PREFIXES = set([
@ -63,6 +63,17 @@ class Auth(object):
"user_id = ", "user_id = ",
]) ])
@defer.inlineCallbacks
def check_from_context(self, event, context, do_sig_check=True):
auth_events_ids = yield self.compute_auth_events(
event, context.prev_state_ids, for_verification=True,
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {
(e.type, e.state_key): e for e in auth_events.values()
}
self.check(event, auth_events=auth_events, do_sig_check=False)
def check(self, event, auth_events, do_sig_check=True): def check(self, event, auth_events, do_sig_check=True):
""" Checks if this event is correctly authed. """ Checks if this event is correctly authed.
@ -267,21 +278,17 @@ class Auth(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def check_host_in_room(self, room_id, host): def check_host_in_room(self, room_id, host):
curr_state = yield self.state.get_current_state(room_id) with Measure(self.clock, "check_host_in_room"):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
for event in curr_state.values(): entry = yield self.state.resolve_state_groups(
if event.type == EventTypes.Member: room_id, latest_event_ids
try: )
if get_domain_from_id(event.state_key) != host:
continue
except:
logger.warn("state_key not user_id: %s", event.state_key)
continue
if event.content["membership"] == Membership.JOIN: ret = yield self.store.is_host_joined(
defer.returnValue(True) room_id, host, entry.state_group, entry.state
)
defer.returnValue(False) defer.returnValue(ret)
def check_event_sender_in_room(self, event, auth_events): def check_event_sender_in_room(self, event, auth_events):
key = (EventTypes.Member, event.user_id, ) key = (EventTypes.Member, event.user_id, )
@ -847,7 +854,7 @@ class Auth(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def add_auth_events(self, builder, context): def add_auth_events(self, builder, context):
auth_ids = self.compute_auth_events(builder, context.current_state) auth_ids = yield self.compute_auth_events(builder, context.prev_state_ids)
auth_events_entries = yield self.store.add_event_hashes( auth_events_entries = yield self.store.add_event_hashes(
auth_ids auth_ids
@ -855,30 +862,32 @@ class Auth(object):
builder.auth_events = auth_events_entries builder.auth_events = auth_events_entries
def compute_auth_events(self, event, current_state): @defer.inlineCallbacks
def compute_auth_events(self, event, current_state_ids, for_verification=False):
if event.type == EventTypes.Create: if event.type == EventTypes.Create:
return [] defer.returnValue([])
auth_ids = [] auth_ids = []
key = (EventTypes.PowerLevels, "", ) key = (EventTypes.PowerLevels, "", )
power_level_event = current_state.get(key) power_level_event_id = current_state_ids.get(key)
if power_level_event: if power_level_event_id:
auth_ids.append(power_level_event.event_id) auth_ids.append(power_level_event_id)
key = (EventTypes.JoinRules, "", ) key = (EventTypes.JoinRules, "", )
join_rule_event = current_state.get(key) join_rule_event_id = current_state_ids.get(key)
key = (EventTypes.Member, event.user_id, ) key = (EventTypes.Member, event.user_id, )
member_event = current_state.get(key) member_event_id = current_state_ids.get(key)
key = (EventTypes.Create, "", ) key = (EventTypes.Create, "", )
create_event = current_state.get(key) create_event_id = current_state_ids.get(key)
if create_event: if create_event_id:
auth_ids.append(create_event.event_id) auth_ids.append(create_event_id)
if join_rule_event: if join_rule_event_id:
join_rule_event = yield self.store.get_event(join_rule_event_id)
join_rule = join_rule_event.content.get("join_rule") join_rule = join_rule_event.content.get("join_rule")
is_public = join_rule == JoinRules.PUBLIC if join_rule else False is_public = join_rule == JoinRules.PUBLIC if join_rule else False
else: else:
@ -887,15 +896,21 @@ class Auth(object):
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
e_type = event.content["membership"] e_type = event.content["membership"]
if e_type in [Membership.JOIN, Membership.INVITE]: if e_type in [Membership.JOIN, Membership.INVITE]:
if join_rule_event: if join_rule_event_id:
auth_ids.append(join_rule_event.event_id) auth_ids.append(join_rule_event_id)
if e_type == Membership.JOIN: if e_type == Membership.JOIN:
if member_event and not is_public: if member_event_id and not is_public:
auth_ids.append(member_event.event_id) auth_ids.append(member_event_id)
else: else:
if member_event: if member_event_id:
auth_ids.append(member_event.event_id) auth_ids.append(member_event_id)
if for_verification:
key = (EventTypes.Member, event.state_key, )
existing_event_id = current_state_ids.get(key)
if existing_event_id:
auth_ids.append(existing_event_id)
if e_type == Membership.INVITE: if e_type == Membership.INVITE:
if "third_party_invite" in event.content: if "third_party_invite" in event.content:
@ -903,14 +918,15 @@ class Auth(object):
EventTypes.ThirdPartyInvite, EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"] event.content["third_party_invite"]["signed"]["token"]
) )
third_party_invite = current_state.get(key) third_party_invite_id = current_state_ids.get(key)
if third_party_invite: if third_party_invite_id:
auth_ids.append(third_party_invite.event_id) auth_ids.append(third_party_invite_id)
elif member_event: elif member_event_id:
member_event = yield self.store.get_event(member_event_id)
if member_event.content["membership"] == Membership.JOIN: if member_event.content["membership"] == Membership.JOIN:
auth_ids.append(member_event.event_id) auth_ids.append(member_event.event_id)
return auth_ids defer.returnValue(auth_ids)
def _get_send_level(self, etype, state_key, auth_events): def _get_send_level(self, etype, state_key, auth_events):
key = (EventTypes.PowerLevels, "", ) key = (EventTypes.PowerLevels, "", )

View file

@ -85,3 +85,8 @@ class RoomCreationPreset(object):
PRIVATE_CHAT = "private_chat" PRIVATE_CHAT = "private_chat"
PUBLIC_CHAT = "public_chat" PUBLIC_CHAT = "public_chat"
TRUSTED_PRIVATE_CHAT = "trusted_private_chat" TRUSTED_PRIVATE_CHAT = "trusted_private_chat"
class ThirdPartyEntityKind(object):
USER = "user"
LOCATION = "location"

View file

@ -25,4 +25,3 @@ SERVER_KEY_PREFIX = "/_matrix/key/v1"
SERVER_KEY_V2_PREFIX = "/_matrix/key/v2" SERVER_KEY_V2_PREFIX = "/_matrix/key/v2"
MEDIA_PREFIX = "/_matrix/media/r0" MEDIA_PREFIX = "/_matrix/media/r0"
LEGACY_MEDIA_PREFIX = "/_matrix/media/v1" LEGACY_MEDIA_PREFIX = "/_matrix/media/v1"
APP_SERVICE_PREFIX = "/_matrix/appservice/v1"

View file

@ -36,6 +36,7 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
from synapse.replication.slave.storage.filtering import SlavedFilteringStore from synapse.replication.slave.storage.filtering import SlavedFilteringStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.presence import SlavedPresenceStore from synapse.replication.slave.storage.presence import SlavedPresenceStore
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore from synapse.storage.client_ips import ClientIpStore
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
@ -72,6 +73,7 @@ class SynchrotronSlavedStore(
SlavedRegistrationStore, SlavedRegistrationStore,
SlavedFilteringStore, SlavedFilteringStore,
SlavedPresenceStore, SlavedPresenceStore,
SlavedDeviceInboxStore,
BaseSlavedStore, BaseSlavedStore,
ClientIpStore, # After BaseSlavedStore because the constructor is different ClientIpStore, # After BaseSlavedStore because the constructor is different
): ):
@ -397,6 +399,9 @@ class SynchrotronServer(HomeServer):
notify_from_stream( notify_from_stream(
result, "typing", "typing_key", room="room_id" result, "typing", "typing_key", room="room_id"
) )
notify_from_stream(
result, "to_device", "to_device_key", user="user_id"
)
while True: while True:
try: try:

View file

@ -88,6 +88,8 @@ class ApplicationService(object):
self.sender = sender self.sender = sender
self.namespaces = self._check_namespaces(namespaces) self.namespaces = self._check_namespaces(namespaces)
self.id = id self.id = id
# .protocols is a publicly visible field
if protocols: if protocols:
self.protocols = set(protocols) self.protocols = set(protocols)
else: else:

View file

@ -14,10 +14,11 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import ThirdPartyEntityKind
from synapse.api.errors import CodeMessageException from synapse.api.errors import CodeMessageException
from synapse.http.client import SimpleHttpClient from synapse.http.client import SimpleHttpClient
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.types import ThirdPartyEntityKind from synapse.util.caches.response_cache import ResponseCache
import logging import logging
import urllib import urllib
@ -25,6 +26,12 @@ import urllib
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
HOUR_IN_MS = 60 * 60 * 1000
APP_SERVICE_PREFIX = "/_matrix/app/unstable"
def _is_valid_3pe_result(r, field): def _is_valid_3pe_result(r, field):
if not isinstance(r, dict): if not isinstance(r, dict):
return False return False
@ -56,8 +63,12 @@ class ApplicationServiceApi(SimpleHttpClient):
super(ApplicationServiceApi, self).__init__(hs) super(ApplicationServiceApi, self).__init__(hs)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.protocol_meta_cache = ResponseCache(hs, timeout_ms=HOUR_IN_MS)
@defer.inlineCallbacks @defer.inlineCallbacks
def query_user(self, service, user_id): def query_user(self, service, user_id):
if service.url is None:
defer.returnValue(False)
uri = service.url + ("/users/%s" % urllib.quote(user_id)) uri = service.url + ("/users/%s" % urllib.quote(user_id))
response = None response = None
try: try:
@ -77,6 +88,8 @@ class ApplicationServiceApi(SimpleHttpClient):
@defer.inlineCallbacks @defer.inlineCallbacks
def query_alias(self, service, alias): def query_alias(self, service, alias):
if service.url is None:
defer.returnValue(False)
uri = service.url + ("/rooms/%s" % urllib.quote(alias)) uri = service.url + ("/rooms/%s" % urllib.quote(alias))
response = None response = None
try: try:
@ -97,16 +110,22 @@ class ApplicationServiceApi(SimpleHttpClient):
@defer.inlineCallbacks @defer.inlineCallbacks
def query_3pe(self, service, kind, protocol, fields): def query_3pe(self, service, kind, protocol, fields):
if kind == ThirdPartyEntityKind.USER: if kind == ThirdPartyEntityKind.USER:
uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol))
required_field = "userid" required_field = "userid"
elif kind == ThirdPartyEntityKind.LOCATION: elif kind == ThirdPartyEntityKind.LOCATION:
uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol))
required_field = "alias" required_field = "alias"
else: else:
raise ValueError( raise ValueError(
"Unrecognised 'kind' argument %r to query_3pe()", kind "Unrecognised 'kind' argument %r to query_3pe()", kind
) )
if service.url is None:
defer.returnValue([])
uri = "%s%s/thirdparty/%s/%s" % (
service.url,
APP_SERVICE_PREFIX,
kind,
urllib.quote(protocol)
)
try: try:
response = yield self.get_json(uri, fields) response = yield self.get_json(uri, fields)
if not isinstance(response, list): if not isinstance(response, list):
@ -131,8 +150,34 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning("query_3pe to %s threw exception %s", uri, ex) logger.warning("query_3pe to %s threw exception %s", uri, ex)
defer.returnValue([]) defer.returnValue([])
def get_3pe_protocol(self, service, protocol):
if service.url is None:
defer.returnValue({})
@defer.inlineCallbacks
def _get():
uri = "%s%s/thirdparty/protocol/%s" % (
service.url,
APP_SERVICE_PREFIX,
urllib.quote(protocol)
)
try:
defer.returnValue((yield self.get_json(uri, {})))
except Exception as ex:
logger.warning("query_3pe_protocol to %s threw exception %s",
uri, ex)
defer.returnValue({})
key = (service.id, protocol)
return self.protocol_meta_cache.get(key) or (
self.protocol_meta_cache.set(key, _get())
)
@defer.inlineCallbacks @defer.inlineCallbacks
def push_bulk(self, service, events, txn_id=None): def push_bulk(self, service, events, txn_id=None):
if service.url is None:
defer.returnValue(True)
events = self._serialize(events) events = self._serialize(events)
if txn_id is None: if txn_id is None:

View file

@ -86,7 +86,7 @@ def load_appservices(hostname, config_files):
def _load_appservice(hostname, as_info, config_filename): def _load_appservice(hostname, as_info, config_filename):
required_string_fields = [ required_string_fields = [
"id", "url", "as_token", "hs_token", "sender_localpart" "id", "as_token", "hs_token", "sender_localpart"
] ]
for field in required_string_fields: for field in required_string_fields:
if not isinstance(as_info.get(field), basestring): if not isinstance(as_info.get(field), basestring):
@ -94,6 +94,14 @@ def _load_appservice(hostname, as_info, config_filename):
field, config_filename, field, config_filename,
)) ))
# 'url' must either be a string or explicitly null, not missing
# to avoid accidentally turning off push for ASes.
if (not isinstance(as_info.get("url"), basestring) and
as_info.get("url", "") is not None):
raise KeyError(
"Required string field or explicit null: 'url' (%s)" % (config_filename,)
)
localpart = as_info["sender_localpart"] localpart = as_info["sender_localpart"]
if urllib.quote(localpart) != localpart: if urllib.quote(localpart) != localpart:
raise ValueError( raise ValueError(
@ -132,6 +140,13 @@ def _load_appservice(hostname, as_info, config_filename):
for p in protocols: for p in protocols:
if not isinstance(p, str): if not isinstance(p, str):
raise KeyError("Bad value for 'protocols' item") raise KeyError("Bad value for 'protocols' item")
if as_info["url"] is None:
logger.info(
"(%s) Explicitly empty 'url' provided. This application service"
" will not receive events or queries.",
config_filename,
)
return ApplicationService( return ApplicationService(
token=as_info["as_token"], token=as_info["as_token"],
url=as_info["url"], url=as_info["url"],

View file

@ -99,7 +99,7 @@ class EventBase(object):
return d return d
def get(self, key, default): def get(self, key, default=None):
return self._event_dict.get(key, default) return self._event_dict.get(key, default)
def get_internal_metadata_dict(self): def get_internal_metadata_dict(self):

View file

@ -15,9 +15,9 @@
class EventContext(object): class EventContext(object):
def __init__(self):
def __init__(self, current_state=None): self.current_state_ids = None
self.current_state = current_state self.prev_state_ids = None
self.state_group = None self.state_group = None
self.rejected = False self.rejected = False
self.push_actions = [] self.push_actions = []

View file

@ -29,6 +29,7 @@ from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.types import get_domain_from_id
import synapse.metrics import synapse.metrics
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
@ -63,6 +64,7 @@ class FederationClient(FederationBase):
self._clock.looping_call( self._clock.looping_call(
self._clear_tried_cache, 60 * 1000, self._clear_tried_cache, 60 * 1000,
) )
self.state = hs.get_state_handler()
def _clear_tried_cache(self): def _clear_tried_cache(self):
"""Clear pdu_destination_tried cache""" """Clear pdu_destination_tried cache"""
@ -267,7 +269,7 @@ class FederationClient(FederationBase):
pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {}) pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
pdu = None signed_pdu = None
for destination in destinations: for destination in destinations:
now = self._clock.time_msec() now = self._clock.time_msec()
last_attempt = pdu_attempts.get(destination, 0) last_attempt = pdu_attempts.get(destination, 0)
@ -297,7 +299,7 @@ class FederationClient(FederationBase):
pdu = pdu_list[0] pdu = pdu_list[0]
# Check signatures are correct. # Check signatures are correct.
pdu = yield self._check_sigs_and_hashes([pdu])[0] signed_pdu = yield self._check_sigs_and_hashes([pdu])[0]
break break
@ -320,10 +322,10 @@ class FederationClient(FederationBase):
) )
continue continue
if self._get_pdu_cache is not None and pdu: if self._get_pdu_cache is not None and signed_pdu:
self._get_pdu_cache[event_id] = pdu self._get_pdu_cache[event_id] = signed_pdu
defer.returnValue(pdu) defer.returnValue(signed_pdu)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@ -811,7 +813,8 @@ class FederationClient(FederationBase):
if len(signed_events) >= limit: if len(signed_events) >= limit:
defer.returnValue(signed_events) defer.returnValue(signed_events)
servers = yield self.store.get_joined_hosts_for_room(room_id) users = yield self.state.get_current_user_in_room(room_id)
servers = set(get_domain_from_id(u) for u in users)
servers = set(servers) servers = set(servers)
servers.discard(self.server_name) servers.discard(self.server_name)

View file

@ -223,16 +223,14 @@ class FederationServer(FederationBase):
if not in_room: if not in_room:
raise AuthError(403, "Host not in room.") raise AuthError(403, "Host not in room.")
pdus = yield self.handler.get_state_for_pdu( state_ids = yield self.handler.get_state_ids_for_pdu(
room_id, event_id, room_id, event_id,
) )
auth_chain = yield self.store.get_auth_chain( auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids)
[pdu.event_id for pdu in pdus]
)
defer.returnValue((200, { defer.returnValue((200, {
"pdu_ids": [pdu.event_id for pdu in pdus], "pdu_ids": state_ids,
"auth_chain_ids": [pdu.event_id for pdu in auth_chain], "auth_chain_ids": auth_chain_ids,
})) }))
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -65,33 +65,21 @@ class BaseHandler(object):
retry_after_ms=int(1000 * (time_allowed - time_now)), retry_after_ms=int(1000 * (time_allowed - time_now)),
) )
def is_host_in_room(self, current_state):
room_members = [
(state_key, event.membership)
for ((event_type, state_key), event) in current_state.items()
if event_type == EventTypes.Member
]
if len(room_members) == 0:
# Have we just created the room, and is this about to be the very
# first member event?
create_event = current_state.get(("m.room.create", ""))
if create_event:
return True
for (state_key, membership) in room_members:
if (
self.hs.is_mine_id(state_key)
and membership == Membership.JOIN
):
return True
return False
@defer.inlineCallbacks @defer.inlineCallbacks
def maybe_kick_guest_users(self, event, current_state): def maybe_kick_guest_users(self, event, context=None):
# Technically this function invalidates current_state by changing it. # Technically this function invalidates current_state by changing it.
# Hopefully this isn't that important to the caller. # Hopefully this isn't that important to the caller.
if event.type == EventTypes.GuestAccess: if event.type == EventTypes.GuestAccess:
guest_access = event.content.get("guest_access", "forbidden") guest_access = event.content.get("guest_access", "forbidden")
if guest_access != "can_join": if guest_access != "can_join":
if context:
current_state = yield self.store.get_events(
context.current_state_ids.values()
)
current_state = current_state.values()
else:
current_state = yield self.store.get_current_state(event.room_id)
logger.info("maybe_kick_guest_users %r", current_state)
yield self.kick_guest_users(current_state) yield self.kick_guest_users(current_state)
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -175,6 +175,16 @@ class ApplicationServicesHandler(object):
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks
def get_3pe_protocols(self):
services = yield self.store.get_app_services()
protocols = {}
for s in services:
for p in s.protocols:
protocols[p] = yield self.appservice_api.get_3pe_protocol(s, p)
defer.returnValue(protocols)
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_services_for_event(self, event): def _get_services_for_event(self, event):
"""Retrieve a list of application services interested in this event. """Retrieve a list of application services interested in this event.

View file

@ -19,7 +19,7 @@ from ._base import BaseHandler
from synapse.api.errors import SynapseError, Codes, CodeMessageException, AuthError from synapse.api.errors import SynapseError, Codes, CodeMessageException, AuthError
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.types import RoomAlias, UserID from synapse.types import RoomAlias, UserID, get_domain_from_id
import logging import logging
import string import string
@ -55,7 +55,8 @@ class DirectoryHandler(BaseHandler):
# TODO(erikj): Add transactions. # TODO(erikj): Add transactions.
# TODO(erikj): Check if there is a current association. # TODO(erikj): Check if there is a current association.
if not servers: if not servers:
servers = yield self.store.get_joined_hosts_for_room(room_id) users = yield self.state.get_current_user_in_room(room_id)
servers = set(get_domain_from_id(u) for u in users)
if not servers: if not servers:
raise SynapseError(400, "Failed to get server list") raise SynapseError(400, "Failed to get server list")
@ -193,7 +194,8 @@ class DirectoryHandler(BaseHandler):
Codes.NOT_FOUND Codes.NOT_FOUND
) )
extra_servers = yield self.store.get_joined_hosts_for_room(room_id) users = yield self.state.get_current_user_in_room(room_id)
extra_servers = set(get_domain_from_id(u) for u in users)
servers = set(extra_servers) | set(servers) servers = set(extra_servers) | set(servers)
# If this server is in the list of servers, return it first. # If this server is in the list of servers, return it first.

View file

@ -47,6 +47,7 @@ class EventStreamHandler(BaseHandler):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.state = hs.get_state_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@ -90,7 +91,7 @@ class EventStreamHandler(BaseHandler):
# Send down presence. # Send down presence.
if event.state_key == auth_user_id: if event.state_key == auth_user_id:
# Send down presence for everyone in the room. # Send down presence for everyone in the room.
users = yield self.store.get_users_in_room(event.room_id) users = yield self.state.get_current_user_in_room(event.room_id)
states = yield presence_handler.get_states( states = yield presence_handler.get_states(
users, users,
as_event=True, as_event=True,

View file

@ -29,6 +29,7 @@ from synapse.util import unwrapFirstError
from synapse.util.logcontext import ( from synapse.util.logcontext import (
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred PreserveLoggingContext, preserve_fn, preserve_context_over_deferred
) )
from synapse.util.metrics import measure_func
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.frozenutils import unfreeze from synapse.util.frozenutils import unfreeze
@ -100,6 +101,9 @@ class FederationHandler(BaseHandler):
def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None): def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None):
""" Called by the ReplicationLayer when we have a new pdu. We need to """ Called by the ReplicationLayer when we have a new pdu. We need to
do auth checks and put it through the StateHandler. do auth checks and put it through the StateHandler.
auth_chain and state are None if we already have the necessary state
and prev_events in the db
""" """
event = pdu event = pdu
@ -117,12 +121,21 @@ class FederationHandler(BaseHandler):
# FIXME (erikj): Awful hack to make the case where we are not currently # FIXME (erikj): Awful hack to make the case where we are not currently
# in the room work # in the room work
# If state and auth_chain are None, then we don't need to do this check
# as we already know we have enough state in the DB to handle this
# event.
if state and auth_chain and not event.internal_metadata.is_outlier():
is_in_room = yield self.auth.check_host_in_room( is_in_room = yield self.auth.check_host_in_room(
event.room_id, event.room_id,
self.server_name self.server_name
) )
if not is_in_room and not event.internal_metadata.is_outlier(): else:
logger.debug("Got event for room we're not in.") is_in_room = True
if not is_in_room:
logger.info(
"Got event for room we're not in: %r %r",
event.room_id, event.event_id
)
try: try:
event_stream_id, max_stream_id = yield self._persist_auth_tree( event_stream_id, max_stream_id = yield self._persist_auth_tree(
@ -217,17 +230,28 @@ class FederationHandler(BaseHandler):
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
prev_state = context.current_state.get((event.type, event.state_key))
if not prev_state or prev_state.membership != Membership.JOIN:
# Only fire user_joined_room if the user has acutally # Only fire user_joined_room if the user has acutally
# joined the room. Don't bother if the user is just # joined the room. Don't bother if the user is just
# changing their profile info. # changing their profile info.
newly_joined = True
prev_state_id = context.prev_state_ids.get(
(event.type, event.state_key)
)
if prev_state_id:
prev_state = yield self.store.get_event(
prev_state_id, allow_none=True,
)
if prev_state and prev_state.membership == Membership.JOIN:
newly_joined = False
if newly_joined:
user = UserID.from_string(event.state_key) user = UserID.from_string(event.state_key)
yield user_joined_room(self.distributor, user, event.room_id) yield user_joined_room(self.distributor, user, event.room_id)
@measure_func("_filter_events_for_server")
@defer.inlineCallbacks @defer.inlineCallbacks
def _filter_events_for_server(self, server_name, room_id, events): def _filter_events_for_server(self, server_name, room_id, events):
event_to_state = yield self.store.get_state_for_events( event_to_state_ids = yield self.store.get_state_ids_for_events(
frozenset(e.event_id for e in events), frozenset(e.event_id for e in events),
types=( types=(
(EventTypes.RoomHistoryVisibility, ""), (EventTypes.RoomHistoryVisibility, ""),
@ -235,6 +259,30 @@ class FederationHandler(BaseHandler):
) )
) )
# We only want to pull out member events that correspond to the
# server's domain.
def check_match(id):
try:
return server_name == get_domain_from_id(id)
except:
return False
event_map = yield self.store.get_events([
e_id for key_to_eid in event_to_state_ids.values()
for key, e_id in key_to_eid
if key[0] != EventTypes.Member or check_match(key[1])
])
event_to_state = {
e_id: {
key: event_map[inner_e_id]
for key, inner_e_id in key_to_eid.items()
if inner_e_id in event_map
}
for e_id, key_to_eid in event_to_state_ids.items()
}
def redact_disallowed(event, state): def redact_disallowed(event, state):
if not state: if not state:
return event return event
@ -377,7 +425,9 @@ class FederationHandler(BaseHandler):
)).addErrback(unwrapFirstError) )).addErrback(unwrapFirstError)
auth_events.update({a.event_id: a for a in results if a}) auth_events.update({a.event_id: a for a in results if a})
required_auth.update( required_auth.update(
a_id for event in results for a_id, _ in event.auth_events if event a_id
for event in results if event
for a_id, _ in event.auth_events
) )
missing_auth = required_auth - set(auth_events) missing_auth = required_auth - set(auth_events)
@ -560,6 +610,18 @@ class FederationHandler(BaseHandler):
])) ]))
states = dict(zip(event_ids, [s[1] for s in states])) states = dict(zip(event_ids, [s[1] for s in states]))
state_map = yield self.store.get_events(
[e_id for ids in states.values() for e_id in ids],
get_prev_content=False
)
states = {
key: {
k: state_map[e_id]
for k, e_id in state_dict.items()
if e_id in state_map
} for key, state_dict in states.items()
}
for e_id, _ in sorted_extremeties_tuple: for e_id, _ in sorted_extremeties_tuple:
likely_domains = get_domains_from_state(states[e_id]) likely_domains = get_domains_from_state(states[e_id])
@ -722,7 +784,7 @@ class FederationHandler(BaseHandler):
# The remote hasn't signed it yet, obviously. We'll do the full checks # The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_join_request` # when we get the event back in `on_send_join_request`
self.auth.check(event, auth_events=context.current_state, do_sig_check=False) yield self.auth.check_from_context(event, context, do_sig_check=False)
defer.returnValue(event) defer.returnValue(event)
@ -770,18 +832,11 @@ class FederationHandler(BaseHandler):
new_pdu = event new_pdu = event
destinations = set() message_handler = self.hs.get_handlers().message_handler
destinations = yield message_handler.get_joined_hosts_for_room_from_state(
for k, s in context.current_state.items(): context
try:
if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.JOIN:
destinations.add(get_domain_from_id(s.state_key))
except:
logger.warn(
"Failed to get destination from event %s", s.event_id
) )
destinations = set(destinations)
destinations.discard(origin) destinations.discard(origin)
logger.debug( logger.debug(
@ -792,13 +847,15 @@ class FederationHandler(BaseHandler):
self.replication_layer.send_pdu(new_pdu, destinations) self.replication_layer.send_pdu(new_pdu, destinations)
state_ids = [e.event_id for e in context.current_state.values()] state_ids = context.prev_state_ids.values()
auth_chain = yield self.store.get_auth_chain(set( auth_chain = yield self.store.get_auth_chain(set(
[event.event_id] + state_ids [event.event_id] + state_ids
)) ))
state = yield self.store.get_events(context.prev_state_ids.values())
defer.returnValue({ defer.returnValue({
"state": context.current_state.values(), "state": state.values(),
"auth_chain": auth_chain, "auth_chain": auth_chain,
}) })
@ -954,7 +1011,7 @@ class FederationHandler(BaseHandler):
try: try:
# The remote hasn't signed it yet, obviously. We'll do the full checks # The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_leave_request` # when we get the event back in `on_send_leave_request`
self.auth.check(event, auth_events=context.current_state, do_sig_check=False) yield self.auth.check_from_context(event, context, do_sig_check=False)
except AuthError as e: except AuthError as e:
logger.warn("Failed to create new leave %r because %s", event, e) logger.warn("Failed to create new leave %r because %s", event, e)
raise e raise e
@ -998,18 +1055,11 @@ class FederationHandler(BaseHandler):
new_pdu = event new_pdu = event
destinations = set() message_handler = self.hs.get_handlers().message_handler
destinations = yield message_handler.get_joined_hosts_for_room_from_state(
for k, s in context.current_state.items(): context
try:
if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.LEAVE:
destinations.add(get_domain_from_id(s.state_key))
except:
logger.warn(
"Failed to get destination from event %s", s.event_id
) )
destinations = set(destinations)
destinations.discard(origin) destinations.discard(origin)
logger.debug( logger.debug(
@ -1024,6 +1074,8 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_for_pdu(self, room_id, event_id): def get_state_for_pdu(self, room_id, event_id):
"""Returns the state at the event. i.e. not including said event.
"""
yield run_on_reactor() yield run_on_reactor()
state_groups = yield self.store.get_state_groups( state_groups = yield self.store.get_state_groups(
@ -1064,6 +1116,34 @@ class FederationHandler(BaseHandler):
else: else:
defer.returnValue([]) defer.returnValue([])
@defer.inlineCallbacks
def get_state_ids_for_pdu(self, room_id, event_id):
"""Returns the state at the event. i.e. not including said event.
"""
yield run_on_reactor()
state_groups = yield self.store.get_state_groups_ids(
room_id, [event_id]
)
if state_groups:
_, state = state_groups.items().pop()
results = state
event = yield self.store.get_event(event_id)
if event and event.is_state():
# Get previous state
if "replaces_state" in event.unsigned:
prev_id = event.unsigned["replaces_state"]
if prev_id != event.event_id:
results[(event.type, event.state_key)] = prev_id
else:
del results[(event.type, event.state_key)]
defer.returnValue(results.values())
else:
defer.returnValue([])
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_backfill_request(self, origin, room_id, pdu_list, limit): def on_backfill_request(self, origin, room_id, pdu_list, limit):
@ -1294,7 +1374,13 @@ class FederationHandler(BaseHandler):
) )
if not auth_events: if not auth_events:
auth_events = context.current_state auth_events_ids = yield self.auth.compute_auth_events(
event, context.prev_state_ids, for_verification=True,
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {
(e.type, e.state_key): e for e in auth_events.values()
}
# This is a hack to fix some old rooms where the initial join event # This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events. # didn't reference the create event in its auth events.
@ -1320,8 +1406,7 @@ class FederationHandler(BaseHandler):
context.rejected = RejectedReason.AUTH_ERROR context.rejected = RejectedReason.AUTH_ERROR
if event.type == EventTypes.GuestAccess: if event.type == EventTypes.GuestAccess:
full_context = yield self.store.get_current_state(room_id=event.room_id) yield self.maybe_kick_guest_users(event)
yield self.maybe_kick_guest_users(event, full_context)
defer.returnValue(context) defer.returnValue(context)
@ -1389,6 +1474,11 @@ class FederationHandler(BaseHandler):
current_state = set(e.event_id for e in auth_events.values()) current_state = set(e.event_id for e in auth_events.values())
event_auth_events = set(e_id for e_id, _ in event.auth_events) event_auth_events = set(e_id for e_id, _ in event.auth_events)
if event.is_state():
event_key = (event.type, event.state_key)
else:
event_key = None
if event_auth_events - current_state: if event_auth_events - current_state:
have_events = yield self.store.have_events( have_events = yield self.store.have_events(
event_auth_events - current_state event_auth_events - current_state
@ -1492,8 +1582,14 @@ class FederationHandler(BaseHandler):
current_state = set(e.event_id for e in auth_events.values()) current_state = set(e.event_id for e in auth_events.values())
different_auth = event_auth_events - current_state different_auth = event_auth_events - current_state
context.current_state.update(auth_events) context.current_state_ids.update({
context.state_group = None k: a.event_id for k, a in auth_events.items()
if k != event_key
})
context.prev_state_ids.update({
k: a.event_id for k, a in auth_events.items()
})
context.state_group = self.store.get_next_state_group()
if different_auth and not event.internal_metadata.is_outlier(): if different_auth and not event.internal_metadata.is_outlier():
logger.info("Different auth after resolution: %s", different_auth) logger.info("Different auth after resolution: %s", different_auth)
@ -1514,8 +1610,8 @@ class FederationHandler(BaseHandler):
if do_resolution: if do_resolution:
# 1. Get what we think is the auth chain. # 1. Get what we think is the auth chain.
auth_ids = self.auth.compute_auth_events( auth_ids = yield self.auth.compute_auth_events(
event, context.current_state event, context.prev_state_ids
) )
local_auth_chain = yield self.store.get_auth_chain(auth_ids) local_auth_chain = yield self.store.get_auth_chain(auth_ids)
@ -1571,8 +1667,14 @@ class FederationHandler(BaseHandler):
# 4. Look at rejects and their proofs. # 4. Look at rejects and their proofs.
# TODO. # TODO.
context.current_state.update(auth_events) context.current_state_ids.update({
context.state_group = None k: a.event_id for k, a in auth_events.items()
if k != event_key
})
context.prev_state_ids.update({
k: a.event_id for k, a in auth_events.items()
})
context.state_group = self.store.get_next_state_group()
try: try:
self.auth.check(event, auth_events=auth_events) self.auth.check(event, auth_events=auth_events)
@ -1758,12 +1860,12 @@ class FederationHandler(BaseHandler):
) )
try: try:
self.auth.check(event, context.current_state) yield self.auth.check_from_context(event, context)
except AuthError as e: except AuthError as e:
logger.warn("Denying new third party invite %r because %s", event, e) logger.warn("Denying new third party invite %r because %s", event, e)
raise e raise e
yield self._check_signature(event, auth_events=context.current_state) yield self._check_signature(event, context)
member_handler = self.hs.get_handlers().room_member_handler member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.send_membership_event(None, event, context) yield member_handler.send_membership_event(None, event, context)
else: else:
@ -1789,11 +1891,11 @@ class FederationHandler(BaseHandler):
) )
try: try:
self.auth.check(event, auth_events=context.current_state) self.auth.check_from_context(event, context)
except AuthError as e: except AuthError as e:
logger.warn("Denying third party invite %r because %s", event, e) logger.warn("Denying third party invite %r because %s", event, e)
raise e raise e
yield self._check_signature(event, auth_events=context.current_state) yield self._check_signature(event, context)
returned_invite = yield self.send_invite(origin, event) returned_invite = yield self.send_invite(origin, event)
# TODO: Make sure the signatures actually are correct. # TODO: Make sure the signatures actually are correct.
@ -1807,7 +1909,12 @@ class FederationHandler(BaseHandler):
EventTypes.ThirdPartyInvite, EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"] event.content["third_party_invite"]["signed"]["token"]
) )
original_invite = context.current_state.get(key) original_invite = None
original_invite_id = context.prev_state_ids.get(key)
if original_invite_id:
original_invite = yield self.store.get_event(
original_invite_id, allow_none=True
)
if not original_invite: if not original_invite:
logger.info( logger.info(
"Could not find invite event for third_party_invite - " "Could not find invite event for third_party_invite - "
@ -1824,13 +1931,13 @@ class FederationHandler(BaseHandler):
defer.returnValue((event, context)) defer.returnValue((event, context))
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_signature(self, event, auth_events): def _check_signature(self, event, context):
""" """
Checks that the signature in the event is consistent with its invite. Checks that the signature in the event is consistent with its invite.
Args: Args:
event (Event): The m.room.member event to check event (Event): The m.room.member event to check
auth_events (dict<(event type, state_key), event>): context (EventContext):
Raises: Raises:
AuthError: if signature didn't match any keys, or key has been AuthError: if signature didn't match any keys, or key has been
@ -1841,10 +1948,14 @@ class FederationHandler(BaseHandler):
signed = event.content["third_party_invite"]["signed"] signed = event.content["third_party_invite"]["signed"]
token = signed["token"] token = signed["token"]
invite_event = auth_events.get( invite_event_id = context.prev_state_ids.get(
(EventTypes.ThirdPartyInvite, token,) (EventTypes.ThirdPartyInvite, token,)
) )
invite_event = None
if invite_event_id:
invite_event = yield self.store.get_event(invite_event_id, allow_none=True)
if not invite_event: if not invite_event:
raise AuthError(403, "Could not find invite") raise AuthError(403, "Could not find invite")

View file

@ -30,6 +30,7 @@ from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLo
from synapse.util.caches.snapshot_cache import SnapshotCache from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
from ._base import BaseHandler from ._base import BaseHandler
@ -248,7 +249,7 @@ class MessageHandler(BaseHandler):
assert self.hs.is_mine(user), "User must be our own: %s" % (user,) assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if event.is_state(): if event.is_state():
prev_state = self.deduplicate_state_event(event, context) prev_state = yield self.deduplicate_state_event(event, context)
if prev_state is not None: if prev_state is not None:
defer.returnValue(prev_state) defer.returnValue(prev_state)
@ -263,6 +264,7 @@ class MessageHandler(BaseHandler):
presence = self.hs.get_presence_handler() presence = self.hs.get_presence_handler()
yield presence.bump_presence_active_time(user) yield presence.bump_presence_active_time(user)
@defer.inlineCallbacks
def deduplicate_state_event(self, event, context): def deduplicate_state_event(self, event, context):
""" """
Checks whether event is in the latest resolved state in context. Checks whether event is in the latest resolved state in context.
@ -270,13 +272,17 @@ class MessageHandler(BaseHandler):
If so, returns the version of the event in context. If so, returns the version of the event in context.
Otherwise, returns None. Otherwise, returns None.
""" """
prev_event = context.current_state.get((event.type, event.state_key)) prev_event_id = context.prev_state_ids.get((event.type, event.state_key))
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
if not prev_event:
return
if prev_event and event.user_id == prev_event.user_id: if prev_event and event.user_id == prev_event.user_id:
prev_content = encode_canonical_json(prev_event.content) prev_content = encode_canonical_json(prev_event.content)
next_content = encode_canonical_json(event.content) next_content = encode_canonical_json(event.content)
if prev_content == next_content: if prev_content == next_content:
return prev_event defer.returnValue(prev_event)
return None return
@defer.inlineCallbacks @defer.inlineCallbacks
def create_and_send_nonmember_event( def create_and_send_nonmember_event(
@ -802,8 +808,8 @@ class MessageHandler(BaseHandler):
event = builder.build() event = builder.build()
logger.debug( logger.debug(
"Created event %s with current state: %s", "Created event %s with state: %s",
event.event_id, context.current_state, event.event_id, context.prev_state_ids,
) )
defer.returnValue( defer.returnValue(
@ -826,12 +832,12 @@ class MessageHandler(BaseHandler):
self.ratelimit(requester) self.ratelimit(requester)
try: try:
self.auth.check(event, auth_events=context.current_state) yield self.auth.check_from_context(event, context)
except AuthError as err: except AuthError as err:
logger.warn("Denying new event %r because %s", event, err) logger.warn("Denying new event %r because %s", event, err)
raise err raise err
yield self.maybe_kick_guest_users(event, context.current_state.values()) yield self.maybe_kick_guest_users(event, context)
if event.type == EventTypes.CanonicalAlias: if event.type == EventTypes.CanonicalAlias:
# Check the alias is acually valid (at this time at least) # Check the alias is acually valid (at this time at least)
@ -859,6 +865,15 @@ class MessageHandler(BaseHandler):
e.sender == event.sender e.sender == event.sender
) )
state_to_include_ids = [
e_id
for k, e_id in context.current_state_ids.items()
if k[0] in self.hs.config.room_invite_state_types
or k[0] == EventTypes.Member and k[1] == event.sender
]
state_to_include = yield self.store.get_events(state_to_include_ids)
event.unsigned["invite_room_state"] = [ event.unsigned["invite_room_state"] = [
{ {
"type": e.type, "type": e.type,
@ -866,9 +881,7 @@ class MessageHandler(BaseHandler):
"content": e.content, "content": e.content,
"sender": e.sender, "sender": e.sender,
} }
for k, e in context.current_state.items() for e in state_to_include.values()
if e.type in self.hs.config.room_invite_state_types
or is_inviter_member_event(e)
] ]
invitee = UserID.from_string(event.state_key) invitee = UserID.from_string(event.state_key)
@ -890,7 +903,14 @@ class MessageHandler(BaseHandler):
) )
if event.type == EventTypes.Redaction: if event.type == EventTypes.Redaction:
if self.auth.check_redaction(event, auth_events=context.current_state): auth_events_ids = yield self.auth.compute_auth_events(
event, context.prev_state_ids, for_verification=True,
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {
(e.type, e.state_key): e for e in auth_events.values()
}
if self.auth.check_redaction(event, auth_events=auth_events):
original_event = yield self.store.get_event( original_event = yield self.store.get_event(
event.redacts, event.redacts,
check_redacted=False, check_redacted=False,
@ -904,7 +924,7 @@ class MessageHandler(BaseHandler):
"You don't have permission to redact events" "You don't have permission to redact events"
) )
if event.type == EventTypes.Create and context.current_state: if event.type == EventTypes.Create and context.prev_state_ids:
raise AuthError( raise AuthError(
403, 403,
"Changing the room create event is forbidden", "Changing the room create event is forbidden",
@ -925,16 +945,7 @@ class MessageHandler(BaseHandler):
event_stream_id, max_stream_id event_stream_id, max_stream_id
) )
destinations = set() destinations = yield self.get_joined_hosts_for_room_from_state(context)
for k, s in context.current_state.items():
try:
if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.JOIN:
destinations.add(get_domain_from_id(s.state_key))
except SynapseError:
logger.warn(
"Failed to get destination from event %s", s.event_id
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _notify(): def _notify():
@ -952,3 +963,39 @@ class MessageHandler(BaseHandler):
preserve_fn(federation_handler.handle_new_event)( preserve_fn(federation_handler.handle_new_event)(
event, destinations=destinations, event, destinations=destinations,
) )
def get_joined_hosts_for_room_from_state(self, context):
state_group = context.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
# of None don't hit previous cached calls with a None state_group.
# To do this we set the state_group to a new object as object() != object()
state_group = object()
return self._get_joined_hosts_for_room_from_state(
state_group, context.current_state_ids
)
@cachedInlineCallbacks(num_args=1, cache_context=True)
def _get_joined_hosts_for_room_from_state(self, state_group, current_state_ids,
cache_context):
# Don't bother getting state for people on the same HS
current_state = yield self.store.get_events([
e_id for key, e_id in current_state_ids.items()
if key[0] == EventTypes.Member and not self.hs.is_mine_id(key[1])
])
destinations = set()
for e in current_state.itervalues():
try:
if e.type == EventTypes.Member:
if e.content["membership"] == Membership.JOIN:
destinations.add(get_domain_from_id(e.state_key))
except SynapseError:
logger.warn(
"Failed to get destination from event %s", e.event_id
)
defer.returnValue(destinations)

View file

@ -88,6 +88,8 @@ class PresenceHandler(object):
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.federation = hs.get_replication_layer() self.federation = hs.get_replication_layer()
self.state = hs.get_state_handler()
self.federation.register_edu_handler( self.federation.register_edu_handler(
"m.presence", self.incoming_presence "m.presence", self.incoming_presence
) )
@ -189,6 +191,13 @@ class PresenceHandler(object):
5000, 5000,
) )
self.clock.call_later(
60,
self.clock.looping_call,
self._persist_unpersisted_changes,
60 * 1000,
)
metrics.register_callback("wheel_timer_size", lambda: len(self.wheel_timer)) metrics.register_callback("wheel_timer_size", lambda: len(self.wheel_timer))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -214,6 +223,27 @@ class PresenceHandler(object):
]) ])
logger.info("Finished _on_shutdown") logger.info("Finished _on_shutdown")
@defer.inlineCallbacks
def _persist_unpersisted_changes(self):
"""We periodically persist the unpersisted changes, as otherwise they
may stack up and slow down shutdown times.
"""
logger.info(
"Performing _persist_unpersisted_changes. Persiting %d unpersisted changes",
len(self.unpersisted_users_changes)
)
unpersisted = self.unpersisted_users_changes
self.unpersisted_users_changes = set()
if unpersisted:
yield self.store.update_presence([
self.user_to_current_state[user_id]
for user_id in unpersisted
])
logger.info("Finished _persist_unpersisted_changes")
@defer.inlineCallbacks @defer.inlineCallbacks
def _update_states(self, new_states): def _update_states(self, new_states):
"""Updates presence of users. Sets the appropriate timeouts. Pokes """Updates presence of users. Sets the appropriate timeouts. Pokes
@ -532,7 +562,9 @@ class PresenceHandler(object):
if not local_states: if not local_states:
continue continue
hosts = yield self.store.get_joined_hosts_for_room(room_id) users = yield self.state.get_current_user_in_room(room_id)
hosts = set(get_domain_from_id(u) for u in users)
for host in hosts: for host in hosts:
hosts_to_states.setdefault(host, []).extend(local_states) hosts_to_states.setdefault(host, []).extend(local_states)
@ -725,13 +757,13 @@ class PresenceHandler(object):
# don't need to send to local clients here, as that is done as part # don't need to send to local clients here, as that is done as part
# of the event stream/sync. # of the event stream/sync.
# TODO: Only send to servers not already in the room. # TODO: Only send to servers not already in the room.
user_ids = yield self.state.get_current_user_in_room(room_id)
if self.is_mine(user): if self.is_mine(user):
state = yield self.current_state_for_user(user.to_string()) state = yield self.current_state_for_user(user.to_string())
hosts = yield self.store.get_joined_hosts_for_room(room_id) hosts = set(get_domain_from_id(u) for u in user_ids)
self._push_to_remotes({host: (state,) for host in hosts}) self._push_to_remotes({host: (state,) for host in hosts})
else: else:
user_ids = yield self.store.get_users_in_room(room_id)
user_ids = filter(self.is_mine_id, user_ids) user_ids = filter(self.is_mine_id, user_ids)
states = yield self.current_state_for_users(user_ids) states = yield self.current_state_for_users(user_ids)
@ -919,6 +951,11 @@ def should_notify(old_state, new_state):
return True return True
if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY:
# Only notify about last active bumps if we're not currently acive
if not (old_state.currently_active and new_state.currently_active):
return True
elif new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY:
# Always notify for a transition where last active gets bumped. # Always notify for a transition where last active gets bumped.
return True return True
@ -955,6 +992,7 @@ class PresenceEventSource(object):
self.get_presence_handler = hs.get_presence_handler self.get_presence_handler = hs.get_presence_handler
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.state = hs.get_state_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@ -1017,7 +1055,7 @@ class PresenceEventSource(object):
user_ids_to_check = set() user_ids_to_check = set()
for room_id in room_ids: for room_id in room_ids:
users = yield self.store.get_users_in_room(room_id) users = yield self.state.get_current_user_in_room(room_id)
user_ids_to_check.update(users) user_ids_to_check.update(users)
user_ids_to_check.update(friends) user_ids_to_check.update(friends)

View file

@ -18,6 +18,7 @@ from ._base import BaseHandler
from twisted.internet import defer from twisted.internet import defer
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import get_domain_from_id
import logging import logging
@ -37,6 +38,7 @@ class ReceiptsHandler(BaseHandler):
"m.receipt", self._received_remote_receipt "m.receipt", self._received_remote_receipt
) )
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()
self.state = hs.get_state_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def received_client_receipt(self, room_id, receipt_type, user_id, def received_client_receipt(self, room_id, receipt_type, user_id,
@ -133,7 +135,8 @@ class ReceiptsHandler(BaseHandler):
event_ids = receipt["event_ids"] event_ids = receipt["event_ids"]
data = receipt["data"] data = receipt["data"]
remotedomains = yield self.store.get_joined_hosts_for_room(room_id) users = yield self.state.get_current_user_in_room(room_id)
remotedomains = set(get_domain_from_id(u) for u in users)
remotedomains = remotedomains.copy() remotedomains = remotedomains.copy()
remotedomains.discard(self.server_name) remotedomains.discard(self.server_name)

View file

@ -85,6 +85,12 @@ class RoomMemberHandler(BaseHandler):
prev_event_ids=prev_event_ids, prev_event_ids=prev_event_ids,
) )
# Check if this event matches the previous membership event for the user.
duplicate = yield msg_handler.deduplicate_state_event(event, context)
if duplicate is not None:
# Discard the new event since this membership change is a no-op.
return
yield msg_handler.handle_new_client_event( yield msg_handler.handle_new_client_event(
requester, requester,
event, event,
@ -93,19 +99,25 @@ class RoomMemberHandler(BaseHandler):
ratelimit=ratelimit, ratelimit=ratelimit,
) )
prev_member_event = context.current_state.get( prev_member_event_id = context.prev_state_ids.get(
(EventTypes.Member, target.to_string()), (EventTypes.Member, target.to_string()),
None None
) )
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
if not prev_member_event or prev_member_event.membership != Membership.JOIN:
# Only fire user_joined_room if the user has acutally joined the # Only fire user_joined_room if the user has acutally joined the
# room. Don't bother if the user is just changing their profile # room. Don't bother if the user is just changing their profile
# info. # info.
newly_joined = True
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
yield user_joined_room(self.distributor, target, room_id) yield user_joined_room(self.distributor, target, room_id)
elif event.membership == Membership.LEAVE: elif event.membership == Membership.LEAVE:
if prev_member_event and prev_member_event.membership == Membership.JOIN: if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
if prev_member_event.membership == Membership.JOIN:
user_left_room(self.distributor, target, room_id) user_left_room(self.distributor, target, room_id)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -195,16 +207,19 @@ class RoomMemberHandler(BaseHandler):
remote_room_hosts = [] remote_room_hosts = []
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
current_state = yield self.state_handler.get_current_state( current_state_ids = yield self.state_handler.get_current_state_ids(
room_id, latest_event_ids=latest_event_ids, room_id, latest_event_ids=latest_event_ids,
) )
old_state = current_state.get((EventTypes.Member, target.to_string())) old_state_id = current_state_ids.get((EventTypes.Member, target.to_string()))
if old_state_id:
old_state = yield self.store.get_event(old_state_id, allow_none=True)
old_membership = old_state.content.get("membership") if old_state else None old_membership = old_state.content.get("membership") if old_state else None
if action == "unban" and old_membership != "ban": if action == "unban" and old_membership != "ban":
raise SynapseError( raise SynapseError(
403, 403,
"Cannot unban user who was not banned (membership=%s)" % old_membership, "Cannot unban user who was not banned"
" (membership=%s)" % old_membership,
errcode=Codes.BAD_STATE errcode=Codes.BAD_STATE
) )
if old_membership == "ban" and action != "unban": if old_membership == "ban" and action != "unban":
@ -214,10 +229,10 @@ class RoomMemberHandler(BaseHandler):
errcode=Codes.BAD_STATE errcode=Codes.BAD_STATE
) )
is_host_in_room = self.is_host_in_room(current_state) is_host_in_room = yield self._is_host_in_room(current_state_ids)
if effective_membership_state == Membership.JOIN: if effective_membership_state == Membership.JOIN:
if requester.is_guest and not self._can_guest_join(current_state): if requester.is_guest and not self._can_guest_join(current_state_ids):
# This should be an auth check, but guests are a local concept, # This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process. # so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed") raise AuthError(403, "Guest access not allowed")
@ -326,12 +341,14 @@ class RoomMemberHandler(BaseHandler):
requester = synapse.types.create_requester(target_user) requester = synapse.types.create_requester(target_user)
message_handler = self.hs.get_handlers().message_handler message_handler = self.hs.get_handlers().message_handler
prev_event = message_handler.deduplicate_state_event(event, context) prev_event = yield message_handler.deduplicate_state_event(event, context)
if prev_event is not None: if prev_event is not None:
return return
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
if requester.is_guest and not self._can_guest_join(context.current_state): if requester.is_guest:
guest_can_join = yield self._can_guest_join(context.prev_state_ids)
if not guest_can_join:
# This should be an auth check, but guests are a local concept, # This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process. # so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed") raise AuthError(403, "Guest access not allowed")
@ -344,27 +361,39 @@ class RoomMemberHandler(BaseHandler):
ratelimit=ratelimit, ratelimit=ratelimit,
) )
prev_member_event = context.current_state.get( prev_member_event_id = context.prev_state_ids.get(
(EventTypes.Member, target_user.to_string()), (EventTypes.Member, event.state_key),
None None
) )
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
if not prev_member_event or prev_member_event.membership != Membership.JOIN:
# Only fire user_joined_room if the user has acutally joined the # Only fire user_joined_room if the user has acutally joined the
# room. Don't bother if the user is just changing their profile # room. Don't bother if the user is just changing their profile
# info. # info.
newly_joined = True
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
yield user_joined_room(self.distributor, target_user, room_id) yield user_joined_room(self.distributor, target_user, room_id)
elif event.membership == Membership.LEAVE: elif event.membership == Membership.LEAVE:
if prev_member_event and prev_member_event.membership == Membership.JOIN: if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
if prev_member_event.membership == Membership.JOIN:
user_left_room(self.distributor, target_user, room_id) user_left_room(self.distributor, target_user, room_id)
def _can_guest_join(self, current_state): @defer.inlineCallbacks
def _can_guest_join(self, current_state_ids):
""" """
Returns whether a guest can join a room based on its current state. Returns whether a guest can join a room based on its current state.
""" """
guest_access = current_state.get((EventTypes.GuestAccess, ""), None) guest_access_id = current_state_ids.get((EventTypes.GuestAccess, ""), None)
return ( if not guest_access_id:
defer.returnValue(False)
guest_access = yield self.store.get_event(guest_access_id)
defer.returnValue(
guest_access guest_access
and guest_access.content and guest_access.content
and "guest_access" in guest_access.content and "guest_access" in guest_access.content
@ -683,3 +712,24 @@ class RoomMemberHandler(BaseHandler):
if membership: if membership:
yield self.store.forget(user_id, room_id) yield self.store.forget(user_id, room_id)
@defer.inlineCallbacks
def _is_host_in_room(self, current_state_ids):
# Have we just created the room, and is this about to be the very
# first member event?
create_event_id = current_state_ids.get(("m.room.create", ""))
if len(current_state_ids) == 1 and create_event_id:
defer.returnValue(self.hs.is_mine_id(create_event_id))
for (etype, state_key), event_id in current_state_ids.items():
if etype != EventTypes.Member or not self.hs.is_mine_id(state_key):
continue
event = yield self.store.get_event(event_id, allow_none=True)
if not event:
continue
if event.membership == Membership.JOIN:
defer.returnValue(True)
defer.returnValue(False)

View file

@ -35,6 +35,7 @@ SyncConfig = collections.namedtuple("SyncConfig", [
"filter_collection", "filter_collection",
"is_guest", "is_guest",
"request_key", "request_key",
"device_id",
]) ])
@ -113,6 +114,7 @@ class SyncResult(collections.namedtuple("SyncResult", [
"joined", # JoinedSyncResult for each joined room. "joined", # JoinedSyncResult for each joined room.
"invited", # InvitedSyncResult for each invited room. "invited", # InvitedSyncResult for each invited room.
"archived", # ArchivedSyncResult for each archived room. "archived", # ArchivedSyncResult for each archived room.
"to_device", # List of direct messages for the device.
])): ])):
__slots__ = [] __slots__ = []
@ -126,7 +128,8 @@ class SyncResult(collections.namedtuple("SyncResult", [
self.joined or self.joined or
self.invited or self.invited or
self.archived or self.archived or
self.account_data self.account_data or
self.to_device
) )
@ -139,6 +142,7 @@ class SyncHandler(object):
self.event_sources = hs.get_event_sources() self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.response_cache = ResponseCache(hs) self.response_cache = ResponseCache(hs)
self.state = hs.get_state_handler()
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
full_state=False): full_state=False):
@ -355,11 +359,11 @@ class SyncHandler(object):
Returns: Returns:
A Deferred map from ((type, state_key)->Event) A Deferred map from ((type, state_key)->Event)
""" """
state = yield self.store.get_state_for_event(event.event_id) state_ids = yield self.store.get_state_ids_for_event(event.event_id)
if event.is_state(): if event.is_state():
state = state.copy() state_ids = state_ids.copy()
state[(event.type, event.state_key)] = event state_ids[(event.type, event.state_key)] = event.event_id
defer.returnValue(state) defer.returnValue(state_ids)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_at(self, room_id, stream_position): def get_state_at(self, room_id, stream_position):
@ -412,57 +416,61 @@ class SyncHandler(object):
with Measure(self.clock, "compute_state_delta"): with Measure(self.clock, "compute_state_delta"):
if full_state: if full_state:
if batch: if batch:
current_state = yield self.store.get_state_for_event( current_state_ids = yield self.store.get_state_ids_for_event(
batch.events[-1].event_id batch.events[-1].event_id
) )
state = yield self.store.get_state_for_event( state_ids = yield self.store.get_state_ids_for_event(
batch.events[0].event_id batch.events[0].event_id
) )
else: else:
current_state = yield self.get_state_at( current_state_ids = yield self.get_state_at(
room_id, stream_position=now_token room_id, stream_position=now_token
) )
state = current_state state_ids = current_state_ids
timeline_state = { timeline_state = {
(event.type, event.state_key): event (event.type, event.state_key): event.event_id
for event in batch.events if event.is_state() for event in batch.events if event.is_state()
} }
state = _calculate_state( state_ids = _calculate_state(
timeline_contains=timeline_state, timeline_contains=timeline_state,
timeline_start=state, timeline_start=state_ids,
previous={}, previous={},
current=current_state, current=current_state_ids,
) )
elif batch.limited: elif batch.limited:
state_at_previous_sync = yield self.get_state_at( state_at_previous_sync = yield self.get_state_at(
room_id, stream_position=since_token room_id, stream_position=since_token
) )
current_state = yield self.store.get_state_for_event( current_state_ids = yield self.store.get_state_ids_for_event(
batch.events[-1].event_id batch.events[-1].event_id
) )
state_at_timeline_start = yield self.store.get_state_for_event( state_at_timeline_start = yield self.store.get_state_ids_for_event(
batch.events[0].event_id batch.events[0].event_id
) )
timeline_state = { timeline_state = {
(event.type, event.state_key): event (event.type, event.state_key): event.event_id
for event in batch.events if event.is_state() for event in batch.events if event.is_state()
} }
state = _calculate_state( state_ids = _calculate_state(
timeline_contains=timeline_state, timeline_contains=timeline_state,
timeline_start=state_at_timeline_start, timeline_start=state_at_timeline_start,
previous=state_at_previous_sync, previous=state_at_previous_sync,
current=current_state, current=current_state_ids,
) )
else: else:
state_ids = {}
state = {} state = {}
if state_ids:
state = yield self.store.get_events(state_ids.values())
defer.returnValue({ defer.returnValue({
(e.type, e.state_key): e (e.type, e.state_key): e
@ -527,15 +535,57 @@ class SyncHandler(object):
sync_result_builder, newly_joined_rooms, newly_joined_users sync_result_builder, newly_joined_rooms, newly_joined_users
) )
yield self._generate_sync_entry_for_to_device(sync_result_builder)
defer.returnValue(SyncResult( defer.returnValue(SyncResult(
presence=sync_result_builder.presence, presence=sync_result_builder.presence,
account_data=sync_result_builder.account_data, account_data=sync_result_builder.account_data,
joined=sync_result_builder.joined, joined=sync_result_builder.joined,
invited=sync_result_builder.invited, invited=sync_result_builder.invited,
archived=sync_result_builder.archived, archived=sync_result_builder.archived,
to_device=sync_result_builder.to_device,
next_batch=sync_result_builder.now_token, next_batch=sync_result_builder.now_token,
)) ))
@defer.inlineCallbacks
def _generate_sync_entry_for_to_device(self, sync_result_builder):
"""Generates the portion of the sync response. Populates
`sync_result_builder` with the result.
Args:
sync_result_builder(SyncResultBuilder)
Returns:
Deferred(dict): A dictionary containing the per room account data.
"""
user_id = sync_result_builder.sync_config.user.to_string()
device_id = sync_result_builder.sync_config.device_id
now_token = sync_result_builder.now_token
since_stream_id = 0
if sync_result_builder.since_token is not None:
since_stream_id = int(sync_result_builder.since_token.to_device_key)
if since_stream_id != int(now_token.to_device_key):
# We only delete messages when a new message comes in, but that's
# fine so long as we delete them at some point.
logger.debug("Deleting messages up to %d", since_stream_id)
yield self.store.delete_messages_for_device(
user_id, device_id, since_stream_id
)
logger.debug("Getting messages up to %d", now_token.to_device_key)
messages, stream_id = yield self.store.get_new_messages_for_device(
user_id, device_id, since_stream_id, now_token.to_device_key
)
logger.debug("Got messages up to %d: %r", stream_id, messages)
sync_result_builder.now_token = now_token.copy_and_replace(
"to_device_key", stream_id
)
sync_result_builder.to_device = messages
else:
sync_result_builder.to_device = []
@defer.inlineCallbacks @defer.inlineCallbacks
def _generate_sync_entry_for_account_data(self, sync_result_builder): def _generate_sync_entry_for_account_data(self, sync_result_builder):
"""Generates the account data portion of the sync response. Populates """Generates the account data portion of the sync response. Populates
@ -626,7 +676,7 @@ class SyncHandler(object):
extra_users_ids = set(newly_joined_users) extra_users_ids = set(newly_joined_users)
for room_id in newly_joined_rooms: for room_id in newly_joined_rooms:
users = yield self.store.get_users_in_room(room_id) users = yield self.state.get_current_user_in_room(room_id)
extra_users_ids.update(users) extra_users_ids.update(users)
extra_users_ids.discard(user.to_string()) extra_users_ids.discard(user.to_string())
@ -766,8 +816,13 @@ class SyncHandler(object):
# the last sync (even if we have since left). This is to make sure # the last sync (even if we have since left). This is to make sure
# we do send down the room, and with full state, where necessary # we do send down the room, and with full state, where necessary
if room_id in joined_room_ids or has_join: if room_id in joined_room_ids or has_join:
old_state = yield self.get_state_at(room_id, since_token) old_state_ids = yield self.get_state_at(room_id, since_token)
old_mem_ev = old_state.get((EventTypes.Member, user_id), None) old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
old_mem_ev = None
if old_mem_ev_id:
old_mem_ev = yield self.store.get_event(
old_mem_ev_id, allow_none=True
)
if not old_mem_ev or old_mem_ev.membership != Membership.JOIN: if not old_mem_ev or old_mem_ev.membership != Membership.JOIN:
newly_joined_rooms.append(room_id) newly_joined_rooms.append(room_id)
@ -1059,27 +1114,25 @@ def _calculate_state(timeline_contains, timeline_start, previous, current):
Returns: Returns:
dict dict
""" """
event_id_to_state = { event_id_to_key = {
e.event_id: e e: key
for e in itertools.chain( for key, e in itertools.chain(
timeline_contains.values(), timeline_contains.items(),
previous.values(), previous.items(),
timeline_start.values(), timeline_start.items(),
current.values(), current.items(),
) )
} }
c_ids = set(e.event_id for e in current.values()) c_ids = set(e for e in current.values())
tc_ids = set(e.event_id for e in timeline_contains.values()) tc_ids = set(e for e in timeline_contains.values())
p_ids = set(e.event_id for e in previous.values()) p_ids = set(e for e in previous.values())
ts_ids = set(e.event_id for e in timeline_start.values()) ts_ids = set(e for e in timeline_start.values())
state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids
evs = (event_id_to_state[e] for e in state_ids)
return { return {
(e.type, e.state_key): e event_id_to_key[e]: e for e in state_ids
for e in evs
} }
@ -1103,6 +1156,7 @@ class SyncResultBuilder(object):
self.joined = [] self.joined = []
self.invited = [] self.invited = []
self.archived = [] self.archived = []
self.device = []
class RoomSyncResultBuilder(object): class RoomSyncResultBuilder(object):

View file

@ -20,7 +20,7 @@ from synapse.util.logcontext import (
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred, PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
) )
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.types import UserID from synapse.types import UserID, get_domain_from_id
import logging import logging
@ -42,6 +42,7 @@ class TypingHandler(object):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.state = hs.get_state_handler()
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -166,7 +167,8 @@ class TypingHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _push_update(self, room_id, user_id, typing): def _push_update(self, room_id, user_id, typing):
domains = yield self.store.get_joined_hosts_for_room(room_id) users = yield self.state.get_current_user_in_room(room_id)
domains = set(get_domain_from_id(u) for u in users)
deferreds = [] deferreds = []
for domain in domains: for domain in domains:
@ -199,7 +201,8 @@ class TypingHandler(object):
# Check that the string is a valid user id # Check that the string is a valid user id
UserID.from_string(user_id) UserID.from_string(user_id)
domains = yield self.store.get_joined_hosts_for_room(room_id) users = yield self.state.get_current_user_in_room(room_id)
domains = set(get_domain_from_id(u) for u in users)
if self.server_name in domains: if self.server_name in domains:
self._push_update_local( self._push_update_local(

View file

@ -423,7 +423,8 @@ class Notifier(object):
def _is_world_readable(self, room_id): def _is_world_readable(self, room_id):
state = yield self.state_handler.get_current_state( state = yield self.state_handler.get_current_state(
room_id, room_id,
EventTypes.RoomHistoryVisibility EventTypes.RoomHistoryVisibility,
"",
) )
if state and "history_visibility" in state.content: if state and "history_visibility" in state.content:
defer.returnValue(state.content["history_visibility"] == "world_readable") defer.returnValue(state.content["history_visibility"] == "world_readable")

View file

@ -40,12 +40,12 @@ class ActionGenerator:
def handle_push_actions_for_event(self, event, context): def handle_push_actions_for_event(self, event, context):
with Measure(self.clock, "evaluator_for_event"): with Measure(self.clock, "evaluator_for_event"):
bulk_evaluator = yield evaluator_for_event( bulk_evaluator = yield evaluator_for_event(
event, self.hs, self.store, context.state_group, context.current_state event, self.hs, self.store, context
) )
with Measure(self.clock, "action_for_event_by_user"): with Measure(self.clock, "action_for_event_by_user"):
actions_by_user = yield bulk_evaluator.action_for_event_by_user( actions_by_user = yield bulk_evaluator.action_for_event_by_user(
event, context.current_state event, context
) )
context.push_actions = [ context.push_actions = [

View file

@ -19,8 +19,8 @@ from twisted.internet import defer
from .push_rule_evaluator import PushRuleEvaluatorForEvent from .push_rule_evaluator import PushRuleEvaluatorForEvent
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes
from synapse.visibility import filter_events_for_clients from synapse.visibility import filter_events_for_clients_context
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -36,9 +36,9 @@ def _get_rules(room_id, user_ids, store):
@defer.inlineCallbacks @defer.inlineCallbacks
def evaluator_for_event(event, hs, store, state_group, current_state): def evaluator_for_event(event, hs, store, context):
rules_by_user = yield store.bulk_get_push_rules_for_room( rules_by_user = yield store.bulk_get_push_rules_for_room(
event.room_id, state_group, current_state event, context
) )
# if this event is an invite event, we may need to run rules for the user # if this event is an invite event, we may need to run rules for the user
@ -72,7 +72,7 @@ class BulkPushRuleEvaluator:
self.store = store self.store = store
@defer.inlineCallbacks @defer.inlineCallbacks
def action_for_event_by_user(self, event, current_state): def action_for_event_by_user(self, event, context):
actions_by_user = {} actions_by_user = {}
# None of these users can be peeking since this list of users comes # None of these users can be peeking since this list of users comes
@ -82,27 +82,25 @@ class BulkPushRuleEvaluator:
(u, False) for u in self.rules_by_user.keys() (u, False) for u in self.rules_by_user.keys()
] ]
filtered_by_user = yield filter_events_for_clients( filtered_by_user = yield filter_events_for_clients_context(
self.store, user_tuples, [event], {event.event_id: current_state} self.store, user_tuples, [event], {event.event_id: context}
) )
room_members = set( room_members = yield self.store.get_joined_users_from_context(
e.state_key for e in current_state.values() event, context
if e.type == EventTypes.Member and e.membership == Membership.JOIN
) )
evaluator = PushRuleEvaluatorForEvent(event, len(room_members)) evaluator = PushRuleEvaluatorForEvent(event, len(room_members))
condition_cache = {} condition_cache = {}
display_names = {}
for ev in current_state.values():
nm = ev.content.get("displayname", None)
if nm and ev.type == EventTypes.Member:
display_names[ev.state_key] = nm
for uid, rules in self.rules_by_user.items(): for uid, rules in self.rules_by_user.items():
display_name = display_names.get(uid, None) display_name = None
member_ev_id = context.current_state_ids.get((EventTypes.Member, uid))
if member_ev_id:
member_ev = yield self.store.get_event(member_ev_id, allow_none=True)
if member_ev:
display_name = member_ev.content.get("displayname", None)
filtered = filtered_by_user[uid] filtered = filtered_by_user[uid]
if len(filtered) == 0: if len(filtered) == 0:

View file

@ -245,7 +245,7 @@ class HttpPusher(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _build_notification_dict(self, event, tweaks, badge): def _build_notification_dict(self, event, tweaks, badge):
ctx = yield push_tools.get_context_for_event( ctx = yield push_tools.get_context_for_event(
self.state_handler, event, self.user_id self.store, self.state_handler, event, self.user_id
) )
d = { d = {

View file

@ -22,7 +22,7 @@ from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart from email.mime.multipart import MIMEMultipart
from synapse.util.async import concurrently_execute from synapse.util.async import concurrently_execute
from synapse.util.presentable_names import ( from synapse.push.presentable_names import (
calculate_room_name, name_from_member_event, descriptor_from_member_events calculate_room_name, name_from_member_event, descriptor_from_member_events
) )
from synapse.types import UserID from synapse.types import UserID
@ -139,7 +139,7 @@ class Mailer(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _fetch_room_state(room_id): def _fetch_room_state(room_id):
room_state = yield self.state_handler.get_current_state(room_id) room_state = yield self.state_handler.get_current_state_ids(room_id)
state_by_room[room_id] = room_state state_by_room[room_id] = room_state
# Run at most 3 of these at once: sync does 10 at a time but email # Run at most 3 of these at once: sync does 10 at a time but email
@ -159,11 +159,12 @@ class Mailer(object):
) )
rooms.append(roomvars) rooms.append(roomvars)
reason['room_name'] = calculate_room_name( reason['room_name'] = yield calculate_room_name(
state_by_room[reason['room_id']], user_id, fallback_to_members=True self.store, state_by_room[reason['room_id']], user_id,
fallback_to_members=True
) )
summary_text = self.make_summary_text( summary_text = yield self.make_summary_text(
notifs_by_room, state_by_room, notif_events, user_id, reason notifs_by_room, state_by_room, notif_events, user_id, reason
) )
@ -203,12 +204,15 @@ class Mailer(object):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_vars(self, room_id, user_id, notifs, notif_events, room_state): def get_room_vars(self, room_id, user_id, notifs, notif_events, room_state_ids):
my_member_event = room_state[("m.room.member", user_id)] my_member_event_id = room_state_ids[("m.room.member", user_id)]
my_member_event = yield self.store.get_event(my_member_event_id)
is_invite = my_member_event.content["membership"] == "invite" is_invite = my_member_event.content["membership"] == "invite"
room_name = yield calculate_room_name(self.store, room_state_ids, user_id)
room_vars = { room_vars = {
"title": calculate_room_name(room_state, user_id), "title": room_name,
"hash": string_ordinal_total(room_id), # See sender avatar hash "hash": string_ordinal_total(room_id), # See sender avatar hash
"notifs": [], "notifs": [],
"invite": is_invite, "invite": is_invite,
@ -218,7 +222,7 @@ class Mailer(object):
if not is_invite: if not is_invite:
for n in notifs: for n in notifs:
notifvars = yield self.get_notif_vars( notifvars = yield self.get_notif_vars(
n, user_id, notif_events[n['event_id']], room_state n, user_id, notif_events[n['event_id']], room_state_ids
) )
# merge overlapping notifs together. # merge overlapping notifs together.
@ -243,7 +247,7 @@ class Mailer(object):
defer.returnValue(room_vars) defer.returnValue(room_vars)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_notif_vars(self, notif, user_id, notif_event, room_state): def get_notif_vars(self, notif, user_id, notif_event, room_state_ids):
results = yield self.store.get_events_around( results = yield self.store.get_events_around(
notif['room_id'], notif['event_id'], notif['room_id'], notif['event_id'],
before_limit=CONTEXT_BEFORE, after_limit=CONTEXT_AFTER before_limit=CONTEXT_BEFORE, after_limit=CONTEXT_AFTER
@ -261,17 +265,19 @@ class Mailer(object):
the_events.append(notif_event) the_events.append(notif_event)
for event in the_events: for event in the_events:
messagevars = self.get_message_vars(notif, event, room_state) messagevars = yield self.get_message_vars(notif, event, room_state_ids)
if messagevars is not None: if messagevars is not None:
ret['messages'].append(messagevars) ret['messages'].append(messagevars)
defer.returnValue(ret) defer.returnValue(ret)
def get_message_vars(self, notif, event, room_state): @defer.inlineCallbacks
def get_message_vars(self, notif, event, room_state_ids):
if event.type != EventTypes.Message: if event.type != EventTypes.Message:
return None return
sender_state_event = room_state[("m.room.member", event.sender)] sender_state_event_id = room_state_ids[("m.room.member", event.sender)]
sender_state_event = yield self.store.get_event(sender_state_event_id)
sender_name = name_from_member_event(sender_state_event) sender_name = name_from_member_event(sender_state_event)
sender_avatar_url = sender_state_event.content.get("avatar_url") sender_avatar_url = sender_state_event.content.get("avatar_url")
@ -299,7 +305,7 @@ class Mailer(object):
if "body" in event.content: if "body" in event.content:
ret["body_text_plain"] = event.content["body"] ret["body_text_plain"] = event.content["body"]
return ret defer.returnValue(ret)
def add_text_message_vars(self, messagevars, event): def add_text_message_vars(self, messagevars, event):
msgformat = event.content.get("format") msgformat = event.content.get("format")
@ -321,6 +327,7 @@ class Mailer(object):
return messagevars return messagevars
@defer.inlineCallbacks
def make_summary_text(self, notifs_by_room, state_by_room, def make_summary_text(self, notifs_by_room, state_by_room,
notif_events, user_id, reason): notif_events, user_id, reason):
if len(notifs_by_room) == 1: if len(notifs_by_room) == 1:
@ -330,8 +337,8 @@ class Mailer(object):
# If the room has some kind of name, use it, but we don't # If the room has some kind of name, use it, but we don't
# want the generated-from-names one here otherwise we'll # want the generated-from-names one here otherwise we'll
# end up with, "new message from Bob in the Bob room" # end up with, "new message from Bob in the Bob room"
room_name = calculate_room_name( room_name = yield calculate_room_name(
state_by_room[room_id], user_id, fallback_to_members=False self.store, state_by_room[room_id], user_id, fallback_to_members=False
) )
my_member_event = state_by_room[room_id][("m.room.member", user_id)] my_member_event = state_by_room[room_id][("m.room.member", user_id)]
@ -342,16 +349,16 @@ class Mailer(object):
inviter_name = name_from_member_event(inviter_member_event) inviter_name = name_from_member_event(inviter_member_event)
if room_name is None: if room_name is None:
return INVITE_FROM_PERSON % { defer.returnValue(INVITE_FROM_PERSON % {
"person": inviter_name, "person": inviter_name,
"app": self.app_name "app": self.app_name
} })
else: else:
return INVITE_FROM_PERSON_TO_ROOM % { defer.returnValue(INVITE_FROM_PERSON_TO_ROOM % {
"person": inviter_name, "person": inviter_name,
"room": room_name, "room": room_name,
"app": self.app_name, "app": self.app_name,
} })
sender_name = None sender_name = None
if len(notifs_by_room[room_id]) == 1: if len(notifs_by_room[room_id]) == 1:
@ -362,24 +369,24 @@ class Mailer(object):
sender_name = name_from_member_event(state_event) sender_name = name_from_member_event(state_event)
if sender_name is not None and room_name is not None: if sender_name is not None and room_name is not None:
return MESSAGE_FROM_PERSON_IN_ROOM % { defer.returnValue(MESSAGE_FROM_PERSON_IN_ROOM % {
"person": sender_name, "person": sender_name,
"room": room_name, "room": room_name,
"app": self.app_name, "app": self.app_name,
} })
elif sender_name is not None: elif sender_name is not None:
return MESSAGE_FROM_PERSON % { defer.returnValue(MESSAGE_FROM_PERSON % {
"person": sender_name, "person": sender_name,
"app": self.app_name, "app": self.app_name,
} })
else: else:
# There's more than one notification for this room, so just # There's more than one notification for this room, so just
# say there are several # say there are several
if room_name is not None: if room_name is not None:
return MESSAGES_IN_ROOM % { defer.returnValue(MESSAGES_IN_ROOM % {
"room": room_name, "room": room_name,
"app": self.app_name, "app": self.app_name,
} })
else: else:
# If the room doesn't have a name, say who the messages # If the room doesn't have a name, say who the messages
# are from explicitly to avoid, "messages in the Bob room" # are from explicitly to avoid, "messages in the Bob room"
@ -388,22 +395,22 @@ class Mailer(object):
for n in notifs_by_room[room_id] for n in notifs_by_room[room_id]
])) ]))
return MESSAGES_FROM_PERSON % { defer.returnValue(MESSAGES_FROM_PERSON % {
"person": descriptor_from_member_events([ "person": descriptor_from_member_events([
state_by_room[room_id][("m.room.member", s)] state_by_room[room_id][("m.room.member", s)]
for s in sender_ids for s in sender_ids
]), ]),
"app": self.app_name, "app": self.app_name,
} })
else: else:
# Stuff's happened in multiple different rooms # Stuff's happened in multiple different rooms
# ...but we still refer to the 'reason' room which triggered the mail # ...but we still refer to the 'reason' room which triggered the mail
if reason['room_name'] is not None: if reason['room_name'] is not None:
return MESSAGES_IN_ROOM_AND_OTHERS % { defer.returnValue(MESSAGES_IN_ROOM_AND_OTHERS % {
"room": reason['room_name'], "room": reason['room_name'],
"app": self.app_name, "app": self.app_name,
} })
else: else:
# If the reason room doesn't have a name, say who the messages # If the reason room doesn't have a name, say who the messages
# are from explicitly to avoid, "messages in the Bob room" # are from explicitly to avoid, "messages in the Bob room"
@ -412,13 +419,13 @@ class Mailer(object):
for n in notifs_by_room[reason['room_id']] for n in notifs_by_room[reason['room_id']]
])) ]))
return MESSAGES_FROM_PERSON_AND_OTHERS % { defer.returnValue(MESSAGES_FROM_PERSON_AND_OTHERS % {
"person": descriptor_from_member_events([ "person": descriptor_from_member_events([
state_by_room[reason['room_id']][("m.room.member", s)] state_by_room[reason['room_id']][("m.room.member", s)]
for s in sender_ids for s in sender_ids
]), ]),
"app": self.app_name, "app": self.app_name,
} })
def make_room_link(self, room_id): def make_room_link(self, room_id):
# need /beta for Universal Links to work on iOS # need /beta for Universal Links to work on iOS

View file

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
import re import re
import logging import logging
@ -25,7 +27,8 @@ ALIAS_RE = re.compile(r"^#.*:.+$")
ALL_ALONE = "Empty Room" ALL_ALONE = "Empty Room"
def calculate_room_name(room_state, user_id, fallback_to_members=True, @defer.inlineCallbacks
def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True,
fallback_to_single_member=True): fallback_to_single_member=True):
""" """
Works out a user-facing name for the given room as per Matrix Works out a user-facing name for the given room as per Matrix
@ -42,59 +45,78 @@ def calculate_room_name(room_state, user_id, fallback_to_members=True,
(string or None) A human readable name for the room. (string or None) A human readable name for the room.
""" """
# does it have a name? # does it have a name?
if ("m.room.name", "") in room_state: if ("m.room.name", "") in room_state_ids:
m_room_name = room_state[("m.room.name", "")] m_room_name = yield store.get_event(
if m_room_name.content and m_room_name.content["name"]: room_state_ids[("m.room.name", "")], allow_none=True
return m_room_name.content["name"] )
if m_room_name and m_room_name.content and m_room_name.content["name"]:
defer.returnValue(m_room_name.content["name"])
# does it have a canonical alias? # does it have a canonical alias?
if ("m.room.canonical_alias", "") in room_state: if ("m.room.canonical_alias", "") in room_state_ids:
canon_alias = room_state[("m.room.canonical_alias", "")] canon_alias = yield store.get_event(
room_state_ids[("m.room.canonical_alias", "")], allow_none=True
)
if ( if (
canon_alias.content and canon_alias.content["alias"] and canon_alias and canon_alias.content and canon_alias.content["alias"] and
_looks_like_an_alias(canon_alias.content["alias"]) _looks_like_an_alias(canon_alias.content["alias"])
): ):
return canon_alias.content["alias"] defer.returnValue(canon_alias.content["alias"])
# at this point we're going to need to search the state by all state keys # at this point we're going to need to search the state by all state keys
# for an event type, so rearrange the data structure # for an event type, so rearrange the data structure
room_state_bytype = _state_as_two_level_dict(room_state) room_state_bytype_ids = _state_as_two_level_dict(room_state_ids)
# right then, any aliases at all? # right then, any aliases at all?
if "m.room.aliases" in room_state_bytype: if "m.room.aliases" in room_state_bytype_ids:
m_room_aliases = room_state_bytype["m.room.aliases"] m_room_aliases = room_state_bytype_ids["m.room.aliases"]
if len(m_room_aliases.values()) > 0: for alias_id in m_room_aliases.values():
first_alias_event = m_room_aliases.values()[0] alias_event = yield store.get_event(
if first_alias_event.content and first_alias_event.content["aliases"]: alias_id, allow_none=True
the_aliases = first_alias_event.content["aliases"] )
if alias_event and alias_event.content.get("aliases"):
the_aliases = alias_event.content["aliases"]
if len(the_aliases) > 0 and _looks_like_an_alias(the_aliases[0]): if len(the_aliases) > 0 and _looks_like_an_alias(the_aliases[0]):
return the_aliases[0] defer.returnValue(the_aliases[0])
if not fallback_to_members: if not fallback_to_members:
return None defer.returnValue(None)
my_member_event = None my_member_event = None
if ("m.room.member", user_id) in room_state: if ("m.room.member", user_id) in room_state_ids:
my_member_event = room_state[("m.room.member", user_id)] my_member_event = yield store.get_event(
room_state_ids[("m.room.member", user_id)], allow_none=True
)
if ( if (
my_member_event is not None and my_member_event is not None and
my_member_event.content['membership'] == "invite" my_member_event.content['membership'] == "invite"
): ):
if ("m.room.member", my_member_event.sender) in room_state: if ("m.room.member", my_member_event.sender) in room_state_ids:
inviter_member_event = room_state[("m.room.member", my_member_event.sender)] inviter_member_event = yield store.get_event(
room_state_ids[("m.room.member", my_member_event.sender)],
allow_none=True,
)
if inviter_member_event:
if fallback_to_single_member: if fallback_to_single_member:
return "Invite from %s" % (name_from_member_event(inviter_member_event),) defer.returnValue(
"Invite from %s" % (
name_from_member_event(inviter_member_event),
)
)
else: else:
return None return
else: else:
return "Room Invite" defer.returnValue("Room Invite")
# we're going to have to generate a name based on who's in the room, # we're going to have to generate a name based on who's in the room,
# so find out who is in the room that isn't the user. # so find out who is in the room that isn't the user.
if "m.room.member" in room_state_bytype: if "m.room.member" in room_state_bytype_ids:
member_events = yield store.get_events(
room_state_bytype_ids["m.room.member"].values()
)
all_members = [ all_members = [
ev for ev in room_state_bytype["m.room.member"].values() ev for ev in member_events.values()
if ev.content['membership'] == "join" or ev.content['membership'] == "invite" if ev.content['membership'] == "join" or ev.content['membership'] == "invite"
] ]
# Sort the member events oldest-first so the we name people in the # Sort the member events oldest-first so the we name people in the
@ -111,9 +133,9 @@ def calculate_room_name(room_state, user_id, fallback_to_members=True,
# self-chat, peeked room with 1 participant, # self-chat, peeked room with 1 participant,
# or inbound invite, or outbound 3PID invite. # or inbound invite, or outbound 3PID invite.
if all_members[0].sender == user_id: if all_members[0].sender == user_id:
if "m.room.third_party_invite" in room_state_bytype: if "m.room.third_party_invite" in room_state_bytype_ids:
third_party_invites = ( third_party_invites = (
room_state_bytype["m.room.third_party_invite"].values() room_state_bytype_ids["m.room.third_party_invite"].values()
) )
if len(third_party_invites) > 0: if len(third_party_invites) > 0:
@ -126,17 +148,17 @@ def calculate_room_name(room_state, user_id, fallback_to_members=True,
# return "Inviting %s" % ( # return "Inviting %s" % (
# descriptor_from_member_events(third_party_invites) # descriptor_from_member_events(third_party_invites)
# ) # )
return "Inviting email address" defer.returnValue("Inviting email address")
else: else:
return ALL_ALONE defer.returnValue(ALL_ALONE)
else: else:
return name_from_member_event(all_members[0]) defer.returnValue(name_from_member_event(all_members[0]))
else: else:
return ALL_ALONE defer.returnValue(ALL_ALONE)
elif len(other_members) == 1 and not fallback_to_single_member: elif len(other_members) == 1 and not fallback_to_single_member:
return None return
else: else:
return descriptor_from_member_events(other_members) defer.returnValue(descriptor_from_member_events(other_members))
def descriptor_from_member_events(member_events): def descriptor_from_member_events(member_events):

View file

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from synapse.util.presentable_names import ( from synapse.push.presentable_names import (
calculate_room_name, name_from_member_event calculate_room_name, name_from_member_event
) )
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
@ -49,21 +49,22 @@ def get_badge_count(store, user_id):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_context_for_event(state_handler, ev, user_id): def get_context_for_event(store, state_handler, ev, user_id):
ctx = {} ctx = {}
room_state = yield state_handler.get_current_state(ev.room_id) room_state_ids = yield state_handler.get_current_state_ids(ev.room_id)
# we no longer bother setting room_alias, and make room_name the # we no longer bother setting room_alias, and make room_name the
# human-readable name instead, be that m.room.name, an alias or # human-readable name instead, be that m.room.name, an alias or
# a list of people in the room # a list of people in the room
name = calculate_room_name( name = yield calculate_room_name(
room_state, user_id, fallback_to_single_member=False store, room_state_ids, user_id, fallback_to_single_member=False
) )
if name: if name:
ctx['name'] = name ctx['name'] = name
sender_state_event = room_state[("m.room.member", ev.sender)] sender_state_event_id = room_state_ids[("m.room.member", ev.sender)]
sender_state_event = yield store.get_event(sender_state_event_id)
ctx['sender_display_name'] = name_from_member_event(sender_state_event) ctx['sender_display_name'] = name_from_member_event(sender_state_event)
defer.returnValue(ctx) defer.returnValue(ctx)

View file

@ -40,8 +40,8 @@ STREAM_NAMES = (
("backfill",), ("backfill",),
("push_rules",), ("push_rules",),
("pushers",), ("pushers",),
("state",),
("caches",), ("caches",),
("to_device",),
) )
@ -130,7 +130,6 @@ class ReplicationResource(Resource):
backfill_token = yield self.store.get_current_backfill_token() backfill_token = yield self.store.get_current_backfill_token()
push_rules_token, room_stream_token = self.store.get_push_rules_stream_token() push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
pushers_token = self.store.get_pushers_stream_token() pushers_token = self.store.get_pushers_stream_token()
state_token = self.store.get_state_stream_token()
caches_token = self.store.get_cache_stream_token() caches_token = self.store.get_cache_stream_token()
defer.returnValue(_ReplicationToken( defer.returnValue(_ReplicationToken(
@ -142,8 +141,9 @@ class ReplicationResource(Resource):
backfill_token, backfill_token,
push_rules_token, push_rules_token,
pushers_token, pushers_token,
state_token, 0, # State stream is no longer a thing
caches_token, caches_token,
int(stream_token.to_device_key),
)) ))
@request_handler() @request_handler()
@ -191,8 +191,8 @@ class ReplicationResource(Resource):
yield self.receipts(writer, current_token, limit, request_streams) yield self.receipts(writer, current_token, limit, request_streams)
yield self.push_rules(writer, current_token, limit, request_streams) yield self.push_rules(writer, current_token, limit, request_streams)
yield self.pushers(writer, current_token, limit, request_streams) yield self.pushers(writer, current_token, limit, request_streams)
yield self.state(writer, current_token, limit, request_streams)
yield self.caches(writer, current_token, limit, request_streams) yield self.caches(writer, current_token, limit, request_streams)
yield self.to_device(writer, current_token, limit, request_streams)
self.streams(writer, current_token, request_streams) self.streams(writer, current_token, request_streams)
logger.info("Replicated %d rows", writer.total) logger.info("Replicated %d rows", writer.total)
@ -365,25 +365,6 @@ class ReplicationResource(Resource):
"position", "user_id", "app_id", "pushkey" "position", "user_id", "app_id", "pushkey"
)) ))
@defer.inlineCallbacks
def state(self, writer, current_token, limit, request_streams):
current_position = current_token.state
state = request_streams.get("state")
if state is not None:
state_groups, state_group_state = (
yield self.store.get_all_new_state_groups(
state, current_position, limit
)
)
writer.write_header_and_rows("state_groups", state_groups, (
"position", "room_id", "event_id"
))
writer.write_header_and_rows("state_group_state", state_group_state, (
"position", "type", "state_key", "event_id"
))
@defer.inlineCallbacks @defer.inlineCallbacks
def caches(self, writer, current_token, limit, request_streams): def caches(self, writer, current_token, limit, request_streams):
current_position = current_token.caches current_position = current_token.caches
@ -398,6 +379,20 @@ class ReplicationResource(Resource):
"position", "cache_func", "keys", "invalidation_ts" "position", "cache_func", "keys", "invalidation_ts"
)) ))
@defer.inlineCallbacks
def to_device(self, writer, current_token, limit, request_streams):
current_position = current_token.to_device
to_device = request_streams.get("to_device")
if to_device is not None:
to_device_rows = yield self.store.get_all_new_device_messages(
to_device, current_position, limit
)
writer.write_header_and_rows("to_device", to_device_rows, (
"position", "user_id", "device_id", "message_json"
))
class _Writer(object): class _Writer(object):
"""Writes the streams as a JSON object as the response to the request""" """Writes the streams as a JSON object as the response to the request"""
@ -426,7 +421,7 @@ class _Writer(object):
class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
"events", "presence", "typing", "receipts", "account_data", "backfill", "events", "presence", "typing", "receipts", "account_data", "backfill",
"push_rules", "pushers", "state", "caches", "push_rules", "pushers", "state", "caches", "to_device",
))): ))):
__slots__ = [] __slots__ = []

View file

@ -0,0 +1,42 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
class SlavedDeviceInboxStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedDeviceInboxStore, self).__init__(db_conn, hs)
self._device_inbox_id_gen = SlavedIdTracker(
db_conn, "device_inbox", "stream_id",
)
get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__
get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__
delete_messages_for_device = DataStore.delete_messages_for_device.__func__
def stream_positions(self):
result = super(SlavedDeviceInboxStore, self).stream_positions()
result["to_device"] = self._device_inbox_id_gen.get_current_token()
return result
def process_replication(self, result):
stream = result.get("to_device")
if stream:
self._device_inbox_id_gen.advance(int(stream["position"]))
return super(SlavedDeviceInboxStore, self).process_replication(result)

View file

@ -120,10 +120,21 @@ class SlavedEventStore(BaseSlavedStore):
get_state_for_event = DataStore.get_state_for_event.__func__ get_state_for_event = DataStore.get_state_for_event.__func__
get_state_for_events = DataStore.get_state_for_events.__func__ get_state_for_events = DataStore.get_state_for_events.__func__
get_state_groups = DataStore.get_state_groups.__func__ get_state_groups = DataStore.get_state_groups.__func__
get_state_groups_ids = DataStore.get_state_groups_ids.__func__
get_state_ids_for_event = DataStore.get_state_ids_for_event.__func__
get_state_ids_for_events = DataStore.get_state_ids_for_events.__func__
get_joined_users_from_state = DataStore.get_joined_users_from_state.__func__
get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__
_get_joined_users_from_context = (
RoomMemberStore.__dict__["_get_joined_users_from_context"]
)
get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__ get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__
get_room_events_stream_for_rooms = ( get_room_events_stream_for_rooms = (
DataStore.get_room_events_stream_for_rooms.__func__ DataStore.get_room_events_stream_for_rooms.__func__
) )
is_host_joined = DataStore.is_host_joined.__func__
_is_host_joined = RoomMemberStore.__dict__["_is_host_joined"]
get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__ get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__
_set_before_and_after = staticmethod(DataStore._set_before_and_after) _set_before_and_after = staticmethod(DataStore._set_before_and_after)
@ -211,7 +222,6 @@ class SlavedEventStore(BaseSlavedStore):
self._get_current_state_for_key.invalidate_all() self._get_current_state_for_key.invalidate_all()
self.get_rooms_for_user.invalidate_all() self.get_rooms_for_user.invalidate_all()
self.get_users_in_room.invalidate((event.room_id,)) self.get_users_in_room.invalidate((event.room_id,))
# self.get_joined_hosts_for_room.invalidate((event.room_id,))
self._invalidate_get_event_cache(event.event_id) self._invalidate_get_event_cache(event.event_id)
@ -235,7 +245,6 @@ class SlavedEventStore(BaseSlavedStore):
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
self.get_rooms_for_user.invalidate((event.state_key,)) self.get_rooms_for_user.invalidate((event.state_key,))
# self.get_joined_hosts_for_room.invalidate((event.room_id,))
self.get_users_in_room.invalidate((event.room_id,)) self.get_users_in_room.invalidate((event.room_id,))
self._membership_stream_cache.entity_has_changed( self._membership_stream_cache.entity_has_changed(
event.state_key, event.internal_metadata.stream_ordering event.state_key, event.internal_metadata.stream_ordering

View file

@ -49,6 +49,7 @@ from synapse.rest.client.v2_alpha import (
notifications, notifications,
devices, devices,
thirdparty, thirdparty,
sendtodevice,
) )
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
@ -96,3 +97,4 @@ class ClientRestResource(JsonResource):
notifications.register_servlets(hs, client_resource) notifications.register_servlets(hs, client_resource)
devices.register_servlets(hs, client_resource) devices.register_servlets(hs, client_resource)
thirdparty.register_servlets(hs, client_resource) thirdparty.register_servlets(hs, client_resource)
sendtodevice.register_servlets(hs, client_resource)

View file

@ -0,0 +1,90 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.http.servlet import parse_json_object_from_request
from synapse.http import servlet
from synapse.rest.client.v1.transactions import HttpTransactionStore
from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
class SendToDeviceRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns(
"/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$",
releases=[], v2_alpha=False
)
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(SendToDeviceRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
self.txns = HttpTransactionStore()
@defer.inlineCallbacks
def on_PUT(self, request, message_type, txn_id):
try:
defer.returnValue(
self.txns.get_client_transaction(request, txn_id)
)
except KeyError:
pass
requester = yield self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
# TODO: Prod the notifier to wake up sync streams.
# TODO: Implement replication for the messages.
# TODO: Send the messages to remote servers if needed.
local_messages = {}
for user_id, by_device in content["messages"].items():
if self.is_mine_id(user_id):
messages_by_device = {
device_id: {
"content": message_content,
"type": message_type,
"sender": requester.user.to_string(),
}
for device_id, message_content in by_device.items()
}
if messages_by_device:
local_messages[user_id] = messages_by_device
stream_id = yield self.store.add_messages_to_device_inbox(local_messages)
self.notifier.on_new_event(
"to_device_key", stream_id, users=local_messages.keys()
)
response = (200, {})
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
def register_servlets(hs, http_server):
SendToDeviceRestServlet(hs).register(http_server)

View file

@ -97,6 +97,7 @@ class SyncRestServlet(RestServlet):
request, allow_guest=True request, allow_guest=True
) )
user = requester.user user = requester.user
device_id = requester.device_id
timeout = parse_integer(request, "timeout", default=0) timeout = parse_integer(request, "timeout", default=0)
since = parse_string(request, "since") since = parse_string(request, "since")
@ -109,12 +110,12 @@ class SyncRestServlet(RestServlet):
logger.info( logger.info(
"/sync: user=%r, timeout=%r, since=%r," "/sync: user=%r, timeout=%r, since=%r,"
" set_presence=%r, filter_id=%r" % ( " set_presence=%r, filter_id=%r, device_id=%r" % (
user, timeout, since, set_presence, filter_id user, timeout, since, set_presence, filter_id, device_id
) )
) )
request_key = (user, timeout, since, filter_id, full_state) request_key = (user, timeout, since, filter_id, full_state, device_id)
if filter_id: if filter_id:
if filter_id.startswith('{'): if filter_id.startswith('{'):
@ -136,6 +137,7 @@ class SyncRestServlet(RestServlet):
filter_collection=filter, filter_collection=filter,
is_guest=requester.is_guest, is_guest=requester.is_guest,
request_key=request_key, request_key=request_key,
device_id=device_id,
) )
if since is not None: if since is not None:
@ -173,6 +175,7 @@ class SyncRestServlet(RestServlet):
response_content = { response_content = {
"account_data": {"events": sync_result.account_data}, "account_data": {"events": sync_result.account_data},
"to_device": {"events": sync_result.to_device},
"presence": self.encode_presence( "presence": self.encode_presence(
sync_result.presence, time_now sync_result.presence, time_now
), ),

View file

@ -18,15 +18,32 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import ThirdPartyEntityKind
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from synapse.types import ThirdPartyEntityKind
from ._base import client_v2_patterns from ._base import client_v2_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ThirdPartyProtocolsServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/protocols", releases=())
def __init__(self, hs):
super(ThirdPartyProtocolsServlet, self).__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
@defer.inlineCallbacks
def on_GET(self, request):
yield self.auth.get_user_by_req(request)
protocols = yield self.appservice_handler.get_3pe_protocols()
defer.returnValue((200, protocols))
class ThirdPartyUserServlet(RestServlet): class ThirdPartyUserServlet(RestServlet):
PATTERNS = client_v2_patterns("/3pu(/(?P<protocol>[^/]+))?$", PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$",
releases=()) releases=())
def __init__(self, hs): def __init__(self, hs):
@ -50,7 +67,7 @@ class ThirdPartyUserServlet(RestServlet):
class ThirdPartyLocationServlet(RestServlet): class ThirdPartyLocationServlet(RestServlet):
PATTERNS = client_v2_patterns("/3pl(/(?P<protocol>[^/]+))?$", PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$",
releases=()) releases=())
def __init__(self, hs): def __init__(self, hs):
@ -74,5 +91,6 @@ class ThirdPartyLocationServlet(RestServlet):
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
ThirdPartyProtocolsServlet(hs).register(http_server)
ThirdPartyUserServlet(hs).register(http_server) ThirdPartyUserServlet(hs).register(http_server)
ThirdPartyLocationServlet(hs).register(http_server) ThirdPartyLocationServlet(hs).register(http_server)

View file

@ -23,6 +23,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.api.auth import AuthEventTypes from synapse.api.auth import AuthEventTypes
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.util.async import Linearizer
from collections import namedtuple from collections import namedtuple
@ -43,11 +44,35 @@ SIZE_OF_CACHE = int(1000 * CACHE_SIZE_FACTOR)
EVICTION_TIMEOUT_SECONDS = 60 * 60 EVICTION_TIMEOUT_SECONDS = 60 * 60
_NEXT_STATE_ID = 1
def _gen_state_id():
global _NEXT_STATE_ID
s = "X%d" % (_NEXT_STATE_ID,)
_NEXT_STATE_ID += 1
return s
class _StateCacheEntry(object): class _StateCacheEntry(object):
def __init__(self, state, state_group, ts): __slots__ = ["state", "state_group", "state_id"]
def __init__(self, state, state_group):
self.state = state self.state = state
self.state_group = state_group self.state_group = state_group
# The `state_id` is a unique ID we generate that can be used as ID for
# this collection of state. Usually this would be the same as the
# state group, but on worker instances we can't generate a new state
# group each time we resolve state, so we generate a separate one that
# isn't persisted and is used solely for caches.
# `state_id` is either a state_group (and so an int) or a string. This
# ensures we don't accidentally persist a state_id as a stateg_group
if state_group:
self.state_id = state_group
else:
self.state_id = _gen_state_id()
class StateHandler(object): class StateHandler(object):
""" Responsible for doing state conflict resolution. """ Responsible for doing state conflict resolution.
@ -60,6 +85,7 @@ class StateHandler(object):
# dict of set of event_ids -> _StateCacheEntry. # dict of set of event_ids -> _StateCacheEntry.
self._state_cache = None self._state_cache = None
self.resolve_linearizer = Linearizer()
def start_caching(self): def start_caching(self):
logger.debug("start_caching") logger.debug("start_caching")
@ -93,8 +119,32 @@ class StateHandler(object):
if not latest_event_ids: if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
res = yield self.resolve_state_groups(room_id, latest_event_ids) ret = yield self.resolve_state_groups(room_id, latest_event_ids)
state = res[1] state = ret.state
if event_type:
event_id = state.get((event_type, state_key))
event = None
if event_id:
event = yield self.store.get_event(event_id, allow_none=True)
defer.returnValue(event)
return
state_map = yield self.store.get_events(state.values(), get_prev_content=False)
state = {
key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
}
defer.returnValue(state)
@defer.inlineCallbacks
def get_current_state_ids(self, room_id, event_type=None, state_key="",
latest_event_ids=None):
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
ret = yield self.resolve_state_groups(room_id, latest_event_ids)
state = ret.state
if event_type: if event_type:
defer.returnValue(state.get((event_type, state_key))) defer.returnValue(state.get((event_type, state_key)))
@ -102,6 +152,15 @@ class StateHandler(object):
defer.returnValue(state) defer.returnValue(state)
@defer.inlineCallbacks
def get_current_user_in_room(self, room_id):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
entry = yield self.resolve_state_groups(room_id, latest_event_ids)
joined_users = yield self.store.get_joined_users_from_state(
room_id, entry.state_id, entry.state
)
defer.returnValue(joined_users)
@defer.inlineCallbacks @defer.inlineCallbacks
def compute_event_context(self, event, old_state=None): def compute_event_context(self, event, old_state=None):
""" Fills out the context with the `current state` of the graph. The """ Fills out the context with the `current state` of the graph. The
@ -123,54 +182,75 @@ class StateHandler(object):
# state. Certainly store.get_current_state won't return any, and # state. Certainly store.get_current_state won't return any, and
# persisting the event won't store the state group. # persisting the event won't store the state group.
if old_state: if old_state:
context.current_state = { context.prev_state_ids = {
(s.type, s.state_key): s for s in old_state (s.type, s.state_key): s.event_id for s in old_state
} }
if event.is_state():
context.current_state_events = dict(context.prev_state_ids)
key = (event.type, event.state_key)
context.current_state_events[key] = event.event_id
else: else:
context.current_state = {} context.current_state_events = context.prev_state_ids
else:
context.current_state_ids = {}
context.prev_state_ids = {}
context.prev_state_events = [] context.prev_state_events = []
context.state_group = None context.state_group = self.store.get_next_state_group()
defer.returnValue(context) defer.returnValue(context)
if old_state: if old_state:
context.current_state = { context.prev_state_ids = {
(s.type, s.state_key): s for s in old_state (s.type, s.state_key): s.event_id for s in old_state
} }
context.state_group = None context.state_group = self.store.get_next_state_group()
if event.is_state(): if event.is_state():
key = (event.type, event.state_key) key = (event.type, event.state_key)
if key in context.current_state: if key in context.prev_state_ids:
replaces = context.current_state[key] replaces = context.prev_state_ids[key]
if replaces.event_id != event.event_id: # Paranoia check if replaces != event.event_id: # Paranoia check
event.unsigned["replaces_state"] = replaces.event_id event.unsigned["replaces_state"] = replaces
context.current_state_ids = dict(context.prev_state_ids)
context.current_state_ids[key] = event.event_id
else:
context.current_state_ids = context.prev_state_ids
context.prev_state_events = [] context.prev_state_events = []
defer.returnValue(context) defer.returnValue(context)
if event.is_state(): if event.is_state():
ret = yield self.resolve_state_groups( entry = yield self.resolve_state_groups(
event.room_id, [e for e, _ in event.prev_events], event.room_id, [e for e, _ in event.prev_events],
event_type=event.type, event_type=event.type,
state_key=event.state_key, state_key=event.state_key,
) )
else: else:
ret = yield self.resolve_state_groups( entry = yield self.resolve_state_groups(
event.room_id, [e for e, _ in event.prev_events], event.room_id, [e for e, _ in event.prev_events],
) )
group, curr_state, prev_state = ret curr_state = entry.state
context.current_state = curr_state context.prev_state_ids = curr_state
context.state_group = group if not event.is_state() else None if event.is_state():
context.state_group = self.store.get_next_state_group()
else:
if entry.state_group is None:
entry.state_group = self.store.get_next_state_group()
entry.state_id = entry.state_group
context.state_group = entry.state_group
if event.is_state(): if event.is_state():
key = (event.type, event.state_key) key = (event.type, event.state_key)
if key in context.current_state: if key in context.prev_state_ids:
replaces = context.current_state[key] replaces = context.prev_state_ids[key]
event.unsigned["replaces_state"] = replaces.event_id event.unsigned["replaces_state"] = replaces
context.current_state_ids = dict(context.prev_state_ids)
context.current_state_ids[key] = event.event_id
else:
context.current_state_ids = context.prev_state_ids
context.prev_state_events = prev_state context.prev_state_events = []
defer.returnValue(context) defer.returnValue(context)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -187,72 +267,88 @@ class StateHandler(object):
""" """
logger.debug("resolve_state_groups event_ids %s", event_ids) logger.debug("resolve_state_groups event_ids %s", event_ids)
state_groups = yield self.store.get_state_groups( state_groups_ids = yield self.store.get_state_groups_ids(
room_id, event_ids room_id, event_ids
) )
logger.debug( logger.debug(
"resolve_state_groups state_groups %s", "resolve_state_groups state_groups %s",
state_groups.keys() state_groups_ids.keys()
) )
group_names = frozenset(state_groups.keys()) group_names = frozenset(state_groups_ids.keys())
if len(group_names) == 1: if len(group_names) == 1:
name, state_list = state_groups.items().pop() name, state_list = state_groups_ids.items().pop()
state = {
(e.type, e.state_key): e
for e in state_list
}
prev_state = state.get((event_type, state_key), None)
if prev_state:
prev_state = prev_state.event_id
prev_states = [prev_state]
else:
prev_states = []
defer.returnValue((name, state, prev_states)) defer.returnValue(_StateCacheEntry(
state=state_list,
state_group=name,
))
with (yield self.resolve_linearizer.queue(group_names)):
if self._state_cache is not None: if self._state_cache is not None:
cache = self._state_cache.get(group_names, None) cache = self._state_cache.get(group_names, None)
if cache: if cache:
cache.ts = self.clock.time_msec() defer.returnValue(cache)
event_dict = yield self.store.get_events(cache.state.values()) logger.info(
state = {(e.type, e.state_key): e for e in event_dict.values()} "Resolving state for %s with %d groups", room_id, len(state_groups_ids)
)
prev_state = state.get((event_type, state_key), None) state = {}
if prev_state: for st in state_groups_ids.values():
prev_state = prev_state.event_id for key, e_id in st.items():
prev_states = [prev_state] state.setdefault(key, set()).add(e_id)
conflicted_state = {
k: list(v)
for k, v in state.items()
if len(v) > 1
}
if conflicted_state:
logger.info("Resolving conflicted state for %r", room_id)
state_map = yield self.store.get_events(
[e_id for st in state_groups_ids.values() for e_id in st.values()],
get_prev_content=False
)
state_sets = [
[state_map[e_id] for key, e_id in st.items() if e_id in state_map]
for st in state_groups_ids.values()
]
new_state, _ = self._resolve_events(
state_sets, event_type, state_key
)
new_state = {
key: e.event_id for key, e in new_state.items()
}
else: else:
prev_states = [] new_state = {
defer.returnValue( key: e_ids.pop() for key, e_ids in state.items()
(cache.state_group, state, prev_states) }
)
logger.info("Resolving state for %s with %d groups", room_id, len(state_groups))
new_state, prev_states = self._resolve_events(
state_groups.values(), event_type, state_key
)
state_group = None state_group = None
new_state_event_ids = frozenset(e.event_id for e in new_state.values()) new_state_event_ids = frozenset(new_state.values())
for sg, events in state_groups.items(): for sg, events in state_groups_ids.items():
if new_state_event_ids == frozenset(e.event_id for e in events): if new_state_event_ids == frozenset(e_id for e_id in events):
state_group = sg state_group = sg
break break
if state_group is None:
# Worker instances don't have access to this method, but we want
# to set the state_group on the main instance to increase cache
# hits.
if hasattr(self.store, "get_next_state_group"):
state_group = self.store.get_next_state_group()
if self._state_cache is not None:
cache = _StateCacheEntry( cache = _StateCacheEntry(
state={key: event.event_id for key, event in new_state.items()}, state=new_state,
state_group=state_group, state_group=state_group,
ts=self.clock.time_msec()
) )
if self._state_cache is not None:
self._state_cache[group_names] = cache self._state_cache[group_names] = cache
defer.returnValue((state_group, new_state, prev_states)) defer.returnValue(cache)
def resolve_events(self, state_sets, event): def resolve_events(self, state_sets, event):
logger.info( logger.info(

View file

@ -36,6 +36,7 @@ from .push_rule import PushRuleStore
from .media_repository import MediaRepositoryStore from .media_repository import MediaRepositoryStore
from .rejections import RejectionsStore from .rejections import RejectionsStore
from .event_push_actions import EventPushActionsStore from .event_push_actions import EventPushActionsStore
from .deviceinbox import DeviceInboxStore
from .state import StateStore from .state import StateStore
from .signatures import SignatureStore from .signatures import SignatureStore
@ -84,6 +85,7 @@ class DataStore(RoomMemberStore, RoomStore,
OpenIdStore, OpenIdStore,
ClientIpStore, ClientIpStore,
DeviceStore, DeviceStore,
DeviceInboxStore,
): ):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
@ -108,9 +110,12 @@ class DataStore(RoomMemberStore, RoomStore,
self._presence_id_gen = StreamIdGenerator( self._presence_id_gen = StreamIdGenerator(
db_conn, "presence_stream", "stream_id" db_conn, "presence_stream", "stream_id"
) )
self._device_inbox_id_gen = StreamIdGenerator(
db_conn, "device_inbox", "stream_id"
)
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
self._state_groups_id_gen = StreamIdGenerator(db_conn, "state_groups", "id") self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")

View file

@ -0,0 +1,184 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import ujson
from twisted.internet import defer
from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
class DeviceInboxStore(SQLBaseStore):
@defer.inlineCallbacks
def add_messages_to_device_inbox(self, messages_by_user_then_device):
"""
Args:
messages_by_user_and_device(dict):
Dictionary of user_id to device_id to message.
Returns:
A deferred stream_id that resolves when the messages have been
inserted.
"""
def select_devices_txn(txn, user_id, devices):
if not devices:
return []
sql = (
"SELECT user_id, device_id FROM devices"
" WHERE user_id = ? AND device_id IN ("
+ ",".join("?" * len(devices))
+ ")"
)
# TODO: Maybe this needs to be done in batches if there are
# too many local devices for a given user.
args = [user_id] + devices
txn.execute(sql, args)
return [tuple(row) for row in txn.fetchall()]
def add_messages_to_device_inbox_txn(txn, stream_id):
local_users_and_devices = set()
for user_id, messages_by_device in messages_by_user_then_device.items():
local_users_and_devices.update(
select_devices_txn(txn, user_id, messages_by_device.keys())
)
sql = (
"INSERT INTO device_inbox"
" (user_id, device_id, stream_id, message_json)"
" VALUES (?,?,?,?)"
)
rows = []
for user_id, messages_by_device in messages_by_user_then_device.items():
for device_id, message in messages_by_device.items():
message_json = ujson.dumps(message)
# Only insert into the local inbox if the device exists on
# this server
if (user_id, device_id) in local_users_and_devices:
rows.append((user_id, device_id, stream_id, message_json))
txn.executemany(sql, rows)
with self._device_inbox_id_gen.get_next() as stream_id:
yield self.runInteraction(
"add_messages_to_device_inbox",
add_messages_to_device_inbox_txn,
stream_id
)
defer.returnValue(self._device_inbox_id_gen.get_current_token())
def get_new_messages_for_device(
self, user_id, device_id, last_stream_id, current_stream_id, limit=100
):
"""
Args:
user_id(str): The recipient user_id.
device_id(str): The recipient device_id.
current_stream_id(int): The current position of the to device
message stream.
Returns:
Deferred ([dict], int): List of messages for the device and where
in the stream the messages got to.
"""
def get_new_messages_for_device_txn(txn):
sql = (
"SELECT stream_id, message_json FROM device_inbox"
" WHERE user_id = ? AND device_id = ?"
" AND ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC"
" LIMIT ?"
)
txn.execute(sql, (
user_id, device_id, last_stream_id, current_stream_id, limit
))
messages = []
for row in txn.fetchall():
stream_pos = row[0]
messages.append(ujson.loads(row[1]))
if len(messages) < limit:
stream_pos = current_stream_id
return (messages, stream_pos)
return self.runInteraction(
"get_new_messages_for_device", get_new_messages_for_device_txn,
)
def delete_messages_for_device(self, user_id, device_id, up_to_stream_id):
"""
Args:
user_id(str): The recipient user_id.
device_id(str): The recipient device_id.
up_to_stream_id(int): Where to delete messages up to.
Returns:
A deferred that resolves when the messages have been deleted.
"""
def delete_messages_for_device_txn(txn):
sql = (
"DELETE FROM device_inbox"
" WHERE user_id = ? AND device_id = ?"
" AND stream_id <= ?"
)
txn.execute(sql, (user_id, device_id, up_to_stream_id))
return self.runInteraction(
"delete_messages_for_device", delete_messages_for_device_txn
)
def get_all_new_device_messages(self, last_pos, current_pos, limit):
"""
Args:
last_pos(int):
current_pos(int):
limit(int):
Returns:
A deferred list of rows from the device inbox
"""
if last_pos == current_pos:
return defer.succeed([])
def get_all_new_device_messages_txn(txn):
sql = (
"SELECT stream_id FROM device_inbox"
" WHERE ? < stream_id AND stream_id <= ?"
" GROUP BY stream_id"
" ORDER BY stream_id ASC"
" LIMIT ?"
)
txn.execute(sql, (last_pos, current_pos, limit))
stream_ids = txn.fetchall()
if not stream_ids:
return []
max_stream_id_in_limit = stream_ids[-1]
sql = (
"SELECT stream_id, user_id, device_id, message_json"
" FROM device_inbox"
" WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC"
)
txn.execute(sql, (last_pos, max_stream_id_in_limit))
return txn.fetchall()
return self.runInteraction(
"get_all_new_device_messages", get_all_new_device_messages_txn
)
def get_to_device_stream_token(self):
return self._device_inbox_id_gen.get_current_token()

View file

@ -271,22 +271,11 @@ class EventsStore(SQLBaseStore):
len(events_and_contexts) len(events_and_contexts)
) )
state_group_id_manager = self._state_groups_id_gen.get_next_mult(
len(events_and_contexts)
)
with stream_ordering_manager as stream_orderings: with stream_ordering_manager as stream_orderings:
with state_group_id_manager as state_group_ids: for (event, context), stream, in zip(
for (event, context), stream, state_group_id in zip( events_and_contexts, stream_orderings
events_and_contexts, stream_orderings, state_group_ids
): ):
event.internal_metadata.stream_ordering = stream event.internal_metadata.stream_ordering = stream
# Assign a state group_id in case a new id is needed for
# this context. In theory we only need to assign this
# for contexts that have current_state and aren't outliers
# but that make the code more complicated. Assigning an ID
# per event only causes the state_group_ids to grow as fast
# as the stream_ordering so in practise shouldn't be a problem.
context.new_state_group_id = state_group_id
chunks = [ chunks = [
events_and_contexts[x:x + 100] events_and_contexts[x:x + 100]
@ -312,9 +301,7 @@ class EventsStore(SQLBaseStore):
delete_existing=False): delete_existing=False):
try: try:
with self._stream_id_gen.get_next() as stream_ordering: with self._stream_id_gen.get_next() as stream_ordering:
with self._state_groups_id_gen.get_next() as state_group_id:
event.internal_metadata.stream_ordering = stream_ordering event.internal_metadata.stream_ordering = stream_ordering
context.new_state_group_id = state_group_id
yield self.runInteraction( yield self.runInteraction(
"persist_event", "persist_event",
self._persist_event_txn, self._persist_event_txn,
@ -393,7 +380,6 @@ class EventsStore(SQLBaseStore):
txn.call_after(self._get_current_state_for_key.invalidate_all) txn.call_after(self._get_current_state_for_key.invalidate_all)
txn.call_after(self.get_rooms_for_user.invalidate_all) txn.call_after(self.get_rooms_for_user.invalidate_all)
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
# Add an entry to the current_state_resets table to record the point # Add an entry to the current_state_resets table to record the point
# where we clobbered the current state # where we clobbered the current state
@ -529,7 +515,7 @@ class EventsStore(SQLBaseStore):
# Add an entry to the ex_outlier_stream table to replicate the # Add an entry to the ex_outlier_stream table to replicate the
# change in outlier status to our workers. # change in outlier status to our workers.
stream_order = event.internal_metadata.stream_ordering stream_order = event.internal_metadata.stream_ordering
state_group_id = context.state_group or context.new_state_group_id state_group_id = context.state_group
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
table="ex_outlier_stream", table="ex_outlier_stream",

View file

@ -16,7 +16,6 @@
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
from synapse.push.baserules import list_with_base_rules from synapse.push.baserules import list_with_base_rules
from synapse.api.constants import EventTypes, Membership
from twisted.internet import defer from twisted.internet import defer
import logging import logging
@ -124,7 +123,8 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(results) defer.returnValue(results)
def bulk_get_push_rules_for_room(self, room_id, state_group, current_state): def bulk_get_push_rules_for_room(self, event, context):
state_group = context.state_group
if not state_group: if not state_group:
# If state_group is None it means it has yet to be assigned a # If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group # state group, i.e. we need to make sure that calls with a state_group
@ -132,11 +132,13 @@ class PushRuleStore(SQLBaseStore):
# To do this we set the state_group to a new object as object() != object() # To do this we set the state_group to a new object as object() != object()
state_group = object() state_group = object()
return self._bulk_get_push_rules_for_room(room_id, state_group, current_state) return self._bulk_get_push_rules_for_room(
event.room_id, state_group, context.current_state_ids, event=event
)
@cachedInlineCallbacks(num_args=2, cache_context=True) @cachedInlineCallbacks(num_args=2, cache_context=True)
def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state, def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids,
cache_context): cache_context, event=None):
# We don't use `state_group`, its there so that we can cache based # We don't use `state_group`, its there so that we can cache based
# on it. However, its important that its never None, since two current_state's # on it. However, its important that its never None, since two current_state's
# with a state_group of None are likely to be different. # with a state_group of None are likely to be different.
@ -147,12 +149,15 @@ class PushRuleStore(SQLBaseStore):
# their unread countss are correct in the event stream, but to avoid # their unread countss are correct in the event stream, but to avoid
# generating them for bot / AS users etc, we only do so for people who've # generating them for bot / AS users etc, we only do so for people who've
# sent a read receipt into the room. # sent a read receipt into the room.
local_users_in_room = set(
e.state_key for e in current_state.values() users_in_room = yield self._get_joined_users_from_context(
if e.type == EventTypes.Member and e.membership == Membership.JOIN room_id, state_group, current_state_ids,
and self.hs.is_mine_id(e.state_key) on_invalidate=cache_context.invalidate,
event=event,
) )
local_users_in_room = set(u for u in users_in_room if self.hs.is_mine_id(u))
# users in the room who have pushers need to get push rules run because # users in the room who have pushers need to get push rules run because
# that's how their pushers work # that's how their pushers work
if_users_with_pushers = yield self.get_if_users_have_pushers( if_users_with_pushers = yield self.get_if_users_have_pushers(

View file

@ -145,7 +145,7 @@ class ReceiptsStore(SQLBaseStore):
defer.returnValue([ev for res in results.values() for ev in res]) defer.returnValue([ev for res in results.values() for ev in res])
@cachedInlineCallbacks(num_args=3, max_entries=5000, tree=True) @cachedInlineCallbacks(num_args=3, tree=True)
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
"""Get receipts for a single room for sending to clients. """Get receipts for a single room for sending to clients.

View file

@ -20,7 +20,7 @@ from collections import namedtuple
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.api.constants import Membership from synapse.api.constants import Membership, EventTypes
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
import logging import logging
@ -56,7 +56,6 @@ class RoomMemberStore(SQLBaseStore):
for event in events: for event in events:
txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,)) txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
txn.call_after( txn.call_after(
self._membership_stream_cache.entity_has_changed, self._membership_stream_cache.entity_has_changed,
@ -238,11 +237,6 @@ class RoomMemberStore(SQLBaseStore):
return results return results
@cachedInlineCallbacks(max_entries=5000)
def get_joined_hosts_for_room(self, room_id):
user_ids = yield self.get_users_in_room(room_id)
defer.returnValue(set(get_domain_from_id(uid) for uid in user_ids))
def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None): def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None):
where_clause = "c.room_id = ?" where_clause = "c.room_id = ?"
where_values = [room_id] where_values = [room_id]
@ -325,7 +319,8 @@ class RoomMemberStore(SQLBaseStore):
@cachedInlineCallbacks(num_args=3) @cachedInlineCallbacks(num_args=3)
def was_forgotten_at(self, user_id, room_id, event_id): def was_forgotten_at(self, user_id, room_id, event_id):
"""Returns whether user_id has elected to discard history for room_id at event_id. """Returns whether user_id has elected to discard history for room_id at
event_id.
event_id must be a membership event.""" event_id must be a membership event."""
def f(txn): def f(txn):
@ -358,3 +353,98 @@ class RoomMemberStore(SQLBaseStore):
}, },
desc="who_forgot" desc="who_forgot"
) )
def get_joined_users_from_context(self, event, context):
state_group = context.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
# of None don't hit previous cached calls with a None state_group.
# To do this we set the state_group to a new object as object() != object()
state_group = object()
return self._get_joined_users_from_context(
event.room_id, state_group, context.current_state_ids, event=event,
)
def get_joined_users_from_state(self, room_id, state_group, state_ids):
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
# of None don't hit previous cached calls with a None state_group.
# To do this we set the state_group to a new object as object() != object()
state_group = object()
return self._get_joined_users_from_context(
room_id, state_group, state_ids,
)
@cachedInlineCallbacks(num_args=2, cache_context=True)
def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
cache_context, event=None):
# We don't use `state_group`, its there so that we can cache based
# on it. However, its important that its never None, since two current_state's
# with a state_group of None are likely to be different.
# See bulk_get_push_rules_for_room for how we work around this.
assert state_group is not None
member_event_ids = [
e_id
for key, e_id in current_state_ids.iteritems()
if key[0] == EventTypes.Member
]
rows = yield self._simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=member_event_ids,
retcols=['user_id'],
keyvalues={
"membership": Membership.JOIN,
},
batch_size=1000,
desc="_get_joined_users_from_context",
)
users_in_room = set(row["user_id"] for row in rows)
if event is not None and event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
if event.event_id in member_event_ids:
users_in_room.add(event.state_key)
defer.returnValue(users_in_room)
def is_host_joined(self, room_id, host, state_group, state_ids):
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
# of None don't hit previous cached calls with a None state_group.
# To do this we set the state_group to a new object as object() != object()
state_group = object()
return self._is_host_joined(
room_id, host, state_group, state_ids
)
@cachedInlineCallbacks(num_args=3)
def _is_host_joined(self, room_id, host, state_group, current_state_ids):
# We don't use `state_group`, its there so that we can cache based
# on it. However, its important that its never None, since two current_state's
# with a state_group of None are likely to be different.
# See bulk_get_push_rules_for_room for how we work around this.
assert state_group is not None
for (etype, state_key), event_id in current_state_ids.items():
if etype == EventTypes.Member:
try:
if get_domain_from_id(state_key) != host:
continue
except:
logger.warn("state_key not user_id: %s", state_key)
continue
event = yield self.get_event(event_id, allow_none=True)
if event and event.content["membership"] == Membership.JOIN:
defer.returnValue(True)
defer.returnValue(False)

View file

@ -0,0 +1,24 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE device_inbox (
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
stream_id BIGINT NOT NULL,
message_json TEXT NOT NULL -- {"type":, "sender":, "content",}
);
CREATE INDEX device_inbox_user_stream_id ON device_inbox(user_id, device_id, stream_id);
CREATE INDEX device_inbox_stream_id ON device_inbox(stream_id);

View file

@ -0,0 +1,32 @@
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage.engines import PostgresEngine
import logging
logger = logging.getLogger(__name__)
def run_create(cur, database_engine, *args, **kwargs):
if isinstance(database_engine, PostgresEngine):
cur.execute("TRUNCATE sent_transactions")
else:
cur.execute("DELETE FROM sent_transactions")
cur.execute("CREATE INDEX sent_transactions_ts ON sent_transactions(ts)")
def run_upgrade(cur, database_engine, *args, **kwargs):
pass

View file

@ -44,11 +44,7 @@ class StateStore(SQLBaseStore):
""" """
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_groups(self, room_id, event_ids): def get_state_groups_ids(self, room_id, event_ids):
""" Get the state groups for the given list of event_ids
The return value is a dict mapping group names to lists of events.
"""
if not event_ids: if not event_ids:
defer.returnValue({}) defer.returnValue({})
@ -59,36 +55,64 @@ class StateStore(SQLBaseStore):
groups = set(event_to_groups.values()) groups = set(event_to_groups.values())
group_to_state = yield self._get_state_for_groups(groups) group_to_state = yield self._get_state_for_groups(groups)
defer.returnValue(group_to_state)
@defer.inlineCallbacks
def get_state_groups(self, room_id, event_ids):
""" Get the state groups for the given list of event_ids
The return value is a dict mapping group names to lists of events.
"""
if not event_ids:
defer.returnValue({})
group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
state_event_map = yield self.get_events(
[
ev_id for group_ids in group_to_ids.values()
for ev_id in group_ids.values()
],
get_prev_content=False
)
defer.returnValue({ defer.returnValue({
group: state_map.values() group: [
for group, state_map in group_to_state.items() state_event_map[v] for v in event_id_map.values() if v in state_event_map
]
for group, event_id_map in group_to_ids.items()
}) })
def _have_persisted_state_group_txn(self, txn, state_group):
txn.execute(
"SELECT count(*) FROM state_groups WHERE id = ?",
(state_group,)
)
row = txn.fetchone()
return row and row[0]
def _store_mult_state_groups_txn(self, txn, events_and_contexts): def _store_mult_state_groups_txn(self, txn, events_and_contexts):
state_groups = {} state_groups = {}
for event, context in events_and_contexts: for event, context in events_and_contexts:
if event.internal_metadata.is_outlier(): if event.internal_metadata.is_outlier():
continue continue
if context.current_state is None: if context.current_state_ids is None:
continue continue
if context.state_group is not None:
state_groups[event.event_id] = context.state_group state_groups[event.event_id] = context.state_group
if self._have_persisted_state_group_txn(txn, context.state_group):
logger.info("Already persisted state_group: %r", context.state_group)
continue continue
state_events = dict(context.current_state) state_event_ids = dict(context.current_state_ids)
if event.is_state():
state_events[(event.type, event.state_key)] = event
state_group = context.new_state_group_id
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
table="state_groups", table="state_groups",
values={ values={
"id": state_group, "id": context.state_group,
"room_id": event.room_id, "room_id": event.room_id,
"event_id": event.event_id, "event_id": event.event_id,
}, },
@ -99,16 +123,15 @@ class StateStore(SQLBaseStore):
table="state_groups_state", table="state_groups_state",
values=[ values=[
{ {
"state_group": state_group, "state_group": context.state_group,
"room_id": state.room_id, "room_id": event.room_id,
"type": state.type, "type": key[0],
"state_key": state.state_key, "state_key": key[1],
"event_id": state.event_id, "event_id": state_id,
} }
for state in state_events.values() for key, state_id in state_event_ids.items()
], ],
) )
state_groups[event.event_id] = state_group
self._simple_insert_many_txn( self._simple_insert_many_txn(
txn, txn,
@ -248,6 +271,31 @@ class StateStore(SQLBaseStore):
groups = set(event_to_groups.values()) groups = set(event_to_groups.values())
group_to_state = yield self._get_state_for_groups(groups, types) group_to_state = yield self._get_state_for_groups(groups, types)
state_event_map = yield self.get_events(
[ev_id for sd in group_to_state.values() for ev_id in sd.values()],
get_prev_content=False
)
event_to_state = {
event_id: {
k: state_event_map[v]
for k, v in group_to_state[group].items()
if v in state_event_map
}
for event_id, group in event_to_groups.items()
}
defer.returnValue({event: event_to_state[event] for event in event_ids})
@defer.inlineCallbacks
def get_state_ids_for_events(self, event_ids, types):
event_to_groups = yield self._get_state_group_for_events(
event_ids,
)
groups = set(event_to_groups.values())
group_to_state = yield self._get_state_for_groups(groups, types)
event_to_state = { event_to_state = {
event_id: group_to_state[group] event_id: group_to_state[group]
for event_id, group in event_to_groups.items() for event_id, group in event_to_groups.items()
@ -272,6 +320,23 @@ class StateStore(SQLBaseStore):
state_map = yield self.get_state_for_events([event_id], types) state_map = yield self.get_state_for_events([event_id], types)
defer.returnValue(state_map[event_id]) defer.returnValue(state_map[event_id])
@defer.inlineCallbacks
def get_state_ids_for_event(self, event_id, types=None):
"""
Get the state dict corresponding to a particular event
Args:
event_id(str): event whose state should be returned
types(list[(str, str)]|None): List of (type, state_key) tuples
which are used to filter the state fetched. May be None, which
matches any key
Returns:
A deferred dict from (type, state_key) -> state_event
"""
state_map = yield self.get_state_ids_for_events([event_id], types)
defer.returnValue(state_map[event_id])
@cached(num_args=2, max_entries=10000) @cached(num_args=2, max_entries=10000)
def _get_state_group_for_event(self, room_id, event_id): def _get_state_group_for_event(self, room_id, event_id):
return self._simple_select_one_onecol( return self._simple_select_one_onecol(
@ -428,20 +493,13 @@ class StateStore(SQLBaseStore):
full=(types is None), full=(types is None),
) )
state_events = yield self._get_events(
[ev_id for sd in results.values() for ev_id in sd.values()],
get_prev_content=False
)
state_events = {e.event_id: e for e in state_events}
# Remove all the entries with None values. The None values were just # Remove all the entries with None values. The None values were just
# used for bookkeeping in the cache. # used for bookkeeping in the cache.
for group, state_dict in results.items(): for group, state_dict in results.items():
results[group] = { results[group] = {
key: state_events[event_id] key: event_id
for key, event_id in state_dict.items() for key, event_id in state_dict.items()
if event_id and event_id in state_events if event_id
} }
defer.returnValue(results) defer.returnValue(results)
@ -473,5 +531,5 @@ class StateStore(SQLBaseStore):
"get_all_new_state_groups", get_all_new_state_groups_txn "get_all_new_state_groups", get_all_new_state_groups_txn
) )
def get_state_stream_token(self): def get_next_state_group(self):
return self._state_groups_id_gen.get_current_token() return self._state_groups_id_gen.get_next()

View file

@ -245,7 +245,7 @@ class TransactionStore(SQLBaseStore):
return self.cursor_to_dict(txn) return self.cursor_to_dict(txn)
@cached() @cached(max_entries=10000)
def get_destination_retry_timings(self, destination): def get_destination_retry_timings(self, destination):
"""Gets the current retry timings (if any) for a given destination. """Gets the current retry timings (if any) for a given destination.
@ -387,8 +387,10 @@ class TransactionStore(SQLBaseStore):
def _cleanup_transactions(self): def _cleanup_transactions(self):
now = self._clock.time_msec() now = self._clock.time_msec()
month_ago = now - 30 * 24 * 60 * 60 * 1000 month_ago = now - 30 * 24 * 60 * 60 * 1000
six_hours_ago = now - 6 * 60 * 60 * 1000
def _cleanup_transactions_txn(txn): def _cleanup_transactions_txn(txn):
txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
txn.execute("DELETE FROM sent_transactions WHERE ts < ?", (six_hours_ago,))
return self.runInteraction("_persist_in_mem_txns", _cleanup_transactions_txn) return self.runInteraction("_persist_in_mem_txns", _cleanup_transactions_txn)

View file

@ -43,6 +43,7 @@ class EventSources(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_current_token(self, direction='f'): def get_current_token(self, direction='f'):
push_rules_key, _ = self.store.get_push_rules_stream_token() push_rules_key, _ = self.store.get_push_rules_stream_token()
to_device_key = self.store.get_to_device_stream_token()
token = StreamToken( token = StreamToken(
room_key=( room_key=(
@ -61,5 +62,6 @@ class EventSources(object):
yield self.sources["account_data"].get_current_key() yield self.sources["account_data"].get_current_key()
), ),
push_rules_key=push_rules_key, push_rules_key=push_rules_key,
to_device_key=to_device_key,
) )
defer.returnValue(token) defer.returnValue(token)

View file

@ -154,6 +154,7 @@ class StreamToken(
"receipt_key", "receipt_key",
"account_data_key", "account_data_key",
"push_rules_key", "push_rules_key",
"to_device_key",
)) ))
): ):
_SEPARATOR = "_" _SEPARATOR = "_"
@ -190,6 +191,7 @@ class StreamToken(
or (int(other.receipt_key) < int(self.receipt_key)) or (int(other.receipt_key) < int(self.receipt_key))
or (int(other.account_data_key) < int(self.account_data_key)) or (int(other.account_data_key) < int(self.account_data_key))
or (int(other.push_rules_key) < int(self.push_rules_key)) or (int(other.push_rules_key) < int(self.push_rules_key))
or (int(other.to_device_key) < int(self.to_device_key))
) )
def copy_and_advance(self, key, new_value): def copy_and_advance(self, key, new_value):
@ -269,10 +271,3 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
return "t%d-%d" % (self.topological, self.stream) return "t%d-%d" % (self.topological, self.stream)
else: else:
return "s%d" % (self.stream,) return "s%d" % (self.stream,)
# Some arbitrary constants used for internal API enumerations. Don't rely on
# exact values; always pass or compare symbolically
class ThirdPartyEntityKind(object):
USER = 'user'
LOCATION = 'location'

View file

@ -180,6 +180,25 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
}) })
@defer.inlineCallbacks
def filter_events_for_clients_context(store, user_tuples, events, event_id_to_context):
user_ids = set(u[0] for u in user_tuples)
event_id_to_state = {}
for event_id, context in event_id_to_context.items():
state = yield store.get_events([
e_id
for key, e_id in context.current_state_ids.iteritems()
if key == (EventTypes.RoomHistoryVisibility, "")
or (key[0] == EventTypes.Member and key[1] in user_ids)
])
event_id_to_state[event_id] = state
res = yield filter_events_for_clients(
store, user_tuples, events, event_id_to_state
)
defer.returnValue(res)
@defer.inlineCallbacks @defer.inlineCallbacks
def filter_events_for_client(store, user_id, events, is_peeking=False): def filter_events_for_client(store, user_id, events, is_peeking=False):
""" """

View file

@ -115,6 +115,53 @@ class PresenceUpdateTestCase(unittest.TestCase):
), ),
], any_order=True) ], any_order=True)
def test_online_to_online_last_active_noop(self):
wheel_timer = Mock()
user_id = "@foo:bar"
now = 5000000
prev_state = UserPresenceState.default(user_id)
prev_state = prev_state.copy_and_replace(
state=PresenceState.ONLINE,
last_active_ts=now - LAST_ACTIVE_GRANULARITY - 10,
currently_active=True,
)
new_state = prev_state.copy_and_replace(
state=PresenceState.ONLINE,
last_active_ts=now,
)
state, persist_and_notify, federation_ping = handle_update(
prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now
)
self.assertFalse(persist_and_notify)
self.assertTrue(federation_ping)
self.assertTrue(state.currently_active)
self.assertEquals(new_state.state, state.state)
self.assertEquals(new_state.status_msg, state.status_msg)
self.assertEquals(state.last_federation_update_ts, now)
self.assertEquals(wheel_timer.insert.call_count, 3)
wheel_timer.insert.assert_has_calls([
call(
now=now,
obj=user_id,
then=new_state.last_active_ts + IDLE_TIMER
),
call(
now=now,
obj=user_id,
then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT
),
call(
now=now,
obj=user_id,
then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY
),
], any_order=True)
def test_online_to_online_last_active(self): def test_online_to_online_last_active(self):
wheel_timer = Mock() wheel_timer = Mock()
user_id = "@foo:bar" user_id = "@foo:bar"

View file

@ -62,6 +62,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.on_new_event = mock_notifier.on_new_event self.on_new_event = mock_notifier.on_new_event
self.auth = Mock(spec=[]) self.auth = Mock(spec=[])
self.state_handler = Mock()
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
"test", "test",
@ -75,6 +76,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
"set_received_txn_response", "set_received_txn_response",
"get_destination_retry_timings", "get_destination_retry_timings",
]), ]),
state_handler=self.state_handler,
handlers=None, handlers=None,
notifier=mock_notifier, notifier=mock_notifier,
resource_for_client=Mock(), resource_for_client=Mock(),
@ -113,6 +115,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
return set(member.domain for member in self.room_members) return set(member.domain for member in self.room_members)
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
def get_current_user_in_room(room_id):
return set(str(u) for u in self.room_members)
self.state_handler.get_current_user_in_room = get_current_user_in_room
self.auth.check_joined_room = check_joined_room self.auth.check_joined_room = check_joined_room
# Some local users to test with # Some local users to test with

View file

@ -305,7 +305,16 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.event_id += 1 self.event_id += 1
context = EventContext(current_state=state) if state is not None:
state_ids = {
key: e.event_id for key, e in state.items()
}
else:
state_ids = None
context = EventContext()
context.current_state_ids = state_ids
context.prev_state_ids = state_ids
context.push_actions = push_actions context.push_actions = push_actions
ordering = None ordering = None

View file

@ -60,8 +60,8 @@ class ReplicationResourceCase(unittest.TestCase):
self.assertEquals(body, {}) self.assertEquals(body, {})
@defer.inlineCallbacks @defer.inlineCallbacks
def test_events_and_state(self): def test_events(self):
get = self.get(events="-1", state="-1", timeout="0") get = self.get(events="-1", timeout="0")
yield self.hs.get_handlers().room_creation_handler.create_room( yield self.hs.get_handlers().room_creation_handler.create_room(
synapse.types.create_requester(self.user), {} synapse.types.create_requester(self.user), {}
) )
@ -70,12 +70,6 @@ class ReplicationResourceCase(unittest.TestCase):
self.assertEquals(body["events"]["field_names"], [ self.assertEquals(body["events"]["field_names"], [
"position", "internal", "json", "state_group" "position", "internal", "json", "state_group"
]) ])
self.assertEquals(body["state_groups"]["field_names"], [
"position", "room_id", "event_id"
])
self.assertEquals(body["state_group_state"]["field_names"], [
"position", "type", "state_key", "event_id"
])
@defer.inlineCallbacks @defer.inlineCallbacks
def test_presence(self): def test_presence(self):

View file

@ -1032,7 +1032,7 @@ class RoomMessageListTestCase(RestTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_topo_token_is_accepted(self): def test_topo_token_is_accepted(self):
token = "t1-0_0_0_0_0_0" token = "t1-0_0_0_0_0_0_0"
(code, response) = yield self.mock_resource.trigger_get( (code, response) = yield self.mock_resource.trigger_get(
"/rooms/%s/messages?access_token=x&from=%s" % "/rooms/%s/messages?access_token=x&from=%s" %
(self.room_id, token)) (self.room_id, token))
@ -1044,7 +1044,7 @@ class RoomMessageListTestCase(RestTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_stream_token_is_accepted_for_fwd_pagianation(self): def test_stream_token_is_accepted_for_fwd_pagianation(self):
token = "s0_0_0_0_0_0" token = "s0_0_0_0_0_0_0"
(code, response) = yield self.mock_resource.trigger_get( (code, response) = yield self.mock_resource.trigger_get(
"/rooms/%s/messages?access_token=x&from=%s" % "/rooms/%s/messages?access_token=x&from=%s" %
(self.room_id, token)) (self.room_id, token))

View file

@ -78,44 +78,3 @@ class RoomMemberStoreTestCase(unittest.TestCase):
) )
)] )]
) )
@defer.inlineCallbacks
def test_room_hosts(self):
yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN)
self.assertEquals(
{"test"},
(yield self.store.get_joined_hosts_for_room(self.room.to_string()))
)
# Should still have just one host after second join from it
yield self.inject_room_member(self.room, self.u_bob, Membership.JOIN)
self.assertEquals(
{"test"},
(yield self.store.get_joined_hosts_for_room(self.room.to_string()))
)
# Should now have two hosts after join from other host
yield self.inject_room_member(self.room, self.u_charlie, Membership.JOIN)
self.assertEquals(
{"test", "elsewhere"},
(yield self.store.get_joined_hosts_for_room(self.room.to_string()))
)
# Should still have both hosts
yield self.inject_room_member(self.room, self.u_alice, Membership.LEAVE)
self.assertEquals(
{"test", "elsewhere"},
(yield self.store.get_joined_hosts_for_room(self.room.to_string()))
)
# Should have only one host after other leaves
yield self.inject_room_member(self.room, self.u_charlie, Membership.LEAVE)
self.assertEquals(
{"test"},
(yield self.store.get_joined_hosts_for_room(self.room.to_string()))
)

View file

@ -67,9 +67,11 @@ class StateGroupStore(object):
self._event_to_state_group = {} self._event_to_state_group = {}
self._group_to_state = {} self._group_to_state = {}
self._event_id_to_event = {}
self._next_group = 1 self._next_group = 1
def get_state_groups(self, room_id, event_ids): def get_state_groups_ids(self, room_id, event_ids):
groups = {} groups = {}
for event_id in event_ids: for event_id in event_ids:
group = self._event_to_state_group.get(event_id) group = self._event_to_state_group.get(event_id)
@ -79,22 +81,23 @@ class StateGroupStore(object):
return defer.succeed(groups) return defer.succeed(groups)
def store_state_groups(self, event, context): def store_state_groups(self, event, context):
if context.current_state is None: if context.current_state_ids is None:
return return
state_events = context.current_state state_events = dict(context.current_state_ids)
if event.is_state(): self._group_to_state[context.state_group] = state_events
state_events[(event.type, event.state_key)] = event self._event_to_state_group[event.event_id] = context.state_group
state_group = context.state_group def get_events(self, event_ids, **kwargs):
if not state_group: return {
state_group = self._next_group e_id: self._event_id_to_event[e_id] for e_id in event_ids
self._next_group += 1 if e_id in self._event_id_to_event
}
self._group_to_state[state_group] = state_events.values() def register_events(self, events):
for e in events:
self._event_to_state_group[event.event_id] = state_group self._event_id_to_event[e.event_id] = e
class DictObj(dict): class DictObj(dict):
@ -136,8 +139,10 @@ class StateTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.store = Mock( self.store = Mock(
spec_set=[ spec_set=[
"get_state_groups", "get_state_groups_ids",
"add_event_hashes", "add_event_hashes",
"get_events",
"get_next_state_group",
] ]
) )
hs = Mock(spec_set=[ hs = Mock(spec_set=[
@ -148,6 +153,8 @@ class StateTestCase(unittest.TestCase):
hs.get_clock.return_value = MockClock() hs.get_clock.return_value = MockClock()
hs.get_auth.return_value = Auth(hs) hs.get_auth.return_value = Auth(hs)
self.store.get_next_state_group.side_effect = Mock
self.state = StateHandler(hs) self.state = StateHandler(hs)
self.event_id = 0 self.event_id = 0
@ -187,7 +194,7 @@ class StateTestCase(unittest.TestCase):
) )
store = StateGroupStore() store = StateGroupStore()
self.store.get_state_groups.side_effect = store.get_state_groups self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
context_store = {} context_store = {}
@ -196,7 +203,7 @@ class StateTestCase(unittest.TestCase):
store.store_state_groups(event, context) store.store_state_groups(event, context)
context_store[event.event_id] = context context_store[event.event_id] = context
self.assertEqual(2, len(context_store["D"].current_state)) self.assertEqual(2, len(context_store["D"].prev_state_ids))
@defer.inlineCallbacks @defer.inlineCallbacks
def test_branch_basic_conflict(self): def test_branch_basic_conflict(self):
@ -239,7 +246,9 @@ class StateTestCase(unittest.TestCase):
) )
store = StateGroupStore() store = StateGroupStore()
self.store.get_state_groups.side_effect = store.get_state_groups self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
self.store.get_events = store.get_events
store.register_events(graph.walk())
context_store = {} context_store = {}
@ -250,7 +259,7 @@ class StateTestCase(unittest.TestCase):
self.assertSetEqual( self.assertSetEqual(
{"START", "A", "C"}, {"START", "A", "C"},
{e.event_id for e in context_store["D"].current_state.values()} {e_id for e_id in context_store["D"].prev_state_ids.values()}
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -303,7 +312,9 @@ class StateTestCase(unittest.TestCase):
) )
store = StateGroupStore() store = StateGroupStore()
self.store.get_state_groups.side_effect = store.get_state_groups self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
self.store.get_events = store.get_events
store.register_events(graph.walk())
context_store = {} context_store = {}
@ -314,7 +325,7 @@ class StateTestCase(unittest.TestCase):
self.assertSetEqual( self.assertSetEqual(
{"START", "A", "B", "C"}, {"START", "A", "B", "C"},
{e.event_id for e in context_store["E"].current_state.values()} {e for e in context_store["E"].prev_state_ids.values()}
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -384,7 +395,9 @@ class StateTestCase(unittest.TestCase):
graph = Graph(nodes, edges) graph = Graph(nodes, edges)
store = StateGroupStore() store = StateGroupStore()
self.store.get_state_groups.side_effect = store.get_state_groups self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
self.store.get_events = store.get_events
store.register_events(graph.walk())
context_store = {} context_store = {}
@ -395,7 +408,7 @@ class StateTestCase(unittest.TestCase):
self.assertSetEqual( self.assertSetEqual(
{"A1", "A2", "A3", "A5", "B"}, {"A1", "A2", "A3", "A5", "B"},
{e.event_id for e in context_store["D"].current_state.values()} {e for e in context_store["D"].prev_state_ids.values()}
) )
def _add_depths(self, nodes, edges): def _add_depths(self, nodes, edges):
@ -424,16 +437,11 @@ class StateTestCase(unittest.TestCase):
event, old_state=old_state event, old_state=old_state
) )
for k, v in context.current_state.items():
type, state_key = k
self.assertEqual(type, v.type)
self.assertEqual(state_key, v.state_key)
self.assertEqual( self.assertEqual(
set(old_state), set(context.current_state.values()) set(e.event_id for e in old_state), set(context.current_state_ids.values())
) )
self.assertIsNone(context.state_group) self.assertIsNotNone(context.state_group)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_annotate_with_old_state(self): def test_annotate_with_old_state(self):
@ -449,18 +457,10 @@ class StateTestCase(unittest.TestCase):
event, old_state=old_state event, old_state=old_state
) )
for k, v in context.current_state.items():
type, state_key = k
self.assertEqual(type, v.type)
self.assertEqual(state_key, v.state_key)
self.assertEqual( self.assertEqual(
set(old_state), set(e.event_id for e in old_state), set(context.prev_state_ids.values())
set(context.current_state.values())
) )
self.assertIsNone(context.state_group)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_trivial_annotate_message(self): def test_trivial_annotate_message(self):
event = create_event(type="test_message", name="event") event = create_event(type="test_message", name="event")
@ -473,20 +473,15 @@ class StateTestCase(unittest.TestCase):
group_name = "group_name_1" group_name = "group_name_1"
self.store.get_state_groups.return_value = { self.store.get_state_groups_ids.return_value = {
group_name: old_state, group_name: {(e.type, e.state_key): e.event_id for e in old_state},
} }
context = yield self.state.compute_event_context(event) context = yield self.state.compute_event_context(event)
for k, v in context.current_state.items():
type, state_key = k
self.assertEqual(type, v.type)
self.assertEqual(state_key, v.state_key)
self.assertEqual( self.assertEqual(
set([e.event_id for e in old_state]), set([e.event_id for e in old_state]),
set([e.event_id for e in context.current_state.values()]) set(context.current_state_ids.values())
) )
self.assertEqual(group_name, context.state_group) self.assertEqual(group_name, context.state_group)
@ -503,23 +498,18 @@ class StateTestCase(unittest.TestCase):
group_name = "group_name_1" group_name = "group_name_1"
self.store.get_state_groups.return_value = { self.store.get_state_groups_ids.return_value = {
group_name: old_state, group_name: {(e.type, e.state_key): e.event_id for e in old_state},
} }
context = yield self.state.compute_event_context(event) context = yield self.state.compute_event_context(event)
for k, v in context.current_state.items():
type, state_key = k
self.assertEqual(type, v.type)
self.assertEqual(state_key, v.state_key)
self.assertEqual( self.assertEqual(
set([e.event_id for e in old_state]), set([e.event_id for e in old_state]),
set([e.event_id for e in context.current_state.values()]) set(context.prev_state_ids.values())
) )
self.assertIsNone(context.state_group) self.assertIsNotNone(context.state_group)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_resolve_message_conflict(self): def test_resolve_message_conflict(self):
@ -543,11 +533,16 @@ class StateTestCase(unittest.TestCase):
create_event(type="test4", state_key=""), create_event(type="test4", state_key=""),
] ]
store = StateGroupStore()
store.register_events(old_state_1)
store.register_events(old_state_2)
self.store.get_events = store.get_events
context = yield self._get_context(event, old_state_1, old_state_2) context = yield self._get_context(event, old_state_1, old_state_2)
self.assertEqual(len(context.current_state), 6) self.assertEqual(len(context.current_state_ids), 6)
self.assertIsNone(context.state_group) self.assertIsNotNone(context.state_group)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_resolve_state_conflict(self): def test_resolve_state_conflict(self):
@ -571,11 +566,16 @@ class StateTestCase(unittest.TestCase):
create_event(type="test4", state_key=""), create_event(type="test4", state_key=""),
] ]
store = StateGroupStore()
store.register_events(old_state_1)
store.register_events(old_state_2)
self.store.get_events = store.get_events
context = yield self._get_context(event, old_state_1, old_state_2) context = yield self._get_context(event, old_state_1, old_state_2)
self.assertEqual(len(context.current_state), 6) self.assertEqual(len(context.current_state_ids), 6)
self.assertIsNone(context.state_group) self.assertIsNotNone(context.state_group)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_standard_depth_conflict(self): def test_standard_depth_conflict(self):
@ -606,9 +606,16 @@ class StateTestCase(unittest.TestCase):
create_event(type="test1", state_key="1", depth=2), create_event(type="test1", state_key="1", depth=2),
] ]
store = StateGroupStore()
store.register_events(old_state_1)
store.register_events(old_state_2)
self.store.get_events = store.get_events
context = yield self._get_context(event, old_state_1, old_state_2) context = yield self._get_context(event, old_state_1, old_state_2)
self.assertEqual(old_state_2[2], context.current_state[("test1", "1")]) self.assertEqual(
old_state_2[2].event_id, context.current_state_ids[("test1", "1")]
)
# Reverse the depth to make sure we are actually using the depths # Reverse the depth to make sure we are actually using the depths
# during state resolution. # during state resolution.
@ -625,17 +632,22 @@ class StateTestCase(unittest.TestCase):
create_event(type="test1", state_key="1", depth=1), create_event(type="test1", state_key="1", depth=1),
] ]
store.register_events(old_state_1)
store.register_events(old_state_2)
context = yield self._get_context(event, old_state_1, old_state_2) context = yield self._get_context(event, old_state_1, old_state_2)
self.assertEqual(old_state_1[2], context.current_state[("test1", "1")]) self.assertEqual(
old_state_1[2].event_id, context.current_state_ids[("test1", "1")]
)
def _get_context(self, event, old_state_1, old_state_2): def _get_context(self, event, old_state_1, old_state_2):
group_name_1 = "group_name_1" group_name_1 = "group_name_1"
group_name_2 = "group_name_2" group_name_2 = "group_name_2"
self.store.get_state_groups.return_value = { self.store.get_state_groups_ids.return_value = {
group_name_1: old_state_1, group_name_1: {(e.type, e.state_key): e.event_id for e in old_state_1},
group_name_2: old_state_2, group_name_2: {(e.type, e.state_key): e.event_id for e in old_state_2},
} }
return self.state.compute_event_context(event) return self.state.compute_event_context(event)