forked from MirrorHub/synapse
Merge branch 'release-v0.17.2' of github.com:matrix-org/synapse
This commit is contained in:
commit
5834c6178c
60 changed files with 1842 additions and 679 deletions
37
CHANGES.rst
37
CHANGES.rst
|
@ -1,3 +1,40 @@
|
||||||
|
Changes in synapse v0.17.2 (2016-09-08)
|
||||||
|
=======================================
|
||||||
|
|
||||||
|
This release contains security bug fixes. Please upgrade.
|
||||||
|
|
||||||
|
|
||||||
|
No changes since v0.17.2
|
||||||
|
|
||||||
|
|
||||||
|
Changes in synapse v0.17.2-rc1 (2016-09-05)
|
||||||
|
===========================================
|
||||||
|
|
||||||
|
Features:
|
||||||
|
|
||||||
|
* Start adding store-and-forward direct-to-device messaging (PR #1046, #1050,
|
||||||
|
#1062, #1066)
|
||||||
|
|
||||||
|
|
||||||
|
Changes:
|
||||||
|
|
||||||
|
* Avoid pulling the full state of a room out so often (PR #1047, #1049, #1063,
|
||||||
|
#1068)
|
||||||
|
* Don't notify for online to online presence transitions. (PR #1054)
|
||||||
|
* Occasionally persist unpersisted presence updates (PR #1055)
|
||||||
|
* Allow application services to have an optional 'url' (PR #1056)
|
||||||
|
* Clean up old sent transactions from DB (PR #1059)
|
||||||
|
|
||||||
|
|
||||||
|
Bug fixes:
|
||||||
|
|
||||||
|
* Fix None check in backfill (PR #1043)
|
||||||
|
* Fix membership changes to be idempotent (PR #1067)
|
||||||
|
* Fix bug in get_pdu where it would sometimes return events with incorrect
|
||||||
|
signature
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Changes in synapse v0.17.1 (2016-08-24)
|
Changes in synapse v0.17.1 (2016-08-24)
|
||||||
=======================================
|
=======================================
|
||||||
|
|
||||||
|
|
24
README.rst
24
README.rst
|
@ -134,6 +134,12 @@ Installing prerequisites on Raspbian::
|
||||||
sudo pip install --upgrade ndg-httpsclient
|
sudo pip install --upgrade ndg-httpsclient
|
||||||
sudo pip install --upgrade virtualenv
|
sudo pip install --upgrade virtualenv
|
||||||
|
|
||||||
|
Installing prerequisites on openSUSE::
|
||||||
|
|
||||||
|
sudo zypper in -t pattern devel_basis
|
||||||
|
sudo zypper in python-pip python-setuptools sqlite3 python-virtualenv \
|
||||||
|
python-devel libffi-devel libopenssl-devel libjpeg62-devel
|
||||||
|
|
||||||
To install the synapse homeserver run::
|
To install the synapse homeserver run::
|
||||||
|
|
||||||
virtualenv -p python2.7 ~/.synapse
|
virtualenv -p python2.7 ~/.synapse
|
||||||
|
@ -199,6 +205,21 @@ run (e.g. ``~/.synapse``), and::
|
||||||
source ./bin/activate
|
source ./bin/activate
|
||||||
synctl start
|
synctl start
|
||||||
|
|
||||||
|
Security Note
|
||||||
|
=============
|
||||||
|
|
||||||
|
Matrix serves raw user generated data in some APIs - specifically the content
|
||||||
|
repository endpoints: http://matrix.org/docs/spec/client_server/r0.2.0.html#get-matrix-media-r0-download-servername-mediaid
|
||||||
|
Whilst we have tried to mitigate against possible XSS attacks (e.g.
|
||||||
|
https://github.com/matrix-org/synapse/pull/1021) we recommend running
|
||||||
|
matrix homeservers on a dedicated domain name, to limit any malicious user generated
|
||||||
|
content served to web browsers a matrix API from being able to attack webapps hosted
|
||||||
|
on the same domain. This is particularly true of sharing a matrix webclient and
|
||||||
|
server on the same domain.
|
||||||
|
|
||||||
|
See https://github.com/vector-im/vector-web/issues/1977 and
|
||||||
|
https://developer.github.com/changes/2014-04-25-user-content-security for more details.
|
||||||
|
|
||||||
Using PostgreSQL
|
Using PostgreSQL
|
||||||
================
|
================
|
||||||
|
|
||||||
|
@ -215,9 +236,6 @@ The advantages of Postgres include:
|
||||||
pointing at the same DB master, as well as enabling DB replication in
|
pointing at the same DB master, as well as enabling DB replication in
|
||||||
synapse itself.
|
synapse itself.
|
||||||
|
|
||||||
The only disadvantage is that the code is relatively new as of April 2015 and
|
|
||||||
may have a few regressions relative to SQLite.
|
|
||||||
|
|
||||||
For information on how to install and use PostgreSQL, please see
|
For information on how to install and use PostgreSQL, please see
|
||||||
`docs/postgres.rst <docs/postgres.rst>`_.
|
`docs/postgres.rst <docs/postgres.rst>`_.
|
||||||
|
|
||||||
|
|
|
@ -16,4 +16,4 @@
|
||||||
""" This is a reference implementation of a Matrix home server.
|
""" This is a reference implementation of a Matrix home server.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "0.17.1"
|
__version__ = "0.17.2"
|
||||||
|
|
|
@ -52,7 +52,7 @@ class Auth(object):
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
|
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
|
||||||
# Docs for these currently lives at
|
# Docs for these currently lives at
|
||||||
# https://github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst
|
# github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst
|
||||||
# In addition, we have type == delete_pusher which grants access only to
|
# In addition, we have type == delete_pusher which grants access only to
|
||||||
# delete pushers.
|
# delete pushers.
|
||||||
self._KNOWN_CAVEAT_PREFIXES = set([
|
self._KNOWN_CAVEAT_PREFIXES = set([
|
||||||
|
@ -63,6 +63,17 @@ class Auth(object):
|
||||||
"user_id = ",
|
"user_id = ",
|
||||||
])
|
])
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def check_from_context(self, event, context, do_sig_check=True):
|
||||||
|
auth_events_ids = yield self.compute_auth_events(
|
||||||
|
event, context.prev_state_ids, for_verification=True,
|
||||||
|
)
|
||||||
|
auth_events = yield self.store.get_events(auth_events_ids)
|
||||||
|
auth_events = {
|
||||||
|
(e.type, e.state_key): e for e in auth_events.values()
|
||||||
|
}
|
||||||
|
self.check(event, auth_events=auth_events, do_sig_check=False)
|
||||||
|
|
||||||
def check(self, event, auth_events, do_sig_check=True):
|
def check(self, event, auth_events, do_sig_check=True):
|
||||||
""" Checks if this event is correctly authed.
|
""" Checks if this event is correctly authed.
|
||||||
|
|
||||||
|
@ -267,21 +278,17 @@ class Auth(object):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_host_in_room(self, room_id, host):
|
def check_host_in_room(self, room_id, host):
|
||||||
curr_state = yield self.state.get_current_state(room_id)
|
with Measure(self.clock, "check_host_in_room"):
|
||||||
|
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||||
|
|
||||||
for event in curr_state.values():
|
entry = yield self.state.resolve_state_groups(
|
||||||
if event.type == EventTypes.Member:
|
room_id, latest_event_ids
|
||||||
try:
|
)
|
||||||
if get_domain_from_id(event.state_key) != host:
|
|
||||||
continue
|
|
||||||
except:
|
|
||||||
logger.warn("state_key not user_id: %s", event.state_key)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if event.content["membership"] == Membership.JOIN:
|
ret = yield self.store.is_host_joined(
|
||||||
defer.returnValue(True)
|
room_id, host, entry.state_group, entry.state
|
||||||
|
)
|
||||||
defer.returnValue(False)
|
defer.returnValue(ret)
|
||||||
|
|
||||||
def check_event_sender_in_room(self, event, auth_events):
|
def check_event_sender_in_room(self, event, auth_events):
|
||||||
key = (EventTypes.Member, event.user_id, )
|
key = (EventTypes.Member, event.user_id, )
|
||||||
|
@ -847,7 +854,7 @@ class Auth(object):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def add_auth_events(self, builder, context):
|
def add_auth_events(self, builder, context):
|
||||||
auth_ids = self.compute_auth_events(builder, context.current_state)
|
auth_ids = yield self.compute_auth_events(builder, context.prev_state_ids)
|
||||||
|
|
||||||
auth_events_entries = yield self.store.add_event_hashes(
|
auth_events_entries = yield self.store.add_event_hashes(
|
||||||
auth_ids
|
auth_ids
|
||||||
|
@ -855,30 +862,32 @@ class Auth(object):
|
||||||
|
|
||||||
builder.auth_events = auth_events_entries
|
builder.auth_events = auth_events_entries
|
||||||
|
|
||||||
def compute_auth_events(self, event, current_state):
|
@defer.inlineCallbacks
|
||||||
|
def compute_auth_events(self, event, current_state_ids, for_verification=False):
|
||||||
if event.type == EventTypes.Create:
|
if event.type == EventTypes.Create:
|
||||||
return []
|
defer.returnValue([])
|
||||||
|
|
||||||
auth_ids = []
|
auth_ids = []
|
||||||
|
|
||||||
key = (EventTypes.PowerLevels, "", )
|
key = (EventTypes.PowerLevels, "", )
|
||||||
power_level_event = current_state.get(key)
|
power_level_event_id = current_state_ids.get(key)
|
||||||
|
|
||||||
if power_level_event:
|
if power_level_event_id:
|
||||||
auth_ids.append(power_level_event.event_id)
|
auth_ids.append(power_level_event_id)
|
||||||
|
|
||||||
key = (EventTypes.JoinRules, "", )
|
key = (EventTypes.JoinRules, "", )
|
||||||
join_rule_event = current_state.get(key)
|
join_rule_event_id = current_state_ids.get(key)
|
||||||
|
|
||||||
key = (EventTypes.Member, event.user_id, )
|
key = (EventTypes.Member, event.user_id, )
|
||||||
member_event = current_state.get(key)
|
member_event_id = current_state_ids.get(key)
|
||||||
|
|
||||||
key = (EventTypes.Create, "", )
|
key = (EventTypes.Create, "", )
|
||||||
create_event = current_state.get(key)
|
create_event_id = current_state_ids.get(key)
|
||||||
if create_event:
|
if create_event_id:
|
||||||
auth_ids.append(create_event.event_id)
|
auth_ids.append(create_event_id)
|
||||||
|
|
||||||
if join_rule_event:
|
if join_rule_event_id:
|
||||||
|
join_rule_event = yield self.store.get_event(join_rule_event_id)
|
||||||
join_rule = join_rule_event.content.get("join_rule")
|
join_rule = join_rule_event.content.get("join_rule")
|
||||||
is_public = join_rule == JoinRules.PUBLIC if join_rule else False
|
is_public = join_rule == JoinRules.PUBLIC if join_rule else False
|
||||||
else:
|
else:
|
||||||
|
@ -887,15 +896,21 @@ class Auth(object):
|
||||||
if event.type == EventTypes.Member:
|
if event.type == EventTypes.Member:
|
||||||
e_type = event.content["membership"]
|
e_type = event.content["membership"]
|
||||||
if e_type in [Membership.JOIN, Membership.INVITE]:
|
if e_type in [Membership.JOIN, Membership.INVITE]:
|
||||||
if join_rule_event:
|
if join_rule_event_id:
|
||||||
auth_ids.append(join_rule_event.event_id)
|
auth_ids.append(join_rule_event_id)
|
||||||
|
|
||||||
if e_type == Membership.JOIN:
|
if e_type == Membership.JOIN:
|
||||||
if member_event and not is_public:
|
if member_event_id and not is_public:
|
||||||
auth_ids.append(member_event.event_id)
|
auth_ids.append(member_event_id)
|
||||||
else:
|
else:
|
||||||
if member_event:
|
if member_event_id:
|
||||||
auth_ids.append(member_event.event_id)
|
auth_ids.append(member_event_id)
|
||||||
|
|
||||||
|
if for_verification:
|
||||||
|
key = (EventTypes.Member, event.state_key, )
|
||||||
|
existing_event_id = current_state_ids.get(key)
|
||||||
|
if existing_event_id:
|
||||||
|
auth_ids.append(existing_event_id)
|
||||||
|
|
||||||
if e_type == Membership.INVITE:
|
if e_type == Membership.INVITE:
|
||||||
if "third_party_invite" in event.content:
|
if "third_party_invite" in event.content:
|
||||||
|
@ -903,14 +918,15 @@ class Auth(object):
|
||||||
EventTypes.ThirdPartyInvite,
|
EventTypes.ThirdPartyInvite,
|
||||||
event.content["third_party_invite"]["signed"]["token"]
|
event.content["third_party_invite"]["signed"]["token"]
|
||||||
)
|
)
|
||||||
third_party_invite = current_state.get(key)
|
third_party_invite_id = current_state_ids.get(key)
|
||||||
if third_party_invite:
|
if third_party_invite_id:
|
||||||
auth_ids.append(third_party_invite.event_id)
|
auth_ids.append(third_party_invite_id)
|
||||||
elif member_event:
|
elif member_event_id:
|
||||||
|
member_event = yield self.store.get_event(member_event_id)
|
||||||
if member_event.content["membership"] == Membership.JOIN:
|
if member_event.content["membership"] == Membership.JOIN:
|
||||||
auth_ids.append(member_event.event_id)
|
auth_ids.append(member_event.event_id)
|
||||||
|
|
||||||
return auth_ids
|
defer.returnValue(auth_ids)
|
||||||
|
|
||||||
def _get_send_level(self, etype, state_key, auth_events):
|
def _get_send_level(self, etype, state_key, auth_events):
|
||||||
key = (EventTypes.PowerLevels, "", )
|
key = (EventTypes.PowerLevels, "", )
|
||||||
|
|
|
@ -85,3 +85,8 @@ class RoomCreationPreset(object):
|
||||||
PRIVATE_CHAT = "private_chat"
|
PRIVATE_CHAT = "private_chat"
|
||||||
PUBLIC_CHAT = "public_chat"
|
PUBLIC_CHAT = "public_chat"
|
||||||
TRUSTED_PRIVATE_CHAT = "trusted_private_chat"
|
TRUSTED_PRIVATE_CHAT = "trusted_private_chat"
|
||||||
|
|
||||||
|
|
||||||
|
class ThirdPartyEntityKind(object):
|
||||||
|
USER = "user"
|
||||||
|
LOCATION = "location"
|
||||||
|
|
|
@ -25,4 +25,3 @@ SERVER_KEY_PREFIX = "/_matrix/key/v1"
|
||||||
SERVER_KEY_V2_PREFIX = "/_matrix/key/v2"
|
SERVER_KEY_V2_PREFIX = "/_matrix/key/v2"
|
||||||
MEDIA_PREFIX = "/_matrix/media/r0"
|
MEDIA_PREFIX = "/_matrix/media/r0"
|
||||||
LEGACY_MEDIA_PREFIX = "/_matrix/media/v1"
|
LEGACY_MEDIA_PREFIX = "/_matrix/media/v1"
|
||||||
APP_SERVICE_PREFIX = "/_matrix/appservice/v1"
|
|
||||||
|
|
|
@ -36,6 +36,7 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
|
||||||
from synapse.replication.slave.storage.filtering import SlavedFilteringStore
|
from synapse.replication.slave.storage.filtering import SlavedFilteringStore
|
||||||
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
|
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
|
||||||
from synapse.replication.slave.storage.presence import SlavedPresenceStore
|
from synapse.replication.slave.storage.presence import SlavedPresenceStore
|
||||||
|
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.client_ips import ClientIpStore
|
from synapse.storage.client_ips import ClientIpStore
|
||||||
from synapse.storage.engines import create_engine
|
from synapse.storage.engines import create_engine
|
||||||
|
@ -72,6 +73,7 @@ class SynchrotronSlavedStore(
|
||||||
SlavedRegistrationStore,
|
SlavedRegistrationStore,
|
||||||
SlavedFilteringStore,
|
SlavedFilteringStore,
|
||||||
SlavedPresenceStore,
|
SlavedPresenceStore,
|
||||||
|
SlavedDeviceInboxStore,
|
||||||
BaseSlavedStore,
|
BaseSlavedStore,
|
||||||
ClientIpStore, # After BaseSlavedStore because the constructor is different
|
ClientIpStore, # After BaseSlavedStore because the constructor is different
|
||||||
):
|
):
|
||||||
|
@ -397,6 +399,9 @@ class SynchrotronServer(HomeServer):
|
||||||
notify_from_stream(
|
notify_from_stream(
|
||||||
result, "typing", "typing_key", room="room_id"
|
result, "typing", "typing_key", room="room_id"
|
||||||
)
|
)
|
||||||
|
notify_from_stream(
|
||||||
|
result, "to_device", "to_device_key", user="user_id"
|
||||||
|
)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -88,6 +88,8 @@ class ApplicationService(object):
|
||||||
self.sender = sender
|
self.sender = sender
|
||||||
self.namespaces = self._check_namespaces(namespaces)
|
self.namespaces = self._check_namespaces(namespaces)
|
||||||
self.id = id
|
self.id = id
|
||||||
|
|
||||||
|
# .protocols is a publicly visible field
|
||||||
if protocols:
|
if protocols:
|
||||||
self.protocols = set(protocols)
|
self.protocols = set(protocols)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -14,10 +14,11 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.api.constants import ThirdPartyEntityKind
|
||||||
from synapse.api.errors import CodeMessageException
|
from synapse.api.errors import CodeMessageException
|
||||||
from synapse.http.client import SimpleHttpClient
|
from synapse.http.client import SimpleHttpClient
|
||||||
from synapse.events.utils import serialize_event
|
from synapse.events.utils import serialize_event
|
||||||
from synapse.types import ThirdPartyEntityKind
|
from synapse.util.caches.response_cache import ResponseCache
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import urllib
|
import urllib
|
||||||
|
@ -25,6 +26,12 @@ import urllib
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
HOUR_IN_MS = 60 * 60 * 1000
|
||||||
|
|
||||||
|
|
||||||
|
APP_SERVICE_PREFIX = "/_matrix/app/unstable"
|
||||||
|
|
||||||
|
|
||||||
def _is_valid_3pe_result(r, field):
|
def _is_valid_3pe_result(r, field):
|
||||||
if not isinstance(r, dict):
|
if not isinstance(r, dict):
|
||||||
return False
|
return False
|
||||||
|
@ -56,8 +63,12 @@ class ApplicationServiceApi(SimpleHttpClient):
|
||||||
super(ApplicationServiceApi, self).__init__(hs)
|
super(ApplicationServiceApi, self).__init__(hs)
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
|
self.protocol_meta_cache = ResponseCache(hs, timeout_ms=HOUR_IN_MS)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def query_user(self, service, user_id):
|
def query_user(self, service, user_id):
|
||||||
|
if service.url is None:
|
||||||
|
defer.returnValue(False)
|
||||||
uri = service.url + ("/users/%s" % urllib.quote(user_id))
|
uri = service.url + ("/users/%s" % urllib.quote(user_id))
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
|
@ -77,6 +88,8 @@ class ApplicationServiceApi(SimpleHttpClient):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def query_alias(self, service, alias):
|
def query_alias(self, service, alias):
|
||||||
|
if service.url is None:
|
||||||
|
defer.returnValue(False)
|
||||||
uri = service.url + ("/rooms/%s" % urllib.quote(alias))
|
uri = service.url + ("/rooms/%s" % urllib.quote(alias))
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
|
@ -97,16 +110,22 @@ class ApplicationServiceApi(SimpleHttpClient):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def query_3pe(self, service, kind, protocol, fields):
|
def query_3pe(self, service, kind, protocol, fields):
|
||||||
if kind == ThirdPartyEntityKind.USER:
|
if kind == ThirdPartyEntityKind.USER:
|
||||||
uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol))
|
|
||||||
required_field = "userid"
|
required_field = "userid"
|
||||||
elif kind == ThirdPartyEntityKind.LOCATION:
|
elif kind == ThirdPartyEntityKind.LOCATION:
|
||||||
uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol))
|
|
||||||
required_field = "alias"
|
required_field = "alias"
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unrecognised 'kind' argument %r to query_3pe()", kind
|
"Unrecognised 'kind' argument %r to query_3pe()", kind
|
||||||
)
|
)
|
||||||
|
if service.url is None:
|
||||||
|
defer.returnValue([])
|
||||||
|
|
||||||
|
uri = "%s%s/thirdparty/%s/%s" % (
|
||||||
|
service.url,
|
||||||
|
APP_SERVICE_PREFIX,
|
||||||
|
kind,
|
||||||
|
urllib.quote(protocol)
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
response = yield self.get_json(uri, fields)
|
response = yield self.get_json(uri, fields)
|
||||||
if not isinstance(response, list):
|
if not isinstance(response, list):
|
||||||
|
@ -131,8 +150,34 @@ class ApplicationServiceApi(SimpleHttpClient):
|
||||||
logger.warning("query_3pe to %s threw exception %s", uri, ex)
|
logger.warning("query_3pe to %s threw exception %s", uri, ex)
|
||||||
defer.returnValue([])
|
defer.returnValue([])
|
||||||
|
|
||||||
|
def get_3pe_protocol(self, service, protocol):
|
||||||
|
if service.url is None:
|
||||||
|
defer.returnValue({})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _get():
|
||||||
|
uri = "%s%s/thirdparty/protocol/%s" % (
|
||||||
|
service.url,
|
||||||
|
APP_SERVICE_PREFIX,
|
||||||
|
urllib.quote(protocol)
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
defer.returnValue((yield self.get_json(uri, {})))
|
||||||
|
except Exception as ex:
|
||||||
|
logger.warning("query_3pe_protocol to %s threw exception %s",
|
||||||
|
uri, ex)
|
||||||
|
defer.returnValue({})
|
||||||
|
|
||||||
|
key = (service.id, protocol)
|
||||||
|
return self.protocol_meta_cache.get(key) or (
|
||||||
|
self.protocol_meta_cache.set(key, _get())
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def push_bulk(self, service, events, txn_id=None):
|
def push_bulk(self, service, events, txn_id=None):
|
||||||
|
if service.url is None:
|
||||||
|
defer.returnValue(True)
|
||||||
|
|
||||||
events = self._serialize(events)
|
events = self._serialize(events)
|
||||||
|
|
||||||
if txn_id is None:
|
if txn_id is None:
|
||||||
|
|
|
@ -86,7 +86,7 @@ def load_appservices(hostname, config_files):
|
||||||
|
|
||||||
def _load_appservice(hostname, as_info, config_filename):
|
def _load_appservice(hostname, as_info, config_filename):
|
||||||
required_string_fields = [
|
required_string_fields = [
|
||||||
"id", "url", "as_token", "hs_token", "sender_localpart"
|
"id", "as_token", "hs_token", "sender_localpart"
|
||||||
]
|
]
|
||||||
for field in required_string_fields:
|
for field in required_string_fields:
|
||||||
if not isinstance(as_info.get(field), basestring):
|
if not isinstance(as_info.get(field), basestring):
|
||||||
|
@ -94,6 +94,14 @@ def _load_appservice(hostname, as_info, config_filename):
|
||||||
field, config_filename,
|
field, config_filename,
|
||||||
))
|
))
|
||||||
|
|
||||||
|
# 'url' must either be a string or explicitly null, not missing
|
||||||
|
# to avoid accidentally turning off push for ASes.
|
||||||
|
if (not isinstance(as_info.get("url"), basestring) and
|
||||||
|
as_info.get("url", "") is not None):
|
||||||
|
raise KeyError(
|
||||||
|
"Required string field or explicit null: 'url' (%s)" % (config_filename,)
|
||||||
|
)
|
||||||
|
|
||||||
localpart = as_info["sender_localpart"]
|
localpart = as_info["sender_localpart"]
|
||||||
if urllib.quote(localpart) != localpart:
|
if urllib.quote(localpart) != localpart:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -132,6 +140,13 @@ def _load_appservice(hostname, as_info, config_filename):
|
||||||
for p in protocols:
|
for p in protocols:
|
||||||
if not isinstance(p, str):
|
if not isinstance(p, str):
|
||||||
raise KeyError("Bad value for 'protocols' item")
|
raise KeyError("Bad value for 'protocols' item")
|
||||||
|
|
||||||
|
if as_info["url"] is None:
|
||||||
|
logger.info(
|
||||||
|
"(%s) Explicitly empty 'url' provided. This application service"
|
||||||
|
" will not receive events or queries.",
|
||||||
|
config_filename,
|
||||||
|
)
|
||||||
return ApplicationService(
|
return ApplicationService(
|
||||||
token=as_info["as_token"],
|
token=as_info["as_token"],
|
||||||
url=as_info["url"],
|
url=as_info["url"],
|
||||||
|
|
|
@ -99,7 +99,7 @@ class EventBase(object):
|
||||||
|
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def get(self, key, default):
|
def get(self, key, default=None):
|
||||||
return self._event_dict.get(key, default)
|
return self._event_dict.get(key, default)
|
||||||
|
|
||||||
def get_internal_metadata_dict(self):
|
def get_internal_metadata_dict(self):
|
||||||
|
|
|
@ -15,9 +15,9 @@
|
||||||
|
|
||||||
|
|
||||||
class EventContext(object):
|
class EventContext(object):
|
||||||
|
def __init__(self):
|
||||||
def __init__(self, current_state=None):
|
self.current_state_ids = None
|
||||||
self.current_state = current_state
|
self.prev_state_ids = None
|
||||||
self.state_group = None
|
self.state_group = None
|
||||||
self.rejected = False
|
self.rejected = False
|
||||||
self.push_actions = []
|
self.push_actions = []
|
||||||
|
|
|
@ -29,6 +29,7 @@ from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||||
from synapse.events import FrozenEvent
|
from synapse.events import FrozenEvent
|
||||||
|
from synapse.types import get_domain_from_id
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
|
||||||
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
|
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
|
||||||
|
@ -63,6 +64,7 @@ class FederationClient(FederationBase):
|
||||||
self._clock.looping_call(
|
self._clock.looping_call(
|
||||||
self._clear_tried_cache, 60 * 1000,
|
self._clear_tried_cache, 60 * 1000,
|
||||||
)
|
)
|
||||||
|
self.state = hs.get_state_handler()
|
||||||
|
|
||||||
def _clear_tried_cache(self):
|
def _clear_tried_cache(self):
|
||||||
"""Clear pdu_destination_tried cache"""
|
"""Clear pdu_destination_tried cache"""
|
||||||
|
@ -267,7 +269,7 @@ class FederationClient(FederationBase):
|
||||||
|
|
||||||
pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
|
pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
|
||||||
|
|
||||||
pdu = None
|
signed_pdu = None
|
||||||
for destination in destinations:
|
for destination in destinations:
|
||||||
now = self._clock.time_msec()
|
now = self._clock.time_msec()
|
||||||
last_attempt = pdu_attempts.get(destination, 0)
|
last_attempt = pdu_attempts.get(destination, 0)
|
||||||
|
@ -297,7 +299,7 @@ class FederationClient(FederationBase):
|
||||||
pdu = pdu_list[0]
|
pdu = pdu_list[0]
|
||||||
|
|
||||||
# Check signatures are correct.
|
# Check signatures are correct.
|
||||||
pdu = yield self._check_sigs_and_hashes([pdu])[0]
|
signed_pdu = yield self._check_sigs_and_hashes([pdu])[0]
|
||||||
|
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -320,10 +322,10 @@ class FederationClient(FederationBase):
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if self._get_pdu_cache is not None and pdu:
|
if self._get_pdu_cache is not None and signed_pdu:
|
||||||
self._get_pdu_cache[event_id] = pdu
|
self._get_pdu_cache[event_id] = signed_pdu
|
||||||
|
|
||||||
defer.returnValue(pdu)
|
defer.returnValue(signed_pdu)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
|
@ -811,7 +813,8 @@ class FederationClient(FederationBase):
|
||||||
if len(signed_events) >= limit:
|
if len(signed_events) >= limit:
|
||||||
defer.returnValue(signed_events)
|
defer.returnValue(signed_events)
|
||||||
|
|
||||||
servers = yield self.store.get_joined_hosts_for_room(room_id)
|
users = yield self.state.get_current_user_in_room(room_id)
|
||||||
|
servers = set(get_domain_from_id(u) for u in users)
|
||||||
|
|
||||||
servers = set(servers)
|
servers = set(servers)
|
||||||
servers.discard(self.server_name)
|
servers.discard(self.server_name)
|
||||||
|
|
|
@ -223,16 +223,14 @@ class FederationServer(FederationBase):
|
||||||
if not in_room:
|
if not in_room:
|
||||||
raise AuthError(403, "Host not in room.")
|
raise AuthError(403, "Host not in room.")
|
||||||
|
|
||||||
pdus = yield self.handler.get_state_for_pdu(
|
state_ids = yield self.handler.get_state_ids_for_pdu(
|
||||||
room_id, event_id,
|
room_id, event_id,
|
||||||
)
|
)
|
||||||
auth_chain = yield self.store.get_auth_chain(
|
auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids)
|
||||||
[pdu.event_id for pdu in pdus]
|
|
||||||
)
|
|
||||||
|
|
||||||
defer.returnValue((200, {
|
defer.returnValue((200, {
|
||||||
"pdu_ids": [pdu.event_id for pdu in pdus],
|
"pdu_ids": state_ids,
|
||||||
"auth_chain_ids": [pdu.event_id for pdu in auth_chain],
|
"auth_chain_ids": auth_chain_ids,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -65,33 +65,21 @@ class BaseHandler(object):
|
||||||
retry_after_ms=int(1000 * (time_allowed - time_now)),
|
retry_after_ms=int(1000 * (time_allowed - time_now)),
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_host_in_room(self, current_state):
|
|
||||||
room_members = [
|
|
||||||
(state_key, event.membership)
|
|
||||||
for ((event_type, state_key), event) in current_state.items()
|
|
||||||
if event_type == EventTypes.Member
|
|
||||||
]
|
|
||||||
if len(room_members) == 0:
|
|
||||||
# Have we just created the room, and is this about to be the very
|
|
||||||
# first member event?
|
|
||||||
create_event = current_state.get(("m.room.create", ""))
|
|
||||||
if create_event:
|
|
||||||
return True
|
|
||||||
for (state_key, membership) in room_members:
|
|
||||||
if (
|
|
||||||
self.hs.is_mine_id(state_key)
|
|
||||||
and membership == Membership.JOIN
|
|
||||||
):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def maybe_kick_guest_users(self, event, current_state):
|
def maybe_kick_guest_users(self, event, context=None):
|
||||||
# Technically this function invalidates current_state by changing it.
|
# Technically this function invalidates current_state by changing it.
|
||||||
# Hopefully this isn't that important to the caller.
|
# Hopefully this isn't that important to the caller.
|
||||||
if event.type == EventTypes.GuestAccess:
|
if event.type == EventTypes.GuestAccess:
|
||||||
guest_access = event.content.get("guest_access", "forbidden")
|
guest_access = event.content.get("guest_access", "forbidden")
|
||||||
if guest_access != "can_join":
|
if guest_access != "can_join":
|
||||||
|
if context:
|
||||||
|
current_state = yield self.store.get_events(
|
||||||
|
context.current_state_ids.values()
|
||||||
|
)
|
||||||
|
current_state = current_state.values()
|
||||||
|
else:
|
||||||
|
current_state = yield self.store.get_current_state(event.room_id)
|
||||||
|
logger.info("maybe_kick_guest_users %r", current_state)
|
||||||
yield self.kick_guest_users(current_state)
|
yield self.kick_guest_users(current_state)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -175,6 +175,16 @@ class ApplicationServicesHandler(object):
|
||||||
|
|
||||||
defer.returnValue(ret)
|
defer.returnValue(ret)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_3pe_protocols(self):
|
||||||
|
services = yield self.store.get_app_services()
|
||||||
|
protocols = {}
|
||||||
|
for s in services:
|
||||||
|
for p in s.protocols:
|
||||||
|
protocols[p] = yield self.appservice_api.get_3pe_protocol(s, p)
|
||||||
|
|
||||||
|
defer.returnValue(protocols)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _get_services_for_event(self, event):
|
def _get_services_for_event(self, event):
|
||||||
"""Retrieve a list of application services interested in this event.
|
"""Retrieve a list of application services interested in this event.
|
||||||
|
|
|
@ -19,7 +19,7 @@ from ._base import BaseHandler
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError, Codes, CodeMessageException, AuthError
|
from synapse.api.errors import SynapseError, Codes, CodeMessageException, AuthError
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.types import RoomAlias, UserID
|
from synapse.types import RoomAlias, UserID, get_domain_from_id
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import string
|
import string
|
||||||
|
@ -55,7 +55,8 @@ class DirectoryHandler(BaseHandler):
|
||||||
# TODO(erikj): Add transactions.
|
# TODO(erikj): Add transactions.
|
||||||
# TODO(erikj): Check if there is a current association.
|
# TODO(erikj): Check if there is a current association.
|
||||||
if not servers:
|
if not servers:
|
||||||
servers = yield self.store.get_joined_hosts_for_room(room_id)
|
users = yield self.state.get_current_user_in_room(room_id)
|
||||||
|
servers = set(get_domain_from_id(u) for u in users)
|
||||||
|
|
||||||
if not servers:
|
if not servers:
|
||||||
raise SynapseError(400, "Failed to get server list")
|
raise SynapseError(400, "Failed to get server list")
|
||||||
|
@ -193,7 +194,8 @@ class DirectoryHandler(BaseHandler):
|
||||||
Codes.NOT_FOUND
|
Codes.NOT_FOUND
|
||||||
)
|
)
|
||||||
|
|
||||||
extra_servers = yield self.store.get_joined_hosts_for_room(room_id)
|
users = yield self.state.get_current_user_in_room(room_id)
|
||||||
|
extra_servers = set(get_domain_from_id(u) for u in users)
|
||||||
servers = set(extra_servers) | set(servers)
|
servers = set(extra_servers) | set(servers)
|
||||||
|
|
||||||
# If this server is in the list of servers, return it first.
|
# If this server is in the list of servers, return it first.
|
||||||
|
|
|
@ -47,6 +47,7 @@ class EventStreamHandler(BaseHandler):
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
|
self.state = hs.get_state_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
|
@ -90,7 +91,7 @@ class EventStreamHandler(BaseHandler):
|
||||||
# Send down presence.
|
# Send down presence.
|
||||||
if event.state_key == auth_user_id:
|
if event.state_key == auth_user_id:
|
||||||
# Send down presence for everyone in the room.
|
# Send down presence for everyone in the room.
|
||||||
users = yield self.store.get_users_in_room(event.room_id)
|
users = yield self.state.get_current_user_in_room(event.room_id)
|
||||||
states = yield presence_handler.get_states(
|
states = yield presence_handler.get_states(
|
||||||
users,
|
users,
|
||||||
as_event=True,
|
as_event=True,
|
||||||
|
|
|
@ -29,6 +29,7 @@ from synapse.util import unwrapFirstError
|
||||||
from synapse.util.logcontext import (
|
from synapse.util.logcontext import (
|
||||||
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred
|
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred
|
||||||
)
|
)
|
||||||
|
from synapse.util.metrics import measure_func
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
from synapse.util.frozenutils import unfreeze
|
from synapse.util.frozenutils import unfreeze
|
||||||
|
@ -100,6 +101,9 @@ class FederationHandler(BaseHandler):
|
||||||
def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None):
|
def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None):
|
||||||
""" Called by the ReplicationLayer when we have a new pdu. We need to
|
""" Called by the ReplicationLayer when we have a new pdu. We need to
|
||||||
do auth checks and put it through the StateHandler.
|
do auth checks and put it through the StateHandler.
|
||||||
|
|
||||||
|
auth_chain and state are None if we already have the necessary state
|
||||||
|
and prev_events in the db
|
||||||
"""
|
"""
|
||||||
event = pdu
|
event = pdu
|
||||||
|
|
||||||
|
@ -117,12 +121,21 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
# FIXME (erikj): Awful hack to make the case where we are not currently
|
# FIXME (erikj): Awful hack to make the case where we are not currently
|
||||||
# in the room work
|
# in the room work
|
||||||
|
# If state and auth_chain are None, then we don't need to do this check
|
||||||
|
# as we already know we have enough state in the DB to handle this
|
||||||
|
# event.
|
||||||
|
if state and auth_chain and not event.internal_metadata.is_outlier():
|
||||||
is_in_room = yield self.auth.check_host_in_room(
|
is_in_room = yield self.auth.check_host_in_room(
|
||||||
event.room_id,
|
event.room_id,
|
||||||
self.server_name
|
self.server_name
|
||||||
)
|
)
|
||||||
if not is_in_room and not event.internal_metadata.is_outlier():
|
else:
|
||||||
logger.debug("Got event for room we're not in.")
|
is_in_room = True
|
||||||
|
if not is_in_room:
|
||||||
|
logger.info(
|
||||||
|
"Got event for room we're not in: %r %r",
|
||||||
|
event.room_id, event.event_id
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
event_stream_id, max_stream_id = yield self._persist_auth_tree(
|
event_stream_id, max_stream_id = yield self._persist_auth_tree(
|
||||||
|
@ -217,17 +230,28 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
if event.type == EventTypes.Member:
|
if event.type == EventTypes.Member:
|
||||||
if event.membership == Membership.JOIN:
|
if event.membership == Membership.JOIN:
|
||||||
prev_state = context.current_state.get((event.type, event.state_key))
|
|
||||||
if not prev_state or prev_state.membership != Membership.JOIN:
|
|
||||||
# Only fire user_joined_room if the user has acutally
|
# Only fire user_joined_room if the user has acutally
|
||||||
# joined the room. Don't bother if the user is just
|
# joined the room. Don't bother if the user is just
|
||||||
# changing their profile info.
|
# changing their profile info.
|
||||||
|
newly_joined = True
|
||||||
|
prev_state_id = context.prev_state_ids.get(
|
||||||
|
(event.type, event.state_key)
|
||||||
|
)
|
||||||
|
if prev_state_id:
|
||||||
|
prev_state = yield self.store.get_event(
|
||||||
|
prev_state_id, allow_none=True,
|
||||||
|
)
|
||||||
|
if prev_state and prev_state.membership == Membership.JOIN:
|
||||||
|
newly_joined = False
|
||||||
|
|
||||||
|
if newly_joined:
|
||||||
user = UserID.from_string(event.state_key)
|
user = UserID.from_string(event.state_key)
|
||||||
yield user_joined_room(self.distributor, user, event.room_id)
|
yield user_joined_room(self.distributor, user, event.room_id)
|
||||||
|
|
||||||
|
@measure_func("_filter_events_for_server")
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _filter_events_for_server(self, server_name, room_id, events):
|
def _filter_events_for_server(self, server_name, room_id, events):
|
||||||
event_to_state = yield self.store.get_state_for_events(
|
event_to_state_ids = yield self.store.get_state_ids_for_events(
|
||||||
frozenset(e.event_id for e in events),
|
frozenset(e.event_id for e in events),
|
||||||
types=(
|
types=(
|
||||||
(EventTypes.RoomHistoryVisibility, ""),
|
(EventTypes.RoomHistoryVisibility, ""),
|
||||||
|
@ -235,6 +259,30 @@ class FederationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# We only want to pull out member events that correspond to the
|
||||||
|
# server's domain.
|
||||||
|
|
||||||
|
def check_match(id):
|
||||||
|
try:
|
||||||
|
return server_name == get_domain_from_id(id)
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
event_map = yield self.store.get_events([
|
||||||
|
e_id for key_to_eid in event_to_state_ids.values()
|
||||||
|
for key, e_id in key_to_eid
|
||||||
|
if key[0] != EventTypes.Member or check_match(key[1])
|
||||||
|
])
|
||||||
|
|
||||||
|
event_to_state = {
|
||||||
|
e_id: {
|
||||||
|
key: event_map[inner_e_id]
|
||||||
|
for key, inner_e_id in key_to_eid.items()
|
||||||
|
if inner_e_id in event_map
|
||||||
|
}
|
||||||
|
for e_id, key_to_eid in event_to_state_ids.items()
|
||||||
|
}
|
||||||
|
|
||||||
def redact_disallowed(event, state):
|
def redact_disallowed(event, state):
|
||||||
if not state:
|
if not state:
|
||||||
return event
|
return event
|
||||||
|
@ -377,7 +425,9 @@ class FederationHandler(BaseHandler):
|
||||||
)).addErrback(unwrapFirstError)
|
)).addErrback(unwrapFirstError)
|
||||||
auth_events.update({a.event_id: a for a in results if a})
|
auth_events.update({a.event_id: a for a in results if a})
|
||||||
required_auth.update(
|
required_auth.update(
|
||||||
a_id for event in results for a_id, _ in event.auth_events if event
|
a_id
|
||||||
|
for event in results if event
|
||||||
|
for a_id, _ in event.auth_events
|
||||||
)
|
)
|
||||||
missing_auth = required_auth - set(auth_events)
|
missing_auth = required_auth - set(auth_events)
|
||||||
|
|
||||||
|
@ -560,6 +610,18 @@ class FederationHandler(BaseHandler):
|
||||||
]))
|
]))
|
||||||
states = dict(zip(event_ids, [s[1] for s in states]))
|
states = dict(zip(event_ids, [s[1] for s in states]))
|
||||||
|
|
||||||
|
state_map = yield self.store.get_events(
|
||||||
|
[e_id for ids in states.values() for e_id in ids],
|
||||||
|
get_prev_content=False
|
||||||
|
)
|
||||||
|
states = {
|
||||||
|
key: {
|
||||||
|
k: state_map[e_id]
|
||||||
|
for k, e_id in state_dict.items()
|
||||||
|
if e_id in state_map
|
||||||
|
} for key, state_dict in states.items()
|
||||||
|
}
|
||||||
|
|
||||||
for e_id, _ in sorted_extremeties_tuple:
|
for e_id, _ in sorted_extremeties_tuple:
|
||||||
likely_domains = get_domains_from_state(states[e_id])
|
likely_domains = get_domains_from_state(states[e_id])
|
||||||
|
|
||||||
|
@ -722,7 +784,7 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
# The remote hasn't signed it yet, obviously. We'll do the full checks
|
# The remote hasn't signed it yet, obviously. We'll do the full checks
|
||||||
# when we get the event back in `on_send_join_request`
|
# when we get the event back in `on_send_join_request`
|
||||||
self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
|
yield self.auth.check_from_context(event, context, do_sig_check=False)
|
||||||
|
|
||||||
defer.returnValue(event)
|
defer.returnValue(event)
|
||||||
|
|
||||||
|
@ -770,18 +832,11 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
new_pdu = event
|
new_pdu = event
|
||||||
|
|
||||||
destinations = set()
|
message_handler = self.hs.get_handlers().message_handler
|
||||||
|
destinations = yield message_handler.get_joined_hosts_for_room_from_state(
|
||||||
for k, s in context.current_state.items():
|
context
|
||||||
try:
|
|
||||||
if k[0] == EventTypes.Member:
|
|
||||||
if s.content["membership"] == Membership.JOIN:
|
|
||||||
destinations.add(get_domain_from_id(s.state_key))
|
|
||||||
except:
|
|
||||||
logger.warn(
|
|
||||||
"Failed to get destination from event %s", s.event_id
|
|
||||||
)
|
)
|
||||||
|
destinations = set(destinations)
|
||||||
destinations.discard(origin)
|
destinations.discard(origin)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
@ -792,13 +847,15 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
self.replication_layer.send_pdu(new_pdu, destinations)
|
self.replication_layer.send_pdu(new_pdu, destinations)
|
||||||
|
|
||||||
state_ids = [e.event_id for e in context.current_state.values()]
|
state_ids = context.prev_state_ids.values()
|
||||||
auth_chain = yield self.store.get_auth_chain(set(
|
auth_chain = yield self.store.get_auth_chain(set(
|
||||||
[event.event_id] + state_ids
|
[event.event_id] + state_ids
|
||||||
))
|
))
|
||||||
|
|
||||||
|
state = yield self.store.get_events(context.prev_state_ids.values())
|
||||||
|
|
||||||
defer.returnValue({
|
defer.returnValue({
|
||||||
"state": context.current_state.values(),
|
"state": state.values(),
|
||||||
"auth_chain": auth_chain,
|
"auth_chain": auth_chain,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -954,7 +1011,7 @@ class FederationHandler(BaseHandler):
|
||||||
try:
|
try:
|
||||||
# The remote hasn't signed it yet, obviously. We'll do the full checks
|
# The remote hasn't signed it yet, obviously. We'll do the full checks
|
||||||
# when we get the event back in `on_send_leave_request`
|
# when we get the event back in `on_send_leave_request`
|
||||||
self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
|
yield self.auth.check_from_context(event, context, do_sig_check=False)
|
||||||
except AuthError as e:
|
except AuthError as e:
|
||||||
logger.warn("Failed to create new leave %r because %s", event, e)
|
logger.warn("Failed to create new leave %r because %s", event, e)
|
||||||
raise e
|
raise e
|
||||||
|
@ -998,18 +1055,11 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
new_pdu = event
|
new_pdu = event
|
||||||
|
|
||||||
destinations = set()
|
message_handler = self.hs.get_handlers().message_handler
|
||||||
|
destinations = yield message_handler.get_joined_hosts_for_room_from_state(
|
||||||
for k, s in context.current_state.items():
|
context
|
||||||
try:
|
|
||||||
if k[0] == EventTypes.Member:
|
|
||||||
if s.content["membership"] == Membership.LEAVE:
|
|
||||||
destinations.add(get_domain_from_id(s.state_key))
|
|
||||||
except:
|
|
||||||
logger.warn(
|
|
||||||
"Failed to get destination from event %s", s.event_id
|
|
||||||
)
|
)
|
||||||
|
destinations = set(destinations)
|
||||||
destinations.discard(origin)
|
destinations.discard(origin)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
@ -1024,6 +1074,8 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_state_for_pdu(self, room_id, event_id):
|
def get_state_for_pdu(self, room_id, event_id):
|
||||||
|
"""Returns the state at the event. i.e. not including said event.
|
||||||
|
"""
|
||||||
yield run_on_reactor()
|
yield run_on_reactor()
|
||||||
|
|
||||||
state_groups = yield self.store.get_state_groups(
|
state_groups = yield self.store.get_state_groups(
|
||||||
|
@ -1064,6 +1116,34 @@ class FederationHandler(BaseHandler):
|
||||||
else:
|
else:
|
||||||
defer.returnValue([])
|
defer.returnValue([])
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_state_ids_for_pdu(self, room_id, event_id):
|
||||||
|
"""Returns the state at the event. i.e. not including said event.
|
||||||
|
"""
|
||||||
|
yield run_on_reactor()
|
||||||
|
|
||||||
|
state_groups = yield self.store.get_state_groups_ids(
|
||||||
|
room_id, [event_id]
|
||||||
|
)
|
||||||
|
|
||||||
|
if state_groups:
|
||||||
|
_, state = state_groups.items().pop()
|
||||||
|
results = state
|
||||||
|
|
||||||
|
event = yield self.store.get_event(event_id)
|
||||||
|
if event and event.is_state():
|
||||||
|
# Get previous state
|
||||||
|
if "replaces_state" in event.unsigned:
|
||||||
|
prev_id = event.unsigned["replaces_state"]
|
||||||
|
if prev_id != event.event_id:
|
||||||
|
results[(event.type, event.state_key)] = prev_id
|
||||||
|
else:
|
||||||
|
del results[(event.type, event.state_key)]
|
||||||
|
|
||||||
|
defer.returnValue(results.values())
|
||||||
|
else:
|
||||||
|
defer.returnValue([])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def on_backfill_request(self, origin, room_id, pdu_list, limit):
|
def on_backfill_request(self, origin, room_id, pdu_list, limit):
|
||||||
|
@ -1294,7 +1374,13 @@ class FederationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
if not auth_events:
|
if not auth_events:
|
||||||
auth_events = context.current_state
|
auth_events_ids = yield self.auth.compute_auth_events(
|
||||||
|
event, context.prev_state_ids, for_verification=True,
|
||||||
|
)
|
||||||
|
auth_events = yield self.store.get_events(auth_events_ids)
|
||||||
|
auth_events = {
|
||||||
|
(e.type, e.state_key): e for e in auth_events.values()
|
||||||
|
}
|
||||||
|
|
||||||
# This is a hack to fix some old rooms where the initial join event
|
# This is a hack to fix some old rooms where the initial join event
|
||||||
# didn't reference the create event in its auth events.
|
# didn't reference the create event in its auth events.
|
||||||
|
@ -1320,8 +1406,7 @@ class FederationHandler(BaseHandler):
|
||||||
context.rejected = RejectedReason.AUTH_ERROR
|
context.rejected = RejectedReason.AUTH_ERROR
|
||||||
|
|
||||||
if event.type == EventTypes.GuestAccess:
|
if event.type == EventTypes.GuestAccess:
|
||||||
full_context = yield self.store.get_current_state(room_id=event.room_id)
|
yield self.maybe_kick_guest_users(event)
|
||||||
yield self.maybe_kick_guest_users(event, full_context)
|
|
||||||
|
|
||||||
defer.returnValue(context)
|
defer.returnValue(context)
|
||||||
|
|
||||||
|
@ -1389,6 +1474,11 @@ class FederationHandler(BaseHandler):
|
||||||
current_state = set(e.event_id for e in auth_events.values())
|
current_state = set(e.event_id for e in auth_events.values())
|
||||||
event_auth_events = set(e_id for e_id, _ in event.auth_events)
|
event_auth_events = set(e_id for e_id, _ in event.auth_events)
|
||||||
|
|
||||||
|
if event.is_state():
|
||||||
|
event_key = (event.type, event.state_key)
|
||||||
|
else:
|
||||||
|
event_key = None
|
||||||
|
|
||||||
if event_auth_events - current_state:
|
if event_auth_events - current_state:
|
||||||
have_events = yield self.store.have_events(
|
have_events = yield self.store.have_events(
|
||||||
event_auth_events - current_state
|
event_auth_events - current_state
|
||||||
|
@ -1492,8 +1582,14 @@ class FederationHandler(BaseHandler):
|
||||||
current_state = set(e.event_id for e in auth_events.values())
|
current_state = set(e.event_id for e in auth_events.values())
|
||||||
different_auth = event_auth_events - current_state
|
different_auth = event_auth_events - current_state
|
||||||
|
|
||||||
context.current_state.update(auth_events)
|
context.current_state_ids.update({
|
||||||
context.state_group = None
|
k: a.event_id for k, a in auth_events.items()
|
||||||
|
if k != event_key
|
||||||
|
})
|
||||||
|
context.prev_state_ids.update({
|
||||||
|
k: a.event_id for k, a in auth_events.items()
|
||||||
|
})
|
||||||
|
context.state_group = self.store.get_next_state_group()
|
||||||
|
|
||||||
if different_auth and not event.internal_metadata.is_outlier():
|
if different_auth and not event.internal_metadata.is_outlier():
|
||||||
logger.info("Different auth after resolution: %s", different_auth)
|
logger.info("Different auth after resolution: %s", different_auth)
|
||||||
|
@ -1514,8 +1610,8 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
if do_resolution:
|
if do_resolution:
|
||||||
# 1. Get what we think is the auth chain.
|
# 1. Get what we think is the auth chain.
|
||||||
auth_ids = self.auth.compute_auth_events(
|
auth_ids = yield self.auth.compute_auth_events(
|
||||||
event, context.current_state
|
event, context.prev_state_ids
|
||||||
)
|
)
|
||||||
local_auth_chain = yield self.store.get_auth_chain(auth_ids)
|
local_auth_chain = yield self.store.get_auth_chain(auth_ids)
|
||||||
|
|
||||||
|
@ -1571,8 +1667,14 @@ class FederationHandler(BaseHandler):
|
||||||
# 4. Look at rejects and their proofs.
|
# 4. Look at rejects and their proofs.
|
||||||
# TODO.
|
# TODO.
|
||||||
|
|
||||||
context.current_state.update(auth_events)
|
context.current_state_ids.update({
|
||||||
context.state_group = None
|
k: a.event_id for k, a in auth_events.items()
|
||||||
|
if k != event_key
|
||||||
|
})
|
||||||
|
context.prev_state_ids.update({
|
||||||
|
k: a.event_id for k, a in auth_events.items()
|
||||||
|
})
|
||||||
|
context.state_group = self.store.get_next_state_group()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.auth.check(event, auth_events=auth_events)
|
self.auth.check(event, auth_events=auth_events)
|
||||||
|
@ -1758,12 +1860,12 @@ class FederationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.auth.check(event, context.current_state)
|
yield self.auth.check_from_context(event, context)
|
||||||
except AuthError as e:
|
except AuthError as e:
|
||||||
logger.warn("Denying new third party invite %r because %s", event, e)
|
logger.warn("Denying new third party invite %r because %s", event, e)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
yield self._check_signature(event, auth_events=context.current_state)
|
yield self._check_signature(event, context)
|
||||||
member_handler = self.hs.get_handlers().room_member_handler
|
member_handler = self.hs.get_handlers().room_member_handler
|
||||||
yield member_handler.send_membership_event(None, event, context)
|
yield member_handler.send_membership_event(None, event, context)
|
||||||
else:
|
else:
|
||||||
|
@ -1789,11 +1891,11 @@ class FederationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.auth.check(event, auth_events=context.current_state)
|
self.auth.check_from_context(event, context)
|
||||||
except AuthError as e:
|
except AuthError as e:
|
||||||
logger.warn("Denying third party invite %r because %s", event, e)
|
logger.warn("Denying third party invite %r because %s", event, e)
|
||||||
raise e
|
raise e
|
||||||
yield self._check_signature(event, auth_events=context.current_state)
|
yield self._check_signature(event, context)
|
||||||
|
|
||||||
returned_invite = yield self.send_invite(origin, event)
|
returned_invite = yield self.send_invite(origin, event)
|
||||||
# TODO: Make sure the signatures actually are correct.
|
# TODO: Make sure the signatures actually are correct.
|
||||||
|
@ -1807,7 +1909,12 @@ class FederationHandler(BaseHandler):
|
||||||
EventTypes.ThirdPartyInvite,
|
EventTypes.ThirdPartyInvite,
|
||||||
event.content["third_party_invite"]["signed"]["token"]
|
event.content["third_party_invite"]["signed"]["token"]
|
||||||
)
|
)
|
||||||
original_invite = context.current_state.get(key)
|
original_invite = None
|
||||||
|
original_invite_id = context.prev_state_ids.get(key)
|
||||||
|
if original_invite_id:
|
||||||
|
original_invite = yield self.store.get_event(
|
||||||
|
original_invite_id, allow_none=True
|
||||||
|
)
|
||||||
if not original_invite:
|
if not original_invite:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Could not find invite event for third_party_invite - "
|
"Could not find invite event for third_party_invite - "
|
||||||
|
@ -1824,13 +1931,13 @@ class FederationHandler(BaseHandler):
|
||||||
defer.returnValue((event, context))
|
defer.returnValue((event, context))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _check_signature(self, event, auth_events):
|
def _check_signature(self, event, context):
|
||||||
"""
|
"""
|
||||||
Checks that the signature in the event is consistent with its invite.
|
Checks that the signature in the event is consistent with its invite.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event (Event): The m.room.member event to check
|
event (Event): The m.room.member event to check
|
||||||
auth_events (dict<(event type, state_key), event>):
|
context (EventContext):
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AuthError: if signature didn't match any keys, or key has been
|
AuthError: if signature didn't match any keys, or key has been
|
||||||
|
@ -1841,10 +1948,14 @@ class FederationHandler(BaseHandler):
|
||||||
signed = event.content["third_party_invite"]["signed"]
|
signed = event.content["third_party_invite"]["signed"]
|
||||||
token = signed["token"]
|
token = signed["token"]
|
||||||
|
|
||||||
invite_event = auth_events.get(
|
invite_event_id = context.prev_state_ids.get(
|
||||||
(EventTypes.ThirdPartyInvite, token,)
|
(EventTypes.ThirdPartyInvite, token,)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
invite_event = None
|
||||||
|
if invite_event_id:
|
||||||
|
invite_event = yield self.store.get_event(invite_event_id, allow_none=True)
|
||||||
|
|
||||||
if not invite_event:
|
if not invite_event:
|
||||||
raise AuthError(403, "Could not find invite")
|
raise AuthError(403, "Could not find invite")
|
||||||
|
|
||||||
|
|
|
@ -30,6 +30,7 @@ from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLo
|
||||||
from synapse.util.caches.snapshot_cache import SnapshotCache
|
from synapse.util.caches.snapshot_cache import SnapshotCache
|
||||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||||
from synapse.util.metrics import measure_func
|
from synapse.util.metrics import measure_func
|
||||||
|
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||||
from synapse.visibility import filter_events_for_client
|
from synapse.visibility import filter_events_for_client
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
@ -248,7 +249,7 @@ class MessageHandler(BaseHandler):
|
||||||
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
|
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
|
||||||
|
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
prev_state = self.deduplicate_state_event(event, context)
|
prev_state = yield self.deduplicate_state_event(event, context)
|
||||||
if prev_state is not None:
|
if prev_state is not None:
|
||||||
defer.returnValue(prev_state)
|
defer.returnValue(prev_state)
|
||||||
|
|
||||||
|
@ -263,6 +264,7 @@ class MessageHandler(BaseHandler):
|
||||||
presence = self.hs.get_presence_handler()
|
presence = self.hs.get_presence_handler()
|
||||||
yield presence.bump_presence_active_time(user)
|
yield presence.bump_presence_active_time(user)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def deduplicate_state_event(self, event, context):
|
def deduplicate_state_event(self, event, context):
|
||||||
"""
|
"""
|
||||||
Checks whether event is in the latest resolved state in context.
|
Checks whether event is in the latest resolved state in context.
|
||||||
|
@ -270,13 +272,17 @@ class MessageHandler(BaseHandler):
|
||||||
If so, returns the version of the event in context.
|
If so, returns the version of the event in context.
|
||||||
Otherwise, returns None.
|
Otherwise, returns None.
|
||||||
"""
|
"""
|
||||||
prev_event = context.current_state.get((event.type, event.state_key))
|
prev_event_id = context.prev_state_ids.get((event.type, event.state_key))
|
||||||
|
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
|
||||||
|
if not prev_event:
|
||||||
|
return
|
||||||
|
|
||||||
if prev_event and event.user_id == prev_event.user_id:
|
if prev_event and event.user_id == prev_event.user_id:
|
||||||
prev_content = encode_canonical_json(prev_event.content)
|
prev_content = encode_canonical_json(prev_event.content)
|
||||||
next_content = encode_canonical_json(event.content)
|
next_content = encode_canonical_json(event.content)
|
||||||
if prev_content == next_content:
|
if prev_content == next_content:
|
||||||
return prev_event
|
defer.returnValue(prev_event)
|
||||||
return None
|
return
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def create_and_send_nonmember_event(
|
def create_and_send_nonmember_event(
|
||||||
|
@ -802,8 +808,8 @@ class MessageHandler(BaseHandler):
|
||||||
event = builder.build()
|
event = builder.build()
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Created event %s with current state: %s",
|
"Created event %s with state: %s",
|
||||||
event.event_id, context.current_state,
|
event.event_id, context.prev_state_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue(
|
defer.returnValue(
|
||||||
|
@ -826,12 +832,12 @@ class MessageHandler(BaseHandler):
|
||||||
self.ratelimit(requester)
|
self.ratelimit(requester)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.auth.check(event, auth_events=context.current_state)
|
yield self.auth.check_from_context(event, context)
|
||||||
except AuthError as err:
|
except AuthError as err:
|
||||||
logger.warn("Denying new event %r because %s", event, err)
|
logger.warn("Denying new event %r because %s", event, err)
|
||||||
raise err
|
raise err
|
||||||
|
|
||||||
yield self.maybe_kick_guest_users(event, context.current_state.values())
|
yield self.maybe_kick_guest_users(event, context)
|
||||||
|
|
||||||
if event.type == EventTypes.CanonicalAlias:
|
if event.type == EventTypes.CanonicalAlias:
|
||||||
# Check the alias is acually valid (at this time at least)
|
# Check the alias is acually valid (at this time at least)
|
||||||
|
@ -859,6 +865,15 @@ class MessageHandler(BaseHandler):
|
||||||
e.sender == event.sender
|
e.sender == event.sender
|
||||||
)
|
)
|
||||||
|
|
||||||
|
state_to_include_ids = [
|
||||||
|
e_id
|
||||||
|
for k, e_id in context.current_state_ids.items()
|
||||||
|
if k[0] in self.hs.config.room_invite_state_types
|
||||||
|
or k[0] == EventTypes.Member and k[1] == event.sender
|
||||||
|
]
|
||||||
|
|
||||||
|
state_to_include = yield self.store.get_events(state_to_include_ids)
|
||||||
|
|
||||||
event.unsigned["invite_room_state"] = [
|
event.unsigned["invite_room_state"] = [
|
||||||
{
|
{
|
||||||
"type": e.type,
|
"type": e.type,
|
||||||
|
@ -866,9 +881,7 @@ class MessageHandler(BaseHandler):
|
||||||
"content": e.content,
|
"content": e.content,
|
||||||
"sender": e.sender,
|
"sender": e.sender,
|
||||||
}
|
}
|
||||||
for k, e in context.current_state.items()
|
for e in state_to_include.values()
|
||||||
if e.type in self.hs.config.room_invite_state_types
|
|
||||||
or is_inviter_member_event(e)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
invitee = UserID.from_string(event.state_key)
|
invitee = UserID.from_string(event.state_key)
|
||||||
|
@ -890,7 +903,14 @@ class MessageHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
if event.type == EventTypes.Redaction:
|
if event.type == EventTypes.Redaction:
|
||||||
if self.auth.check_redaction(event, auth_events=context.current_state):
|
auth_events_ids = yield self.auth.compute_auth_events(
|
||||||
|
event, context.prev_state_ids, for_verification=True,
|
||||||
|
)
|
||||||
|
auth_events = yield self.store.get_events(auth_events_ids)
|
||||||
|
auth_events = {
|
||||||
|
(e.type, e.state_key): e for e in auth_events.values()
|
||||||
|
}
|
||||||
|
if self.auth.check_redaction(event, auth_events=auth_events):
|
||||||
original_event = yield self.store.get_event(
|
original_event = yield self.store.get_event(
|
||||||
event.redacts,
|
event.redacts,
|
||||||
check_redacted=False,
|
check_redacted=False,
|
||||||
|
@ -904,7 +924,7 @@ class MessageHandler(BaseHandler):
|
||||||
"You don't have permission to redact events"
|
"You don't have permission to redact events"
|
||||||
)
|
)
|
||||||
|
|
||||||
if event.type == EventTypes.Create and context.current_state:
|
if event.type == EventTypes.Create and context.prev_state_ids:
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
403,
|
403,
|
||||||
"Changing the room create event is forbidden",
|
"Changing the room create event is forbidden",
|
||||||
|
@ -925,16 +945,7 @@ class MessageHandler(BaseHandler):
|
||||||
event_stream_id, max_stream_id
|
event_stream_id, max_stream_id
|
||||||
)
|
)
|
||||||
|
|
||||||
destinations = set()
|
destinations = yield self.get_joined_hosts_for_room_from_state(context)
|
||||||
for k, s in context.current_state.items():
|
|
||||||
try:
|
|
||||||
if k[0] == EventTypes.Member:
|
|
||||||
if s.content["membership"] == Membership.JOIN:
|
|
||||||
destinations.add(get_domain_from_id(s.state_key))
|
|
||||||
except SynapseError:
|
|
||||||
logger.warn(
|
|
||||||
"Failed to get destination from event %s", s.event_id
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _notify():
|
def _notify():
|
||||||
|
@ -952,3 +963,39 @@ class MessageHandler(BaseHandler):
|
||||||
preserve_fn(federation_handler.handle_new_event)(
|
preserve_fn(federation_handler.handle_new_event)(
|
||||||
event, destinations=destinations,
|
event, destinations=destinations,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_joined_hosts_for_room_from_state(self, context):
|
||||||
|
state_group = context.state_group
|
||||||
|
if not state_group:
|
||||||
|
# If state_group is None it means it has yet to be assigned a
|
||||||
|
# state group, i.e. we need to make sure that calls with a state_group
|
||||||
|
# of None don't hit previous cached calls with a None state_group.
|
||||||
|
# To do this we set the state_group to a new object as object() != object()
|
||||||
|
state_group = object()
|
||||||
|
|
||||||
|
return self._get_joined_hosts_for_room_from_state(
|
||||||
|
state_group, context.current_state_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
@cachedInlineCallbacks(num_args=1, cache_context=True)
|
||||||
|
def _get_joined_hosts_for_room_from_state(self, state_group, current_state_ids,
|
||||||
|
cache_context):
|
||||||
|
|
||||||
|
# Don't bother getting state for people on the same HS
|
||||||
|
current_state = yield self.store.get_events([
|
||||||
|
e_id for key, e_id in current_state_ids.items()
|
||||||
|
if key[0] == EventTypes.Member and not self.hs.is_mine_id(key[1])
|
||||||
|
])
|
||||||
|
|
||||||
|
destinations = set()
|
||||||
|
for e in current_state.itervalues():
|
||||||
|
try:
|
||||||
|
if e.type == EventTypes.Member:
|
||||||
|
if e.content["membership"] == Membership.JOIN:
|
||||||
|
destinations.add(get_domain_from_id(e.state_key))
|
||||||
|
except SynapseError:
|
||||||
|
logger.warn(
|
||||||
|
"Failed to get destination from event %s", e.event_id
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue(destinations)
|
||||||
|
|
|
@ -88,6 +88,8 @@ class PresenceHandler(object):
|
||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
self.federation = hs.get_replication_layer()
|
self.federation = hs.get_replication_layer()
|
||||||
|
|
||||||
|
self.state = hs.get_state_handler()
|
||||||
|
|
||||||
self.federation.register_edu_handler(
|
self.federation.register_edu_handler(
|
||||||
"m.presence", self.incoming_presence
|
"m.presence", self.incoming_presence
|
||||||
)
|
)
|
||||||
|
@ -189,6 +191,13 @@ class PresenceHandler(object):
|
||||||
5000,
|
5000,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.clock.call_later(
|
||||||
|
60,
|
||||||
|
self.clock.looping_call,
|
||||||
|
self._persist_unpersisted_changes,
|
||||||
|
60 * 1000,
|
||||||
|
)
|
||||||
|
|
||||||
metrics.register_callback("wheel_timer_size", lambda: len(self.wheel_timer))
|
metrics.register_callback("wheel_timer_size", lambda: len(self.wheel_timer))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -214,6 +223,27 @@ class PresenceHandler(object):
|
||||||
])
|
])
|
||||||
logger.info("Finished _on_shutdown")
|
logger.info("Finished _on_shutdown")
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _persist_unpersisted_changes(self):
|
||||||
|
"""We periodically persist the unpersisted changes, as otherwise they
|
||||||
|
may stack up and slow down shutdown times.
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
"Performing _persist_unpersisted_changes. Persiting %d unpersisted changes",
|
||||||
|
len(self.unpersisted_users_changes)
|
||||||
|
)
|
||||||
|
|
||||||
|
unpersisted = self.unpersisted_users_changes
|
||||||
|
self.unpersisted_users_changes = set()
|
||||||
|
|
||||||
|
if unpersisted:
|
||||||
|
yield self.store.update_presence([
|
||||||
|
self.user_to_current_state[user_id]
|
||||||
|
for user_id in unpersisted
|
||||||
|
])
|
||||||
|
|
||||||
|
logger.info("Finished _persist_unpersisted_changes")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _update_states(self, new_states):
|
def _update_states(self, new_states):
|
||||||
"""Updates presence of users. Sets the appropriate timeouts. Pokes
|
"""Updates presence of users. Sets the appropriate timeouts. Pokes
|
||||||
|
@ -532,7 +562,9 @@ class PresenceHandler(object):
|
||||||
if not local_states:
|
if not local_states:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
hosts = yield self.store.get_joined_hosts_for_room(room_id)
|
users = yield self.state.get_current_user_in_room(room_id)
|
||||||
|
hosts = set(get_domain_from_id(u) for u in users)
|
||||||
|
|
||||||
for host in hosts:
|
for host in hosts:
|
||||||
hosts_to_states.setdefault(host, []).extend(local_states)
|
hosts_to_states.setdefault(host, []).extend(local_states)
|
||||||
|
|
||||||
|
@ -725,13 +757,13 @@ class PresenceHandler(object):
|
||||||
# don't need to send to local clients here, as that is done as part
|
# don't need to send to local clients here, as that is done as part
|
||||||
# of the event stream/sync.
|
# of the event stream/sync.
|
||||||
# TODO: Only send to servers not already in the room.
|
# TODO: Only send to servers not already in the room.
|
||||||
|
user_ids = yield self.state.get_current_user_in_room(room_id)
|
||||||
if self.is_mine(user):
|
if self.is_mine(user):
|
||||||
state = yield self.current_state_for_user(user.to_string())
|
state = yield self.current_state_for_user(user.to_string())
|
||||||
|
|
||||||
hosts = yield self.store.get_joined_hosts_for_room(room_id)
|
hosts = set(get_domain_from_id(u) for u in user_ids)
|
||||||
self._push_to_remotes({host: (state,) for host in hosts})
|
self._push_to_remotes({host: (state,) for host in hosts})
|
||||||
else:
|
else:
|
||||||
user_ids = yield self.store.get_users_in_room(room_id)
|
|
||||||
user_ids = filter(self.is_mine_id, user_ids)
|
user_ids = filter(self.is_mine_id, user_ids)
|
||||||
|
|
||||||
states = yield self.current_state_for_users(user_ids)
|
states = yield self.current_state_for_users(user_ids)
|
||||||
|
@ -919,6 +951,11 @@ def should_notify(old_state, new_state):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY:
|
if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY:
|
||||||
|
# Only notify about last active bumps if we're not currently acive
|
||||||
|
if not (old_state.currently_active and new_state.currently_active):
|
||||||
|
return True
|
||||||
|
|
||||||
|
elif new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY:
|
||||||
# Always notify for a transition where last active gets bumped.
|
# Always notify for a transition where last active gets bumped.
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -955,6 +992,7 @@ class PresenceEventSource(object):
|
||||||
self.get_presence_handler = hs.get_presence_handler
|
self.get_presence_handler = hs.get_presence_handler
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
self.state = hs.get_state_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
|
@ -1017,7 +1055,7 @@ class PresenceEventSource(object):
|
||||||
|
|
||||||
user_ids_to_check = set()
|
user_ids_to_check = set()
|
||||||
for room_id in room_ids:
|
for room_id in room_ids:
|
||||||
users = yield self.store.get_users_in_room(room_id)
|
users = yield self.state.get_current_user_in_room(room_id)
|
||||||
user_ids_to_check.update(users)
|
user_ids_to_check.update(users)
|
||||||
|
|
||||||
user_ids_to_check.update(friends)
|
user_ids_to_check.update(friends)
|
||||||
|
|
|
@ -18,6 +18,7 @@ from ._base import BaseHandler
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.util.logcontext import PreserveLoggingContext
|
from synapse.util.logcontext import PreserveLoggingContext
|
||||||
|
from synapse.types import get_domain_from_id
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
@ -37,6 +38,7 @@ class ReceiptsHandler(BaseHandler):
|
||||||
"m.receipt", self._received_remote_receipt
|
"m.receipt", self._received_remote_receipt
|
||||||
)
|
)
|
||||||
self.clock = self.hs.get_clock()
|
self.clock = self.hs.get_clock()
|
||||||
|
self.state = hs.get_state_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def received_client_receipt(self, room_id, receipt_type, user_id,
|
def received_client_receipt(self, room_id, receipt_type, user_id,
|
||||||
|
@ -133,7 +135,8 @@ class ReceiptsHandler(BaseHandler):
|
||||||
event_ids = receipt["event_ids"]
|
event_ids = receipt["event_ids"]
|
||||||
data = receipt["data"]
|
data = receipt["data"]
|
||||||
|
|
||||||
remotedomains = yield self.store.get_joined_hosts_for_room(room_id)
|
users = yield self.state.get_current_user_in_room(room_id)
|
||||||
|
remotedomains = set(get_domain_from_id(u) for u in users)
|
||||||
remotedomains = remotedomains.copy()
|
remotedomains = remotedomains.copy()
|
||||||
remotedomains.discard(self.server_name)
|
remotedomains.discard(self.server_name)
|
||||||
|
|
||||||
|
|
|
@ -85,6 +85,12 @@ class RoomMemberHandler(BaseHandler):
|
||||||
prev_event_ids=prev_event_ids,
|
prev_event_ids=prev_event_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check if this event matches the previous membership event for the user.
|
||||||
|
duplicate = yield msg_handler.deduplicate_state_event(event, context)
|
||||||
|
if duplicate is not None:
|
||||||
|
# Discard the new event since this membership change is a no-op.
|
||||||
|
return
|
||||||
|
|
||||||
yield msg_handler.handle_new_client_event(
|
yield msg_handler.handle_new_client_event(
|
||||||
requester,
|
requester,
|
||||||
event,
|
event,
|
||||||
|
@ -93,19 +99,25 @@ class RoomMemberHandler(BaseHandler):
|
||||||
ratelimit=ratelimit,
|
ratelimit=ratelimit,
|
||||||
)
|
)
|
||||||
|
|
||||||
prev_member_event = context.current_state.get(
|
prev_member_event_id = context.prev_state_ids.get(
|
||||||
(EventTypes.Member, target.to_string()),
|
(EventTypes.Member, target.to_string()),
|
||||||
None
|
None
|
||||||
)
|
)
|
||||||
|
|
||||||
if event.membership == Membership.JOIN:
|
if event.membership == Membership.JOIN:
|
||||||
if not prev_member_event or prev_member_event.membership != Membership.JOIN:
|
|
||||||
# Only fire user_joined_room if the user has acutally joined the
|
# Only fire user_joined_room if the user has acutally joined the
|
||||||
# room. Don't bother if the user is just changing their profile
|
# room. Don't bother if the user is just changing their profile
|
||||||
# info.
|
# info.
|
||||||
|
newly_joined = True
|
||||||
|
if prev_member_event_id:
|
||||||
|
prev_member_event = yield self.store.get_event(prev_member_event_id)
|
||||||
|
newly_joined = prev_member_event.membership != Membership.JOIN
|
||||||
|
if newly_joined:
|
||||||
yield user_joined_room(self.distributor, target, room_id)
|
yield user_joined_room(self.distributor, target, room_id)
|
||||||
elif event.membership == Membership.LEAVE:
|
elif event.membership == Membership.LEAVE:
|
||||||
if prev_member_event and prev_member_event.membership == Membership.JOIN:
|
if prev_member_event_id:
|
||||||
|
prev_member_event = yield self.store.get_event(prev_member_event_id)
|
||||||
|
if prev_member_event.membership == Membership.JOIN:
|
||||||
user_left_room(self.distributor, target, room_id)
|
user_left_room(self.distributor, target, room_id)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -195,16 +207,19 @@ class RoomMemberHandler(BaseHandler):
|
||||||
remote_room_hosts = []
|
remote_room_hosts = []
|
||||||
|
|
||||||
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||||
current_state = yield self.state_handler.get_current_state(
|
current_state_ids = yield self.state_handler.get_current_state_ids(
|
||||||
room_id, latest_event_ids=latest_event_ids,
|
room_id, latest_event_ids=latest_event_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
old_state = current_state.get((EventTypes.Member, target.to_string()))
|
old_state_id = current_state_ids.get((EventTypes.Member, target.to_string()))
|
||||||
|
if old_state_id:
|
||||||
|
old_state = yield self.store.get_event(old_state_id, allow_none=True)
|
||||||
old_membership = old_state.content.get("membership") if old_state else None
|
old_membership = old_state.content.get("membership") if old_state else None
|
||||||
if action == "unban" and old_membership != "ban":
|
if action == "unban" and old_membership != "ban":
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
403,
|
403,
|
||||||
"Cannot unban user who was not banned (membership=%s)" % old_membership,
|
"Cannot unban user who was not banned"
|
||||||
|
" (membership=%s)" % old_membership,
|
||||||
errcode=Codes.BAD_STATE
|
errcode=Codes.BAD_STATE
|
||||||
)
|
)
|
||||||
if old_membership == "ban" and action != "unban":
|
if old_membership == "ban" and action != "unban":
|
||||||
|
@ -214,10 +229,10 @@ class RoomMemberHandler(BaseHandler):
|
||||||
errcode=Codes.BAD_STATE
|
errcode=Codes.BAD_STATE
|
||||||
)
|
)
|
||||||
|
|
||||||
is_host_in_room = self.is_host_in_room(current_state)
|
is_host_in_room = yield self._is_host_in_room(current_state_ids)
|
||||||
|
|
||||||
if effective_membership_state == Membership.JOIN:
|
if effective_membership_state == Membership.JOIN:
|
||||||
if requester.is_guest and not self._can_guest_join(current_state):
|
if requester.is_guest and not self._can_guest_join(current_state_ids):
|
||||||
# This should be an auth check, but guests are a local concept,
|
# This should be an auth check, but guests are a local concept,
|
||||||
# so don't really fit into the general auth process.
|
# so don't really fit into the general auth process.
|
||||||
raise AuthError(403, "Guest access not allowed")
|
raise AuthError(403, "Guest access not allowed")
|
||||||
|
@ -326,12 +341,14 @@ class RoomMemberHandler(BaseHandler):
|
||||||
requester = synapse.types.create_requester(target_user)
|
requester = synapse.types.create_requester(target_user)
|
||||||
|
|
||||||
message_handler = self.hs.get_handlers().message_handler
|
message_handler = self.hs.get_handlers().message_handler
|
||||||
prev_event = message_handler.deduplicate_state_event(event, context)
|
prev_event = yield message_handler.deduplicate_state_event(event, context)
|
||||||
if prev_event is not None:
|
if prev_event is not None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if event.membership == Membership.JOIN:
|
if event.membership == Membership.JOIN:
|
||||||
if requester.is_guest and not self._can_guest_join(context.current_state):
|
if requester.is_guest:
|
||||||
|
guest_can_join = yield self._can_guest_join(context.prev_state_ids)
|
||||||
|
if not guest_can_join:
|
||||||
# This should be an auth check, but guests are a local concept,
|
# This should be an auth check, but guests are a local concept,
|
||||||
# so don't really fit into the general auth process.
|
# so don't really fit into the general auth process.
|
||||||
raise AuthError(403, "Guest access not allowed")
|
raise AuthError(403, "Guest access not allowed")
|
||||||
|
@ -344,27 +361,39 @@ class RoomMemberHandler(BaseHandler):
|
||||||
ratelimit=ratelimit,
|
ratelimit=ratelimit,
|
||||||
)
|
)
|
||||||
|
|
||||||
prev_member_event = context.current_state.get(
|
prev_member_event_id = context.prev_state_ids.get(
|
||||||
(EventTypes.Member, target_user.to_string()),
|
(EventTypes.Member, event.state_key),
|
||||||
None
|
None
|
||||||
)
|
)
|
||||||
|
|
||||||
if event.membership == Membership.JOIN:
|
if event.membership == Membership.JOIN:
|
||||||
if not prev_member_event or prev_member_event.membership != Membership.JOIN:
|
|
||||||
# Only fire user_joined_room if the user has acutally joined the
|
# Only fire user_joined_room if the user has acutally joined the
|
||||||
# room. Don't bother if the user is just changing their profile
|
# room. Don't bother if the user is just changing their profile
|
||||||
# info.
|
# info.
|
||||||
|
newly_joined = True
|
||||||
|
if prev_member_event_id:
|
||||||
|
prev_member_event = yield self.store.get_event(prev_member_event_id)
|
||||||
|
newly_joined = prev_member_event.membership != Membership.JOIN
|
||||||
|
if newly_joined:
|
||||||
yield user_joined_room(self.distributor, target_user, room_id)
|
yield user_joined_room(self.distributor, target_user, room_id)
|
||||||
elif event.membership == Membership.LEAVE:
|
elif event.membership == Membership.LEAVE:
|
||||||
if prev_member_event and prev_member_event.membership == Membership.JOIN:
|
if prev_member_event_id:
|
||||||
|
prev_member_event = yield self.store.get_event(prev_member_event_id)
|
||||||
|
if prev_member_event.membership == Membership.JOIN:
|
||||||
user_left_room(self.distributor, target_user, room_id)
|
user_left_room(self.distributor, target_user, room_id)
|
||||||
|
|
||||||
def _can_guest_join(self, current_state):
|
@defer.inlineCallbacks
|
||||||
|
def _can_guest_join(self, current_state_ids):
|
||||||
"""
|
"""
|
||||||
Returns whether a guest can join a room based on its current state.
|
Returns whether a guest can join a room based on its current state.
|
||||||
"""
|
"""
|
||||||
guest_access = current_state.get((EventTypes.GuestAccess, ""), None)
|
guest_access_id = current_state_ids.get((EventTypes.GuestAccess, ""), None)
|
||||||
return (
|
if not guest_access_id:
|
||||||
|
defer.returnValue(False)
|
||||||
|
|
||||||
|
guest_access = yield self.store.get_event(guest_access_id)
|
||||||
|
|
||||||
|
defer.returnValue(
|
||||||
guest_access
|
guest_access
|
||||||
and guest_access.content
|
and guest_access.content
|
||||||
and "guest_access" in guest_access.content
|
and "guest_access" in guest_access.content
|
||||||
|
@ -683,3 +712,24 @@ class RoomMemberHandler(BaseHandler):
|
||||||
|
|
||||||
if membership:
|
if membership:
|
||||||
yield self.store.forget(user_id, room_id)
|
yield self.store.forget(user_id, room_id)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _is_host_in_room(self, current_state_ids):
|
||||||
|
# Have we just created the room, and is this about to be the very
|
||||||
|
# first member event?
|
||||||
|
create_event_id = current_state_ids.get(("m.room.create", ""))
|
||||||
|
if len(current_state_ids) == 1 and create_event_id:
|
||||||
|
defer.returnValue(self.hs.is_mine_id(create_event_id))
|
||||||
|
|
||||||
|
for (etype, state_key), event_id in current_state_ids.items():
|
||||||
|
if etype != EventTypes.Member or not self.hs.is_mine_id(state_key):
|
||||||
|
continue
|
||||||
|
|
||||||
|
event = yield self.store.get_event(event_id, allow_none=True)
|
||||||
|
if not event:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if event.membership == Membership.JOIN:
|
||||||
|
defer.returnValue(True)
|
||||||
|
|
||||||
|
defer.returnValue(False)
|
||||||
|
|
|
@ -35,6 +35,7 @@ SyncConfig = collections.namedtuple("SyncConfig", [
|
||||||
"filter_collection",
|
"filter_collection",
|
||||||
"is_guest",
|
"is_guest",
|
||||||
"request_key",
|
"request_key",
|
||||||
|
"device_id",
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
|
@ -113,6 +114,7 @@ class SyncResult(collections.namedtuple("SyncResult", [
|
||||||
"joined", # JoinedSyncResult for each joined room.
|
"joined", # JoinedSyncResult for each joined room.
|
||||||
"invited", # InvitedSyncResult for each invited room.
|
"invited", # InvitedSyncResult for each invited room.
|
||||||
"archived", # ArchivedSyncResult for each archived room.
|
"archived", # ArchivedSyncResult for each archived room.
|
||||||
|
"to_device", # List of direct messages for the device.
|
||||||
])):
|
])):
|
||||||
__slots__ = []
|
__slots__ = []
|
||||||
|
|
||||||
|
@ -126,7 +128,8 @@ class SyncResult(collections.namedtuple("SyncResult", [
|
||||||
self.joined or
|
self.joined or
|
||||||
self.invited or
|
self.invited or
|
||||||
self.archived or
|
self.archived or
|
||||||
self.account_data
|
self.account_data or
|
||||||
|
self.to_device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -139,6 +142,7 @@ class SyncHandler(object):
|
||||||
self.event_sources = hs.get_event_sources()
|
self.event_sources = hs.get_event_sources()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.response_cache = ResponseCache(hs)
|
self.response_cache = ResponseCache(hs)
|
||||||
|
self.state = hs.get_state_handler()
|
||||||
|
|
||||||
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
|
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
|
||||||
full_state=False):
|
full_state=False):
|
||||||
|
@ -355,11 +359,11 @@ class SyncHandler(object):
|
||||||
Returns:
|
Returns:
|
||||||
A Deferred map from ((type, state_key)->Event)
|
A Deferred map from ((type, state_key)->Event)
|
||||||
"""
|
"""
|
||||||
state = yield self.store.get_state_for_event(event.event_id)
|
state_ids = yield self.store.get_state_ids_for_event(event.event_id)
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
state = state.copy()
|
state_ids = state_ids.copy()
|
||||||
state[(event.type, event.state_key)] = event
|
state_ids[(event.type, event.state_key)] = event.event_id
|
||||||
defer.returnValue(state)
|
defer.returnValue(state_ids)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_state_at(self, room_id, stream_position):
|
def get_state_at(self, room_id, stream_position):
|
||||||
|
@ -412,57 +416,61 @@ class SyncHandler(object):
|
||||||
with Measure(self.clock, "compute_state_delta"):
|
with Measure(self.clock, "compute_state_delta"):
|
||||||
if full_state:
|
if full_state:
|
||||||
if batch:
|
if batch:
|
||||||
current_state = yield self.store.get_state_for_event(
|
current_state_ids = yield self.store.get_state_ids_for_event(
|
||||||
batch.events[-1].event_id
|
batch.events[-1].event_id
|
||||||
)
|
)
|
||||||
|
|
||||||
state = yield self.store.get_state_for_event(
|
state_ids = yield self.store.get_state_ids_for_event(
|
||||||
batch.events[0].event_id
|
batch.events[0].event_id
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
current_state = yield self.get_state_at(
|
current_state_ids = yield self.get_state_at(
|
||||||
room_id, stream_position=now_token
|
room_id, stream_position=now_token
|
||||||
)
|
)
|
||||||
|
|
||||||
state = current_state
|
state_ids = current_state_ids
|
||||||
|
|
||||||
timeline_state = {
|
timeline_state = {
|
||||||
(event.type, event.state_key): event
|
(event.type, event.state_key): event.event_id
|
||||||
for event in batch.events if event.is_state()
|
for event in batch.events if event.is_state()
|
||||||
}
|
}
|
||||||
|
|
||||||
state = _calculate_state(
|
state_ids = _calculate_state(
|
||||||
timeline_contains=timeline_state,
|
timeline_contains=timeline_state,
|
||||||
timeline_start=state,
|
timeline_start=state_ids,
|
||||||
previous={},
|
previous={},
|
||||||
current=current_state,
|
current=current_state_ids,
|
||||||
)
|
)
|
||||||
elif batch.limited:
|
elif batch.limited:
|
||||||
state_at_previous_sync = yield self.get_state_at(
|
state_at_previous_sync = yield self.get_state_at(
|
||||||
room_id, stream_position=since_token
|
room_id, stream_position=since_token
|
||||||
)
|
)
|
||||||
|
|
||||||
current_state = yield self.store.get_state_for_event(
|
current_state_ids = yield self.store.get_state_ids_for_event(
|
||||||
batch.events[-1].event_id
|
batch.events[-1].event_id
|
||||||
)
|
)
|
||||||
|
|
||||||
state_at_timeline_start = yield self.store.get_state_for_event(
|
state_at_timeline_start = yield self.store.get_state_ids_for_event(
|
||||||
batch.events[0].event_id
|
batch.events[0].event_id
|
||||||
)
|
)
|
||||||
|
|
||||||
timeline_state = {
|
timeline_state = {
|
||||||
(event.type, event.state_key): event
|
(event.type, event.state_key): event.event_id
|
||||||
for event in batch.events if event.is_state()
|
for event in batch.events if event.is_state()
|
||||||
}
|
}
|
||||||
|
|
||||||
state = _calculate_state(
|
state_ids = _calculate_state(
|
||||||
timeline_contains=timeline_state,
|
timeline_contains=timeline_state,
|
||||||
timeline_start=state_at_timeline_start,
|
timeline_start=state_at_timeline_start,
|
||||||
previous=state_at_previous_sync,
|
previous=state_at_previous_sync,
|
||||||
current=current_state,
|
current=current_state_ids,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
state_ids = {}
|
||||||
|
|
||||||
state = {}
|
state = {}
|
||||||
|
if state_ids:
|
||||||
|
state = yield self.store.get_events(state_ids.values())
|
||||||
|
|
||||||
defer.returnValue({
|
defer.returnValue({
|
||||||
(e.type, e.state_key): e
|
(e.type, e.state_key): e
|
||||||
|
@ -527,15 +535,57 @@ class SyncHandler(object):
|
||||||
sync_result_builder, newly_joined_rooms, newly_joined_users
|
sync_result_builder, newly_joined_rooms, newly_joined_users
|
||||||
)
|
)
|
||||||
|
|
||||||
|
yield self._generate_sync_entry_for_to_device(sync_result_builder)
|
||||||
|
|
||||||
defer.returnValue(SyncResult(
|
defer.returnValue(SyncResult(
|
||||||
presence=sync_result_builder.presence,
|
presence=sync_result_builder.presence,
|
||||||
account_data=sync_result_builder.account_data,
|
account_data=sync_result_builder.account_data,
|
||||||
joined=sync_result_builder.joined,
|
joined=sync_result_builder.joined,
|
||||||
invited=sync_result_builder.invited,
|
invited=sync_result_builder.invited,
|
||||||
archived=sync_result_builder.archived,
|
archived=sync_result_builder.archived,
|
||||||
|
to_device=sync_result_builder.to_device,
|
||||||
next_batch=sync_result_builder.now_token,
|
next_batch=sync_result_builder.now_token,
|
||||||
))
|
))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _generate_sync_entry_for_to_device(self, sync_result_builder):
|
||||||
|
"""Generates the portion of the sync response. Populates
|
||||||
|
`sync_result_builder` with the result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sync_result_builder(SyncResultBuilder)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred(dict): A dictionary containing the per room account data.
|
||||||
|
"""
|
||||||
|
user_id = sync_result_builder.sync_config.user.to_string()
|
||||||
|
device_id = sync_result_builder.sync_config.device_id
|
||||||
|
now_token = sync_result_builder.now_token
|
||||||
|
since_stream_id = 0
|
||||||
|
if sync_result_builder.since_token is not None:
|
||||||
|
since_stream_id = int(sync_result_builder.since_token.to_device_key)
|
||||||
|
|
||||||
|
if since_stream_id != int(now_token.to_device_key):
|
||||||
|
# We only delete messages when a new message comes in, but that's
|
||||||
|
# fine so long as we delete them at some point.
|
||||||
|
|
||||||
|
logger.debug("Deleting messages up to %d", since_stream_id)
|
||||||
|
yield self.store.delete_messages_for_device(
|
||||||
|
user_id, device_id, since_stream_id
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("Getting messages up to %d", now_token.to_device_key)
|
||||||
|
messages, stream_id = yield self.store.get_new_messages_for_device(
|
||||||
|
user_id, device_id, since_stream_id, now_token.to_device_key
|
||||||
|
)
|
||||||
|
logger.debug("Got messages up to %d: %r", stream_id, messages)
|
||||||
|
sync_result_builder.now_token = now_token.copy_and_replace(
|
||||||
|
"to_device_key", stream_id
|
||||||
|
)
|
||||||
|
sync_result_builder.to_device = messages
|
||||||
|
else:
|
||||||
|
sync_result_builder.to_device = []
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _generate_sync_entry_for_account_data(self, sync_result_builder):
|
def _generate_sync_entry_for_account_data(self, sync_result_builder):
|
||||||
"""Generates the account data portion of the sync response. Populates
|
"""Generates the account data portion of the sync response. Populates
|
||||||
|
@ -626,7 +676,7 @@ class SyncHandler(object):
|
||||||
|
|
||||||
extra_users_ids = set(newly_joined_users)
|
extra_users_ids = set(newly_joined_users)
|
||||||
for room_id in newly_joined_rooms:
|
for room_id in newly_joined_rooms:
|
||||||
users = yield self.store.get_users_in_room(room_id)
|
users = yield self.state.get_current_user_in_room(room_id)
|
||||||
extra_users_ids.update(users)
|
extra_users_ids.update(users)
|
||||||
extra_users_ids.discard(user.to_string())
|
extra_users_ids.discard(user.to_string())
|
||||||
|
|
||||||
|
@ -766,8 +816,13 @@ class SyncHandler(object):
|
||||||
# the last sync (even if we have since left). This is to make sure
|
# the last sync (even if we have since left). This is to make sure
|
||||||
# we do send down the room, and with full state, where necessary
|
# we do send down the room, and with full state, where necessary
|
||||||
if room_id in joined_room_ids or has_join:
|
if room_id in joined_room_ids or has_join:
|
||||||
old_state = yield self.get_state_at(room_id, since_token)
|
old_state_ids = yield self.get_state_at(room_id, since_token)
|
||||||
old_mem_ev = old_state.get((EventTypes.Member, user_id), None)
|
old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
|
||||||
|
old_mem_ev = None
|
||||||
|
if old_mem_ev_id:
|
||||||
|
old_mem_ev = yield self.store.get_event(
|
||||||
|
old_mem_ev_id, allow_none=True
|
||||||
|
)
|
||||||
if not old_mem_ev or old_mem_ev.membership != Membership.JOIN:
|
if not old_mem_ev or old_mem_ev.membership != Membership.JOIN:
|
||||||
newly_joined_rooms.append(room_id)
|
newly_joined_rooms.append(room_id)
|
||||||
|
|
||||||
|
@ -1059,27 +1114,25 @@ def _calculate_state(timeline_contains, timeline_start, previous, current):
|
||||||
Returns:
|
Returns:
|
||||||
dict
|
dict
|
||||||
"""
|
"""
|
||||||
event_id_to_state = {
|
event_id_to_key = {
|
||||||
e.event_id: e
|
e: key
|
||||||
for e in itertools.chain(
|
for key, e in itertools.chain(
|
||||||
timeline_contains.values(),
|
timeline_contains.items(),
|
||||||
previous.values(),
|
previous.items(),
|
||||||
timeline_start.values(),
|
timeline_start.items(),
|
||||||
current.values(),
|
current.items(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
c_ids = set(e.event_id for e in current.values())
|
c_ids = set(e for e in current.values())
|
||||||
tc_ids = set(e.event_id for e in timeline_contains.values())
|
tc_ids = set(e for e in timeline_contains.values())
|
||||||
p_ids = set(e.event_id for e in previous.values())
|
p_ids = set(e for e in previous.values())
|
||||||
ts_ids = set(e.event_id for e in timeline_start.values())
|
ts_ids = set(e for e in timeline_start.values())
|
||||||
|
|
||||||
state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids
|
state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids
|
||||||
|
|
||||||
evs = (event_id_to_state[e] for e in state_ids)
|
|
||||||
return {
|
return {
|
||||||
(e.type, e.state_key): e
|
event_id_to_key[e]: e for e in state_ids
|
||||||
for e in evs
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1103,6 +1156,7 @@ class SyncResultBuilder(object):
|
||||||
self.joined = []
|
self.joined = []
|
||||||
self.invited = []
|
self.invited = []
|
||||||
self.archived = []
|
self.archived = []
|
||||||
|
self.device = []
|
||||||
|
|
||||||
|
|
||||||
class RoomSyncResultBuilder(object):
|
class RoomSyncResultBuilder(object):
|
||||||
|
|
|
@ -20,7 +20,7 @@ from synapse.util.logcontext import (
|
||||||
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
|
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
|
||||||
)
|
)
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID, get_domain_from_id
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
@ -42,6 +42,7 @@ class TypingHandler(object):
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.is_mine_id = hs.is_mine_id
|
self.is_mine_id = hs.is_mine_id
|
||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
|
self.state = hs.get_state_handler()
|
||||||
|
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
|
@ -166,7 +167,8 @@ class TypingHandler(object):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _push_update(self, room_id, user_id, typing):
|
def _push_update(self, room_id, user_id, typing):
|
||||||
domains = yield self.store.get_joined_hosts_for_room(room_id)
|
users = yield self.state.get_current_user_in_room(room_id)
|
||||||
|
domains = set(get_domain_from_id(u) for u in users)
|
||||||
|
|
||||||
deferreds = []
|
deferreds = []
|
||||||
for domain in domains:
|
for domain in domains:
|
||||||
|
@ -199,7 +201,8 @@ class TypingHandler(object):
|
||||||
# Check that the string is a valid user id
|
# Check that the string is a valid user id
|
||||||
UserID.from_string(user_id)
|
UserID.from_string(user_id)
|
||||||
|
|
||||||
domains = yield self.store.get_joined_hosts_for_room(room_id)
|
users = yield self.state.get_current_user_in_room(room_id)
|
||||||
|
domains = set(get_domain_from_id(u) for u in users)
|
||||||
|
|
||||||
if self.server_name in domains:
|
if self.server_name in domains:
|
||||||
self._push_update_local(
|
self._push_update_local(
|
||||||
|
|
|
@ -423,7 +423,8 @@ class Notifier(object):
|
||||||
def _is_world_readable(self, room_id):
|
def _is_world_readable(self, room_id):
|
||||||
state = yield self.state_handler.get_current_state(
|
state = yield self.state_handler.get_current_state(
|
||||||
room_id,
|
room_id,
|
||||||
EventTypes.RoomHistoryVisibility
|
EventTypes.RoomHistoryVisibility,
|
||||||
|
"",
|
||||||
)
|
)
|
||||||
if state and "history_visibility" in state.content:
|
if state and "history_visibility" in state.content:
|
||||||
defer.returnValue(state.content["history_visibility"] == "world_readable")
|
defer.returnValue(state.content["history_visibility"] == "world_readable")
|
||||||
|
|
|
@ -40,12 +40,12 @@ class ActionGenerator:
|
||||||
def handle_push_actions_for_event(self, event, context):
|
def handle_push_actions_for_event(self, event, context):
|
||||||
with Measure(self.clock, "evaluator_for_event"):
|
with Measure(self.clock, "evaluator_for_event"):
|
||||||
bulk_evaluator = yield evaluator_for_event(
|
bulk_evaluator = yield evaluator_for_event(
|
||||||
event, self.hs, self.store, context.state_group, context.current_state
|
event, self.hs, self.store, context
|
||||||
)
|
)
|
||||||
|
|
||||||
with Measure(self.clock, "action_for_event_by_user"):
|
with Measure(self.clock, "action_for_event_by_user"):
|
||||||
actions_by_user = yield bulk_evaluator.action_for_event_by_user(
|
actions_by_user = yield bulk_evaluator.action_for_event_by_user(
|
||||||
event, context.current_state
|
event, context
|
||||||
)
|
)
|
||||||
|
|
||||||
context.push_actions = [
|
context.push_actions = [
|
||||||
|
|
|
@ -19,8 +19,8 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from .push_rule_evaluator import PushRuleEvaluatorForEvent
|
from .push_rule_evaluator import PushRuleEvaluatorForEvent
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.visibility import filter_events_for_clients
|
from synapse.visibility import filter_events_for_clients_context
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -36,9 +36,9 @@ def _get_rules(room_id, user_ids, store):
|
||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def evaluator_for_event(event, hs, store, state_group, current_state):
|
def evaluator_for_event(event, hs, store, context):
|
||||||
rules_by_user = yield store.bulk_get_push_rules_for_room(
|
rules_by_user = yield store.bulk_get_push_rules_for_room(
|
||||||
event.room_id, state_group, current_state
|
event, context
|
||||||
)
|
)
|
||||||
|
|
||||||
# if this event is an invite event, we may need to run rules for the user
|
# if this event is an invite event, we may need to run rules for the user
|
||||||
|
@ -72,7 +72,7 @@ class BulkPushRuleEvaluator:
|
||||||
self.store = store
|
self.store = store
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def action_for_event_by_user(self, event, current_state):
|
def action_for_event_by_user(self, event, context):
|
||||||
actions_by_user = {}
|
actions_by_user = {}
|
||||||
|
|
||||||
# None of these users can be peeking since this list of users comes
|
# None of these users can be peeking since this list of users comes
|
||||||
|
@ -82,27 +82,25 @@ class BulkPushRuleEvaluator:
|
||||||
(u, False) for u in self.rules_by_user.keys()
|
(u, False) for u in self.rules_by_user.keys()
|
||||||
]
|
]
|
||||||
|
|
||||||
filtered_by_user = yield filter_events_for_clients(
|
filtered_by_user = yield filter_events_for_clients_context(
|
||||||
self.store, user_tuples, [event], {event.event_id: current_state}
|
self.store, user_tuples, [event], {event.event_id: context}
|
||||||
)
|
)
|
||||||
|
|
||||||
room_members = set(
|
room_members = yield self.store.get_joined_users_from_context(
|
||||||
e.state_key for e in current_state.values()
|
event, context
|
||||||
if e.type == EventTypes.Member and e.membership == Membership.JOIN
|
|
||||||
)
|
)
|
||||||
|
|
||||||
evaluator = PushRuleEvaluatorForEvent(event, len(room_members))
|
evaluator = PushRuleEvaluatorForEvent(event, len(room_members))
|
||||||
|
|
||||||
condition_cache = {}
|
condition_cache = {}
|
||||||
|
|
||||||
display_names = {}
|
|
||||||
for ev in current_state.values():
|
|
||||||
nm = ev.content.get("displayname", None)
|
|
||||||
if nm and ev.type == EventTypes.Member:
|
|
||||||
display_names[ev.state_key] = nm
|
|
||||||
|
|
||||||
for uid, rules in self.rules_by_user.items():
|
for uid, rules in self.rules_by_user.items():
|
||||||
display_name = display_names.get(uid, None)
|
display_name = None
|
||||||
|
member_ev_id = context.current_state_ids.get((EventTypes.Member, uid))
|
||||||
|
if member_ev_id:
|
||||||
|
member_ev = yield self.store.get_event(member_ev_id, allow_none=True)
|
||||||
|
if member_ev:
|
||||||
|
display_name = member_ev.content.get("displayname", None)
|
||||||
|
|
||||||
filtered = filtered_by_user[uid]
|
filtered = filtered_by_user[uid]
|
||||||
if len(filtered) == 0:
|
if len(filtered) == 0:
|
||||||
|
|
|
@ -245,7 +245,7 @@ class HttpPusher(object):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _build_notification_dict(self, event, tweaks, badge):
|
def _build_notification_dict(self, event, tweaks, badge):
|
||||||
ctx = yield push_tools.get_context_for_event(
|
ctx = yield push_tools.get_context_for_event(
|
||||||
self.state_handler, event, self.user_id
|
self.store, self.state_handler, event, self.user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
d = {
|
d = {
|
||||||
|
|
|
@ -22,7 +22,7 @@ from email.mime.text import MIMEText
|
||||||
from email.mime.multipart import MIMEMultipart
|
from email.mime.multipart import MIMEMultipart
|
||||||
|
|
||||||
from synapse.util.async import concurrently_execute
|
from synapse.util.async import concurrently_execute
|
||||||
from synapse.util.presentable_names import (
|
from synapse.push.presentable_names import (
|
||||||
calculate_room_name, name_from_member_event, descriptor_from_member_events
|
calculate_room_name, name_from_member_event, descriptor_from_member_events
|
||||||
)
|
)
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
@ -139,7 +139,7 @@ class Mailer(object):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _fetch_room_state(room_id):
|
def _fetch_room_state(room_id):
|
||||||
room_state = yield self.state_handler.get_current_state(room_id)
|
room_state = yield self.state_handler.get_current_state_ids(room_id)
|
||||||
state_by_room[room_id] = room_state
|
state_by_room[room_id] = room_state
|
||||||
|
|
||||||
# Run at most 3 of these at once: sync does 10 at a time but email
|
# Run at most 3 of these at once: sync does 10 at a time but email
|
||||||
|
@ -159,11 +159,12 @@ class Mailer(object):
|
||||||
)
|
)
|
||||||
rooms.append(roomvars)
|
rooms.append(roomvars)
|
||||||
|
|
||||||
reason['room_name'] = calculate_room_name(
|
reason['room_name'] = yield calculate_room_name(
|
||||||
state_by_room[reason['room_id']], user_id, fallback_to_members=True
|
self.store, state_by_room[reason['room_id']], user_id,
|
||||||
|
fallback_to_members=True
|
||||||
)
|
)
|
||||||
|
|
||||||
summary_text = self.make_summary_text(
|
summary_text = yield self.make_summary_text(
|
||||||
notifs_by_room, state_by_room, notif_events, user_id, reason
|
notifs_by_room, state_by_room, notif_events, user_id, reason
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -203,12 +204,15 @@ class Mailer(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_room_vars(self, room_id, user_id, notifs, notif_events, room_state):
|
def get_room_vars(self, room_id, user_id, notifs, notif_events, room_state_ids):
|
||||||
my_member_event = room_state[("m.room.member", user_id)]
|
my_member_event_id = room_state_ids[("m.room.member", user_id)]
|
||||||
|
my_member_event = yield self.store.get_event(my_member_event_id)
|
||||||
is_invite = my_member_event.content["membership"] == "invite"
|
is_invite = my_member_event.content["membership"] == "invite"
|
||||||
|
|
||||||
|
room_name = yield calculate_room_name(self.store, room_state_ids, user_id)
|
||||||
|
|
||||||
room_vars = {
|
room_vars = {
|
||||||
"title": calculate_room_name(room_state, user_id),
|
"title": room_name,
|
||||||
"hash": string_ordinal_total(room_id), # See sender avatar hash
|
"hash": string_ordinal_total(room_id), # See sender avatar hash
|
||||||
"notifs": [],
|
"notifs": [],
|
||||||
"invite": is_invite,
|
"invite": is_invite,
|
||||||
|
@ -218,7 +222,7 @@ class Mailer(object):
|
||||||
if not is_invite:
|
if not is_invite:
|
||||||
for n in notifs:
|
for n in notifs:
|
||||||
notifvars = yield self.get_notif_vars(
|
notifvars = yield self.get_notif_vars(
|
||||||
n, user_id, notif_events[n['event_id']], room_state
|
n, user_id, notif_events[n['event_id']], room_state_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
# merge overlapping notifs together.
|
# merge overlapping notifs together.
|
||||||
|
@ -243,7 +247,7 @@ class Mailer(object):
|
||||||
defer.returnValue(room_vars)
|
defer.returnValue(room_vars)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_notif_vars(self, notif, user_id, notif_event, room_state):
|
def get_notif_vars(self, notif, user_id, notif_event, room_state_ids):
|
||||||
results = yield self.store.get_events_around(
|
results = yield self.store.get_events_around(
|
||||||
notif['room_id'], notif['event_id'],
|
notif['room_id'], notif['event_id'],
|
||||||
before_limit=CONTEXT_BEFORE, after_limit=CONTEXT_AFTER
|
before_limit=CONTEXT_BEFORE, after_limit=CONTEXT_AFTER
|
||||||
|
@ -261,17 +265,19 @@ class Mailer(object):
|
||||||
the_events.append(notif_event)
|
the_events.append(notif_event)
|
||||||
|
|
||||||
for event in the_events:
|
for event in the_events:
|
||||||
messagevars = self.get_message_vars(notif, event, room_state)
|
messagevars = yield self.get_message_vars(notif, event, room_state_ids)
|
||||||
if messagevars is not None:
|
if messagevars is not None:
|
||||||
ret['messages'].append(messagevars)
|
ret['messages'].append(messagevars)
|
||||||
|
|
||||||
defer.returnValue(ret)
|
defer.returnValue(ret)
|
||||||
|
|
||||||
def get_message_vars(self, notif, event, room_state):
|
@defer.inlineCallbacks
|
||||||
|
def get_message_vars(self, notif, event, room_state_ids):
|
||||||
if event.type != EventTypes.Message:
|
if event.type != EventTypes.Message:
|
||||||
return None
|
return
|
||||||
|
|
||||||
sender_state_event = room_state[("m.room.member", event.sender)]
|
sender_state_event_id = room_state_ids[("m.room.member", event.sender)]
|
||||||
|
sender_state_event = yield self.store.get_event(sender_state_event_id)
|
||||||
sender_name = name_from_member_event(sender_state_event)
|
sender_name = name_from_member_event(sender_state_event)
|
||||||
sender_avatar_url = sender_state_event.content.get("avatar_url")
|
sender_avatar_url = sender_state_event.content.get("avatar_url")
|
||||||
|
|
||||||
|
@ -299,7 +305,7 @@ class Mailer(object):
|
||||||
if "body" in event.content:
|
if "body" in event.content:
|
||||||
ret["body_text_plain"] = event.content["body"]
|
ret["body_text_plain"] = event.content["body"]
|
||||||
|
|
||||||
return ret
|
defer.returnValue(ret)
|
||||||
|
|
||||||
def add_text_message_vars(self, messagevars, event):
|
def add_text_message_vars(self, messagevars, event):
|
||||||
msgformat = event.content.get("format")
|
msgformat = event.content.get("format")
|
||||||
|
@ -321,6 +327,7 @@ class Mailer(object):
|
||||||
|
|
||||||
return messagevars
|
return messagevars
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def make_summary_text(self, notifs_by_room, state_by_room,
|
def make_summary_text(self, notifs_by_room, state_by_room,
|
||||||
notif_events, user_id, reason):
|
notif_events, user_id, reason):
|
||||||
if len(notifs_by_room) == 1:
|
if len(notifs_by_room) == 1:
|
||||||
|
@ -330,8 +337,8 @@ class Mailer(object):
|
||||||
# If the room has some kind of name, use it, but we don't
|
# If the room has some kind of name, use it, but we don't
|
||||||
# want the generated-from-names one here otherwise we'll
|
# want the generated-from-names one here otherwise we'll
|
||||||
# end up with, "new message from Bob in the Bob room"
|
# end up with, "new message from Bob in the Bob room"
|
||||||
room_name = calculate_room_name(
|
room_name = yield calculate_room_name(
|
||||||
state_by_room[room_id], user_id, fallback_to_members=False
|
self.store, state_by_room[room_id], user_id, fallback_to_members=False
|
||||||
)
|
)
|
||||||
|
|
||||||
my_member_event = state_by_room[room_id][("m.room.member", user_id)]
|
my_member_event = state_by_room[room_id][("m.room.member", user_id)]
|
||||||
|
@ -342,16 +349,16 @@ class Mailer(object):
|
||||||
inviter_name = name_from_member_event(inviter_member_event)
|
inviter_name = name_from_member_event(inviter_member_event)
|
||||||
|
|
||||||
if room_name is None:
|
if room_name is None:
|
||||||
return INVITE_FROM_PERSON % {
|
defer.returnValue(INVITE_FROM_PERSON % {
|
||||||
"person": inviter_name,
|
"person": inviter_name,
|
||||||
"app": self.app_name
|
"app": self.app_name
|
||||||
}
|
})
|
||||||
else:
|
else:
|
||||||
return INVITE_FROM_PERSON_TO_ROOM % {
|
defer.returnValue(INVITE_FROM_PERSON_TO_ROOM % {
|
||||||
"person": inviter_name,
|
"person": inviter_name,
|
||||||
"room": room_name,
|
"room": room_name,
|
||||||
"app": self.app_name,
|
"app": self.app_name,
|
||||||
}
|
})
|
||||||
|
|
||||||
sender_name = None
|
sender_name = None
|
||||||
if len(notifs_by_room[room_id]) == 1:
|
if len(notifs_by_room[room_id]) == 1:
|
||||||
|
@ -362,24 +369,24 @@ class Mailer(object):
|
||||||
sender_name = name_from_member_event(state_event)
|
sender_name = name_from_member_event(state_event)
|
||||||
|
|
||||||
if sender_name is not None and room_name is not None:
|
if sender_name is not None and room_name is not None:
|
||||||
return MESSAGE_FROM_PERSON_IN_ROOM % {
|
defer.returnValue(MESSAGE_FROM_PERSON_IN_ROOM % {
|
||||||
"person": sender_name,
|
"person": sender_name,
|
||||||
"room": room_name,
|
"room": room_name,
|
||||||
"app": self.app_name,
|
"app": self.app_name,
|
||||||
}
|
})
|
||||||
elif sender_name is not None:
|
elif sender_name is not None:
|
||||||
return MESSAGE_FROM_PERSON % {
|
defer.returnValue(MESSAGE_FROM_PERSON % {
|
||||||
"person": sender_name,
|
"person": sender_name,
|
||||||
"app": self.app_name,
|
"app": self.app_name,
|
||||||
}
|
})
|
||||||
else:
|
else:
|
||||||
# There's more than one notification for this room, so just
|
# There's more than one notification for this room, so just
|
||||||
# say there are several
|
# say there are several
|
||||||
if room_name is not None:
|
if room_name is not None:
|
||||||
return MESSAGES_IN_ROOM % {
|
defer.returnValue(MESSAGES_IN_ROOM % {
|
||||||
"room": room_name,
|
"room": room_name,
|
||||||
"app": self.app_name,
|
"app": self.app_name,
|
||||||
}
|
})
|
||||||
else:
|
else:
|
||||||
# If the room doesn't have a name, say who the messages
|
# If the room doesn't have a name, say who the messages
|
||||||
# are from explicitly to avoid, "messages in the Bob room"
|
# are from explicitly to avoid, "messages in the Bob room"
|
||||||
|
@ -388,22 +395,22 @@ class Mailer(object):
|
||||||
for n in notifs_by_room[room_id]
|
for n in notifs_by_room[room_id]
|
||||||
]))
|
]))
|
||||||
|
|
||||||
return MESSAGES_FROM_PERSON % {
|
defer.returnValue(MESSAGES_FROM_PERSON % {
|
||||||
"person": descriptor_from_member_events([
|
"person": descriptor_from_member_events([
|
||||||
state_by_room[room_id][("m.room.member", s)]
|
state_by_room[room_id][("m.room.member", s)]
|
||||||
for s in sender_ids
|
for s in sender_ids
|
||||||
]),
|
]),
|
||||||
"app": self.app_name,
|
"app": self.app_name,
|
||||||
}
|
})
|
||||||
else:
|
else:
|
||||||
# Stuff's happened in multiple different rooms
|
# Stuff's happened in multiple different rooms
|
||||||
|
|
||||||
# ...but we still refer to the 'reason' room which triggered the mail
|
# ...but we still refer to the 'reason' room which triggered the mail
|
||||||
if reason['room_name'] is not None:
|
if reason['room_name'] is not None:
|
||||||
return MESSAGES_IN_ROOM_AND_OTHERS % {
|
defer.returnValue(MESSAGES_IN_ROOM_AND_OTHERS % {
|
||||||
"room": reason['room_name'],
|
"room": reason['room_name'],
|
||||||
"app": self.app_name,
|
"app": self.app_name,
|
||||||
}
|
})
|
||||||
else:
|
else:
|
||||||
# If the reason room doesn't have a name, say who the messages
|
# If the reason room doesn't have a name, say who the messages
|
||||||
# are from explicitly to avoid, "messages in the Bob room"
|
# are from explicitly to avoid, "messages in the Bob room"
|
||||||
|
@ -412,13 +419,13 @@ class Mailer(object):
|
||||||
for n in notifs_by_room[reason['room_id']]
|
for n in notifs_by_room[reason['room_id']]
|
||||||
]))
|
]))
|
||||||
|
|
||||||
return MESSAGES_FROM_PERSON_AND_OTHERS % {
|
defer.returnValue(MESSAGES_FROM_PERSON_AND_OTHERS % {
|
||||||
"person": descriptor_from_member_events([
|
"person": descriptor_from_member_events([
|
||||||
state_by_room[reason['room_id']][("m.room.member", s)]
|
state_by_room[reason['room_id']][("m.room.member", s)]
|
||||||
for s in sender_ids
|
for s in sender_ids
|
||||||
]),
|
]),
|
||||||
"app": self.app_name,
|
"app": self.app_name,
|
||||||
}
|
})
|
||||||
|
|
||||||
def make_room_link(self, room_id):
|
def make_room_link(self, room_id):
|
||||||
# need /beta for Universal Links to work on iOS
|
# need /beta for Universal Links to work on iOS
|
||||||
|
|
|
@ -13,6 +13,8 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
@ -25,7 +27,8 @@ ALIAS_RE = re.compile(r"^#.*:.+$")
|
||||||
ALL_ALONE = "Empty Room"
|
ALL_ALONE = "Empty Room"
|
||||||
|
|
||||||
|
|
||||||
def calculate_room_name(room_state, user_id, fallback_to_members=True,
|
@defer.inlineCallbacks
|
||||||
|
def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True,
|
||||||
fallback_to_single_member=True):
|
fallback_to_single_member=True):
|
||||||
"""
|
"""
|
||||||
Works out a user-facing name for the given room as per Matrix
|
Works out a user-facing name for the given room as per Matrix
|
||||||
|
@ -42,59 +45,78 @@ def calculate_room_name(room_state, user_id, fallback_to_members=True,
|
||||||
(string or None) A human readable name for the room.
|
(string or None) A human readable name for the room.
|
||||||
"""
|
"""
|
||||||
# does it have a name?
|
# does it have a name?
|
||||||
if ("m.room.name", "") in room_state:
|
if ("m.room.name", "") in room_state_ids:
|
||||||
m_room_name = room_state[("m.room.name", "")]
|
m_room_name = yield store.get_event(
|
||||||
if m_room_name.content and m_room_name.content["name"]:
|
room_state_ids[("m.room.name", "")], allow_none=True
|
||||||
return m_room_name.content["name"]
|
)
|
||||||
|
if m_room_name and m_room_name.content and m_room_name.content["name"]:
|
||||||
|
defer.returnValue(m_room_name.content["name"])
|
||||||
|
|
||||||
# does it have a canonical alias?
|
# does it have a canonical alias?
|
||||||
if ("m.room.canonical_alias", "") in room_state:
|
if ("m.room.canonical_alias", "") in room_state_ids:
|
||||||
canon_alias = room_state[("m.room.canonical_alias", "")]
|
canon_alias = yield store.get_event(
|
||||||
|
room_state_ids[("m.room.canonical_alias", "")], allow_none=True
|
||||||
|
)
|
||||||
if (
|
if (
|
||||||
canon_alias.content and canon_alias.content["alias"] and
|
canon_alias and canon_alias.content and canon_alias.content["alias"] and
|
||||||
_looks_like_an_alias(canon_alias.content["alias"])
|
_looks_like_an_alias(canon_alias.content["alias"])
|
||||||
):
|
):
|
||||||
return canon_alias.content["alias"]
|
defer.returnValue(canon_alias.content["alias"])
|
||||||
|
|
||||||
# at this point we're going to need to search the state by all state keys
|
# at this point we're going to need to search the state by all state keys
|
||||||
# for an event type, so rearrange the data structure
|
# for an event type, so rearrange the data structure
|
||||||
room_state_bytype = _state_as_two_level_dict(room_state)
|
room_state_bytype_ids = _state_as_two_level_dict(room_state_ids)
|
||||||
|
|
||||||
# right then, any aliases at all?
|
# right then, any aliases at all?
|
||||||
if "m.room.aliases" in room_state_bytype:
|
if "m.room.aliases" in room_state_bytype_ids:
|
||||||
m_room_aliases = room_state_bytype["m.room.aliases"]
|
m_room_aliases = room_state_bytype_ids["m.room.aliases"]
|
||||||
if len(m_room_aliases.values()) > 0:
|
for alias_id in m_room_aliases.values():
|
||||||
first_alias_event = m_room_aliases.values()[0]
|
alias_event = yield store.get_event(
|
||||||
if first_alias_event.content and first_alias_event.content["aliases"]:
|
alias_id, allow_none=True
|
||||||
the_aliases = first_alias_event.content["aliases"]
|
)
|
||||||
|
if alias_event and alias_event.content.get("aliases"):
|
||||||
|
the_aliases = alias_event.content["aliases"]
|
||||||
if len(the_aliases) > 0 and _looks_like_an_alias(the_aliases[0]):
|
if len(the_aliases) > 0 and _looks_like_an_alias(the_aliases[0]):
|
||||||
return the_aliases[0]
|
defer.returnValue(the_aliases[0])
|
||||||
|
|
||||||
if not fallback_to_members:
|
if not fallback_to_members:
|
||||||
return None
|
defer.returnValue(None)
|
||||||
|
|
||||||
my_member_event = None
|
my_member_event = None
|
||||||
if ("m.room.member", user_id) in room_state:
|
if ("m.room.member", user_id) in room_state_ids:
|
||||||
my_member_event = room_state[("m.room.member", user_id)]
|
my_member_event = yield store.get_event(
|
||||||
|
room_state_ids[("m.room.member", user_id)], allow_none=True
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
my_member_event is not None and
|
my_member_event is not None and
|
||||||
my_member_event.content['membership'] == "invite"
|
my_member_event.content['membership'] == "invite"
|
||||||
):
|
):
|
||||||
if ("m.room.member", my_member_event.sender) in room_state:
|
if ("m.room.member", my_member_event.sender) in room_state_ids:
|
||||||
inviter_member_event = room_state[("m.room.member", my_member_event.sender)]
|
inviter_member_event = yield store.get_event(
|
||||||
|
room_state_ids[("m.room.member", my_member_event.sender)],
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
if inviter_member_event:
|
||||||
if fallback_to_single_member:
|
if fallback_to_single_member:
|
||||||
return "Invite from %s" % (name_from_member_event(inviter_member_event),)
|
defer.returnValue(
|
||||||
|
"Invite from %s" % (
|
||||||
|
name_from_member_event(inviter_member_event),
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return None
|
return
|
||||||
else:
|
else:
|
||||||
return "Room Invite"
|
defer.returnValue("Room Invite")
|
||||||
|
|
||||||
# we're going to have to generate a name based on who's in the room,
|
# we're going to have to generate a name based on who's in the room,
|
||||||
# so find out who is in the room that isn't the user.
|
# so find out who is in the room that isn't the user.
|
||||||
if "m.room.member" in room_state_bytype:
|
if "m.room.member" in room_state_bytype_ids:
|
||||||
|
member_events = yield store.get_events(
|
||||||
|
room_state_bytype_ids["m.room.member"].values()
|
||||||
|
)
|
||||||
all_members = [
|
all_members = [
|
||||||
ev for ev in room_state_bytype["m.room.member"].values()
|
ev for ev in member_events.values()
|
||||||
if ev.content['membership'] == "join" or ev.content['membership'] == "invite"
|
if ev.content['membership'] == "join" or ev.content['membership'] == "invite"
|
||||||
]
|
]
|
||||||
# Sort the member events oldest-first so the we name people in the
|
# Sort the member events oldest-first so the we name people in the
|
||||||
|
@ -111,9 +133,9 @@ def calculate_room_name(room_state, user_id, fallback_to_members=True,
|
||||||
# self-chat, peeked room with 1 participant,
|
# self-chat, peeked room with 1 participant,
|
||||||
# or inbound invite, or outbound 3PID invite.
|
# or inbound invite, or outbound 3PID invite.
|
||||||
if all_members[0].sender == user_id:
|
if all_members[0].sender == user_id:
|
||||||
if "m.room.third_party_invite" in room_state_bytype:
|
if "m.room.third_party_invite" in room_state_bytype_ids:
|
||||||
third_party_invites = (
|
third_party_invites = (
|
||||||
room_state_bytype["m.room.third_party_invite"].values()
|
room_state_bytype_ids["m.room.third_party_invite"].values()
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(third_party_invites) > 0:
|
if len(third_party_invites) > 0:
|
||||||
|
@ -126,17 +148,17 @@ def calculate_room_name(room_state, user_id, fallback_to_members=True,
|
||||||
# return "Inviting %s" % (
|
# return "Inviting %s" % (
|
||||||
# descriptor_from_member_events(third_party_invites)
|
# descriptor_from_member_events(third_party_invites)
|
||||||
# )
|
# )
|
||||||
return "Inviting email address"
|
defer.returnValue("Inviting email address")
|
||||||
else:
|
else:
|
||||||
return ALL_ALONE
|
defer.returnValue(ALL_ALONE)
|
||||||
else:
|
else:
|
||||||
return name_from_member_event(all_members[0])
|
defer.returnValue(name_from_member_event(all_members[0]))
|
||||||
else:
|
else:
|
||||||
return ALL_ALONE
|
defer.returnValue(ALL_ALONE)
|
||||||
elif len(other_members) == 1 and not fallback_to_single_member:
|
elif len(other_members) == 1 and not fallback_to_single_member:
|
||||||
return None
|
return
|
||||||
else:
|
else:
|
||||||
return descriptor_from_member_events(other_members)
|
defer.returnValue(descriptor_from_member_events(other_members))
|
||||||
|
|
||||||
|
|
||||||
def descriptor_from_member_events(member_events):
|
def descriptor_from_member_events(member_events):
|
|
@ -14,7 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from synapse.util.presentable_names import (
|
from synapse.push.presentable_names import (
|
||||||
calculate_room_name, name_from_member_event
|
calculate_room_name, name_from_member_event
|
||||||
)
|
)
|
||||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||||
|
@ -49,21 +49,22 @@ def get_badge_count(store, user_id):
|
||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_context_for_event(state_handler, ev, user_id):
|
def get_context_for_event(store, state_handler, ev, user_id):
|
||||||
ctx = {}
|
ctx = {}
|
||||||
|
|
||||||
room_state = yield state_handler.get_current_state(ev.room_id)
|
room_state_ids = yield state_handler.get_current_state_ids(ev.room_id)
|
||||||
|
|
||||||
# we no longer bother setting room_alias, and make room_name the
|
# we no longer bother setting room_alias, and make room_name the
|
||||||
# human-readable name instead, be that m.room.name, an alias or
|
# human-readable name instead, be that m.room.name, an alias or
|
||||||
# a list of people in the room
|
# a list of people in the room
|
||||||
name = calculate_room_name(
|
name = yield calculate_room_name(
|
||||||
room_state, user_id, fallback_to_single_member=False
|
store, room_state_ids, user_id, fallback_to_single_member=False
|
||||||
)
|
)
|
||||||
if name:
|
if name:
|
||||||
ctx['name'] = name
|
ctx['name'] = name
|
||||||
|
|
||||||
sender_state_event = room_state[("m.room.member", ev.sender)]
|
sender_state_event_id = room_state_ids[("m.room.member", ev.sender)]
|
||||||
|
sender_state_event = yield store.get_event(sender_state_event_id)
|
||||||
ctx['sender_display_name'] = name_from_member_event(sender_state_event)
|
ctx['sender_display_name'] = name_from_member_event(sender_state_event)
|
||||||
|
|
||||||
defer.returnValue(ctx)
|
defer.returnValue(ctx)
|
||||||
|
|
|
@ -40,8 +40,8 @@ STREAM_NAMES = (
|
||||||
("backfill",),
|
("backfill",),
|
||||||
("push_rules",),
|
("push_rules",),
|
||||||
("pushers",),
|
("pushers",),
|
||||||
("state",),
|
|
||||||
("caches",),
|
("caches",),
|
||||||
|
("to_device",),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -130,7 +130,6 @@ class ReplicationResource(Resource):
|
||||||
backfill_token = yield self.store.get_current_backfill_token()
|
backfill_token = yield self.store.get_current_backfill_token()
|
||||||
push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
|
push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
|
||||||
pushers_token = self.store.get_pushers_stream_token()
|
pushers_token = self.store.get_pushers_stream_token()
|
||||||
state_token = self.store.get_state_stream_token()
|
|
||||||
caches_token = self.store.get_cache_stream_token()
|
caches_token = self.store.get_cache_stream_token()
|
||||||
|
|
||||||
defer.returnValue(_ReplicationToken(
|
defer.returnValue(_ReplicationToken(
|
||||||
|
@ -142,8 +141,9 @@ class ReplicationResource(Resource):
|
||||||
backfill_token,
|
backfill_token,
|
||||||
push_rules_token,
|
push_rules_token,
|
||||||
pushers_token,
|
pushers_token,
|
||||||
state_token,
|
0, # State stream is no longer a thing
|
||||||
caches_token,
|
caches_token,
|
||||||
|
int(stream_token.to_device_key),
|
||||||
))
|
))
|
||||||
|
|
||||||
@request_handler()
|
@request_handler()
|
||||||
|
@ -191,8 +191,8 @@ class ReplicationResource(Resource):
|
||||||
yield self.receipts(writer, current_token, limit, request_streams)
|
yield self.receipts(writer, current_token, limit, request_streams)
|
||||||
yield self.push_rules(writer, current_token, limit, request_streams)
|
yield self.push_rules(writer, current_token, limit, request_streams)
|
||||||
yield self.pushers(writer, current_token, limit, request_streams)
|
yield self.pushers(writer, current_token, limit, request_streams)
|
||||||
yield self.state(writer, current_token, limit, request_streams)
|
|
||||||
yield self.caches(writer, current_token, limit, request_streams)
|
yield self.caches(writer, current_token, limit, request_streams)
|
||||||
|
yield self.to_device(writer, current_token, limit, request_streams)
|
||||||
self.streams(writer, current_token, request_streams)
|
self.streams(writer, current_token, request_streams)
|
||||||
|
|
||||||
logger.info("Replicated %d rows", writer.total)
|
logger.info("Replicated %d rows", writer.total)
|
||||||
|
@ -365,25 +365,6 @@ class ReplicationResource(Resource):
|
||||||
"position", "user_id", "app_id", "pushkey"
|
"position", "user_id", "app_id", "pushkey"
|
||||||
))
|
))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def state(self, writer, current_token, limit, request_streams):
|
|
||||||
current_position = current_token.state
|
|
||||||
|
|
||||||
state = request_streams.get("state")
|
|
||||||
|
|
||||||
if state is not None:
|
|
||||||
state_groups, state_group_state = (
|
|
||||||
yield self.store.get_all_new_state_groups(
|
|
||||||
state, current_position, limit
|
|
||||||
)
|
|
||||||
)
|
|
||||||
writer.write_header_and_rows("state_groups", state_groups, (
|
|
||||||
"position", "room_id", "event_id"
|
|
||||||
))
|
|
||||||
writer.write_header_and_rows("state_group_state", state_group_state, (
|
|
||||||
"position", "type", "state_key", "event_id"
|
|
||||||
))
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def caches(self, writer, current_token, limit, request_streams):
|
def caches(self, writer, current_token, limit, request_streams):
|
||||||
current_position = current_token.caches
|
current_position = current_token.caches
|
||||||
|
@ -398,6 +379,20 @@ class ReplicationResource(Resource):
|
||||||
"position", "cache_func", "keys", "invalidation_ts"
|
"position", "cache_func", "keys", "invalidation_ts"
|
||||||
))
|
))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def to_device(self, writer, current_token, limit, request_streams):
|
||||||
|
current_position = current_token.to_device
|
||||||
|
|
||||||
|
to_device = request_streams.get("to_device")
|
||||||
|
|
||||||
|
if to_device is not None:
|
||||||
|
to_device_rows = yield self.store.get_all_new_device_messages(
|
||||||
|
to_device, current_position, limit
|
||||||
|
)
|
||||||
|
writer.write_header_and_rows("to_device", to_device_rows, (
|
||||||
|
"position", "user_id", "device_id", "message_json"
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
class _Writer(object):
|
class _Writer(object):
|
||||||
"""Writes the streams as a JSON object as the response to the request"""
|
"""Writes the streams as a JSON object as the response to the request"""
|
||||||
|
@ -426,7 +421,7 @@ class _Writer(object):
|
||||||
|
|
||||||
class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
|
class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
|
||||||
"events", "presence", "typing", "receipts", "account_data", "backfill",
|
"events", "presence", "typing", "receipts", "account_data", "backfill",
|
||||||
"push_rules", "pushers", "state", "caches",
|
"push_rules", "pushers", "state", "caches", "to_device",
|
||||||
))):
|
))):
|
||||||
__slots__ = []
|
__slots__ = []
|
||||||
|
|
||||||
|
|
42
synapse/replication/slave/storage/deviceinbox.py
Normal file
42
synapse/replication/slave/storage/deviceinbox.py
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from ._base import BaseSlavedStore
|
||||||
|
from ._slaved_id_tracker import SlavedIdTracker
|
||||||
|
from synapse.storage import DataStore
|
||||||
|
|
||||||
|
|
||||||
|
class SlavedDeviceInboxStore(BaseSlavedStore):
|
||||||
|
def __init__(self, db_conn, hs):
|
||||||
|
super(SlavedDeviceInboxStore, self).__init__(db_conn, hs)
|
||||||
|
self._device_inbox_id_gen = SlavedIdTracker(
|
||||||
|
db_conn, "device_inbox", "stream_id",
|
||||||
|
)
|
||||||
|
|
||||||
|
get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__
|
||||||
|
get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__
|
||||||
|
delete_messages_for_device = DataStore.delete_messages_for_device.__func__
|
||||||
|
|
||||||
|
def stream_positions(self):
|
||||||
|
result = super(SlavedDeviceInboxStore, self).stream_positions()
|
||||||
|
result["to_device"] = self._device_inbox_id_gen.get_current_token()
|
||||||
|
return result
|
||||||
|
|
||||||
|
def process_replication(self, result):
|
||||||
|
stream = result.get("to_device")
|
||||||
|
if stream:
|
||||||
|
self._device_inbox_id_gen.advance(int(stream["position"]))
|
||||||
|
|
||||||
|
return super(SlavedDeviceInboxStore, self).process_replication(result)
|
|
@ -120,10 +120,21 @@ class SlavedEventStore(BaseSlavedStore):
|
||||||
get_state_for_event = DataStore.get_state_for_event.__func__
|
get_state_for_event = DataStore.get_state_for_event.__func__
|
||||||
get_state_for_events = DataStore.get_state_for_events.__func__
|
get_state_for_events = DataStore.get_state_for_events.__func__
|
||||||
get_state_groups = DataStore.get_state_groups.__func__
|
get_state_groups = DataStore.get_state_groups.__func__
|
||||||
|
get_state_groups_ids = DataStore.get_state_groups_ids.__func__
|
||||||
|
get_state_ids_for_event = DataStore.get_state_ids_for_event.__func__
|
||||||
|
get_state_ids_for_events = DataStore.get_state_ids_for_events.__func__
|
||||||
|
get_joined_users_from_state = DataStore.get_joined_users_from_state.__func__
|
||||||
|
get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__
|
||||||
|
_get_joined_users_from_context = (
|
||||||
|
RoomMemberStore.__dict__["_get_joined_users_from_context"]
|
||||||
|
)
|
||||||
|
|
||||||
get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__
|
get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__
|
||||||
get_room_events_stream_for_rooms = (
|
get_room_events_stream_for_rooms = (
|
||||||
DataStore.get_room_events_stream_for_rooms.__func__
|
DataStore.get_room_events_stream_for_rooms.__func__
|
||||||
)
|
)
|
||||||
|
is_host_joined = DataStore.is_host_joined.__func__
|
||||||
|
_is_host_joined = RoomMemberStore.__dict__["_is_host_joined"]
|
||||||
get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__
|
get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__
|
||||||
|
|
||||||
_set_before_and_after = staticmethod(DataStore._set_before_and_after)
|
_set_before_and_after = staticmethod(DataStore._set_before_and_after)
|
||||||
|
@ -211,7 +222,6 @@ class SlavedEventStore(BaseSlavedStore):
|
||||||
self._get_current_state_for_key.invalidate_all()
|
self._get_current_state_for_key.invalidate_all()
|
||||||
self.get_rooms_for_user.invalidate_all()
|
self.get_rooms_for_user.invalidate_all()
|
||||||
self.get_users_in_room.invalidate((event.room_id,))
|
self.get_users_in_room.invalidate((event.room_id,))
|
||||||
# self.get_joined_hosts_for_room.invalidate((event.room_id,))
|
|
||||||
|
|
||||||
self._invalidate_get_event_cache(event.event_id)
|
self._invalidate_get_event_cache(event.event_id)
|
||||||
|
|
||||||
|
@ -235,7 +245,6 @@ class SlavedEventStore(BaseSlavedStore):
|
||||||
|
|
||||||
if event.type == EventTypes.Member:
|
if event.type == EventTypes.Member:
|
||||||
self.get_rooms_for_user.invalidate((event.state_key,))
|
self.get_rooms_for_user.invalidate((event.state_key,))
|
||||||
# self.get_joined_hosts_for_room.invalidate((event.room_id,))
|
|
||||||
self.get_users_in_room.invalidate((event.room_id,))
|
self.get_users_in_room.invalidate((event.room_id,))
|
||||||
self._membership_stream_cache.entity_has_changed(
|
self._membership_stream_cache.entity_has_changed(
|
||||||
event.state_key, event.internal_metadata.stream_ordering
|
event.state_key, event.internal_metadata.stream_ordering
|
||||||
|
|
|
@ -49,6 +49,7 @@ from synapse.rest.client.v2_alpha import (
|
||||||
notifications,
|
notifications,
|
||||||
devices,
|
devices,
|
||||||
thirdparty,
|
thirdparty,
|
||||||
|
sendtodevice,
|
||||||
)
|
)
|
||||||
|
|
||||||
from synapse.http.server import JsonResource
|
from synapse.http.server import JsonResource
|
||||||
|
@ -96,3 +97,4 @@ class ClientRestResource(JsonResource):
|
||||||
notifications.register_servlets(hs, client_resource)
|
notifications.register_servlets(hs, client_resource)
|
||||||
devices.register_servlets(hs, client_resource)
|
devices.register_servlets(hs, client_resource)
|
||||||
thirdparty.register_servlets(hs, client_resource)
|
thirdparty.register_servlets(hs, client_resource)
|
||||||
|
sendtodevice.register_servlets(hs, client_resource)
|
||||||
|
|
90
synapse/rest/client/v2_alpha/sendtodevice.py
Normal file
90
synapse/rest/client/v2_alpha/sendtodevice.py
Normal file
|
@ -0,0 +1,90 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
from synapse.http.servlet import parse_json_object_from_request
|
||||||
|
|
||||||
|
from synapse.http import servlet
|
||||||
|
from synapse.rest.client.v1.transactions import HttpTransactionStore
|
||||||
|
from ._base import client_v2_patterns
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SendToDeviceRestServlet(servlet.RestServlet):
|
||||||
|
PATTERNS = client_v2_patterns(
|
||||||
|
"/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$",
|
||||||
|
releases=[], v2_alpha=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hs (synapse.server.HomeServer): server
|
||||||
|
"""
|
||||||
|
super(SendToDeviceRestServlet, self).__init__()
|
||||||
|
self.hs = hs
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.notifier = hs.get_notifier()
|
||||||
|
self.is_mine_id = hs.is_mine_id
|
||||||
|
self.txns = HttpTransactionStore()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_PUT(self, request, message_type, txn_id):
|
||||||
|
try:
|
||||||
|
defer.returnValue(
|
||||||
|
self.txns.get_client_transaction(request, txn_id)
|
||||||
|
)
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
# TODO: Prod the notifier to wake up sync streams.
|
||||||
|
# TODO: Implement replication for the messages.
|
||||||
|
# TODO: Send the messages to remote servers if needed.
|
||||||
|
|
||||||
|
local_messages = {}
|
||||||
|
for user_id, by_device in content["messages"].items():
|
||||||
|
if self.is_mine_id(user_id):
|
||||||
|
messages_by_device = {
|
||||||
|
device_id: {
|
||||||
|
"content": message_content,
|
||||||
|
"type": message_type,
|
||||||
|
"sender": requester.user.to_string(),
|
||||||
|
}
|
||||||
|
for device_id, message_content in by_device.items()
|
||||||
|
}
|
||||||
|
if messages_by_device:
|
||||||
|
local_messages[user_id] = messages_by_device
|
||||||
|
|
||||||
|
stream_id = yield self.store.add_messages_to_device_inbox(local_messages)
|
||||||
|
|
||||||
|
self.notifier.on_new_event(
|
||||||
|
"to_device_key", stream_id, users=local_messages.keys()
|
||||||
|
)
|
||||||
|
|
||||||
|
response = (200, {})
|
||||||
|
self.txns.store_client_transaction(request, txn_id, response)
|
||||||
|
defer.returnValue(response)
|
||||||
|
|
||||||
|
|
||||||
|
def register_servlets(hs, http_server):
|
||||||
|
SendToDeviceRestServlet(hs).register(http_server)
|
|
@ -97,6 +97,7 @@ class SyncRestServlet(RestServlet):
|
||||||
request, allow_guest=True
|
request, allow_guest=True
|
||||||
)
|
)
|
||||||
user = requester.user
|
user = requester.user
|
||||||
|
device_id = requester.device_id
|
||||||
|
|
||||||
timeout = parse_integer(request, "timeout", default=0)
|
timeout = parse_integer(request, "timeout", default=0)
|
||||||
since = parse_string(request, "since")
|
since = parse_string(request, "since")
|
||||||
|
@ -109,12 +110,12 @@ class SyncRestServlet(RestServlet):
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"/sync: user=%r, timeout=%r, since=%r,"
|
"/sync: user=%r, timeout=%r, since=%r,"
|
||||||
" set_presence=%r, filter_id=%r" % (
|
" set_presence=%r, filter_id=%r, device_id=%r" % (
|
||||||
user, timeout, since, set_presence, filter_id
|
user, timeout, since, set_presence, filter_id, device_id
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
request_key = (user, timeout, since, filter_id, full_state)
|
request_key = (user, timeout, since, filter_id, full_state, device_id)
|
||||||
|
|
||||||
if filter_id:
|
if filter_id:
|
||||||
if filter_id.startswith('{'):
|
if filter_id.startswith('{'):
|
||||||
|
@ -136,6 +137,7 @@ class SyncRestServlet(RestServlet):
|
||||||
filter_collection=filter,
|
filter_collection=filter,
|
||||||
is_guest=requester.is_guest,
|
is_guest=requester.is_guest,
|
||||||
request_key=request_key,
|
request_key=request_key,
|
||||||
|
device_id=device_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if since is not None:
|
if since is not None:
|
||||||
|
@ -173,6 +175,7 @@ class SyncRestServlet(RestServlet):
|
||||||
|
|
||||||
response_content = {
|
response_content = {
|
||||||
"account_data": {"events": sync_result.account_data},
|
"account_data": {"events": sync_result.account_data},
|
||||||
|
"to_device": {"events": sync_result.to_device},
|
||||||
"presence": self.encode_presence(
|
"presence": self.encode_presence(
|
||||||
sync_result.presence, time_now
|
sync_result.presence, time_now
|
||||||
),
|
),
|
||||||
|
|
|
@ -18,15 +18,32 @@ import logging
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.api.constants import ThirdPartyEntityKind
|
||||||
from synapse.http.servlet import RestServlet
|
from synapse.http.servlet import RestServlet
|
||||||
from synapse.types import ThirdPartyEntityKind
|
|
||||||
from ._base import client_v2_patterns
|
from ._base import client_v2_patterns
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ThirdPartyProtocolsServlet(RestServlet):
|
||||||
|
PATTERNS = client_v2_patterns("/thirdparty/protocols", releases=())
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(ThirdPartyProtocolsServlet, self).__init__()
|
||||||
|
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.appservice_handler = hs.get_application_service_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, request):
|
||||||
|
yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
|
protocols = yield self.appservice_handler.get_3pe_protocols()
|
||||||
|
defer.returnValue((200, protocols))
|
||||||
|
|
||||||
|
|
||||||
class ThirdPartyUserServlet(RestServlet):
|
class ThirdPartyUserServlet(RestServlet):
|
||||||
PATTERNS = client_v2_patterns("/3pu(/(?P<protocol>[^/]+))?$",
|
PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$",
|
||||||
releases=())
|
releases=())
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
|
@ -50,7 +67,7 @@ class ThirdPartyUserServlet(RestServlet):
|
||||||
|
|
||||||
|
|
||||||
class ThirdPartyLocationServlet(RestServlet):
|
class ThirdPartyLocationServlet(RestServlet):
|
||||||
PATTERNS = client_v2_patterns("/3pl(/(?P<protocol>[^/]+))?$",
|
PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$",
|
||||||
releases=())
|
releases=())
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
|
@ -74,5 +91,6 @@ class ThirdPartyLocationServlet(RestServlet):
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs, http_server):
|
def register_servlets(hs, http_server):
|
||||||
|
ThirdPartyProtocolsServlet(hs).register(http_server)
|
||||||
ThirdPartyUserServlet(hs).register(http_server)
|
ThirdPartyUserServlet(hs).register(http_server)
|
||||||
ThirdPartyLocationServlet(hs).register(http_server)
|
ThirdPartyLocationServlet(hs).register(http_server)
|
||||||
|
|
220
synapse/state.py
220
synapse/state.py
|
@ -23,6 +23,7 @@ from synapse.api.constants import EventTypes
|
||||||
from synapse.api.errors import AuthError
|
from synapse.api.errors import AuthError
|
||||||
from synapse.api.auth import AuthEventTypes
|
from synapse.api.auth import AuthEventTypes
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
|
from synapse.util.async import Linearizer
|
||||||
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
|
@ -43,11 +44,35 @@ SIZE_OF_CACHE = int(1000 * CACHE_SIZE_FACTOR)
|
||||||
EVICTION_TIMEOUT_SECONDS = 60 * 60
|
EVICTION_TIMEOUT_SECONDS = 60 * 60
|
||||||
|
|
||||||
|
|
||||||
|
_NEXT_STATE_ID = 1
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_state_id():
|
||||||
|
global _NEXT_STATE_ID
|
||||||
|
s = "X%d" % (_NEXT_STATE_ID,)
|
||||||
|
_NEXT_STATE_ID += 1
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
class _StateCacheEntry(object):
|
class _StateCacheEntry(object):
|
||||||
def __init__(self, state, state_group, ts):
|
__slots__ = ["state", "state_group", "state_id"]
|
||||||
|
|
||||||
|
def __init__(self, state, state_group):
|
||||||
self.state = state
|
self.state = state
|
||||||
self.state_group = state_group
|
self.state_group = state_group
|
||||||
|
|
||||||
|
# The `state_id` is a unique ID we generate that can be used as ID for
|
||||||
|
# this collection of state. Usually this would be the same as the
|
||||||
|
# state group, but on worker instances we can't generate a new state
|
||||||
|
# group each time we resolve state, so we generate a separate one that
|
||||||
|
# isn't persisted and is used solely for caches.
|
||||||
|
# `state_id` is either a state_group (and so an int) or a string. This
|
||||||
|
# ensures we don't accidentally persist a state_id as a stateg_group
|
||||||
|
if state_group:
|
||||||
|
self.state_id = state_group
|
||||||
|
else:
|
||||||
|
self.state_id = _gen_state_id()
|
||||||
|
|
||||||
|
|
||||||
class StateHandler(object):
|
class StateHandler(object):
|
||||||
""" Responsible for doing state conflict resolution.
|
""" Responsible for doing state conflict resolution.
|
||||||
|
@ -60,6 +85,7 @@ class StateHandler(object):
|
||||||
|
|
||||||
# dict of set of event_ids -> _StateCacheEntry.
|
# dict of set of event_ids -> _StateCacheEntry.
|
||||||
self._state_cache = None
|
self._state_cache = None
|
||||||
|
self.resolve_linearizer = Linearizer()
|
||||||
|
|
||||||
def start_caching(self):
|
def start_caching(self):
|
||||||
logger.debug("start_caching")
|
logger.debug("start_caching")
|
||||||
|
@ -93,8 +119,32 @@ class StateHandler(object):
|
||||||
if not latest_event_ids:
|
if not latest_event_ids:
|
||||||
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||||
|
|
||||||
res = yield self.resolve_state_groups(room_id, latest_event_ids)
|
ret = yield self.resolve_state_groups(room_id, latest_event_ids)
|
||||||
state = res[1]
|
state = ret.state
|
||||||
|
|
||||||
|
if event_type:
|
||||||
|
event_id = state.get((event_type, state_key))
|
||||||
|
event = None
|
||||||
|
if event_id:
|
||||||
|
event = yield self.store.get_event(event_id, allow_none=True)
|
||||||
|
defer.returnValue(event)
|
||||||
|
return
|
||||||
|
|
||||||
|
state_map = yield self.store.get_events(state.values(), get_prev_content=False)
|
||||||
|
state = {
|
||||||
|
key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
|
||||||
|
}
|
||||||
|
|
||||||
|
defer.returnValue(state)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_current_state_ids(self, room_id, event_type=None, state_key="",
|
||||||
|
latest_event_ids=None):
|
||||||
|
if not latest_event_ids:
|
||||||
|
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||||
|
|
||||||
|
ret = yield self.resolve_state_groups(room_id, latest_event_ids)
|
||||||
|
state = ret.state
|
||||||
|
|
||||||
if event_type:
|
if event_type:
|
||||||
defer.returnValue(state.get((event_type, state_key)))
|
defer.returnValue(state.get((event_type, state_key)))
|
||||||
|
@ -102,6 +152,15 @@ class StateHandler(object):
|
||||||
|
|
||||||
defer.returnValue(state)
|
defer.returnValue(state)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_current_user_in_room(self, room_id):
|
||||||
|
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||||
|
entry = yield self.resolve_state_groups(room_id, latest_event_ids)
|
||||||
|
joined_users = yield self.store.get_joined_users_from_state(
|
||||||
|
room_id, entry.state_id, entry.state
|
||||||
|
)
|
||||||
|
defer.returnValue(joined_users)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def compute_event_context(self, event, old_state=None):
|
def compute_event_context(self, event, old_state=None):
|
||||||
""" Fills out the context with the `current state` of the graph. The
|
""" Fills out the context with the `current state` of the graph. The
|
||||||
|
@ -123,54 +182,75 @@ class StateHandler(object):
|
||||||
# state. Certainly store.get_current_state won't return any, and
|
# state. Certainly store.get_current_state won't return any, and
|
||||||
# persisting the event won't store the state group.
|
# persisting the event won't store the state group.
|
||||||
if old_state:
|
if old_state:
|
||||||
context.current_state = {
|
context.prev_state_ids = {
|
||||||
(s.type, s.state_key): s for s in old_state
|
(s.type, s.state_key): s.event_id for s in old_state
|
||||||
}
|
}
|
||||||
|
if event.is_state():
|
||||||
|
context.current_state_events = dict(context.prev_state_ids)
|
||||||
|
key = (event.type, event.state_key)
|
||||||
|
context.current_state_events[key] = event.event_id
|
||||||
else:
|
else:
|
||||||
context.current_state = {}
|
context.current_state_events = context.prev_state_ids
|
||||||
|
else:
|
||||||
|
context.current_state_ids = {}
|
||||||
|
context.prev_state_ids = {}
|
||||||
context.prev_state_events = []
|
context.prev_state_events = []
|
||||||
context.state_group = None
|
context.state_group = self.store.get_next_state_group()
|
||||||
defer.returnValue(context)
|
defer.returnValue(context)
|
||||||
|
|
||||||
if old_state:
|
if old_state:
|
||||||
context.current_state = {
|
context.prev_state_ids = {
|
||||||
(s.type, s.state_key): s for s in old_state
|
(s.type, s.state_key): s.event_id for s in old_state
|
||||||
}
|
}
|
||||||
context.state_group = None
|
context.state_group = self.store.get_next_state_group()
|
||||||
|
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
key = (event.type, event.state_key)
|
key = (event.type, event.state_key)
|
||||||
if key in context.current_state:
|
if key in context.prev_state_ids:
|
||||||
replaces = context.current_state[key]
|
replaces = context.prev_state_ids[key]
|
||||||
if replaces.event_id != event.event_id: # Paranoia check
|
if replaces != event.event_id: # Paranoia check
|
||||||
event.unsigned["replaces_state"] = replaces.event_id
|
event.unsigned["replaces_state"] = replaces
|
||||||
|
context.current_state_ids = dict(context.prev_state_ids)
|
||||||
|
context.current_state_ids[key] = event.event_id
|
||||||
|
else:
|
||||||
|
context.current_state_ids = context.prev_state_ids
|
||||||
|
|
||||||
context.prev_state_events = []
|
context.prev_state_events = []
|
||||||
defer.returnValue(context)
|
defer.returnValue(context)
|
||||||
|
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
ret = yield self.resolve_state_groups(
|
entry = yield self.resolve_state_groups(
|
||||||
event.room_id, [e for e, _ in event.prev_events],
|
event.room_id, [e for e, _ in event.prev_events],
|
||||||
event_type=event.type,
|
event_type=event.type,
|
||||||
state_key=event.state_key,
|
state_key=event.state_key,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
ret = yield self.resolve_state_groups(
|
entry = yield self.resolve_state_groups(
|
||||||
event.room_id, [e for e, _ in event.prev_events],
|
event.room_id, [e for e, _ in event.prev_events],
|
||||||
)
|
)
|
||||||
|
|
||||||
group, curr_state, prev_state = ret
|
curr_state = entry.state
|
||||||
|
|
||||||
context.current_state = curr_state
|
context.prev_state_ids = curr_state
|
||||||
context.state_group = group if not event.is_state() else None
|
if event.is_state():
|
||||||
|
context.state_group = self.store.get_next_state_group()
|
||||||
|
else:
|
||||||
|
if entry.state_group is None:
|
||||||
|
entry.state_group = self.store.get_next_state_group()
|
||||||
|
entry.state_id = entry.state_group
|
||||||
|
context.state_group = entry.state_group
|
||||||
|
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
key = (event.type, event.state_key)
|
key = (event.type, event.state_key)
|
||||||
if key in context.current_state:
|
if key in context.prev_state_ids:
|
||||||
replaces = context.current_state[key]
|
replaces = context.prev_state_ids[key]
|
||||||
event.unsigned["replaces_state"] = replaces.event_id
|
event.unsigned["replaces_state"] = replaces
|
||||||
|
context.current_state_ids = dict(context.prev_state_ids)
|
||||||
|
context.current_state_ids[key] = event.event_id
|
||||||
|
else:
|
||||||
|
context.current_state_ids = context.prev_state_ids
|
||||||
|
|
||||||
context.prev_state_events = prev_state
|
context.prev_state_events = []
|
||||||
defer.returnValue(context)
|
defer.returnValue(context)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -187,72 +267,88 @@ class StateHandler(object):
|
||||||
"""
|
"""
|
||||||
logger.debug("resolve_state_groups event_ids %s", event_ids)
|
logger.debug("resolve_state_groups event_ids %s", event_ids)
|
||||||
|
|
||||||
state_groups = yield self.store.get_state_groups(
|
state_groups_ids = yield self.store.get_state_groups_ids(
|
||||||
room_id, event_ids
|
room_id, event_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"resolve_state_groups state_groups %s",
|
"resolve_state_groups state_groups %s",
|
||||||
state_groups.keys()
|
state_groups_ids.keys()
|
||||||
)
|
)
|
||||||
|
|
||||||
group_names = frozenset(state_groups.keys())
|
group_names = frozenset(state_groups_ids.keys())
|
||||||
if len(group_names) == 1:
|
if len(group_names) == 1:
|
||||||
name, state_list = state_groups.items().pop()
|
name, state_list = state_groups_ids.items().pop()
|
||||||
state = {
|
|
||||||
(e.type, e.state_key): e
|
|
||||||
for e in state_list
|
|
||||||
}
|
|
||||||
prev_state = state.get((event_type, state_key), None)
|
|
||||||
if prev_state:
|
|
||||||
prev_state = prev_state.event_id
|
|
||||||
prev_states = [prev_state]
|
|
||||||
else:
|
|
||||||
prev_states = []
|
|
||||||
|
|
||||||
defer.returnValue((name, state, prev_states))
|
defer.returnValue(_StateCacheEntry(
|
||||||
|
state=state_list,
|
||||||
|
state_group=name,
|
||||||
|
))
|
||||||
|
|
||||||
|
with (yield self.resolve_linearizer.queue(group_names)):
|
||||||
if self._state_cache is not None:
|
if self._state_cache is not None:
|
||||||
cache = self._state_cache.get(group_names, None)
|
cache = self._state_cache.get(group_names, None)
|
||||||
if cache:
|
if cache:
|
||||||
cache.ts = self.clock.time_msec()
|
defer.returnValue(cache)
|
||||||
|
|
||||||
event_dict = yield self.store.get_events(cache.state.values())
|
logger.info(
|
||||||
state = {(e.type, e.state_key): e for e in event_dict.values()}
|
"Resolving state for %s with %d groups", room_id, len(state_groups_ids)
|
||||||
|
)
|
||||||
|
|
||||||
prev_state = state.get((event_type, state_key), None)
|
state = {}
|
||||||
if prev_state:
|
for st in state_groups_ids.values():
|
||||||
prev_state = prev_state.event_id
|
for key, e_id in st.items():
|
||||||
prev_states = [prev_state]
|
state.setdefault(key, set()).add(e_id)
|
||||||
|
|
||||||
|
conflicted_state = {
|
||||||
|
k: list(v)
|
||||||
|
for k, v in state.items()
|
||||||
|
if len(v) > 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if conflicted_state:
|
||||||
|
logger.info("Resolving conflicted state for %r", room_id)
|
||||||
|
state_map = yield self.store.get_events(
|
||||||
|
[e_id for st in state_groups_ids.values() for e_id in st.values()],
|
||||||
|
get_prev_content=False
|
||||||
|
)
|
||||||
|
state_sets = [
|
||||||
|
[state_map[e_id] for key, e_id in st.items() if e_id in state_map]
|
||||||
|
for st in state_groups_ids.values()
|
||||||
|
]
|
||||||
|
new_state, _ = self._resolve_events(
|
||||||
|
state_sets, event_type, state_key
|
||||||
|
)
|
||||||
|
new_state = {
|
||||||
|
key: e.event_id for key, e in new_state.items()
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
prev_states = []
|
new_state = {
|
||||||
defer.returnValue(
|
key: e_ids.pop() for key, e_ids in state.items()
|
||||||
(cache.state_group, state, prev_states)
|
}
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("Resolving state for %s with %d groups", room_id, len(state_groups))
|
|
||||||
|
|
||||||
new_state, prev_states = self._resolve_events(
|
|
||||||
state_groups.values(), event_type, state_key
|
|
||||||
)
|
|
||||||
|
|
||||||
state_group = None
|
state_group = None
|
||||||
new_state_event_ids = frozenset(e.event_id for e in new_state.values())
|
new_state_event_ids = frozenset(new_state.values())
|
||||||
for sg, events in state_groups.items():
|
for sg, events in state_groups_ids.items():
|
||||||
if new_state_event_ids == frozenset(e.event_id for e in events):
|
if new_state_event_ids == frozenset(e_id for e_id in events):
|
||||||
state_group = sg
|
state_group = sg
|
||||||
break
|
break
|
||||||
|
if state_group is None:
|
||||||
|
# Worker instances don't have access to this method, but we want
|
||||||
|
# to set the state_group on the main instance to increase cache
|
||||||
|
# hits.
|
||||||
|
if hasattr(self.store, "get_next_state_group"):
|
||||||
|
state_group = self.store.get_next_state_group()
|
||||||
|
|
||||||
if self._state_cache is not None:
|
|
||||||
cache = _StateCacheEntry(
|
cache = _StateCacheEntry(
|
||||||
state={key: event.event_id for key, event in new_state.items()},
|
state=new_state,
|
||||||
state_group=state_group,
|
state_group=state_group,
|
||||||
ts=self.clock.time_msec()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self._state_cache is not None:
|
||||||
self._state_cache[group_names] = cache
|
self._state_cache[group_names] = cache
|
||||||
|
|
||||||
defer.returnValue((state_group, new_state, prev_states))
|
defer.returnValue(cache)
|
||||||
|
|
||||||
def resolve_events(self, state_sets, event):
|
def resolve_events(self, state_sets, event):
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
|
@ -36,6 +36,7 @@ from .push_rule import PushRuleStore
|
||||||
from .media_repository import MediaRepositoryStore
|
from .media_repository import MediaRepositoryStore
|
||||||
from .rejections import RejectionsStore
|
from .rejections import RejectionsStore
|
||||||
from .event_push_actions import EventPushActionsStore
|
from .event_push_actions import EventPushActionsStore
|
||||||
|
from .deviceinbox import DeviceInboxStore
|
||||||
|
|
||||||
from .state import StateStore
|
from .state import StateStore
|
||||||
from .signatures import SignatureStore
|
from .signatures import SignatureStore
|
||||||
|
@ -84,6 +85,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||||
OpenIdStore,
|
OpenIdStore,
|
||||||
ClientIpStore,
|
ClientIpStore,
|
||||||
DeviceStore,
|
DeviceStore,
|
||||||
|
DeviceInboxStore,
|
||||||
):
|
):
|
||||||
|
|
||||||
def __init__(self, db_conn, hs):
|
def __init__(self, db_conn, hs):
|
||||||
|
@ -108,9 +110,12 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||||
self._presence_id_gen = StreamIdGenerator(
|
self._presence_id_gen = StreamIdGenerator(
|
||||||
db_conn, "presence_stream", "stream_id"
|
db_conn, "presence_stream", "stream_id"
|
||||||
)
|
)
|
||||||
|
self._device_inbox_id_gen = StreamIdGenerator(
|
||||||
|
db_conn, "device_inbox", "stream_id"
|
||||||
|
)
|
||||||
|
|
||||||
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
|
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
|
||||||
self._state_groups_id_gen = StreamIdGenerator(db_conn, "state_groups", "id")
|
self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
|
||||||
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
|
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
|
||||||
self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
|
self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
|
||||||
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
|
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
|
||||||
|
|
184
synapse/storage/deviceinbox.py
Normal file
184
synapse/storage/deviceinbox.py
Normal file
|
@ -0,0 +1,184 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import ujson
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from ._base import SQLBaseStore
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceInboxStore(SQLBaseStore):
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def add_messages_to_device_inbox(self, messages_by_user_then_device):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
messages_by_user_and_device(dict):
|
||||||
|
Dictionary of user_id to device_id to message.
|
||||||
|
Returns:
|
||||||
|
A deferred stream_id that resolves when the messages have been
|
||||||
|
inserted.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def select_devices_txn(txn, user_id, devices):
|
||||||
|
if not devices:
|
||||||
|
return []
|
||||||
|
sql = (
|
||||||
|
"SELECT user_id, device_id FROM devices"
|
||||||
|
" WHERE user_id = ? AND device_id IN ("
|
||||||
|
+ ",".join("?" * len(devices))
|
||||||
|
+ ")"
|
||||||
|
)
|
||||||
|
# TODO: Maybe this needs to be done in batches if there are
|
||||||
|
# too many local devices for a given user.
|
||||||
|
args = [user_id] + devices
|
||||||
|
txn.execute(sql, args)
|
||||||
|
return [tuple(row) for row in txn.fetchall()]
|
||||||
|
|
||||||
|
def add_messages_to_device_inbox_txn(txn, stream_id):
|
||||||
|
local_users_and_devices = set()
|
||||||
|
for user_id, messages_by_device in messages_by_user_then_device.items():
|
||||||
|
local_users_and_devices.update(
|
||||||
|
select_devices_txn(txn, user_id, messages_by_device.keys())
|
||||||
|
)
|
||||||
|
|
||||||
|
sql = (
|
||||||
|
"INSERT INTO device_inbox"
|
||||||
|
" (user_id, device_id, stream_id, message_json)"
|
||||||
|
" VALUES (?,?,?,?)"
|
||||||
|
)
|
||||||
|
rows = []
|
||||||
|
for user_id, messages_by_device in messages_by_user_then_device.items():
|
||||||
|
for device_id, message in messages_by_device.items():
|
||||||
|
message_json = ujson.dumps(message)
|
||||||
|
# Only insert into the local inbox if the device exists on
|
||||||
|
# this server
|
||||||
|
if (user_id, device_id) in local_users_and_devices:
|
||||||
|
rows.append((user_id, device_id, stream_id, message_json))
|
||||||
|
|
||||||
|
txn.executemany(sql, rows)
|
||||||
|
|
||||||
|
with self._device_inbox_id_gen.get_next() as stream_id:
|
||||||
|
yield self.runInteraction(
|
||||||
|
"add_messages_to_device_inbox",
|
||||||
|
add_messages_to_device_inbox_txn,
|
||||||
|
stream_id
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue(self._device_inbox_id_gen.get_current_token())
|
||||||
|
|
||||||
|
def get_new_messages_for_device(
|
||||||
|
self, user_id, device_id, last_stream_id, current_stream_id, limit=100
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
user_id(str): The recipient user_id.
|
||||||
|
device_id(str): The recipient device_id.
|
||||||
|
current_stream_id(int): The current position of the to device
|
||||||
|
message stream.
|
||||||
|
Returns:
|
||||||
|
Deferred ([dict], int): List of messages for the device and where
|
||||||
|
in the stream the messages got to.
|
||||||
|
"""
|
||||||
|
def get_new_messages_for_device_txn(txn):
|
||||||
|
sql = (
|
||||||
|
"SELECT stream_id, message_json FROM device_inbox"
|
||||||
|
" WHERE user_id = ? AND device_id = ?"
|
||||||
|
" AND ? < stream_id AND stream_id <= ?"
|
||||||
|
" ORDER BY stream_id ASC"
|
||||||
|
" LIMIT ?"
|
||||||
|
)
|
||||||
|
txn.execute(sql, (
|
||||||
|
user_id, device_id, last_stream_id, current_stream_id, limit
|
||||||
|
))
|
||||||
|
messages = []
|
||||||
|
for row in txn.fetchall():
|
||||||
|
stream_pos = row[0]
|
||||||
|
messages.append(ujson.loads(row[1]))
|
||||||
|
if len(messages) < limit:
|
||||||
|
stream_pos = current_stream_id
|
||||||
|
return (messages, stream_pos)
|
||||||
|
|
||||||
|
return self.runInteraction(
|
||||||
|
"get_new_messages_for_device", get_new_messages_for_device_txn,
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete_messages_for_device(self, user_id, device_id, up_to_stream_id):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
user_id(str): The recipient user_id.
|
||||||
|
device_id(str): The recipient device_id.
|
||||||
|
up_to_stream_id(int): Where to delete messages up to.
|
||||||
|
Returns:
|
||||||
|
A deferred that resolves when the messages have been deleted.
|
||||||
|
"""
|
||||||
|
def delete_messages_for_device_txn(txn):
|
||||||
|
sql = (
|
||||||
|
"DELETE FROM device_inbox"
|
||||||
|
" WHERE user_id = ? AND device_id = ?"
|
||||||
|
" AND stream_id <= ?"
|
||||||
|
)
|
||||||
|
txn.execute(sql, (user_id, device_id, up_to_stream_id))
|
||||||
|
|
||||||
|
return self.runInteraction(
|
||||||
|
"delete_messages_for_device", delete_messages_for_device_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_all_new_device_messages(self, last_pos, current_pos, limit):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
last_pos(int):
|
||||||
|
current_pos(int):
|
||||||
|
limit(int):
|
||||||
|
Returns:
|
||||||
|
A deferred list of rows from the device inbox
|
||||||
|
"""
|
||||||
|
if last_pos == current_pos:
|
||||||
|
return defer.succeed([])
|
||||||
|
|
||||||
|
def get_all_new_device_messages_txn(txn):
|
||||||
|
sql = (
|
||||||
|
"SELECT stream_id FROM device_inbox"
|
||||||
|
" WHERE ? < stream_id AND stream_id <= ?"
|
||||||
|
" GROUP BY stream_id"
|
||||||
|
" ORDER BY stream_id ASC"
|
||||||
|
" LIMIT ?"
|
||||||
|
)
|
||||||
|
txn.execute(sql, (last_pos, current_pos, limit))
|
||||||
|
stream_ids = txn.fetchall()
|
||||||
|
if not stream_ids:
|
||||||
|
return []
|
||||||
|
max_stream_id_in_limit = stream_ids[-1]
|
||||||
|
|
||||||
|
sql = (
|
||||||
|
"SELECT stream_id, user_id, device_id, message_json"
|
||||||
|
" FROM device_inbox"
|
||||||
|
" WHERE ? < stream_id AND stream_id <= ?"
|
||||||
|
" ORDER BY stream_id ASC"
|
||||||
|
)
|
||||||
|
txn.execute(sql, (last_pos, max_stream_id_in_limit))
|
||||||
|
return txn.fetchall()
|
||||||
|
|
||||||
|
return self.runInteraction(
|
||||||
|
"get_all_new_device_messages", get_all_new_device_messages_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_to_device_stream_token(self):
|
||||||
|
return self._device_inbox_id_gen.get_current_token()
|
|
@ -271,22 +271,11 @@ class EventsStore(SQLBaseStore):
|
||||||
len(events_and_contexts)
|
len(events_and_contexts)
|
||||||
)
|
)
|
||||||
|
|
||||||
state_group_id_manager = self._state_groups_id_gen.get_next_mult(
|
|
||||||
len(events_and_contexts)
|
|
||||||
)
|
|
||||||
with stream_ordering_manager as stream_orderings:
|
with stream_ordering_manager as stream_orderings:
|
||||||
with state_group_id_manager as state_group_ids:
|
for (event, context), stream, in zip(
|
||||||
for (event, context), stream, state_group_id in zip(
|
events_and_contexts, stream_orderings
|
||||||
events_and_contexts, stream_orderings, state_group_ids
|
|
||||||
):
|
):
|
||||||
event.internal_metadata.stream_ordering = stream
|
event.internal_metadata.stream_ordering = stream
|
||||||
# Assign a state group_id in case a new id is needed for
|
|
||||||
# this context. In theory we only need to assign this
|
|
||||||
# for contexts that have current_state and aren't outliers
|
|
||||||
# but that make the code more complicated. Assigning an ID
|
|
||||||
# per event only causes the state_group_ids to grow as fast
|
|
||||||
# as the stream_ordering so in practise shouldn't be a problem.
|
|
||||||
context.new_state_group_id = state_group_id
|
|
||||||
|
|
||||||
chunks = [
|
chunks = [
|
||||||
events_and_contexts[x:x + 100]
|
events_and_contexts[x:x + 100]
|
||||||
|
@ -312,9 +301,7 @@ class EventsStore(SQLBaseStore):
|
||||||
delete_existing=False):
|
delete_existing=False):
|
||||||
try:
|
try:
|
||||||
with self._stream_id_gen.get_next() as stream_ordering:
|
with self._stream_id_gen.get_next() as stream_ordering:
|
||||||
with self._state_groups_id_gen.get_next() as state_group_id:
|
|
||||||
event.internal_metadata.stream_ordering = stream_ordering
|
event.internal_metadata.stream_ordering = stream_ordering
|
||||||
context.new_state_group_id = state_group_id
|
|
||||||
yield self.runInteraction(
|
yield self.runInteraction(
|
||||||
"persist_event",
|
"persist_event",
|
||||||
self._persist_event_txn,
|
self._persist_event_txn,
|
||||||
|
@ -393,7 +380,6 @@ class EventsStore(SQLBaseStore):
|
||||||
txn.call_after(self._get_current_state_for_key.invalidate_all)
|
txn.call_after(self._get_current_state_for_key.invalidate_all)
|
||||||
txn.call_after(self.get_rooms_for_user.invalidate_all)
|
txn.call_after(self.get_rooms_for_user.invalidate_all)
|
||||||
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
|
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
|
||||||
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
|
|
||||||
|
|
||||||
# Add an entry to the current_state_resets table to record the point
|
# Add an entry to the current_state_resets table to record the point
|
||||||
# where we clobbered the current state
|
# where we clobbered the current state
|
||||||
|
@ -529,7 +515,7 @@ class EventsStore(SQLBaseStore):
|
||||||
# Add an entry to the ex_outlier_stream table to replicate the
|
# Add an entry to the ex_outlier_stream table to replicate the
|
||||||
# change in outlier status to our workers.
|
# change in outlier status to our workers.
|
||||||
stream_order = event.internal_metadata.stream_ordering
|
stream_order = event.internal_metadata.stream_ordering
|
||||||
state_group_id = context.state_group or context.new_state_group_id
|
state_group_id = context.state_group
|
||||||
self._simple_insert_txn(
|
self._simple_insert_txn(
|
||||||
txn,
|
txn,
|
||||||
table="ex_outlier_stream",
|
table="ex_outlier_stream",
|
||||||
|
|
|
@ -16,7 +16,6 @@
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
|
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
|
||||||
from synapse.push.baserules import list_with_base_rules
|
from synapse.push.baserules import list_with_base_rules
|
||||||
from synapse.api.constants import EventTypes, Membership
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
@ -124,7 +123,8 @@ class PushRuleStore(SQLBaseStore):
|
||||||
|
|
||||||
defer.returnValue(results)
|
defer.returnValue(results)
|
||||||
|
|
||||||
def bulk_get_push_rules_for_room(self, room_id, state_group, current_state):
|
def bulk_get_push_rules_for_room(self, event, context):
|
||||||
|
state_group = context.state_group
|
||||||
if not state_group:
|
if not state_group:
|
||||||
# If state_group is None it means it has yet to be assigned a
|
# If state_group is None it means it has yet to be assigned a
|
||||||
# state group, i.e. we need to make sure that calls with a state_group
|
# state group, i.e. we need to make sure that calls with a state_group
|
||||||
|
@ -132,11 +132,13 @@ class PushRuleStore(SQLBaseStore):
|
||||||
# To do this we set the state_group to a new object as object() != object()
|
# To do this we set the state_group to a new object as object() != object()
|
||||||
state_group = object()
|
state_group = object()
|
||||||
|
|
||||||
return self._bulk_get_push_rules_for_room(room_id, state_group, current_state)
|
return self._bulk_get_push_rules_for_room(
|
||||||
|
event.room_id, state_group, context.current_state_ids, event=event
|
||||||
|
)
|
||||||
|
|
||||||
@cachedInlineCallbacks(num_args=2, cache_context=True)
|
@cachedInlineCallbacks(num_args=2, cache_context=True)
|
||||||
def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state,
|
def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids,
|
||||||
cache_context):
|
cache_context, event=None):
|
||||||
# We don't use `state_group`, its there so that we can cache based
|
# We don't use `state_group`, its there so that we can cache based
|
||||||
# on it. However, its important that its never None, since two current_state's
|
# on it. However, its important that its never None, since two current_state's
|
||||||
# with a state_group of None are likely to be different.
|
# with a state_group of None are likely to be different.
|
||||||
|
@ -147,12 +149,15 @@ class PushRuleStore(SQLBaseStore):
|
||||||
# their unread countss are correct in the event stream, but to avoid
|
# their unread countss are correct in the event stream, but to avoid
|
||||||
# generating them for bot / AS users etc, we only do so for people who've
|
# generating them for bot / AS users etc, we only do so for people who've
|
||||||
# sent a read receipt into the room.
|
# sent a read receipt into the room.
|
||||||
local_users_in_room = set(
|
|
||||||
e.state_key for e in current_state.values()
|
users_in_room = yield self._get_joined_users_from_context(
|
||||||
if e.type == EventTypes.Member and e.membership == Membership.JOIN
|
room_id, state_group, current_state_ids,
|
||||||
and self.hs.is_mine_id(e.state_key)
|
on_invalidate=cache_context.invalidate,
|
||||||
|
event=event,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
local_users_in_room = set(u for u in users_in_room if self.hs.is_mine_id(u))
|
||||||
|
|
||||||
# users in the room who have pushers need to get push rules run because
|
# users in the room who have pushers need to get push rules run because
|
||||||
# that's how their pushers work
|
# that's how their pushers work
|
||||||
if_users_with_pushers = yield self.get_if_users_have_pushers(
|
if_users_with_pushers = yield self.get_if_users_have_pushers(
|
||||||
|
|
|
@ -145,7 +145,7 @@ class ReceiptsStore(SQLBaseStore):
|
||||||
|
|
||||||
defer.returnValue([ev for res in results.values() for ev in res])
|
defer.returnValue([ev for res in results.values() for ev in res])
|
||||||
|
|
||||||
@cachedInlineCallbacks(num_args=3, max_entries=5000, tree=True)
|
@cachedInlineCallbacks(num_args=3, tree=True)
|
||||||
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
|
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
|
||||||
"""Get receipts for a single room for sending to clients.
|
"""Get receipts for a single room for sending to clients.
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ from collections import namedtuple
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
||||||
|
|
||||||
from synapse.api.constants import Membership
|
from synapse.api.constants import Membership, EventTypes
|
||||||
from synapse.types import get_domain_from_id
|
from synapse.types import get_domain_from_id
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
@ -56,7 +56,6 @@ class RoomMemberStore(SQLBaseStore):
|
||||||
|
|
||||||
for event in events:
|
for event in events:
|
||||||
txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
|
txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
|
||||||
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
|
|
||||||
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
|
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self._membership_stream_cache.entity_has_changed,
|
self._membership_stream_cache.entity_has_changed,
|
||||||
|
@ -238,11 +237,6 @@ class RoomMemberStore(SQLBaseStore):
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@cachedInlineCallbacks(max_entries=5000)
|
|
||||||
def get_joined_hosts_for_room(self, room_id):
|
|
||||||
user_ids = yield self.get_users_in_room(room_id)
|
|
||||||
defer.returnValue(set(get_domain_from_id(uid) for uid in user_ids))
|
|
||||||
|
|
||||||
def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None):
|
def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None):
|
||||||
where_clause = "c.room_id = ?"
|
where_clause = "c.room_id = ?"
|
||||||
where_values = [room_id]
|
where_values = [room_id]
|
||||||
|
@ -325,7 +319,8 @@ class RoomMemberStore(SQLBaseStore):
|
||||||
|
|
||||||
@cachedInlineCallbacks(num_args=3)
|
@cachedInlineCallbacks(num_args=3)
|
||||||
def was_forgotten_at(self, user_id, room_id, event_id):
|
def was_forgotten_at(self, user_id, room_id, event_id):
|
||||||
"""Returns whether user_id has elected to discard history for room_id at event_id.
|
"""Returns whether user_id has elected to discard history for room_id at
|
||||||
|
event_id.
|
||||||
|
|
||||||
event_id must be a membership event."""
|
event_id must be a membership event."""
|
||||||
def f(txn):
|
def f(txn):
|
||||||
|
@ -358,3 +353,98 @@ class RoomMemberStore(SQLBaseStore):
|
||||||
},
|
},
|
||||||
desc="who_forgot"
|
desc="who_forgot"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_joined_users_from_context(self, event, context):
|
||||||
|
state_group = context.state_group
|
||||||
|
if not state_group:
|
||||||
|
# If state_group is None it means it has yet to be assigned a
|
||||||
|
# state group, i.e. we need to make sure that calls with a state_group
|
||||||
|
# of None don't hit previous cached calls with a None state_group.
|
||||||
|
# To do this we set the state_group to a new object as object() != object()
|
||||||
|
state_group = object()
|
||||||
|
|
||||||
|
return self._get_joined_users_from_context(
|
||||||
|
event.room_id, state_group, context.current_state_ids, event=event,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_joined_users_from_state(self, room_id, state_group, state_ids):
|
||||||
|
if not state_group:
|
||||||
|
# If state_group is None it means it has yet to be assigned a
|
||||||
|
# state group, i.e. we need to make sure that calls with a state_group
|
||||||
|
# of None don't hit previous cached calls with a None state_group.
|
||||||
|
# To do this we set the state_group to a new object as object() != object()
|
||||||
|
state_group = object()
|
||||||
|
|
||||||
|
return self._get_joined_users_from_context(
|
||||||
|
room_id, state_group, state_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
@cachedInlineCallbacks(num_args=2, cache_context=True)
|
||||||
|
def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
|
||||||
|
cache_context, event=None):
|
||||||
|
# We don't use `state_group`, its there so that we can cache based
|
||||||
|
# on it. However, its important that its never None, since two current_state's
|
||||||
|
# with a state_group of None are likely to be different.
|
||||||
|
# See bulk_get_push_rules_for_room for how we work around this.
|
||||||
|
assert state_group is not None
|
||||||
|
|
||||||
|
member_event_ids = [
|
||||||
|
e_id
|
||||||
|
for key, e_id in current_state_ids.iteritems()
|
||||||
|
if key[0] == EventTypes.Member
|
||||||
|
]
|
||||||
|
|
||||||
|
rows = yield self._simple_select_many_batch(
|
||||||
|
table="room_memberships",
|
||||||
|
column="event_id",
|
||||||
|
iterable=member_event_ids,
|
||||||
|
retcols=['user_id'],
|
||||||
|
keyvalues={
|
||||||
|
"membership": Membership.JOIN,
|
||||||
|
},
|
||||||
|
batch_size=1000,
|
||||||
|
desc="_get_joined_users_from_context",
|
||||||
|
)
|
||||||
|
|
||||||
|
users_in_room = set(row["user_id"] for row in rows)
|
||||||
|
if event is not None and event.type == EventTypes.Member:
|
||||||
|
if event.membership == Membership.JOIN:
|
||||||
|
if event.event_id in member_event_ids:
|
||||||
|
users_in_room.add(event.state_key)
|
||||||
|
|
||||||
|
defer.returnValue(users_in_room)
|
||||||
|
|
||||||
|
def is_host_joined(self, room_id, host, state_group, state_ids):
|
||||||
|
if not state_group:
|
||||||
|
# If state_group is None it means it has yet to be assigned a
|
||||||
|
# state group, i.e. we need to make sure that calls with a state_group
|
||||||
|
# of None don't hit previous cached calls with a None state_group.
|
||||||
|
# To do this we set the state_group to a new object as object() != object()
|
||||||
|
state_group = object()
|
||||||
|
|
||||||
|
return self._is_host_joined(
|
||||||
|
room_id, host, state_group, state_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
@cachedInlineCallbacks(num_args=3)
|
||||||
|
def _is_host_joined(self, room_id, host, state_group, current_state_ids):
|
||||||
|
# We don't use `state_group`, its there so that we can cache based
|
||||||
|
# on it. However, its important that its never None, since two current_state's
|
||||||
|
# with a state_group of None are likely to be different.
|
||||||
|
# See bulk_get_push_rules_for_room for how we work around this.
|
||||||
|
assert state_group is not None
|
||||||
|
|
||||||
|
for (etype, state_key), event_id in current_state_ids.items():
|
||||||
|
if etype == EventTypes.Member:
|
||||||
|
try:
|
||||||
|
if get_domain_from_id(state_key) != host:
|
||||||
|
continue
|
||||||
|
except:
|
||||||
|
logger.warn("state_key not user_id: %s", state_key)
|
||||||
|
continue
|
||||||
|
|
||||||
|
event = yield self.get_event(event_id, allow_none=True)
|
||||||
|
if event and event.content["membership"] == Membership.JOIN:
|
||||||
|
defer.returnValue(True)
|
||||||
|
|
||||||
|
defer.returnValue(False)
|
||||||
|
|
24
synapse/storage/schema/delta/34/device_inbox.sql
Normal file
24
synapse/storage/schema/delta/34/device_inbox.sql
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
/* Copyright 2016 OpenMarket Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
CREATE TABLE device_inbox (
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
device_id TEXT NOT NULL,
|
||||||
|
stream_id BIGINT NOT NULL,
|
||||||
|
message_json TEXT NOT NULL -- {"type":, "sender":, "content",}
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX device_inbox_user_stream_id ON device_inbox(user_id, device_id, stream_id);
|
||||||
|
CREATE INDEX device_inbox_stream_id ON device_inbox(stream_id);
|
32
synapse/storage/schema/delta/34/sent_txn_purge.py
Normal file
32
synapse/storage/schema/delta/34/sent_txn_purge.py
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from synapse.storage.engines import PostgresEngine
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def run_create(cur, database_engine, *args, **kwargs):
|
||||||
|
if isinstance(database_engine, PostgresEngine):
|
||||||
|
cur.execute("TRUNCATE sent_transactions")
|
||||||
|
else:
|
||||||
|
cur.execute("DELETE FROM sent_transactions")
|
||||||
|
|
||||||
|
cur.execute("CREATE INDEX sent_transactions_ts ON sent_transactions(ts)")
|
||||||
|
|
||||||
|
|
||||||
|
def run_upgrade(cur, database_engine, *args, **kwargs):
|
||||||
|
pass
|
|
@ -44,11 +44,7 @@ class StateStore(SQLBaseStore):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_state_groups(self, room_id, event_ids):
|
def get_state_groups_ids(self, room_id, event_ids):
|
||||||
""" Get the state groups for the given list of event_ids
|
|
||||||
|
|
||||||
The return value is a dict mapping group names to lists of events.
|
|
||||||
"""
|
|
||||||
if not event_ids:
|
if not event_ids:
|
||||||
defer.returnValue({})
|
defer.returnValue({})
|
||||||
|
|
||||||
|
@ -59,36 +55,64 @@ class StateStore(SQLBaseStore):
|
||||||
groups = set(event_to_groups.values())
|
groups = set(event_to_groups.values())
|
||||||
group_to_state = yield self._get_state_for_groups(groups)
|
group_to_state = yield self._get_state_for_groups(groups)
|
||||||
|
|
||||||
|
defer.returnValue(group_to_state)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_state_groups(self, room_id, event_ids):
|
||||||
|
""" Get the state groups for the given list of event_ids
|
||||||
|
|
||||||
|
The return value is a dict mapping group names to lists of events.
|
||||||
|
"""
|
||||||
|
if not event_ids:
|
||||||
|
defer.returnValue({})
|
||||||
|
|
||||||
|
group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
|
||||||
|
|
||||||
|
state_event_map = yield self.get_events(
|
||||||
|
[
|
||||||
|
ev_id for group_ids in group_to_ids.values()
|
||||||
|
for ev_id in group_ids.values()
|
||||||
|
],
|
||||||
|
get_prev_content=False
|
||||||
|
)
|
||||||
|
|
||||||
defer.returnValue({
|
defer.returnValue({
|
||||||
group: state_map.values()
|
group: [
|
||||||
for group, state_map in group_to_state.items()
|
state_event_map[v] for v in event_id_map.values() if v in state_event_map
|
||||||
|
]
|
||||||
|
for group, event_id_map in group_to_ids.items()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
def _have_persisted_state_group_txn(self, txn, state_group):
|
||||||
|
txn.execute(
|
||||||
|
"SELECT count(*) FROM state_groups WHERE id = ?",
|
||||||
|
(state_group,)
|
||||||
|
)
|
||||||
|
row = txn.fetchone()
|
||||||
|
return row and row[0]
|
||||||
|
|
||||||
def _store_mult_state_groups_txn(self, txn, events_and_contexts):
|
def _store_mult_state_groups_txn(self, txn, events_and_contexts):
|
||||||
state_groups = {}
|
state_groups = {}
|
||||||
for event, context in events_and_contexts:
|
for event, context in events_and_contexts:
|
||||||
if event.internal_metadata.is_outlier():
|
if event.internal_metadata.is_outlier():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if context.current_state is None:
|
if context.current_state_ids is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if context.state_group is not None:
|
|
||||||
state_groups[event.event_id] = context.state_group
|
state_groups[event.event_id] = context.state_group
|
||||||
|
|
||||||
|
if self._have_persisted_state_group_txn(txn, context.state_group):
|
||||||
|
logger.info("Already persisted state_group: %r", context.state_group)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
state_events = dict(context.current_state)
|
state_event_ids = dict(context.current_state_ids)
|
||||||
|
|
||||||
if event.is_state():
|
|
||||||
state_events[(event.type, event.state_key)] = event
|
|
||||||
|
|
||||||
state_group = context.new_state_group_id
|
|
||||||
|
|
||||||
self._simple_insert_txn(
|
self._simple_insert_txn(
|
||||||
txn,
|
txn,
|
||||||
table="state_groups",
|
table="state_groups",
|
||||||
values={
|
values={
|
||||||
"id": state_group,
|
"id": context.state_group,
|
||||||
"room_id": event.room_id,
|
"room_id": event.room_id,
|
||||||
"event_id": event.event_id,
|
"event_id": event.event_id,
|
||||||
},
|
},
|
||||||
|
@ -99,16 +123,15 @@ class StateStore(SQLBaseStore):
|
||||||
table="state_groups_state",
|
table="state_groups_state",
|
||||||
values=[
|
values=[
|
||||||
{
|
{
|
||||||
"state_group": state_group,
|
"state_group": context.state_group,
|
||||||
"room_id": state.room_id,
|
"room_id": event.room_id,
|
||||||
"type": state.type,
|
"type": key[0],
|
||||||
"state_key": state.state_key,
|
"state_key": key[1],
|
||||||
"event_id": state.event_id,
|
"event_id": state_id,
|
||||||
}
|
}
|
||||||
for state in state_events.values()
|
for key, state_id in state_event_ids.items()
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
state_groups[event.event_id] = state_group
|
|
||||||
|
|
||||||
self._simple_insert_many_txn(
|
self._simple_insert_many_txn(
|
||||||
txn,
|
txn,
|
||||||
|
@ -248,6 +271,31 @@ class StateStore(SQLBaseStore):
|
||||||
groups = set(event_to_groups.values())
|
groups = set(event_to_groups.values())
|
||||||
group_to_state = yield self._get_state_for_groups(groups, types)
|
group_to_state = yield self._get_state_for_groups(groups, types)
|
||||||
|
|
||||||
|
state_event_map = yield self.get_events(
|
||||||
|
[ev_id for sd in group_to_state.values() for ev_id in sd.values()],
|
||||||
|
get_prev_content=False
|
||||||
|
)
|
||||||
|
|
||||||
|
event_to_state = {
|
||||||
|
event_id: {
|
||||||
|
k: state_event_map[v]
|
||||||
|
for k, v in group_to_state[group].items()
|
||||||
|
if v in state_event_map
|
||||||
|
}
|
||||||
|
for event_id, group in event_to_groups.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
defer.returnValue({event: event_to_state[event] for event in event_ids})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_state_ids_for_events(self, event_ids, types):
|
||||||
|
event_to_groups = yield self._get_state_group_for_events(
|
||||||
|
event_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
groups = set(event_to_groups.values())
|
||||||
|
group_to_state = yield self._get_state_for_groups(groups, types)
|
||||||
|
|
||||||
event_to_state = {
|
event_to_state = {
|
||||||
event_id: group_to_state[group]
|
event_id: group_to_state[group]
|
||||||
for event_id, group in event_to_groups.items()
|
for event_id, group in event_to_groups.items()
|
||||||
|
@ -272,6 +320,23 @@ class StateStore(SQLBaseStore):
|
||||||
state_map = yield self.get_state_for_events([event_id], types)
|
state_map = yield self.get_state_for_events([event_id], types)
|
||||||
defer.returnValue(state_map[event_id])
|
defer.returnValue(state_map[event_id])
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_state_ids_for_event(self, event_id, types=None):
|
||||||
|
"""
|
||||||
|
Get the state dict corresponding to a particular event
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_id(str): event whose state should be returned
|
||||||
|
types(list[(str, str)]|None): List of (type, state_key) tuples
|
||||||
|
which are used to filter the state fetched. May be None, which
|
||||||
|
matches any key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A deferred dict from (type, state_key) -> state_event
|
||||||
|
"""
|
||||||
|
state_map = yield self.get_state_ids_for_events([event_id], types)
|
||||||
|
defer.returnValue(state_map[event_id])
|
||||||
|
|
||||||
@cached(num_args=2, max_entries=10000)
|
@cached(num_args=2, max_entries=10000)
|
||||||
def _get_state_group_for_event(self, room_id, event_id):
|
def _get_state_group_for_event(self, room_id, event_id):
|
||||||
return self._simple_select_one_onecol(
|
return self._simple_select_one_onecol(
|
||||||
|
@ -428,20 +493,13 @@ class StateStore(SQLBaseStore):
|
||||||
full=(types is None),
|
full=(types is None),
|
||||||
)
|
)
|
||||||
|
|
||||||
state_events = yield self._get_events(
|
|
||||||
[ev_id for sd in results.values() for ev_id in sd.values()],
|
|
||||||
get_prev_content=False
|
|
||||||
)
|
|
||||||
|
|
||||||
state_events = {e.event_id: e for e in state_events}
|
|
||||||
|
|
||||||
# Remove all the entries with None values. The None values were just
|
# Remove all the entries with None values. The None values were just
|
||||||
# used for bookkeeping in the cache.
|
# used for bookkeeping in the cache.
|
||||||
for group, state_dict in results.items():
|
for group, state_dict in results.items():
|
||||||
results[group] = {
|
results[group] = {
|
||||||
key: state_events[event_id]
|
key: event_id
|
||||||
for key, event_id in state_dict.items()
|
for key, event_id in state_dict.items()
|
||||||
if event_id and event_id in state_events
|
if event_id
|
||||||
}
|
}
|
||||||
|
|
||||||
defer.returnValue(results)
|
defer.returnValue(results)
|
||||||
|
@ -473,5 +531,5 @@ class StateStore(SQLBaseStore):
|
||||||
"get_all_new_state_groups", get_all_new_state_groups_txn
|
"get_all_new_state_groups", get_all_new_state_groups_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_state_stream_token(self):
|
def get_next_state_group(self):
|
||||||
return self._state_groups_id_gen.get_current_token()
|
return self._state_groups_id_gen.get_next()
|
||||||
|
|
|
@ -245,7 +245,7 @@ class TransactionStore(SQLBaseStore):
|
||||||
|
|
||||||
return self.cursor_to_dict(txn)
|
return self.cursor_to_dict(txn)
|
||||||
|
|
||||||
@cached()
|
@cached(max_entries=10000)
|
||||||
def get_destination_retry_timings(self, destination):
|
def get_destination_retry_timings(self, destination):
|
||||||
"""Gets the current retry timings (if any) for a given destination.
|
"""Gets the current retry timings (if any) for a given destination.
|
||||||
|
|
||||||
|
@ -387,8 +387,10 @@ class TransactionStore(SQLBaseStore):
|
||||||
def _cleanup_transactions(self):
|
def _cleanup_transactions(self):
|
||||||
now = self._clock.time_msec()
|
now = self._clock.time_msec()
|
||||||
month_ago = now - 30 * 24 * 60 * 60 * 1000
|
month_ago = now - 30 * 24 * 60 * 60 * 1000
|
||||||
|
six_hours_ago = now - 6 * 60 * 60 * 1000
|
||||||
|
|
||||||
def _cleanup_transactions_txn(txn):
|
def _cleanup_transactions_txn(txn):
|
||||||
txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
|
txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
|
||||||
|
txn.execute("DELETE FROM sent_transactions WHERE ts < ?", (six_hours_ago,))
|
||||||
|
|
||||||
return self.runInteraction("_persist_in_mem_txns", _cleanup_transactions_txn)
|
return self.runInteraction("_persist_in_mem_txns", _cleanup_transactions_txn)
|
||||||
|
|
|
@ -43,6 +43,7 @@ class EventSources(object):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_current_token(self, direction='f'):
|
def get_current_token(self, direction='f'):
|
||||||
push_rules_key, _ = self.store.get_push_rules_stream_token()
|
push_rules_key, _ = self.store.get_push_rules_stream_token()
|
||||||
|
to_device_key = self.store.get_to_device_stream_token()
|
||||||
|
|
||||||
token = StreamToken(
|
token = StreamToken(
|
||||||
room_key=(
|
room_key=(
|
||||||
|
@ -61,5 +62,6 @@ class EventSources(object):
|
||||||
yield self.sources["account_data"].get_current_key()
|
yield self.sources["account_data"].get_current_key()
|
||||||
),
|
),
|
||||||
push_rules_key=push_rules_key,
|
push_rules_key=push_rules_key,
|
||||||
|
to_device_key=to_device_key,
|
||||||
)
|
)
|
||||||
defer.returnValue(token)
|
defer.returnValue(token)
|
||||||
|
|
|
@ -154,6 +154,7 @@ class StreamToken(
|
||||||
"receipt_key",
|
"receipt_key",
|
||||||
"account_data_key",
|
"account_data_key",
|
||||||
"push_rules_key",
|
"push_rules_key",
|
||||||
|
"to_device_key",
|
||||||
))
|
))
|
||||||
):
|
):
|
||||||
_SEPARATOR = "_"
|
_SEPARATOR = "_"
|
||||||
|
@ -190,6 +191,7 @@ class StreamToken(
|
||||||
or (int(other.receipt_key) < int(self.receipt_key))
|
or (int(other.receipt_key) < int(self.receipt_key))
|
||||||
or (int(other.account_data_key) < int(self.account_data_key))
|
or (int(other.account_data_key) < int(self.account_data_key))
|
||||||
or (int(other.push_rules_key) < int(self.push_rules_key))
|
or (int(other.push_rules_key) < int(self.push_rules_key))
|
||||||
|
or (int(other.to_device_key) < int(self.to_device_key))
|
||||||
)
|
)
|
||||||
|
|
||||||
def copy_and_advance(self, key, new_value):
|
def copy_and_advance(self, key, new_value):
|
||||||
|
@ -269,10 +271,3 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
|
||||||
return "t%d-%d" % (self.topological, self.stream)
|
return "t%d-%d" % (self.topological, self.stream)
|
||||||
else:
|
else:
|
||||||
return "s%d" % (self.stream,)
|
return "s%d" % (self.stream,)
|
||||||
|
|
||||||
|
|
||||||
# Some arbitrary constants used for internal API enumerations. Don't rely on
|
|
||||||
# exact values; always pass or compare symbolically
|
|
||||||
class ThirdPartyEntityKind(object):
|
|
||||||
USER = 'user'
|
|
||||||
LOCATION = 'location'
|
|
||||||
|
|
|
@ -180,6 +180,25 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def filter_events_for_clients_context(store, user_tuples, events, event_id_to_context):
|
||||||
|
user_ids = set(u[0] for u in user_tuples)
|
||||||
|
event_id_to_state = {}
|
||||||
|
for event_id, context in event_id_to_context.items():
|
||||||
|
state = yield store.get_events([
|
||||||
|
e_id
|
||||||
|
for key, e_id in context.current_state_ids.iteritems()
|
||||||
|
if key == (EventTypes.RoomHistoryVisibility, "")
|
||||||
|
or (key[0] == EventTypes.Member and key[1] in user_ids)
|
||||||
|
])
|
||||||
|
event_id_to_state[event_id] = state
|
||||||
|
|
||||||
|
res = yield filter_events_for_clients(
|
||||||
|
store, user_tuples, events, event_id_to_state
|
||||||
|
)
|
||||||
|
defer.returnValue(res)
|
||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def filter_events_for_client(store, user_id, events, is_peeking=False):
|
def filter_events_for_client(store, user_id, events, is_peeking=False):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -115,6 +115,53 @@ class PresenceUpdateTestCase(unittest.TestCase):
|
||||||
),
|
),
|
||||||
], any_order=True)
|
], any_order=True)
|
||||||
|
|
||||||
|
def test_online_to_online_last_active_noop(self):
|
||||||
|
wheel_timer = Mock()
|
||||||
|
user_id = "@foo:bar"
|
||||||
|
now = 5000000
|
||||||
|
|
||||||
|
prev_state = UserPresenceState.default(user_id)
|
||||||
|
prev_state = prev_state.copy_and_replace(
|
||||||
|
state=PresenceState.ONLINE,
|
||||||
|
last_active_ts=now - LAST_ACTIVE_GRANULARITY - 10,
|
||||||
|
currently_active=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
new_state = prev_state.copy_and_replace(
|
||||||
|
state=PresenceState.ONLINE,
|
||||||
|
last_active_ts=now,
|
||||||
|
)
|
||||||
|
|
||||||
|
state, persist_and_notify, federation_ping = handle_update(
|
||||||
|
prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertFalse(persist_and_notify)
|
||||||
|
self.assertTrue(federation_ping)
|
||||||
|
self.assertTrue(state.currently_active)
|
||||||
|
self.assertEquals(new_state.state, state.state)
|
||||||
|
self.assertEquals(new_state.status_msg, state.status_msg)
|
||||||
|
self.assertEquals(state.last_federation_update_ts, now)
|
||||||
|
|
||||||
|
self.assertEquals(wheel_timer.insert.call_count, 3)
|
||||||
|
wheel_timer.insert.assert_has_calls([
|
||||||
|
call(
|
||||||
|
now=now,
|
||||||
|
obj=user_id,
|
||||||
|
then=new_state.last_active_ts + IDLE_TIMER
|
||||||
|
),
|
||||||
|
call(
|
||||||
|
now=now,
|
||||||
|
obj=user_id,
|
||||||
|
then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT
|
||||||
|
),
|
||||||
|
call(
|
||||||
|
now=now,
|
||||||
|
obj=user_id,
|
||||||
|
then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY
|
||||||
|
),
|
||||||
|
], any_order=True)
|
||||||
|
|
||||||
def test_online_to_online_last_active(self):
|
def test_online_to_online_last_active(self):
|
||||||
wheel_timer = Mock()
|
wheel_timer = Mock()
|
||||||
user_id = "@foo:bar"
|
user_id = "@foo:bar"
|
||||||
|
|
|
@ -62,6 +62,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||||
self.on_new_event = mock_notifier.on_new_event
|
self.on_new_event = mock_notifier.on_new_event
|
||||||
|
|
||||||
self.auth = Mock(spec=[])
|
self.auth = Mock(spec=[])
|
||||||
|
self.state_handler = Mock()
|
||||||
|
|
||||||
hs = yield setup_test_homeserver(
|
hs = yield setup_test_homeserver(
|
||||||
"test",
|
"test",
|
||||||
|
@ -75,6 +76,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||||
"set_received_txn_response",
|
"set_received_txn_response",
|
||||||
"get_destination_retry_timings",
|
"get_destination_retry_timings",
|
||||||
]),
|
]),
|
||||||
|
state_handler=self.state_handler,
|
||||||
handlers=None,
|
handlers=None,
|
||||||
notifier=mock_notifier,
|
notifier=mock_notifier,
|
||||||
resource_for_client=Mock(),
|
resource_for_client=Mock(),
|
||||||
|
@ -113,6 +115,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||||
return set(member.domain for member in self.room_members)
|
return set(member.domain for member in self.room_members)
|
||||||
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
|
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
|
||||||
|
|
||||||
|
def get_current_user_in_room(room_id):
|
||||||
|
return set(str(u) for u in self.room_members)
|
||||||
|
self.state_handler.get_current_user_in_room = get_current_user_in_room
|
||||||
|
|
||||||
self.auth.check_joined_room = check_joined_room
|
self.auth.check_joined_room = check_joined_room
|
||||||
|
|
||||||
# Some local users to test with
|
# Some local users to test with
|
||||||
|
|
|
@ -305,7 +305,16 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||||
|
|
||||||
self.event_id += 1
|
self.event_id += 1
|
||||||
|
|
||||||
context = EventContext(current_state=state)
|
if state is not None:
|
||||||
|
state_ids = {
|
||||||
|
key: e.event_id for key, e in state.items()
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
state_ids = None
|
||||||
|
|
||||||
|
context = EventContext()
|
||||||
|
context.current_state_ids = state_ids
|
||||||
|
context.prev_state_ids = state_ids
|
||||||
context.push_actions = push_actions
|
context.push_actions = push_actions
|
||||||
|
|
||||||
ordering = None
|
ordering = None
|
||||||
|
|
|
@ -60,8 +60,8 @@ class ReplicationResourceCase(unittest.TestCase):
|
||||||
self.assertEquals(body, {})
|
self.assertEquals(body, {})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_events_and_state(self):
|
def test_events(self):
|
||||||
get = self.get(events="-1", state="-1", timeout="0")
|
get = self.get(events="-1", timeout="0")
|
||||||
yield self.hs.get_handlers().room_creation_handler.create_room(
|
yield self.hs.get_handlers().room_creation_handler.create_room(
|
||||||
synapse.types.create_requester(self.user), {}
|
synapse.types.create_requester(self.user), {}
|
||||||
)
|
)
|
||||||
|
@ -70,12 +70,6 @@ class ReplicationResourceCase(unittest.TestCase):
|
||||||
self.assertEquals(body["events"]["field_names"], [
|
self.assertEquals(body["events"]["field_names"], [
|
||||||
"position", "internal", "json", "state_group"
|
"position", "internal", "json", "state_group"
|
||||||
])
|
])
|
||||||
self.assertEquals(body["state_groups"]["field_names"], [
|
|
||||||
"position", "room_id", "event_id"
|
|
||||||
])
|
|
||||||
self.assertEquals(body["state_group_state"]["field_names"], [
|
|
||||||
"position", "type", "state_key", "event_id"
|
|
||||||
])
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_presence(self):
|
def test_presence(self):
|
||||||
|
|
|
@ -1032,7 +1032,7 @@ class RoomMessageListTestCase(RestTestCase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_topo_token_is_accepted(self):
|
def test_topo_token_is_accepted(self):
|
||||||
token = "t1-0_0_0_0_0_0"
|
token = "t1-0_0_0_0_0_0_0"
|
||||||
(code, response) = yield self.mock_resource.trigger_get(
|
(code, response) = yield self.mock_resource.trigger_get(
|
||||||
"/rooms/%s/messages?access_token=x&from=%s" %
|
"/rooms/%s/messages?access_token=x&from=%s" %
|
||||||
(self.room_id, token))
|
(self.room_id, token))
|
||||||
|
@ -1044,7 +1044,7 @@ class RoomMessageListTestCase(RestTestCase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_stream_token_is_accepted_for_fwd_pagianation(self):
|
def test_stream_token_is_accepted_for_fwd_pagianation(self):
|
||||||
token = "s0_0_0_0_0_0"
|
token = "s0_0_0_0_0_0_0"
|
||||||
(code, response) = yield self.mock_resource.trigger_get(
|
(code, response) = yield self.mock_resource.trigger_get(
|
||||||
"/rooms/%s/messages?access_token=x&from=%s" %
|
"/rooms/%s/messages?access_token=x&from=%s" %
|
||||||
(self.room_id, token))
|
(self.room_id, token))
|
||||||
|
|
|
@ -78,44 +78,3 @@ class RoomMemberStoreTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_room_hosts(self):
|
|
||||||
yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN)
|
|
||||||
|
|
||||||
self.assertEquals(
|
|
||||||
{"test"},
|
|
||||||
(yield self.store.get_joined_hosts_for_room(self.room.to_string()))
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should still have just one host after second join from it
|
|
||||||
yield self.inject_room_member(self.room, self.u_bob, Membership.JOIN)
|
|
||||||
|
|
||||||
self.assertEquals(
|
|
||||||
{"test"},
|
|
||||||
(yield self.store.get_joined_hosts_for_room(self.room.to_string()))
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should now have two hosts after join from other host
|
|
||||||
yield self.inject_room_member(self.room, self.u_charlie, Membership.JOIN)
|
|
||||||
|
|
||||||
self.assertEquals(
|
|
||||||
{"test", "elsewhere"},
|
|
||||||
(yield self.store.get_joined_hosts_for_room(self.room.to_string()))
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should still have both hosts
|
|
||||||
yield self.inject_room_member(self.room, self.u_alice, Membership.LEAVE)
|
|
||||||
|
|
||||||
self.assertEquals(
|
|
||||||
{"test", "elsewhere"},
|
|
||||||
(yield self.store.get_joined_hosts_for_room(self.room.to_string()))
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should have only one host after other leaves
|
|
||||||
yield self.inject_room_member(self.room, self.u_charlie, Membership.LEAVE)
|
|
||||||
|
|
||||||
self.assertEquals(
|
|
||||||
{"test"},
|
|
||||||
(yield self.store.get_joined_hosts_for_room(self.room.to_string()))
|
|
||||||
)
|
|
||||||
|
|
|
@ -67,9 +67,11 @@ class StateGroupStore(object):
|
||||||
self._event_to_state_group = {}
|
self._event_to_state_group = {}
|
||||||
self._group_to_state = {}
|
self._group_to_state = {}
|
||||||
|
|
||||||
|
self._event_id_to_event = {}
|
||||||
|
|
||||||
self._next_group = 1
|
self._next_group = 1
|
||||||
|
|
||||||
def get_state_groups(self, room_id, event_ids):
|
def get_state_groups_ids(self, room_id, event_ids):
|
||||||
groups = {}
|
groups = {}
|
||||||
for event_id in event_ids:
|
for event_id in event_ids:
|
||||||
group = self._event_to_state_group.get(event_id)
|
group = self._event_to_state_group.get(event_id)
|
||||||
|
@ -79,22 +81,23 @@ class StateGroupStore(object):
|
||||||
return defer.succeed(groups)
|
return defer.succeed(groups)
|
||||||
|
|
||||||
def store_state_groups(self, event, context):
|
def store_state_groups(self, event, context):
|
||||||
if context.current_state is None:
|
if context.current_state_ids is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
state_events = context.current_state
|
state_events = dict(context.current_state_ids)
|
||||||
|
|
||||||
if event.is_state():
|
self._group_to_state[context.state_group] = state_events
|
||||||
state_events[(event.type, event.state_key)] = event
|
self._event_to_state_group[event.event_id] = context.state_group
|
||||||
|
|
||||||
state_group = context.state_group
|
def get_events(self, event_ids, **kwargs):
|
||||||
if not state_group:
|
return {
|
||||||
state_group = self._next_group
|
e_id: self._event_id_to_event[e_id] for e_id in event_ids
|
||||||
self._next_group += 1
|
if e_id in self._event_id_to_event
|
||||||
|
}
|
||||||
|
|
||||||
self._group_to_state[state_group] = state_events.values()
|
def register_events(self, events):
|
||||||
|
for e in events:
|
||||||
self._event_to_state_group[event.event_id] = state_group
|
self._event_id_to_event[e.event_id] = e
|
||||||
|
|
||||||
|
|
||||||
class DictObj(dict):
|
class DictObj(dict):
|
||||||
|
@ -136,8 +139,10 @@ class StateTestCase(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.store = Mock(
|
self.store = Mock(
|
||||||
spec_set=[
|
spec_set=[
|
||||||
"get_state_groups",
|
"get_state_groups_ids",
|
||||||
"add_event_hashes",
|
"add_event_hashes",
|
||||||
|
"get_events",
|
||||||
|
"get_next_state_group",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
hs = Mock(spec_set=[
|
hs = Mock(spec_set=[
|
||||||
|
@ -148,6 +153,8 @@ class StateTestCase(unittest.TestCase):
|
||||||
hs.get_clock.return_value = MockClock()
|
hs.get_clock.return_value = MockClock()
|
||||||
hs.get_auth.return_value = Auth(hs)
|
hs.get_auth.return_value = Auth(hs)
|
||||||
|
|
||||||
|
self.store.get_next_state_group.side_effect = Mock
|
||||||
|
|
||||||
self.state = StateHandler(hs)
|
self.state = StateHandler(hs)
|
||||||
self.event_id = 0
|
self.event_id = 0
|
||||||
|
|
||||||
|
@ -187,7 +194,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
store = StateGroupStore()
|
store = StateGroupStore()
|
||||||
self.store.get_state_groups.side_effect = store.get_state_groups
|
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
|
||||||
|
|
||||||
context_store = {}
|
context_store = {}
|
||||||
|
|
||||||
|
@ -196,7 +203,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
store.store_state_groups(event, context)
|
store.store_state_groups(event, context)
|
||||||
context_store[event.event_id] = context
|
context_store[event.event_id] = context
|
||||||
|
|
||||||
self.assertEqual(2, len(context_store["D"].current_state))
|
self.assertEqual(2, len(context_store["D"].prev_state_ids))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_branch_basic_conflict(self):
|
def test_branch_basic_conflict(self):
|
||||||
|
@ -239,7 +246,9 @@ class StateTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
store = StateGroupStore()
|
store = StateGroupStore()
|
||||||
self.store.get_state_groups.side_effect = store.get_state_groups
|
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
|
||||||
|
self.store.get_events = store.get_events
|
||||||
|
store.register_events(graph.walk())
|
||||||
|
|
||||||
context_store = {}
|
context_store = {}
|
||||||
|
|
||||||
|
@ -250,7 +259,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.assertSetEqual(
|
self.assertSetEqual(
|
||||||
{"START", "A", "C"},
|
{"START", "A", "C"},
|
||||||
{e.event_id for e in context_store["D"].current_state.values()}
|
{e_id for e_id in context_store["D"].prev_state_ids.values()}
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -303,7 +312,9 @@ class StateTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
store = StateGroupStore()
|
store = StateGroupStore()
|
||||||
self.store.get_state_groups.side_effect = store.get_state_groups
|
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
|
||||||
|
self.store.get_events = store.get_events
|
||||||
|
store.register_events(graph.walk())
|
||||||
|
|
||||||
context_store = {}
|
context_store = {}
|
||||||
|
|
||||||
|
@ -314,7 +325,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.assertSetEqual(
|
self.assertSetEqual(
|
||||||
{"START", "A", "B", "C"},
|
{"START", "A", "B", "C"},
|
||||||
{e.event_id for e in context_store["E"].current_state.values()}
|
{e for e in context_store["E"].prev_state_ids.values()}
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -384,7 +395,9 @@ class StateTestCase(unittest.TestCase):
|
||||||
graph = Graph(nodes, edges)
|
graph = Graph(nodes, edges)
|
||||||
|
|
||||||
store = StateGroupStore()
|
store = StateGroupStore()
|
||||||
self.store.get_state_groups.side_effect = store.get_state_groups
|
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
|
||||||
|
self.store.get_events = store.get_events
|
||||||
|
store.register_events(graph.walk())
|
||||||
|
|
||||||
context_store = {}
|
context_store = {}
|
||||||
|
|
||||||
|
@ -395,7 +408,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.assertSetEqual(
|
self.assertSetEqual(
|
||||||
{"A1", "A2", "A3", "A5", "B"},
|
{"A1", "A2", "A3", "A5", "B"},
|
||||||
{e.event_id for e in context_store["D"].current_state.values()}
|
{e for e in context_store["D"].prev_state_ids.values()}
|
||||||
)
|
)
|
||||||
|
|
||||||
def _add_depths(self, nodes, edges):
|
def _add_depths(self, nodes, edges):
|
||||||
|
@ -424,16 +437,11 @@ class StateTestCase(unittest.TestCase):
|
||||||
event, old_state=old_state
|
event, old_state=old_state
|
||||||
)
|
)
|
||||||
|
|
||||||
for k, v in context.current_state.items():
|
|
||||||
type, state_key = k
|
|
||||||
self.assertEqual(type, v.type)
|
|
||||||
self.assertEqual(state_key, v.state_key)
|
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
set(old_state), set(context.current_state.values())
|
set(e.event_id for e in old_state), set(context.current_state_ids.values())
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsNone(context.state_group)
|
self.assertIsNotNone(context.state_group)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_annotate_with_old_state(self):
|
def test_annotate_with_old_state(self):
|
||||||
|
@ -449,18 +457,10 @@ class StateTestCase(unittest.TestCase):
|
||||||
event, old_state=old_state
|
event, old_state=old_state
|
||||||
)
|
)
|
||||||
|
|
||||||
for k, v in context.current_state.items():
|
|
||||||
type, state_key = k
|
|
||||||
self.assertEqual(type, v.type)
|
|
||||||
self.assertEqual(state_key, v.state_key)
|
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
set(old_state),
|
set(e.event_id for e in old_state), set(context.prev_state_ids.values())
|
||||||
set(context.current_state.values())
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsNone(context.state_group)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_trivial_annotate_message(self):
|
def test_trivial_annotate_message(self):
|
||||||
event = create_event(type="test_message", name="event")
|
event = create_event(type="test_message", name="event")
|
||||||
|
@ -473,20 +473,15 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
group_name = "group_name_1"
|
group_name = "group_name_1"
|
||||||
|
|
||||||
self.store.get_state_groups.return_value = {
|
self.store.get_state_groups_ids.return_value = {
|
||||||
group_name: old_state,
|
group_name: {(e.type, e.state_key): e.event_id for e in old_state},
|
||||||
}
|
}
|
||||||
|
|
||||||
context = yield self.state.compute_event_context(event)
|
context = yield self.state.compute_event_context(event)
|
||||||
|
|
||||||
for k, v in context.current_state.items():
|
|
||||||
type, state_key = k
|
|
||||||
self.assertEqual(type, v.type)
|
|
||||||
self.assertEqual(state_key, v.state_key)
|
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
set([e.event_id for e in old_state]),
|
set([e.event_id for e in old_state]),
|
||||||
set([e.event_id for e in context.current_state.values()])
|
set(context.current_state_ids.values())
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(group_name, context.state_group)
|
self.assertEqual(group_name, context.state_group)
|
||||||
|
@ -503,23 +498,18 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
group_name = "group_name_1"
|
group_name = "group_name_1"
|
||||||
|
|
||||||
self.store.get_state_groups.return_value = {
|
self.store.get_state_groups_ids.return_value = {
|
||||||
group_name: old_state,
|
group_name: {(e.type, e.state_key): e.event_id for e in old_state},
|
||||||
}
|
}
|
||||||
|
|
||||||
context = yield self.state.compute_event_context(event)
|
context = yield self.state.compute_event_context(event)
|
||||||
|
|
||||||
for k, v in context.current_state.items():
|
|
||||||
type, state_key = k
|
|
||||||
self.assertEqual(type, v.type)
|
|
||||||
self.assertEqual(state_key, v.state_key)
|
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
set([e.event_id for e in old_state]),
|
set([e.event_id for e in old_state]),
|
||||||
set([e.event_id for e in context.current_state.values()])
|
set(context.prev_state_ids.values())
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsNone(context.state_group)
|
self.assertIsNotNone(context.state_group)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_resolve_message_conflict(self):
|
def test_resolve_message_conflict(self):
|
||||||
|
@ -543,11 +533,16 @@ class StateTestCase(unittest.TestCase):
|
||||||
create_event(type="test4", state_key=""),
|
create_event(type="test4", state_key=""),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
store = StateGroupStore()
|
||||||
|
store.register_events(old_state_1)
|
||||||
|
store.register_events(old_state_2)
|
||||||
|
self.store.get_events = store.get_events
|
||||||
|
|
||||||
context = yield self._get_context(event, old_state_1, old_state_2)
|
context = yield self._get_context(event, old_state_1, old_state_2)
|
||||||
|
|
||||||
self.assertEqual(len(context.current_state), 6)
|
self.assertEqual(len(context.current_state_ids), 6)
|
||||||
|
|
||||||
self.assertIsNone(context.state_group)
|
self.assertIsNotNone(context.state_group)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_resolve_state_conflict(self):
|
def test_resolve_state_conflict(self):
|
||||||
|
@ -571,11 +566,16 @@ class StateTestCase(unittest.TestCase):
|
||||||
create_event(type="test4", state_key=""),
|
create_event(type="test4", state_key=""),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
store = StateGroupStore()
|
||||||
|
store.register_events(old_state_1)
|
||||||
|
store.register_events(old_state_2)
|
||||||
|
self.store.get_events = store.get_events
|
||||||
|
|
||||||
context = yield self._get_context(event, old_state_1, old_state_2)
|
context = yield self._get_context(event, old_state_1, old_state_2)
|
||||||
|
|
||||||
self.assertEqual(len(context.current_state), 6)
|
self.assertEqual(len(context.current_state_ids), 6)
|
||||||
|
|
||||||
self.assertIsNone(context.state_group)
|
self.assertIsNotNone(context.state_group)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_standard_depth_conflict(self):
|
def test_standard_depth_conflict(self):
|
||||||
|
@ -606,9 +606,16 @@ class StateTestCase(unittest.TestCase):
|
||||||
create_event(type="test1", state_key="1", depth=2),
|
create_event(type="test1", state_key="1", depth=2),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
store = StateGroupStore()
|
||||||
|
store.register_events(old_state_1)
|
||||||
|
store.register_events(old_state_2)
|
||||||
|
self.store.get_events = store.get_events
|
||||||
|
|
||||||
context = yield self._get_context(event, old_state_1, old_state_2)
|
context = yield self._get_context(event, old_state_1, old_state_2)
|
||||||
|
|
||||||
self.assertEqual(old_state_2[2], context.current_state[("test1", "1")])
|
self.assertEqual(
|
||||||
|
old_state_2[2].event_id, context.current_state_ids[("test1", "1")]
|
||||||
|
)
|
||||||
|
|
||||||
# Reverse the depth to make sure we are actually using the depths
|
# Reverse the depth to make sure we are actually using the depths
|
||||||
# during state resolution.
|
# during state resolution.
|
||||||
|
@ -625,17 +632,22 @@ class StateTestCase(unittest.TestCase):
|
||||||
create_event(type="test1", state_key="1", depth=1),
|
create_event(type="test1", state_key="1", depth=1),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
store.register_events(old_state_1)
|
||||||
|
store.register_events(old_state_2)
|
||||||
|
|
||||||
context = yield self._get_context(event, old_state_1, old_state_2)
|
context = yield self._get_context(event, old_state_1, old_state_2)
|
||||||
|
|
||||||
self.assertEqual(old_state_1[2], context.current_state[("test1", "1")])
|
self.assertEqual(
|
||||||
|
old_state_1[2].event_id, context.current_state_ids[("test1", "1")]
|
||||||
|
)
|
||||||
|
|
||||||
def _get_context(self, event, old_state_1, old_state_2):
|
def _get_context(self, event, old_state_1, old_state_2):
|
||||||
group_name_1 = "group_name_1"
|
group_name_1 = "group_name_1"
|
||||||
group_name_2 = "group_name_2"
|
group_name_2 = "group_name_2"
|
||||||
|
|
||||||
self.store.get_state_groups.return_value = {
|
self.store.get_state_groups_ids.return_value = {
|
||||||
group_name_1: old_state_1,
|
group_name_1: {(e.type, e.state_key): e.event_id for e in old_state_1},
|
||||||
group_name_2: old_state_2,
|
group_name_2: {(e.type, e.state_key): e.event_id for e in old_state_2},
|
||||||
}
|
}
|
||||||
|
|
||||||
return self.state.compute_event_context(event)
|
return self.state.compute_event_context(event)
|
||||||
|
|
Loading…
Reference in a new issue