Merge branch 'develop' into matthew/gin_work_mem

This commit is contained in:
Richard van der Hoff 2018-02-13 12:16:01 +00:00
commit a9b712e9dc
40 changed files with 1298 additions and 671 deletions

2
.gitignore vendored
View file

@ -46,3 +46,5 @@ static/client/register/register_config.js
env/ env/
*.config *.config
.vscode/

View file

@ -0,0 +1,23 @@
# List all media in a room
This API gets a list of known media in a room.
The API is:
```
GET /_matrix/client/r0/admin/room/<room_id>/media
```
including an `access_token` of a server admin.
It returns a JSON body like the following:
```
{
"local": [
"mxc://localhost/xwvutsrqponmlkjihgfedcba",
"mxc://localhost/abcdefghijklmnopqrstuvwx"
],
"remote": [
"mxc://matrix.org/xwvutsrqponmlkjihgfedcba",
"mxc://matrix.org/abcdefghijklmnopqrstuvwx"
]
}
```

View file

@ -4,8 +4,6 @@ Purge History API
The purge history API allows server admins to purge historic events from their The purge history API allows server admins to purge historic events from their
database, reclaiming disk space. database, reclaiming disk space.
**NB!** This will not delete local events (locally sent messages content etc) from the database, but will remove lots of the metadata about them and does dramatically reduce the on disk space usage
Depending on the amount of history being purged a call to the API may take Depending on the amount of history being purged a call to the API may take
several minutes or longer. During this period users will not be able to several minutes or longer. During this period users will not be able to
paginate further back in the room from the point being purged from. paginate further back in the room from the point being purged from.
@ -15,3 +13,15 @@ The API is simply:
``POST /_matrix/client/r0/admin/purge_history/<room_id>/<event_id>`` ``POST /_matrix/client/r0/admin/purge_history/<room_id>/<event_id>``
including an ``access_token`` of a server admin. including an ``access_token`` of a server admin.
By default, events sent by local users are not deleted, as they may represent
the only copies of this content in existence. (Events sent by remote users are
deleted, and room state data before the cutoff is always removed).
To delete local events as well, set ``delete_local_events`` in the body:
.. code:: json
{
"delete_local_events": True,
}

View file

@ -25,7 +25,9 @@ class EventContext(object):
The current state map excluding the current event. The current state map excluding the current event.
(type, state_key) -> event_id (type, state_key) -> event_id
state_group (int): state group id state_group (int|None): state group id, if the state has been stored
as a state group. This is usually only None if e.g. the event is
an outlier.
rejected (bool|str): A rejection reason if the event was rejected, else rejected (bool|str): A rejection reason if the event was rejected, else
False False

View file

@ -34,6 +34,7 @@ class DirectoryHandler(BaseHandler):
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.appservice_handler = hs.get_application_service_handler() self.appservice_handler = hs.get_application_service_handler()
self.event_creation_handler = hs.get_event_creation_handler()
self.federation = hs.get_replication_layer() self.federation = hs.get_replication_layer()
self.federation.register_query_handler( self.federation.register_query_handler(
@ -249,8 +250,7 @@ class DirectoryHandler(BaseHandler):
def send_room_alias_update_event(self, requester, user_id, room_id): def send_room_alias_update_event(self, requester, user_id, room_id):
aliases = yield self.store.get_aliases_for_room(room_id) aliases = yield self.store.get_aliases_for_room(room_id)
msg_handler = self.hs.get_handlers().message_handler yield self.event_creation_handler.create_and_send_nonmember_event(
yield msg_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.Aliases, "type": EventTypes.Aliases,
@ -272,8 +272,7 @@ class DirectoryHandler(BaseHandler):
if not alias_event or alias_event.content.get("alias", "") != alias_str: if not alias_event or alias_event.content.get("alias", "") != alias_str:
return return
msg_handler = self.hs.get_handlers().message_handler yield self.event_creation_handler.create_and_send_nonmember_event(
yield msg_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.CanonicalAlias, "type": EventTypes.CanonicalAlias,

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -75,6 +76,7 @@ class FederationHandler(BaseHandler):
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self.pusher_pool = hs.get_pusherpool() self.pusher_pool = hs.get_pusherpool()
self.spam_checker = hs.get_spam_checker() self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
self.replication_layer.set_handler(self) self.replication_layer.set_handler(self)
@ -808,13 +810,12 @@ class FederationHandler(BaseHandler):
event_ids = list(extremities.keys()) event_ids = list(extremities.keys())
logger.debug("calling resolve_state_groups in _maybe_backfill") logger.debug("calling resolve_state_groups in _maybe_backfill")
resolve = logcontext.preserve_fn(
self.state_handler.resolve_state_groups_for_events
)
states = yield logcontext.make_deferred_yieldable(defer.gatherResults( states = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[ [resolve(room_id, [e]) for e in event_ids],
logcontext.preserve_fn(self.state_handler.resolve_state_groups)( consumeErrors=True,
room_id, [e]
)
for e in event_ids
], consumeErrors=True,
)) ))
states = dict(zip(event_ids, [s.state for s in states])) states = dict(zip(event_ids, [s.state for s in states]))
@ -1008,8 +1009,7 @@ class FederationHandler(BaseHandler):
}) })
try: try:
message_handler = self.hs.get_handlers().message_handler event, context = yield self.event_creation_handler.create_new_client_event(
event, context = yield message_handler._create_new_client_event(
builder=builder, builder=builder,
) )
except AuthError as e: except AuthError as e:
@ -1249,8 +1249,7 @@ class FederationHandler(BaseHandler):
"state_key": user_id, "state_key": user_id,
}) })
message_handler = self.hs.get_handlers().message_handler event, context = yield self.event_creation_handler.create_new_client_event(
event, context = yield message_handler._create_new_client_event(
builder=builder, builder=builder,
) )
@ -1832,8 +1831,8 @@ class FederationHandler(BaseHandler):
current_state = set(e.event_id for e in auth_events.values()) current_state = set(e.event_id for e in auth_events.values())
different_auth = event_auth_events - current_state different_auth = event_auth_events - current_state
self._update_context_for_auth_events( yield self._update_context_for_auth_events(
context, auth_events, event_key, event, context, auth_events, event_key,
) )
if different_auth and not event.internal_metadata.is_outlier(): if different_auth and not event.internal_metadata.is_outlier():
@ -1914,8 +1913,8 @@ class FederationHandler(BaseHandler):
# 4. Look at rejects and their proofs. # 4. Look at rejects and their proofs.
# TODO. # TODO.
self._update_context_for_auth_events( yield self._update_context_for_auth_events(
context, auth_events, event_key, event, context, auth_events, event_key,
) )
try: try:
@ -1924,11 +1923,15 @@ class FederationHandler(BaseHandler):
logger.warn("Failed auth resolution for %r because %s", event, e) logger.warn("Failed auth resolution for %r because %s", event, e)
raise e raise e
def _update_context_for_auth_events(self, context, auth_events, @defer.inlineCallbacks
def _update_context_for_auth_events(self, event, context, auth_events,
event_key): event_key):
"""Update the state_ids in an event context after auth event resolution """Update the state_ids in an event context after auth event resolution,
storing the changes as a new state group.
Args: Args:
event (Event): The event we're handling the context for
context (synapse.events.snapshot.EventContext): event context context (synapse.events.snapshot.EventContext): event context
to be updated to be updated
@ -1951,7 +1954,13 @@ class FederationHandler(BaseHandler):
context.prev_state_ids.update({ context.prev_state_ids.update({
k: a.event_id for k, a in auth_events.iteritems() k: a.event_id for k, a in auth_events.iteritems()
}) })
context.state_group = self.store.get_next_state_group() context.state_group = yield self.store.store_state_group(
event.event_id,
event.room_id,
prev_group=context.prev_group,
delta_ids=context.delta_ids,
current_state_ids=context.current_state_ids,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def construct_auth_difference(self, local_auth, remote_auth): def construct_auth_difference(self, local_auth, remote_auth):
@ -2121,8 +2130,7 @@ class FederationHandler(BaseHandler):
if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)): if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
builder = self.event_builder_factory.new(event_dict) builder = self.event_builder_factory.new(event_dict)
EventValidator().validate_new(builder) EventValidator().validate_new(builder)
message_handler = self.hs.get_handlers().message_handler event, context = yield self.event_creation_handler.create_new_client_event(
event, context = yield message_handler._create_new_client_event(
builder=builder builder=builder
) )
@ -2160,8 +2168,7 @@ class FederationHandler(BaseHandler):
""" """
builder = self.event_builder_factory.new(event_dict) builder = self.event_builder_factory.new(event_dict)
message_handler = self.hs.get_handlers().message_handler event, context = yield self.event_creation_handler.create_new_client_event(
event, context = yield message_handler._create_new_client_event(
builder=builder, builder=builder,
) )
@ -2211,8 +2218,9 @@ class FederationHandler(BaseHandler):
builder = self.event_builder_factory.new(event_dict) builder = self.event_builder_factory.new(event_dict)
EventValidator().validate_new(builder) EventValidator().validate_new(builder)
message_handler = self.hs.get_handlers().message_handler event, context = yield self.event_creation_handler.create_new_client_event(
event, context = yield message_handler._create_new_client_event(builder=builder) builder=builder,
)
defer.returnValue((event, context)) defer.returnValue((event, context))
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd # Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2017 New Vector Ltd # Copyright 2017 - 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -47,23 +47,11 @@ class MessageHandler(BaseHandler):
self.hs = hs self.hs = hs
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.validator = EventValidator()
self.profile_handler = hs.get_profile_handler()
self.pagination_lock = ReadWriteLock() self.pagination_lock = ReadWriteLock()
self.pusher_pool = hs.get_pusherpool()
# We arbitrarily limit concurrent event creation for a room to 5.
# This is to stop us from diverging history *too* much.
self.limiter = Limiter(max_count=5)
self.action_generator = hs.get_action_generator()
self.spam_checker = hs.get_spam_checker()
@defer.inlineCallbacks @defer.inlineCallbacks
def purge_history(self, room_id, event_id): def purge_history(self, room_id, event_id, delete_local_events=False):
event = yield self.store.get_event(event_id) event = yield self.store.get_event(event_id)
if event.room_id != room_id: if event.room_id != room_id:
@ -72,7 +60,7 @@ class MessageHandler(BaseHandler):
depth = event.depth depth = event.depth
with (yield self.pagination_lock.write(room_id)): with (yield self.pagination_lock.write(room_id)):
yield self.store.delete_old_state(room_id, depth) yield self.store.purge_history(room_id, depth, delete_local_events)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_messages(self, requester, room_id=None, pagin_config=None, def get_messages(self, requester, room_id=None, pagin_config=None,
@ -182,166 +170,6 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk) defer.returnValue(chunk)
@defer.inlineCallbacks
def create_event(self, requester, event_dict, token_id=None, txn_id=None,
prev_event_ids=None):
"""
Given a dict from a client, create a new event.
Creates an FrozenEvent object, filling out auth_events, prev_events,
etc.
Adds display names to Join membership events.
Args:
requester
event_dict (dict): An entire event
token_id (str)
txn_id (str)
prev_event_ids (list): The prev event ids to use when creating the event
Returns:
Tuple of created event (FrozenEvent), Context
"""
builder = self.event_builder_factory.new(event_dict)
with (yield self.limiter.queue(builder.room_id)):
self.validator.validate_new(builder)
if builder.type == EventTypes.Member:
membership = builder.content.get("membership", None)
target = UserID.from_string(builder.state_key)
if membership in {Membership.JOIN, Membership.INVITE}:
# If event doesn't include a display name, add one.
profile = self.profile_handler
content = builder.content
try:
if "displayname" not in content:
content["displayname"] = yield profile.get_displayname(target)
if "avatar_url" not in content:
content["avatar_url"] = yield profile.get_avatar_url(target)
except Exception as e:
logger.info(
"Failed to get profile information for %r: %s",
target, e
)
if token_id is not None:
builder.internal_metadata.token_id = token_id
if txn_id is not None:
builder.internal_metadata.txn_id = txn_id
event, context = yield self._create_new_client_event(
builder=builder,
requester=requester,
prev_event_ids=prev_event_ids,
)
defer.returnValue((event, context))
@defer.inlineCallbacks
def send_nonmember_event(self, requester, event, context, ratelimit=True):
"""
Persists and notifies local clients and federation of an event.
Args:
event (FrozenEvent) the event to send.
context (Context) the context of the event.
ratelimit (bool): Whether to rate limit this send.
is_guest (bool): Whether the sender is a guest.
"""
if event.type == EventTypes.Member:
raise SynapseError(
500,
"Tried to send member event through non-member codepath"
)
# We check here if we are currently being rate limited, so that we
# don't do unnecessary work. We check again just before we actually
# send the event.
yield self.ratelimit(requester, update=False)
user = UserID.from_string(event.sender)
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if event.is_state():
prev_state = yield self.deduplicate_state_event(event, context)
if prev_state is not None:
defer.returnValue(prev_state)
yield self.handle_new_client_event(
requester=requester,
event=event,
context=context,
ratelimit=ratelimit,
)
if event.type == EventTypes.Message:
presence = self.hs.get_presence_handler()
# We don't want to block sending messages on any presence code. This
# matters as sometimes presence code can take a while.
preserve_fn(presence.bump_presence_active_time)(user)
@defer.inlineCallbacks
def deduplicate_state_event(self, event, context):
"""
Checks whether event is in the latest resolved state in context.
If so, returns the version of the event in context.
Otherwise, returns None.
"""
prev_event_id = context.prev_state_ids.get((event.type, event.state_key))
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
if not prev_event:
return
if prev_event and event.user_id == prev_event.user_id:
prev_content = encode_canonical_json(prev_event.content)
next_content = encode_canonical_json(event.content)
if prev_content == next_content:
defer.returnValue(prev_event)
return
@defer.inlineCallbacks
def create_and_send_nonmember_event(
self,
requester,
event_dict,
ratelimit=True,
txn_id=None
):
"""
Creates an event, then sends it.
See self.create_event and self.send_nonmember_event.
"""
event, context = yield self.create_event(
requester,
event_dict,
token_id=requester.access_token_id,
txn_id=txn_id
)
spam_error = self.spam_checker.check_event_for_spam(event)
if spam_error:
if not isinstance(spam_error, basestring):
spam_error = "Spam is not permitted here"
raise SynapseError(
403, spam_error, Codes.FORBIDDEN
)
yield self.send_nonmember_event(
requester,
event,
context,
ratelimit=ratelimit,
)
defer.returnValue(event)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_data(self, user_id=None, room_id=None, def get_room_data(self, user_id=None, room_id=None,
event_type=None, state_key="", is_guest=False): event_type=None, state_key="", is_guest=False):
@ -470,9 +298,192 @@ class MessageHandler(BaseHandler):
for user_id, profile in users_with_profile.iteritems() for user_id, profile in users_with_profile.iteritems()
}) })
@measure_func("_create_new_client_event")
class EventCreationHandler(object):
def __init__(self, hs):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
self.validator = EventValidator()
self.profile_handler = hs.get_profile_handler()
self.event_builder_factory = hs.get_event_builder_factory()
self.server_name = hs.hostname
self.ratelimiter = hs.get_ratelimiter()
self.notifier = hs.get_notifier()
# This is only used to get at ratelimit function, and maybe_kick_guest_users
self.base_handler = BaseHandler(hs)
self.pusher_pool = hs.get_pusherpool()
# We arbitrarily limit concurrent event creation for a room to 5.
# This is to stop us from diverging history *too* much.
self.limiter = Limiter(max_count=5)
self.action_generator = hs.get_action_generator()
self.spam_checker = hs.get_spam_checker()
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_new_client_event(self, builder, requester=None, prev_event_ids=None): def create_event(self, requester, event_dict, token_id=None, txn_id=None,
prev_event_ids=None):
"""
Given a dict from a client, create a new event.
Creates an FrozenEvent object, filling out auth_events, prev_events,
etc.
Adds display names to Join membership events.
Args:
requester
event_dict (dict): An entire event
token_id (str)
txn_id (str)
prev_event_ids (list): The prev event ids to use when creating the event
Returns:
Tuple of created event (FrozenEvent), Context
"""
builder = self.event_builder_factory.new(event_dict)
with (yield self.limiter.queue(builder.room_id)):
self.validator.validate_new(builder)
if builder.type == EventTypes.Member:
membership = builder.content.get("membership", None)
target = UserID.from_string(builder.state_key)
if membership in {Membership.JOIN, Membership.INVITE}:
# If event doesn't include a display name, add one.
profile = self.profile_handler
content = builder.content
try:
if "displayname" not in content:
content["displayname"] = yield profile.get_displayname(target)
if "avatar_url" not in content:
content["avatar_url"] = yield profile.get_avatar_url(target)
except Exception as e:
logger.info(
"Failed to get profile information for %r: %s",
target, e
)
if token_id is not None:
builder.internal_metadata.token_id = token_id
if txn_id is not None:
builder.internal_metadata.txn_id = txn_id
event, context = yield self.create_new_client_event(
builder=builder,
requester=requester,
prev_event_ids=prev_event_ids,
)
defer.returnValue((event, context))
@defer.inlineCallbacks
def send_nonmember_event(self, requester, event, context, ratelimit=True):
"""
Persists and notifies local clients and federation of an event.
Args:
event (FrozenEvent) the event to send.
context (Context) the context of the event.
ratelimit (bool): Whether to rate limit this send.
is_guest (bool): Whether the sender is a guest.
"""
if event.type == EventTypes.Member:
raise SynapseError(
500,
"Tried to send member event through non-member codepath"
)
user = UserID.from_string(event.sender)
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if event.is_state():
prev_state = yield self.deduplicate_state_event(event, context)
if prev_state is not None:
defer.returnValue(prev_state)
yield self.handle_new_client_event(
requester=requester,
event=event,
context=context,
ratelimit=ratelimit,
)
if event.type == EventTypes.Message:
presence = self.hs.get_presence_handler()
# We don't want to block sending messages on any presence code. This
# matters as sometimes presence code can take a while.
preserve_fn(presence.bump_presence_active_time)(user)
@defer.inlineCallbacks
def deduplicate_state_event(self, event, context):
"""
Checks whether event is in the latest resolved state in context.
If so, returns the version of the event in context.
Otherwise, returns None.
"""
prev_event_id = context.prev_state_ids.get((event.type, event.state_key))
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
if not prev_event:
return
if prev_event and event.user_id == prev_event.user_id:
prev_content = encode_canonical_json(prev_event.content)
next_content = encode_canonical_json(event.content)
if prev_content == next_content:
defer.returnValue(prev_event)
return
@defer.inlineCallbacks
def create_and_send_nonmember_event(
self,
requester,
event_dict,
ratelimit=True,
txn_id=None
):
"""
Creates an event, then sends it.
See self.create_event and self.send_nonmember_event.
"""
event, context = yield self.create_event(
requester,
event_dict,
token_id=requester.access_token_id,
txn_id=txn_id
)
spam_error = self.spam_checker.check_event_for_spam(event)
if spam_error:
if not isinstance(spam_error, basestring):
spam_error = "Spam is not permitted here"
raise SynapseError(
403, spam_error, Codes.FORBIDDEN
)
yield self.send_nonmember_event(
requester,
event,
context,
ratelimit=ratelimit,
)
defer.returnValue(event)
@measure_func("create_new_client_event")
@defer.inlineCallbacks
def create_new_client_event(self, builder, requester=None, prev_event_ids=None):
if prev_event_ids: if prev_event_ids:
prev_events = yield self.store.add_event_hashes(prev_event_ids) prev_events = yield self.store.add_event_hashes(prev_event_ids)
prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids) prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids)
@ -509,9 +520,7 @@ class MessageHandler(BaseHandler):
builder.prev_events = prev_events builder.prev_events = prev_events
builder.depth = depth builder.depth = depth
state_handler = self.state_handler context = yield self.state.compute_event_context(builder)
context = yield state_handler.compute_event_context(builder)
if requester: if requester:
context.app_service = requester.app_service context.app_service = requester.app_service
@ -551,7 +560,7 @@ class MessageHandler(BaseHandler):
# We now need to go and hit out to wherever we need to hit out to. # We now need to go and hit out to wherever we need to hit out to.
if ratelimit: if ratelimit:
yield self.ratelimit(requester) yield self.base_handler.ratelimit(requester)
try: try:
yield self.auth.check_from_context(event, context) yield self.auth.check_from_context(event, context)
@ -567,7 +576,7 @@ class MessageHandler(BaseHandler):
logger.exception("Failed to encode content: %r", event.content) logger.exception("Failed to encode content: %r", event.content)
raise raise
yield self.maybe_kick_guest_users(event, context) yield self.base_handler.maybe_kick_guest_users(event, context)
if event.type == EventTypes.CanonicalAlias: if event.type == EventTypes.CanonicalAlias:
# Check the alias is acually valid (at this time at least) # Check the alias is acually valid (at this time at least)

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd # Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -64,6 +65,7 @@ class RoomCreationHandler(BaseHandler):
super(RoomCreationHandler, self).__init__(hs) super(RoomCreationHandler, self).__init__(hs)
self.spam_checker = hs.get_spam_checker() self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def create_room(self, requester, config, ratelimit=True): def create_room(self, requester, config, ratelimit=True):
@ -163,13 +165,11 @@ class RoomCreationHandler(BaseHandler):
creation_content = config.get("creation_content", {}) creation_content = config.get("creation_content", {})
msg_handler = self.hs.get_handlers().message_handler
room_member_handler = self.hs.get_handlers().room_member_handler room_member_handler = self.hs.get_handlers().room_member_handler
yield self._send_events_for_new_room( yield self._send_events_for_new_room(
requester, requester,
room_id, room_id,
msg_handler,
room_member_handler, room_member_handler,
preset_config=preset_config, preset_config=preset_config,
invite_list=invite_list, invite_list=invite_list,
@ -181,7 +181,7 @@ class RoomCreationHandler(BaseHandler):
if "name" in config: if "name" in config:
name = config["name"] name = config["name"]
yield msg_handler.create_and_send_nonmember_event( yield self.event_creation_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.Name, "type": EventTypes.Name,
@ -194,7 +194,7 @@ class RoomCreationHandler(BaseHandler):
if "topic" in config: if "topic" in config:
topic = config["topic"] topic = config["topic"]
yield msg_handler.create_and_send_nonmember_event( yield self.event_creation_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.Topic, "type": EventTypes.Topic,
@ -249,7 +249,6 @@ class RoomCreationHandler(BaseHandler):
self, self,
creator, # A Requester object. creator, # A Requester object.
room_id, room_id,
msg_handler,
room_member_handler, room_member_handler,
preset_config, preset_config,
invite_list, invite_list,
@ -272,7 +271,7 @@ class RoomCreationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def send(etype, content, **kwargs): def send(etype, content, **kwargs):
event = create(etype, content, **kwargs) event = create(etype, content, **kwargs)
yield msg_handler.create_and_send_nonmember_event( yield self.event_creation_handler.create_and_send_nonmember_event(
creator, creator,
event, event,
ratelimit=False ratelimit=False

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd # Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -46,6 +47,7 @@ class RoomMemberHandler(BaseHandler):
super(RoomMemberHandler, self).__init__(hs) super(RoomMemberHandler, self).__init__(hs)
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
self.event_creation_hander = hs.get_event_creation_handler()
self.member_linearizer = Linearizer(name="member") self.member_linearizer = Linearizer(name="member")
@ -66,13 +68,12 @@ class RoomMemberHandler(BaseHandler):
): ):
if content is None: if content is None:
content = {} content = {}
msg_handler = self.hs.get_handlers().message_handler
content["membership"] = membership content["membership"] = membership
if requester.is_guest: if requester.is_guest:
content["kind"] = "guest" content["kind"] = "guest"
event, context = yield msg_handler.create_event( event, context = yield self.event_creation_hander.create_event(
requester, requester,
{ {
"type": EventTypes.Member, "type": EventTypes.Member,
@ -90,12 +91,14 @@ class RoomMemberHandler(BaseHandler):
) )
# Check if this event matches the previous membership event for the user. # Check if this event matches the previous membership event for the user.
duplicate = yield msg_handler.deduplicate_state_event(event, context) duplicate = yield self.event_creation_hander.deduplicate_state_event(
event, context,
)
if duplicate is not None: if duplicate is not None:
# Discard the new event since this membership change is a no-op. # Discard the new event since this membership change is a no-op.
defer.returnValue(duplicate) defer.returnValue(duplicate)
yield msg_handler.handle_new_client_event( yield self.event_creation_hander.handle_new_client_event(
requester, requester,
event, event,
context, context,
@ -394,8 +397,9 @@ class RoomMemberHandler(BaseHandler):
else: else:
requester = synapse.types.create_requester(target_user) requester = synapse.types.create_requester(target_user)
message_handler = self.hs.get_handlers().message_handler prev_event = yield self.event_creation_hander.deduplicate_state_event(
prev_event = yield message_handler.deduplicate_state_event(event, context) event, context,
)
if prev_event is not None: if prev_event is not None:
return return
@ -412,7 +416,7 @@ class RoomMemberHandler(BaseHandler):
if is_blocked: if is_blocked:
raise SynapseError(403, "This room has been blocked on this server") raise SynapseError(403, "This room has been blocked on this server")
yield message_handler.handle_new_client_event( yield self.event_creation_hander.handle_new_client_event(
requester, requester,
event, event,
context, context,
@ -644,8 +648,7 @@ class RoomMemberHandler(BaseHandler):
) )
) )
msg_handler = self.hs.get_handlers().message_handler yield self.event_creation_hander.create_and_send_nonmember_event(
yield msg_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.ThirdPartyInvite, "type": EventTypes.ThirdPartyInvite,

View file

@ -148,11 +148,13 @@ def parse_string_from_args(args, name, default=None, required=False,
return default return default
def parse_json_value_from_request(request): def parse_json_value_from_request(request, allow_empty_body=False):
"""Parse a JSON value from the body of a twisted HTTP request. """Parse a JSON value from the body of a twisted HTTP request.
Args: Args:
request: the twisted HTTP request. request: the twisted HTTP request.
allow_empty_body (bool): if True, an empty body will be accepted and
turned into None
Returns: Returns:
The JSON value. The JSON value.
@ -165,6 +167,9 @@ def parse_json_value_from_request(request):
except Exception: except Exception:
raise SynapseError(400, "Error reading JSON content.") raise SynapseError(400, "Error reading JSON content.")
if not content_bytes and allow_empty_body:
return None
try: try:
content = simplejson.loads(content_bytes) content = simplejson.loads(content_bytes)
except Exception as e: except Exception as e:
@ -174,17 +179,24 @@ def parse_json_value_from_request(request):
return content return content
def parse_json_object_from_request(request): def parse_json_object_from_request(request, allow_empty_body=False):
"""Parse a JSON object from the body of a twisted HTTP request. """Parse a JSON object from the body of a twisted HTTP request.
Args: Args:
request: the twisted HTTP request. request: the twisted HTTP request.
allow_empty_body (bool): if True, an empty body will be accepted and
turned into an empty dict.
Raises: Raises:
SynapseError if the request body couldn't be decoded as JSON or SynapseError if the request body couldn't be decoded as JSON or
if it wasn't a JSON object. if it wasn't a JSON object.
""" """
content = parse_json_value_from_request(request) content = parse_json_value_from_request(
request, allow_empty_body=allow_empty_body,
)
if allow_empty_body and content is None:
return {}
if type(content) != dict: if type(content) != dict:
message = "Content must be a JSON object." message = "Content must be a JSON object."

View file

@ -193,7 +193,9 @@ class DistributionMetric(object):
class CacheMetric(object): class CacheMetric(object):
__slots__ = ("name", "cache_name", "hits", "misses", "size_callback") __slots__ = (
"name", "cache_name", "hits", "misses", "evicted_size", "size_callback",
)
def __init__(self, name, size_callback, cache_name): def __init__(self, name, size_callback, cache_name):
self.name = name self.name = name
@ -201,6 +203,7 @@ class CacheMetric(object):
self.hits = 0 self.hits = 0
self.misses = 0 self.misses = 0
self.evicted_size = 0
self.size_callback = size_callback self.size_callback = size_callback
@ -210,6 +213,9 @@ class CacheMetric(object):
def inc_misses(self): def inc_misses(self):
self.misses += 1 self.misses += 1
def inc_evictions(self, size=1):
self.evicted_size += size
def render(self): def render(self):
size = self.size_callback() size = self.size_callback()
hits = self.hits hits = self.hits
@ -219,6 +225,9 @@ class CacheMetric(object):
"""%s:hits{name="%s"} %d""" % (self.name, self.cache_name, hits), """%s:hits{name="%s"} %d""" % (self.name, self.cache_name, hits),
"""%s:total{name="%s"} %d""" % (self.name, self.cache_name, total), """%s:total{name="%s"} %d""" % (self.name, self.cache_name, total),
"""%s:size{name="%s"} %d""" % (self.name, self.cache_name, size), """%s:size{name="%s"} %d""" % (self.name, self.cache_name, size),
"""%s:evicted_size{name="%s"} %d""" % (
self.name, self.cache_name, self.evicted_size
),
] ]

View file

@ -19,7 +19,7 @@ from synapse.storage import DataStore
from synapse.storage.event_federation import EventFederationStore from synapse.storage.event_federation import EventFederationStore
from synapse.storage.event_push_actions import EventPushActionsStore from synapse.storage.event_push_actions import EventPushActionsStore
from synapse.storage.roommember import RoomMemberStore from synapse.storage.roommember import RoomMemberStore
from synapse.storage.state import StateGroupReadStore from synapse.storage.state import StateGroupWorkerStore
from synapse.storage.stream import StreamStore from synapse.storage.stream import StreamStore
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
# the method descriptor on the DataStore and chuck them into our class. # the method descriptor on the DataStore and chuck them into our class.
class SlavedEventStore(StateGroupReadStore, BaseSlavedStore): class SlavedEventStore(StateGroupWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(SlavedEventStore, self).__init__(db_conn, hs) super(SlavedEventStore, self).__init__(db_conn, hs)

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -128,7 +129,16 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
if not is_admin: if not is_admin:
raise AuthError(403, "You are not a server admin") raise AuthError(403, "You are not a server admin")
yield self.handlers.message_handler.purge_history(room_id, event_id) body = parse_json_object_from_request(request, allow_empty_body=True)
delete_local_events = bool(
body.get("delete_local_history", False)
)
yield self.handlers.message_handler.purge_history(
room_id, event_id,
delete_local_events=delete_local_events,
)
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -171,6 +181,7 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.event_creation_handler = hs.get_event_creation_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id): def on_POST(self, request, room_id):
@ -203,8 +214,7 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
) )
new_room_id = info["room_id"] new_room_id = info["room_id"]
msg_handler = self.handlers.message_handler yield self.event_creation_handler.create_and_send_nonmember_event(
yield msg_handler.create_and_send_nonmember_event(
room_creator_requester, room_creator_requester,
{ {
"type": "m.room.message", "type": "m.room.message",
@ -289,6 +299,27 @@ class QuarantineMediaInRoom(ClientV1RestServlet):
defer.returnValue((200, {"num_quarantined": num_quarantined})) defer.returnValue((200, {"num_quarantined": num_quarantined}))
class ListMediaInRoom(ClientV1RestServlet):
"""Lists all of the media in a given room.
"""
PATTERNS = client_path_patterns("/admin/room/(?P<room_id>[^/]+)/media")
def __init__(self, hs):
super(ListMediaInRoom, self).__init__(hs)
self.store = hs.get_datastore()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(requester.user)
if not is_admin:
raise AuthError(403, "You are not a server admin")
local_mxcs, remote_mxcs = yield self.store.get_media_mxcs_in_room(room_id)
defer.returnValue((200, {"local": local_mxcs, "remote": remote_mxcs}))
class ResetPasswordRestServlet(ClientV1RestServlet): class ResetPasswordRestServlet(ClientV1RestServlet):
"""Post request to allow an administrator reset password for a user. """Post request to allow an administrator reset password for a user.
This needs user to have administrator access in Synapse. This needs user to have administrator access in Synapse.
@ -487,3 +518,4 @@ def register_servlets(hs, http_server):
SearchUsersRestServlet(hs).register(http_server) SearchUsersRestServlet(hs).register(http_server)
ShutdownRoomRestServlet(hs).register(http_server) ShutdownRoomRestServlet(hs).register(http_server)
QuarantineMediaInRoom(hs).register(http_server) QuarantineMediaInRoom(hs).register(http_server)
ListMediaInRoom(hs).register(http_server)

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -82,6 +83,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomStateEventRestServlet, self).__init__(hs) super(RoomStateEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.event_creation_hander = hs.get_event_creation_handler()
def register(self, http_server): def register(self, http_server):
# /room/$roomid/state/$eventtype # /room/$roomid/state/$eventtype
@ -162,15 +164,16 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
content=content, content=content,
) )
else: else:
msg_handler = self.handlers.message_handler event, context = yield self.event_creation_hander.create_event(
event, context = yield msg_handler.create_event(
requester, requester,
event_dict, event_dict,
token_id=requester.access_token_id, token_id=requester.access_token_id,
txn_id=txn_id, txn_id=txn_id,
) )
yield msg_handler.send_nonmember_event(requester, event, context) yield self.event_creation_hander.send_nonmember_event(
requester, event, context,
)
ret = {} ret = {}
if event: if event:
@ -184,6 +187,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomSendEventRestServlet, self).__init__(hs) super(RoomSendEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.event_creation_hander = hs.get_event_creation_handler()
def register(self, http_server): def register(self, http_server):
# /rooms/$roomid/send/$event_type[/$txn_id] # /rooms/$roomid/send/$event_type[/$txn_id]
@ -205,8 +209,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
if 'ts' in request.args and requester.app_service: if 'ts' in request.args and requester.app_service:
event_dict['origin_server_ts'] = parse_integer(request, "ts", 0) event_dict['origin_server_ts'] = parse_integer(request, "ts", 0)
msg_handler = self.handlers.message_handler event = yield self.event_creation_hander.create_and_send_nonmember_event(
event = yield msg_handler.create_and_send_nonmember_event(
requester, requester,
event_dict, event_dict,
txn_id=txn_id, txn_id=txn_id,
@ -670,6 +673,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomRedactEventRestServlet, self).__init__(hs) super(RoomRedactEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.event_creation_handler = hs.get_event_creation_handler()
def register(self, http_server): def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)") PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
@ -680,8 +684,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
msg_handler = self.handlers.message_handler event = yield self.event_creation_handler.create_and_send_nonmember_event(
event = yield msg_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.Redaction, "type": EventTypes.Redaction,

View file

@ -472,8 +472,10 @@ class MediaRepository(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def generate_local_exact_thumbnail(self, media_id, t_width, t_height, def generate_local_exact_thumbnail(self, media_id, t_width, t_height,
t_method, t_type): t_method, t_type, url_cache):
input_path = self.filepaths.local_media_filepath(media_id) input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
None, media_id, url_cache=url_cache,
))
thumbnailer = Thumbnailer(input_path) thumbnailer = Thumbnailer(input_path)
t_byte_source = yield make_deferred_yieldable(threads.deferToThread( t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
@ -486,6 +488,7 @@ class MediaRepository(object):
file_info = FileInfo( file_info = FileInfo(
server_name=None, server_name=None,
file_id=media_id, file_id=media_id,
url_cache=url_cache,
thumbnail=True, thumbnail=True,
thumbnail_width=t_width, thumbnail_width=t_width,
thumbnail_height=t_height, thumbnail_height=t_height,
@ -512,7 +515,9 @@ class MediaRepository(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def generate_remote_exact_thumbnail(self, server_name, file_id, media_id, def generate_remote_exact_thumbnail(self, server_name, file_id, media_id,
t_width, t_height, t_method, t_type): t_width, t_height, t_method, t_type):
input_path = self.filepaths.remote_media_filepath(server_name, file_id) input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
server_name, file_id, url_cache=False,
))
thumbnailer = Thumbnailer(input_path) thumbnailer = Thumbnailer(input_path)
t_byte_source = yield make_deferred_yieldable(threads.deferToThread( t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
@ -570,12 +575,9 @@ class MediaRepository(object):
if not requirements: if not requirements:
return return
if server_name: input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
input_path = self.filepaths.remote_media_filepath(server_name, file_id) server_name, file_id, url_cache=url_cache,
elif url_cache: ))
input_path = self.filepaths.url_cache_filepath(media_id)
else:
input_path = self.filepaths.local_media_filepath(media_id)
thumbnailer = Thumbnailer(input_path) thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width m_width = thumbnailer.width

View file

@ -18,6 +18,7 @@ from twisted.protocols.basic import FileSender
from ._base import Responder from ._base import Responder
from synapse.util.file_consumer import BackgroundFileConsumer
from synapse.util.logcontext import make_deferred_yieldable from synapse.util.logcontext import make_deferred_yieldable
import contextlib import contextlib
@ -26,6 +27,7 @@ import logging
import shutil import shutil
import sys import sys
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -68,6 +70,12 @@ class MediaStorage(object):
_write_file_synchronously, source, fname, _write_file_synchronously, source, fname,
)) ))
# Tell the storage providers about the new file. They'll decide
# if they should upload it and whether to do so synchronously
# or not.
for provider in self.storage_providers:
yield provider.store_file(path, file_info)
defer.returnValue(fname) defer.returnValue(fname)
@contextlib.contextmanager @contextlib.contextmanager
@ -151,6 +159,37 @@ class MediaStorage(object):
defer.returnValue(None) defer.returnValue(None)
@defer.inlineCallbacks
def ensure_media_is_in_local_cache(self, file_info):
"""Ensures that the given file is in the local cache. Attempts to
download it from storage providers if it isn't.
Args:
file_info (FileInfo)
Returns:
Deferred[str]: Full path to local file
"""
path = self._file_info_to_path(file_info)
local_path = os.path.join(self.local_media_directory, path)
if os.path.exists(local_path):
defer.returnValue(local_path)
dirname = os.path.dirname(local_path)
if not os.path.exists(dirname):
os.makedirs(dirname)
for provider in self.storage_providers:
res = yield provider.fetch(path, file_info)
if res:
with res:
consumer = BackgroundFileConsumer(open(local_path, "w"))
yield res.write_to_consumer(consumer)
yield consumer.wait()
defer.returnValue(local_path)
raise Exception("file could not be found")
def _file_info_to_path(self, file_info): def _file_info_to_path(self, file_info):
"""Converts file_info into a relative path. """Converts file_info into a relative path.
@ -228,9 +267,8 @@ class FileResponder(Responder):
def __init__(self, open_file): def __init__(self, open_file):
self.open_file = open_file self.open_file = open_file
@defer.inlineCallbacks
def write_to_consumer(self, consumer): def write_to_consumer(self, consumer):
yield FileSender().beginFileTransfer(self.open_file, consumer) return FileSender().beginFileTransfer(self.open_file, consumer)
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
self.open_file.close() self.open_file.close()

View file

@ -12,6 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import cgi
import datetime
import errno
import fnmatch
import itertools
import logging
import os
import re
import shutil
import sys
import traceback
import ujson as json
import urlparse
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer from twisted.internet import defer
@ -33,18 +46,6 @@ from synapse.http.server import (
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.stringutils import is_ascii from synapse.util.stringutils import is_ascii
import os
import re
import fnmatch
import cgi
import ujson as json
import urlparse
import itertools
import datetime
import errno
import shutil
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -286,17 +287,28 @@ class PreviewUrlResource(Resource):
url_cache=True, url_cache=True,
) )
try: with self.media_storage.store_into_file(file_info) as (f, fname, finish):
with self.media_storage.store_into_file(file_info) as (f, fname, finish): try:
logger.debug("Trying to get url '%s'" % url) logger.debug("Trying to get url '%s'" % url)
length, headers, uri, code = yield self.client.get_file( length, headers, uri, code = yield self.client.get_file(
url, output_stream=f, max_size=self.max_spider_size, url, output_stream=f, max_size=self.max_spider_size,
) )
except Exception as e:
# FIXME: pass through 404s and other error messages nicely # FIXME: pass through 404s and other error messages nicely
logger.warn("Error downloading %s: %r", url, e)
raise SynapseError(
500, "Failed to download content: %s" % (
traceback.format_exception_only(sys.exc_type, e),
),
Codes.UNKNOWN,
)
yield finish()
yield finish() try:
if "Content-Type" in headers:
media_type = headers["Content-Type"][0] media_type = headers["Content-Type"][0]
else:
media_type = "application/octet-stream"
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
content_disposition = headers.get("Content-Disposition", None) content_disposition = headers.get("Content-Disposition", None)
@ -336,10 +348,11 @@ class PreviewUrlResource(Resource):
) )
except Exception as e: except Exception as e:
raise SynapseError( logger.error("Error handling downloaded %s: %r", url, e)
500, ("Failed to download content: %s" % e), # TODO: we really ought to delete the downloaded file in this
Codes.UNKNOWN # case, since we won't have recorded it in the db, and will
) # therefore not expire it.
raise
defer.returnValue({ defer.returnValue({
"media_type": media_type, "media_type": media_type,

View file

@ -164,7 +164,8 @@ class ThumbnailResource(Resource):
# Okay, so we generate one. # Okay, so we generate one.
file_path = yield self.media_repo.generate_local_exact_thumbnail( file_path = yield self.media_repo.generate_local_exact_thumbnail(
media_id, desired_width, desired_height, desired_method, desired_type media_id, desired_width, desired_height, desired_method, desired_type,
url_cache=media_info["url_cache"],
) )
if file_path: if file_path:

View file

@ -55,6 +55,7 @@ from synapse.handlers.read_marker import ReadMarkerHandler
from synapse.handlers.user_directory import UserDirectoryHandler from synapse.handlers.user_directory import UserDirectoryHandler
from synapse.handlers.groups_local import GroupsLocalHandler from synapse.handlers.groups_local import GroupsLocalHandler
from synapse.handlers.profile import ProfileHandler from synapse.handlers.profile import ProfileHandler
from synapse.handlers.message import EventCreationHandler
from synapse.groups.groups_server import GroupsServerHandler from synapse.groups.groups_server import GroupsServerHandler
from synapse.groups.attestations import GroupAttestionRenewer, GroupAttestationSigning from synapse.groups.attestations import GroupAttestionRenewer, GroupAttestationSigning
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
@ -66,7 +67,7 @@ from synapse.rest.media.v1.media_repository import (
MediaRepository, MediaRepository,
MediaRepositoryResource, MediaRepositoryResource,
) )
from synapse.state import StateHandler from synapse.state import StateHandler, StateResolutionHandler
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.streams.events import EventSources from synapse.streams.events import EventSources
from synapse.util import Clock from synapse.util import Clock
@ -102,6 +103,7 @@ class HomeServer(object):
'v1auth', 'v1auth',
'auth', 'auth',
'state_handler', 'state_handler',
'state_resolution_handler',
'presence_handler', 'presence_handler',
'sync_handler', 'sync_handler',
'typing_handler', 'typing_handler',
@ -117,6 +119,7 @@ class HomeServer(object):
'application_service_handler', 'application_service_handler',
'device_message_handler', 'device_message_handler',
'profile_handler', 'profile_handler',
'event_creation_handler',
'deactivate_account_handler', 'deactivate_account_handler',
'set_password_handler', 'set_password_handler',
'notifier', 'notifier',
@ -224,6 +227,9 @@ class HomeServer(object):
def build_state_handler(self): def build_state_handler(self):
return StateHandler(self) return StateHandler(self)
def build_state_resolution_handler(self):
return StateResolutionHandler(self)
def build_presence_handler(self): def build_presence_handler(self):
return PresenceHandler(self) return PresenceHandler(self)
@ -272,6 +278,9 @@ class HomeServer(object):
def build_profile_handler(self): def build_profile_handler(self):
return ProfileHandler(self) return ProfileHandler(self)
def build_event_creation_handler(self):
return EventCreationHandler(self)
def build_deactivate_account_handler(self): def build_deactivate_account_handler(self):
return DeactivateAccountHandler(self) return DeactivateAccountHandler(self)

View file

@ -34,6 +34,9 @@ class HomeServer(object):
def get_state_handler(self) -> synapse.state.StateHandler: def get_state_handler(self) -> synapse.state.StateHandler:
pass pass
def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler:
pass
def get_deactivate_account_handler(self) -> synapse.handlers.deactivate_account.DeactivateAccountHandler: def get_deactivate_account_handler(self) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
pass pass

View file

@ -58,7 +58,11 @@ class _StateCacheEntry(object):
__slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"] __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
def __init__(self, state, state_group, prev_group=None, delta_ids=None): def __init__(self, state, state_group, prev_group=None, delta_ids=None):
# dict[(str, str), str] map from (type, state_key) to event_id
self.state = frozendict(state) self.state = frozendict(state)
# the ID of a state group if one and only one is involved.
# otherwise, None otherwise?
self.state_group = state_group self.state_group = state_group
self.prev_group = prev_group self.prev_group = prev_group
@ -81,31 +85,19 @@ class _StateCacheEntry(object):
class StateHandler(object): class StateHandler(object):
""" Responsible for doing state conflict resolution. """Fetches bits of state from the stores, and does state resolution
where necessary
""" """
def __init__(self, hs): def __init__(self, hs):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.hs = hs self.hs = hs
self._state_resolution_handler = hs.get_state_resolution_handler()
# dict of set of event_ids -> _StateCacheEntry.
self._state_cache = None
self.resolve_linearizer = Linearizer(name="state_resolve_lock")
def start_caching(self): def start_caching(self):
logger.debug("start_caching") # TODO: remove this shim
self._state_resolution_handler.start_caching()
self._state_cache = ExpiringCache(
cache_name="state_cache",
clock=self.clock,
max_len=SIZE_OF_CACHE,
expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
iterable=True,
reset_expiry_on_get=True,
)
self._state_cache.start()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key="", def get_current_state(self, room_id, event_type=None, state_key="",
@ -127,7 +119,7 @@ class StateHandler(object):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_state") logger.debug("calling resolve_state_groups from get_current_state")
ret = yield self.resolve_state_groups(room_id, latest_event_ids) ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state state = ret.state
if event_type: if event_type:
@ -164,7 +156,7 @@ class StateHandler(object):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_state_ids") logger.debug("calling resolve_state_groups from get_current_state_ids")
ret = yield self.resolve_state_groups(room_id, latest_event_ids) ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state state = ret.state
defer.returnValue(state) defer.returnValue(state)
@ -174,7 +166,7 @@ class StateHandler(object):
if not latest_event_ids: if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_user_in_room") logger.debug("calling resolve_state_groups from get_current_user_in_room")
entry = yield self.resolve_state_groups(room_id, latest_event_ids) entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
joined_users = yield self.store.get_joined_users_from_state(room_id, entry) joined_users = yield self.store.get_joined_users_from_state(room_id, entry)
defer.returnValue(joined_users) defer.returnValue(joined_users)
@ -183,7 +175,7 @@ class StateHandler(object):
if not latest_event_ids: if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_hosts_in_room") logger.debug("calling resolve_state_groups from get_current_hosts_in_room")
entry = yield self.resolve_state_groups(room_id, latest_event_ids) entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
joined_hosts = yield self.store.get_joined_hosts(room_id, entry) joined_hosts = yield self.store.get_joined_hosts(room_id, entry)
defer.returnValue(joined_hosts) defer.returnValue(joined_hosts)
@ -191,8 +183,15 @@ class StateHandler(object):
def compute_event_context(self, event, old_state=None): def compute_event_context(self, event, old_state=None):
"""Build an EventContext structure for the event. """Build an EventContext structure for the event.
This works out what the current state should be for the event, and
generates a new state group if necessary.
Args: Args:
event (synapse.events.EventBase): event (synapse.events.EventBase):
old_state (dict|None): The state at the event if it can't be
calculated from existing events. This is normally only specified
when receiving an event from federation where we don't have the
prev events for, e.g. when backfilling.
Returns: Returns:
synapse.events.snapshot.EventContext: synapse.events.snapshot.EventContext:
""" """
@ -216,15 +215,22 @@ class StateHandler(object):
context.current_state_ids = {} context.current_state_ids = {}
context.prev_state_ids = {} context.prev_state_ids = {}
context.prev_state_events = [] context.prev_state_events = []
context.state_group = self.store.get_next_state_group()
# We don't store state for outliers, so we don't generate a state
# froup for it.
context.state_group = None
defer.returnValue(context) defer.returnValue(context)
if old_state: if old_state:
# We already have the state, so we don't need to calculate it.
# Let's just correctly fill out the context and create a
# new state group for it.
context = EventContext() context = EventContext()
context.prev_state_ids = { context.prev_state_ids = {
(s.type, s.state_key): s.event_id for s in old_state (s.type, s.state_key): s.event_id for s in old_state
} }
context.state_group = self.store.get_next_state_group()
if event.is_state(): if event.is_state():
key = (event.type, event.state_key) key = (event.type, event.state_key)
@ -237,11 +243,19 @@ class StateHandler(object):
else: else:
context.current_state_ids = context.prev_state_ids context.current_state_ids = context.prev_state_ids
context.state_group = yield self.store.store_state_group(
event.event_id,
event.room_id,
prev_group=None,
delta_ids=None,
current_state_ids=context.current_state_ids,
)
context.prev_state_events = [] context.prev_state_events = []
defer.returnValue(context) defer.returnValue(context)
logger.debug("calling resolve_state_groups from compute_event_context") logger.debug("calling resolve_state_groups from compute_event_context")
entry = yield self.resolve_state_groups( entry = yield self.resolve_state_groups_for_events(
event.room_id, [e for e, _ in event.prev_events], event.room_id, [e for e, _ in event.prev_events],
) )
@ -250,7 +264,8 @@ class StateHandler(object):
context = EventContext() context = EventContext()
context.prev_state_ids = curr_state context.prev_state_ids = curr_state
if event.is_state(): if event.is_state():
context.state_group = self.store.get_next_state_group() # If this is a state event then we need to create a new state
# group for the state after this event.
key = (event.type, event.state_key) key = (event.type, event.state_key)
if key in context.prev_state_ids: if key in context.prev_state_ids:
@ -261,38 +276,57 @@ class StateHandler(object):
context.current_state_ids[key] = event.event_id context.current_state_ids[key] = event.event_id
if entry.state_group: if entry.state_group:
# If the state at the event has a state group assigned then
# we can use that as the prev group
context.prev_group = entry.state_group context.prev_group = entry.state_group
context.delta_ids = { context.delta_ids = {
key: event.event_id key: event.event_id
} }
elif entry.prev_group: elif entry.prev_group:
# If the state at the event only has a prev group, then we can
# use that as a prev group too.
context.prev_group = entry.prev_group context.prev_group = entry.prev_group
context.delta_ids = dict(entry.delta_ids) context.delta_ids = dict(entry.delta_ids)
context.delta_ids[key] = event.event_id context.delta_ids[key] = event.event_id
else:
if entry.state_group is None:
entry.state_group = self.store.get_next_state_group()
entry.state_id = entry.state_group
context.state_group = entry.state_group context.state_group = yield self.store.store_state_group(
event.event_id,
event.room_id,
prev_group=context.prev_group,
delta_ids=context.delta_ids,
current_state_ids=context.current_state_ids,
)
else:
context.current_state_ids = context.prev_state_ids context.current_state_ids = context.prev_state_ids
context.prev_group = entry.prev_group context.prev_group = entry.prev_group
context.delta_ids = entry.delta_ids context.delta_ids = entry.delta_ids
if entry.state_group is None:
entry.state_group = yield self.store.store_state_group(
event.event_id,
event.room_id,
prev_group=entry.prev_group,
delta_ids=entry.delta_ids,
current_state_ids=context.current_state_ids,
)
entry.state_id = entry.state_group
context.state_group = entry.state_group
context.prev_state_events = [] context.prev_state_events = []
defer.returnValue(context) defer.returnValue(context)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function def resolve_state_groups_for_events(self, room_id, event_ids):
def resolve_state_groups(self, room_id, event_ids):
""" Given a list of event_ids this method fetches the state at each """ Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them. event, resolves conflicts between them and returns them.
Args:
room_id (str):
event_ids (list[str]):
Returns: Returns:
a Deferred tuple of (`state_group`, `state`, `prev_state`). Deferred[_StateCacheEntry]: resolved state
`state_group` is the name of a state group if one and only one is
involved. `state` is a map from (type, state_key) to event, and
`prev_state` is a list of event ids.
""" """
logger.debug("resolve_state_groups event_ids %s", event_ids) logger.debug("resolve_state_groups event_ids %s", event_ids)
@ -303,13 +337,7 @@ class StateHandler(object):
room_id, event_ids room_id, event_ids
) )
logger.debug( if len(state_groups_ids) == 1:
"resolve_state_groups state_groups %s",
state_groups_ids.keys()
)
group_names = frozenset(state_groups_ids.keys())
if len(group_names) == 1:
name, state_list = state_groups_ids.items().pop() name, state_list = state_groups_ids.items().pop()
prev_group, delta_ids = yield self.store.get_state_group_delta(name) prev_group, delta_ids = yield self.store.get_state_group_delta(name)
@ -321,6 +349,92 @@ class StateHandler(object):
delta_ids=delta_ids, delta_ids=delta_ids,
)) ))
result = yield self._state_resolution_handler.resolve_state_groups(
room_id, state_groups_ids, self._state_map_factory,
)
defer.returnValue(result)
def _state_map_factory(self, ev_ids):
return self.store.get_events(
ev_ids, get_prev_content=False, check_redacted=False,
)
def resolve_events(self, state_sets, event):
logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
)
state_set_ids = [{
(ev.type, ev.state_key): ev.event_id
for ev in st
} for st in state_sets]
state_map = {
ev.event_id: ev
for st in state_sets
for ev in st
}
with Measure(self.clock, "state._resolve_events"):
new_state = resolve_events_with_state_map(state_set_ids, state_map)
new_state = {
key: state_map[ev_id] for key, ev_id in new_state.items()
}
return new_state
class StateResolutionHandler(object):
"""Responsible for doing state conflict resolution.
Note that the storage layer depends on this handler, so all functions must
be storage-independent.
"""
def __init__(self, hs):
self.clock = hs.get_clock()
# dict of set of event_ids -> _StateCacheEntry.
self._state_cache = None
self.resolve_linearizer = Linearizer(name="state_resolve_lock")
def start_caching(self):
logger.debug("start_caching")
self._state_cache = ExpiringCache(
cache_name="state_cache",
clock=self.clock,
max_len=SIZE_OF_CACHE,
expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
iterable=True,
reset_expiry_on_get=True,
)
self._state_cache.start()
@defer.inlineCallbacks
@log_function
def resolve_state_groups(self, room_id, state_groups_ids, state_map_factory):
"""Resolves conflicts between a set of state groups
Always generates a new state group (unless we hit the cache), so should
not be called for a single state group
Args:
room_id (str): room we are resolving for (used for logging)
state_groups_ids (dict[int, dict[(str, str), str]]):
map from state group id to the state in that state group
(where 'state' is a map from state key to event id)
Returns:
Deferred[_StateCacheEntry]: resolved state
"""
logger.debug(
"resolve_state_groups state_groups %s",
state_groups_ids.keys()
)
group_names = frozenset(state_groups_ids.keys())
with (yield self.resolve_linearizer.queue(group_names)): with (yield self.resolve_linearizer.queue(group_names)):
if self._state_cache is not None: if self._state_cache is not None:
cache = self._state_cache.get(group_names, None) cache = self._state_cache.get(group_names, None)
@ -351,15 +465,17 @@ class StateHandler(object):
with Measure(self.clock, "state._resolve_events"): with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_factory( new_state = yield resolve_events_with_factory(
state_groups_ids.values(), state_groups_ids.values(),
state_map_factory=lambda ev_ids: self.store.get_events( state_map_factory=state_map_factory,
ev_ids, get_prev_content=False, check_redacted=False,
),
) )
else: else:
new_state = { new_state = {
key: e_ids.pop() for key, e_ids in state.items() key: e_ids.pop() for key, e_ids in state.items()
} }
# if the new state matches any of the input state groups, we can
# use that state group again. Otherwise we will generate a state_id
# which will be used as a cache key for future resolutions, but
# not get persisted.
state_group = None state_group = None
new_state_event_ids = frozenset(new_state.values()) new_state_event_ids = frozenset(new_state.values())
for sg, events in state_groups_ids.items(): for sg, events in state_groups_ids.items():
@ -396,30 +512,6 @@ class StateHandler(object):
defer.returnValue(cache) defer.returnValue(cache)
def resolve_events(self, state_sets, event):
logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
)
state_set_ids = [{
(ev.type, ev.state_key): ev.event_id
for ev in st
} for st in state_sets]
state_map = {
ev.event_id: ev
for st in state_sets
for ev in st
}
with Measure(self.clock, "state._resolve_events"):
new_state = resolve_events_with_state_map(state_set_ids, state_map)
new_state = {
key: state_map[ev_id] for key, ev_id in new_state.items()
}
return new_state
def _ordered_events(events): def _ordered_events(events):
def key_func(e): def key_func(e):
@ -437,8 +529,8 @@ def resolve_events_with_state_map(state_sets, state_map):
state_sets. state_sets.
Returns Returns
dict[(str, str), synapse.events.FrozenEvent]: dict[(str, str), str]:
a map from (type, state_key) to event. a map from (type, state_key) to event_id.
""" """
if len(state_sets) == 1: if len(state_sets) == 1:
return state_sets[0] return state_sets[0]
@ -460,6 +552,21 @@ def _seperate(state_sets):
"""Takes the state_sets and figures out which keys are conflicted and """Takes the state_sets and figures out which keys are conflicted and
which aren't. i.e., which have multiple different event_ids associated which aren't. i.e., which have multiple different event_ids associated
with them in different state sets. with them in different state sets.
Args:
state_sets(list[dict[(str, str), str]]):
List of dicts of (type, state_key) -> event_id, which are the
different state groups to resolve.
Returns:
(dict[(str, str), str], dict[(str, str), set[str]]):
A tuple of (unconflicted_state, conflicted_state), where:
unconflicted_state is a dict mapping (type, state_key)->event_id
for unconflicted state keys.
conflicted_state is a dict mapping (type, state_key) to a set of
event ids for conflicted state keys.
""" """
unconflicted_state = dict(state_sets[0]) unconflicted_state = dict(state_sets[0])
conflicted_state = {} conflicted_state = {}
@ -500,8 +607,8 @@ def resolve_events_with_factory(state_sets, state_map_factory):
a Deferred of dict of event_id to event. a Deferred of dict of event_id to event.
Returns Returns
Deferred[dict[(str, str), synapse.events.FrozenEvent]]: Deferred[dict[(str, str), str]]:
a map from (type, state_key) to event. a map from (type, state_key) to event_id.
""" """
if len(state_sets) == 1: if len(state_sets) == 1:
defer.returnValue(state_sets[0]) defer.returnValue(state_sets[0])

View file

@ -124,7 +124,6 @@ class DataStore(RoomMemberStore, RoomStore,
) )
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")

View file

@ -62,3 +62,9 @@ class PostgresEngine(object):
def lock_table(self, txn, table): def lock_table(self, txn, table):
txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,)) txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))
def get_next_state_group_id(self, txn):
"""Returns an int that can be used as a new state_group ID
"""
txn.execute("SELECT nextval('state_group_id_seq')")
return txn.fetchone()[0]

View file

@ -16,6 +16,7 @@
from synapse.storage.prepare_database import prepare_database from synapse.storage.prepare_database import prepare_database
import struct import struct
import threading
class Sqlite3Engine(object): class Sqlite3Engine(object):
@ -24,6 +25,11 @@ class Sqlite3Engine(object):
def __init__(self, database_module, database_config): def __init__(self, database_module, database_config):
self.module = database_module self.module = database_module
# The current max state_group, or None if we haven't looked
# in the DB yet.
self._current_state_group_id = None
self._current_state_group_id_lock = threading.Lock()
def check_database(self, txn): def check_database(self, txn):
pass pass
@ -43,6 +49,19 @@ class Sqlite3Engine(object):
def lock_table(self, txn, table): def lock_table(self, txn, table):
return return
def get_next_state_group_id(self, txn):
"""Returns an int that can be used as a new state_group ID
"""
# We do application locking here since if we're using sqlite then
# we are a single process synapse.
with self._current_state_group_id_lock:
if self._current_state_group_id is None:
txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
self._current_state_group_id = txn.fetchone()[0]
self._current_state_group_id += 1
return self._current_state_group_id
# Following functions taken from: https://github.com/coleifer/peewee # Following functions taken from: https://github.com/coleifer/peewee

View file

@ -342,8 +342,20 @@ class EventsStore(SQLBaseStore):
# NB: Assumes that we are only persisting events for one room # NB: Assumes that we are only persisting events for one room
# at a time. # at a time.
# map room_id->list[event_ids] giving the new forward
# extremities in each room
new_forward_extremeties = {} new_forward_extremeties = {}
# map room_id->(type,state_key)->event_id tracking the full
# state in each room after adding these events
current_state_for_room = {} current_state_for_room = {}
# map room_id->(to_delete, to_insert) where each entry is
# a map (type,key)->event_id giving the state delta in each
# room
state_delta_for_room = {}
if not backfilled: if not backfilled:
with Measure(self._clock, "_calculate_state_and_extrem"): with Measure(self._clock, "_calculate_state_and_extrem"):
# Work out the new "current state" for each room. # Work out the new "current state" for each room.
@ -386,11 +398,19 @@ class EventsStore(SQLBaseStore):
if all_single_prev_not_state: if all_single_prev_not_state:
continue continue
state = yield self._calculate_state_delta( logger.info(
room_id, ev_ctx_rm, new_latest_event_ids "Calculating state delta for room %s", room_id,
) )
if state: current_state = yield self._get_new_state_after_events(
current_state_for_room[room_id] = state ev_ctx_rm, new_latest_event_ids,
)
if current_state is not None:
current_state_for_room[room_id] = current_state
delta = yield self._calculate_state_delta(
room_id, current_state,
)
if delta is not None:
state_delta_for_room[room_id] = delta
yield self.runInteraction( yield self.runInteraction(
"persist_events", "persist_events",
@ -398,7 +418,7 @@ class EventsStore(SQLBaseStore):
events_and_contexts=chunk, events_and_contexts=chunk,
backfilled=backfilled, backfilled=backfilled,
delete_existing=delete_existing, delete_existing=delete_existing,
current_state_for_room=current_state_for_room, state_delta_for_room=state_delta_for_room,
new_forward_extremeties=new_forward_extremeties, new_forward_extremeties=new_forward_extremeties,
) )
persist_event_counter.inc_by(len(chunk)) persist_event_counter.inc_by(len(chunk))
@ -415,7 +435,7 @@ class EventsStore(SQLBaseStore):
event_counter.inc(event.type, origin_type, origin_entity) event_counter.inc(event.type, origin_type, origin_entity)
for room_id, (_, _, new_state) in current_state_for_room.iteritems(): for room_id, new_state in current_state_for_room.iteritems():
self.get_current_state_ids.prefill( self.get_current_state_ids.prefill(
(room_id, ), new_state (room_id, ), new_state
) )
@ -467,20 +487,22 @@ class EventsStore(SQLBaseStore):
defer.returnValue(new_latest_event_ids) defer.returnValue(new_latest_event_ids)
@defer.inlineCallbacks @defer.inlineCallbacks
def _calculate_state_delta(self, room_id, events_context, new_latest_event_ids): def _get_new_state_after_events(self, events_context, new_latest_event_ids):
"""Calculate the new state deltas for a room. """Calculate the current state dict after adding some new events to
a room
Assumes that we are only persisting events for one room at a time. Args:
events_context (list[(EventBase, EventContext)]):
events and contexts which are being added to the room
new_latest_event_ids (iterable[str]):
the new forward extremities for the room.
Returns: Returns:
3-tuple (to_delete, to_insert, new_state) where both are state dicts, Deferred[dict[(str,str), str]|None]:
i.e. (type, state_key) -> event_id. `to_delete` are the entries to None if there are no changes to the room state, or
first be deleted from current_state_events, `to_insert` are entries a dict of (type, state_key) -> event_id].
to insert. `new_state` is the full set of state.
May return None if there are no changes to be applied.
""" """
# Now we need to work out the different state sets for
# each state extremities
state_sets = [] state_sets = []
state_groups = set() state_groups = set()
missing_event_ids = [] missing_event_ids = []
@ -523,12 +545,12 @@ class EventsStore(SQLBaseStore):
state_sets.extend(group_to_state.itervalues()) state_sets.extend(group_to_state.itervalues())
if not new_latest_event_ids: if not new_latest_event_ids:
current_state = {} defer.returnValue({})
elif was_updated: elif was_updated:
if len(state_sets) == 1: if len(state_sets) == 1:
# If there is only one state set, then we know what the current # If there is only one state set, then we know what the current
# state is. # state is.
current_state = state_sets[0] defer.returnValue(state_sets[0])
else: else:
# We work out the current state by passing the state sets to the # We work out the current state by passing the state sets to the
# state resolution algorithm. It may ask for some events, including # state resolution algorithm. It may ask for some events, including
@ -537,8 +559,7 @@ class EventsStore(SQLBaseStore):
# up in the db. # up in the db.
logger.info( logger.info(
"Resolving state for %s with %i state sets", "Resolving state with %i state sets", len(state_sets),
room_id, len(state_sets),
) )
events_map = {ev.event_id: ev for ev, _ in events_context} events_map = {ev.event_id: ev for ev, _ in events_context}
@ -567,9 +588,22 @@ class EventsStore(SQLBaseStore):
state_sets, state_sets,
state_map_factory=get_events, state_map_factory=get_events,
) )
defer.returnValue(current_state)
else: else:
return return
@defer.inlineCallbacks
def _calculate_state_delta(self, room_id, current_state):
"""Calculate the new state deltas for a room.
Assumes that we are only persisting events for one room at a time.
Returns:
2-tuple (to_delete, to_insert) where both are state dicts,
i.e. (type, state_key) -> event_id. `to_delete` are the entries to
first be deleted from current_state_events, `to_insert` are entries
to insert.
"""
existing_state = yield self.get_current_state_ids(room_id) existing_state = yield self.get_current_state_ids(room_id)
existing_events = set(existing_state.itervalues()) existing_events = set(existing_state.itervalues())
@ -589,7 +623,7 @@ class EventsStore(SQLBaseStore):
if ev_id in events_to_insert if ev_id in events_to_insert
} }
defer.returnValue((to_delete, to_insert, current_state)) defer.returnValue((to_delete, to_insert))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_event(self, event_id, check_redacted=True, def get_event(self, event_id, check_redacted=True,
@ -649,7 +683,7 @@ class EventsStore(SQLBaseStore):
@log_function @log_function
def _persist_events_txn(self, txn, events_and_contexts, backfilled, def _persist_events_txn(self, txn, events_and_contexts, backfilled,
delete_existing=False, current_state_for_room={}, delete_existing=False, state_delta_for_room={},
new_forward_extremeties={}): new_forward_extremeties={}):
"""Insert some number of room events into the necessary database tables. """Insert some number of room events into the necessary database tables.
@ -665,7 +699,7 @@ class EventsStore(SQLBaseStore):
delete_existing (bool): True to purge existing table rows for the delete_existing (bool): True to purge existing table rows for the
events from the database. This is useful when retrying due to events from the database. This is useful when retrying due to
IntegrityError. IntegrityError.
current_state_for_room (dict[str, (list[str], list[str])]): state_delta_for_room (dict[str, (list[str], list[str])]):
The current-state delta for each room. For each room, a tuple The current-state delta for each room. For each room, a tuple
(to_delete, to_insert), being a list of event ids to be removed (to_delete, to_insert), being a list of event ids to be removed
from the current state, and a list of event ids to be added to from the current state, and a list of event ids to be added to
@ -677,7 +711,7 @@ class EventsStore(SQLBaseStore):
""" """
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
self._update_current_state_txn(txn, current_state_for_room, max_stream_order) self._update_current_state_txn(txn, state_delta_for_room, max_stream_order)
self._update_forward_extremities_txn( self._update_forward_extremities_txn(
txn, txn,
@ -721,9 +755,8 @@ class EventsStore(SQLBaseStore):
events_and_contexts=events_and_contexts, events_and_contexts=events_and_contexts,
) )
# Insert into the state_groups, state_groups_state, and # Insert into event_to_state_groups.
# event_to_state_groups tables. self._store_event_state_mappings_txn(txn, events_and_contexts)
self._store_mult_state_groups_txn(txn, events_and_contexts)
# _store_rejected_events_txn filters out any events which were # _store_rejected_events_txn filters out any events which were
# rejected, and returns the filtered list. # rejected, and returns the filtered list.
@ -743,7 +776,7 @@ class EventsStore(SQLBaseStore):
def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order): def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order):
for room_id, current_state_tuple in state_delta_by_room.iteritems(): for room_id, current_state_tuple in state_delta_by_room.iteritems():
to_delete, to_insert, _ = current_state_tuple to_delete, to_insert = current_state_tuple
txn.executemany( txn.executemany(
"DELETE FROM current_state_events WHERE event_id = ?", "DELETE FROM current_state_events WHERE event_id = ?",
[(ev_id,) for ev_id in to_delete.itervalues()], [(ev_id,) for ev_id in to_delete.itervalues()],
@ -958,10 +991,9 @@ class EventsStore(SQLBaseStore):
# an outlier in the database. We now have some state at that # an outlier in the database. We now have some state at that
# so we need to update the state_groups table with that state. # so we need to update the state_groups table with that state.
# insert into the state_group, state_groups_state and # insert into event_to_state_groups.
# event_to_state_groups tables.
try: try:
self._store_mult_state_groups_txn(txn, ((event, context),)) self._store_event_state_mappings_txn(txn, ((event, context),))
except Exception: except Exception:
logger.exception("") logger.exception("")
raise raise
@ -2031,16 +2063,32 @@ class EventsStore(SQLBaseStore):
) )
return self.runInteraction("get_all_new_events", get_all_new_events_txn) return self.runInteraction("get_all_new_events", get_all_new_events_txn)
def delete_old_state(self, room_id, topological_ordering): def purge_history(
return self.runInteraction( self, room_id, topological_ordering, delete_local_events,
"delete_old_state", ):
self._delete_old_state_txn, room_id, topological_ordering """Deletes room history before a certain point
)
def _delete_old_state_txn(self, txn, room_id, topological_ordering): Args:
"""Deletes old room state room_id (str):
topological_ordering (int):
minimum topo ordering to preserve
delete_local_events (bool):
if True, we will delete local events as well as remote ones
(instead of just marking them as outliers and deleting their
state groups).
""" """
return self.runInteraction(
"purge_history",
self._purge_history_txn, room_id, topological_ordering,
delete_local_events,
)
def _purge_history_txn(
self, txn, room_id, topological_ordering, delete_local_events,
):
# Tables that should be pruned: # Tables that should be pruned:
# event_auth # event_auth
# event_backward_extremities # event_backward_extremities
@ -2081,7 +2129,7 @@ class EventsStore(SQLBaseStore):
400, "topological_ordering is greater than forward extremeties" 400, "topological_ordering is greater than forward extremeties"
) )
logger.debug("[purge] looking for events to delete") logger.info("[purge] looking for events to delete")
txn.execute( txn.execute(
"SELECT event_id, state_key FROM events" "SELECT event_id, state_key FROM events"
@ -2093,16 +2141,16 @@ class EventsStore(SQLBaseStore):
to_delete = [ to_delete = [
(event_id,) for event_id, state_key in event_rows (event_id,) for event_id, state_key in event_rows
if state_key is None and not self.hs.is_mine_id(event_id) if state_key is None and (
delete_local_events or not self.hs.is_mine_id(event_id)
)
] ]
logger.info( logger.info(
"[purge] found %i events before cutoff, of which %i are remote" "[purge] found %i events before cutoff, of which %i can be deleted",
" non-state events to delete", len(event_rows), len(to_delete)) len(event_rows), len(to_delete),
)
for event_id, state_key in event_rows: logger.info("[purge] Finding new backward extremities")
txn.call_after(self._get_state_group_for_event.invalidate, (event_id,))
logger.debug("[purge] Finding new backward extremities")
# We calculate the new entries for the backward extremeties by finding # We calculate the new entries for the backward extremeties by finding
# all events that point to events that are to be purged # all events that point to events that are to be purged
@ -2116,7 +2164,7 @@ class EventsStore(SQLBaseStore):
) )
new_backwards_extrems = txn.fetchall() new_backwards_extrems = txn.fetchall()
logger.debug("[purge] replacing backward extremities: %r", new_backwards_extrems) logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems)
txn.execute( txn.execute(
"DELETE FROM event_backward_extremities WHERE room_id = ?", "DELETE FROM event_backward_extremities WHERE room_id = ?",
@ -2132,7 +2180,7 @@ class EventsStore(SQLBaseStore):
] ]
) )
logger.debug("[purge] finding redundant state groups") logger.info("[purge] finding redundant state groups")
# Get all state groups that are only referenced by events that are # Get all state groups that are only referenced by events that are
# to be deleted. # to be deleted.
@ -2149,15 +2197,15 @@ class EventsStore(SQLBaseStore):
) )
state_rows = txn.fetchall() state_rows = txn.fetchall()
logger.debug("[purge] found %i redundant state groups", len(state_rows)) logger.info("[purge] found %i redundant state groups", len(state_rows))
# make a set of the redundant state groups, so that we can look them up # make a set of the redundant state groups, so that we can look them up
# efficiently # efficiently
state_groups_to_delete = set([sg for sg, in state_rows]) state_groups_to_delete = set([sg for sg, in state_rows])
# Now we get all the state groups that rely on these state groups # Now we get all the state groups that rely on these state groups
logger.debug("[purge] finding state groups which depend on redundant" logger.info("[purge] finding state groups which depend on redundant"
" state groups") " state groups")
remaining_state_groups = [] remaining_state_groups = []
for i in xrange(0, len(state_rows), 100): for i in xrange(0, len(state_rows), 100):
chunk = [sg for sg, in state_rows[i:i + 100]] chunk = [sg for sg, in state_rows[i:i + 100]]
@ -2182,7 +2230,7 @@ class EventsStore(SQLBaseStore):
# Now we turn the state groups that reference to-be-deleted state # Now we turn the state groups that reference to-be-deleted state
# groups to non delta versions. # groups to non delta versions.
for sg in remaining_state_groups: for sg in remaining_state_groups:
logger.debug("[purge] de-delta-ing remaining state group %s", sg) logger.info("[purge] de-delta-ing remaining state group %s", sg)
curr_state = self._get_state_groups_from_groups_txn( curr_state = self._get_state_groups_from_groups_txn(
txn, [sg], types=None txn, [sg], types=None
) )
@ -2219,7 +2267,7 @@ class EventsStore(SQLBaseStore):
], ],
) )
logger.debug("[purge] removing redundant state groups") logger.info("[purge] removing redundant state groups")
txn.executemany( txn.executemany(
"DELETE FROM state_groups_state WHERE state_group = ?", "DELETE FROM state_groups_state WHERE state_group = ?",
state_rows state_rows
@ -2229,18 +2277,15 @@ class EventsStore(SQLBaseStore):
state_rows state_rows
) )
# Delete all non-state logger.info("[purge] removing events from event_to_state_groups")
logger.debug("[purge] removing events from event_to_state_groups")
txn.executemany( txn.executemany(
"DELETE FROM event_to_state_groups WHERE event_id = ?", "DELETE FROM event_to_state_groups WHERE event_id = ?",
[(event_id,) for event_id, _ in event_rows] [(event_id,) for event_id, _ in event_rows]
) )
for event_id, _ in event_rows:
logger.debug("[purge] updating room_depth") txn.call_after(self._get_state_group_for_event.invalidate, (
txn.execute( event_id,
"UPDATE room_depth SET min_depth = ? WHERE room_id = ?", ))
(topological_ordering, room_id,)
)
# Delete all remote non-state events # Delete all remote non-state events
for table in ( for table in (
@ -2258,7 +2303,8 @@ class EventsStore(SQLBaseStore):
"event_signatures", "event_signatures",
"rejections", "rejections",
): ):
logger.debug("[purge] removing remote non-state events from %s", table) logger.info("[purge] removing remote non-state events from %s",
table)
txn.executemany( txn.executemany(
"DELETE FROM %s WHERE event_id = ?" % (table,), "DELETE FROM %s WHERE event_id = ?" % (table,),
@ -2266,16 +2312,30 @@ class EventsStore(SQLBaseStore):
) )
# Mark all state and own events as outliers # Mark all state and own events as outliers
logger.debug("[purge] marking remaining events as outliers") logger.info("[purge] marking remaining events as outliers")
txn.executemany( txn.executemany(
"UPDATE events SET outlier = ?" "UPDATE events SET outlier = ?"
" WHERE event_id = ?", " WHERE event_id = ?",
[ [
(True, event_id,) for event_id, state_key in event_rows (True, event_id,) for event_id, state_key in event_rows
if state_key is not None or self.hs.is_mine_id(event_id) if state_key is not None or (
not delete_local_events and self.hs.is_mine_id(event_id)
)
] ]
) )
# synapse tries to take out an exclusive lock on room_depth whenever it
# persists events (because upsert), and once we run this update, we
# will block that for the rest of our transaction.
#
# So, let's stick it at the end so that we don't block event
# persistence.
logger.info("[purge] updating room_depth")
txn.execute(
"UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
(topological_ordering, room_id,)
)
logger.info("[purge] done") logger.info("[purge] done")
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -506,73 +506,114 @@ class RoomStore(SearchStore):
) )
self.is_room_blocked.invalidate((room_id,)) self.is_room_blocked.invalidate((room_id,))
def get_media_mxcs_in_room(self, room_id):
"""Retrieves all the local and remote media MXC URIs in a given room
Args:
room_id (str)
Returns:
The local and remote media as a lists of tuples where the key is
the hostname and the value is the media ID.
"""
def _get_media_mxcs_in_room_txn(txn):
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
local_media_mxcs = []
remote_media_mxcs = []
# Convert the IDs to MXC URIs
for media_id in local_mxcs:
local_media_mxcs.append("mxc://%s/%s" % (self.hostname, media_id))
for hostname, media_id in remote_mxcs:
remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
return local_media_mxcs, remote_media_mxcs
return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn)
def quarantine_media_ids_in_room(self, room_id, quarantined_by): def quarantine_media_ids_in_room(self, room_id, quarantined_by):
"""For a room loops through all events with media and quarantines """For a room loops through all events with media and quarantines
the associated media the associated media
""" """
def _get_media_ids_in_room(txn): def _quarantine_media_in_room_txn(txn):
mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)") local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
next_token = self.get_current_events_token() + 1
total_media_quarantined = 0 total_media_quarantined = 0
while next_token: # Now update all the tables to set the quarantined_by flag
sql = """
SELECT stream_ordering, content FROM events txn.executemany("""
WHERE room_id = ? UPDATE local_media_repository
AND stream_ordering < ? SET quarantined_by = ?
AND contains_url = ? AND outlier = ? WHERE media_id = ?
ORDER BY stream_ordering DESC """, ((quarantined_by, media_id) for media_id in local_mxcs))
LIMIT ?
txn.executemany(
""" """
txn.execute(sql, (room_id, next_token, True, False, 100)) UPDATE remote_media_cache
next_token = None
local_media_mxcs = []
remote_media_mxcs = []
for stream_ordering, content_json in txn:
next_token = stream_ordering
content = json.loads(content_json)
content_url = content.get("url")
thumbnail_url = content.get("info", {}).get("thumbnail_url")
for url in (content_url, thumbnail_url):
if not url:
continue
matches = mxc_re.match(url)
if matches:
hostname = matches.group(1)
media_id = matches.group(2)
if hostname == self.hostname:
local_media_mxcs.append(media_id)
else:
remote_media_mxcs.append((hostname, media_id))
# Now update all the tables to set the quarantined_by flag
txn.executemany("""
UPDATE local_media_repository
SET quarantined_by = ? SET quarantined_by = ?
WHERE media_id = ? WHERE media_origin = ? AND media_id = ?
""", ((quarantined_by, media_id) for media_id in local_media_mxcs)) """,
(
txn.executemany( (quarantined_by, origin, media_id)
""" for origin, media_id in remote_mxcs
UPDATE remote_media_cache
SET quarantined_by = ?
WHERE media_origin = ? AND media_id = ?
""",
(
(quarantined_by, origin, media_id)
for origin, media_id in remote_media_mxcs
)
) )
)
total_media_quarantined += len(local_media_mxcs) total_media_quarantined += len(local_mxcs)
total_media_quarantined += len(remote_media_mxcs) total_media_quarantined += len(remote_mxcs)
return total_media_quarantined return total_media_quarantined
return self.runInteraction("get_media_ids_in_room", _get_media_ids_in_room) return self.runInteraction(
"quarantine_media_in_room",
_quarantine_media_in_room_txn,
)
def _get_media_mxcs_in_room_txn(self, txn, room_id):
"""Retrieves all the local and remote media MXC URIs in a given room
Args:
txn (cursor)
room_id (str)
Returns:
The local and remote media as a lists of tuples where the key is
the hostname and the value is the media ID.
"""
mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
next_token = self.get_current_events_token() + 1
local_media_mxcs = []
remote_media_mxcs = []
while next_token:
sql = """
SELECT stream_ordering, content FROM events
WHERE room_id = ?
AND stream_ordering < ?
AND contains_url = ? AND outlier = ?
ORDER BY stream_ordering DESC
LIMIT ?
"""
txn.execute(sql, (room_id, next_token, True, False, 100))
next_token = None
for stream_ordering, content_json in txn:
next_token = stream_ordering
content = json.loads(content_json)
content_url = content.get("url")
thumbnail_url = content.get("info", {}).get("thumbnail_url")
for url in (content_url, thumbnail_url):
if not url:
continue
matches = mxc_re.match(url)
if matches:
hostname = matches.group(1)
media_id = matches.group(2)
if hostname == self.hostname:
local_media_mxcs.append(media_id)
else:
remote_media_mxcs.append((hostname, media_id))
return local_media_mxcs, remote_media_mxcs

View file

@ -0,0 +1,37 @@
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage.engines import PostgresEngine
def run_create(cur, database_engine, *args, **kwargs):
if isinstance(database_engine, PostgresEngine):
# if we already have some state groups, we want to start making new
# ones with a higher id.
cur.execute("SELECT max(id) FROM state_groups")
row = cur.fetchone()
if row[0] is None:
start_val = 1
else:
start_val = row[0] + 1
cur.execute(
"CREATE SEQUENCE state_group_id_seq START WITH %s",
(start_val, ),
)
def run_upgrade(*args, **kwargs):
pass

View file

@ -12,18 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import namedtuple from collections import namedtuple
import logging
import re
import sys import sys
import ujson as json
from twisted.internet import defer from twisted.internet import defer
from .background_updates import BackgroundUpdateStore from .background_updates import BackgroundUpdateStore
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.engines import PostgresEngine, Sqlite3Engine
import logging
import re
import ujson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -280,10 +281,10 @@ class SearchStore(BackgroundUpdateStore):
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
sql = ( sql = (
"INSERT INTO event_search" "INSERT INTO event_search"
" (event_id, room_id, key, vector, stream_ordering, " " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
" origin_server_ts)"
" VALUES (?,?,?,to_tsvector('english', ?),?,?)" " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
) )
args = (( args = ((
entry.event_id, entry.room_id, entry.key, entry.value, entry.event_id, entry.room_id, entry.key, entry.value,
entry.stream_ordering, entry.origin_server_ts, entry.stream_ordering, entry.origin_server_ts,

View file

@ -42,11 +42,8 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt
return len(self.delta_ids) if self.delta_ids else 0 return len(self.delta_ids) if self.delta_ids else 0
class StateGroupReadStore(SQLBaseStore): class StateGroupWorkerStore(SQLBaseStore):
"""The read-only parts of StateGroupStore """The parts of StateGroupStore that can be called from workers.
None of these functions write to the state tables, so are suitable for
including in the SlavedStores.
""" """
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
@ -54,7 +51,7 @@ class StateGroupReadStore(SQLBaseStore):
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(StateGroupReadStore, self).__init__(db_conn, hs) super(StateGroupWorkerStore, self).__init__(db_conn, hs)
self._state_group_cache = DictionaryCache( self._state_group_cache = DictionaryCache(
"*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
@ -549,8 +546,117 @@ class StateGroupReadStore(SQLBaseStore):
defer.returnValue(results) defer.returnValue(results)
def store_state_group(self, event_id, room_id, prev_group, delta_ids,
current_state_ids):
"""Store a new set of state, returning a newly assigned state group.
class StateStore(StateGroupReadStore, BackgroundUpdateStore): Args:
event_id (str): The event ID for which the state was calculated
room_id (str)
prev_group (int|None): A previous state group for the room, optional.
delta_ids (dict|None): The delta between state at `prev_group` and
`current_state_ids`, if `prev_group` was given. Same format as
`current_state_ids`.
current_state_ids (dict): The state to store. Map of (type, state_key)
to event_id.
Returns:
Deferred[int]: The state group ID
"""
def _store_state_group_txn(txn):
if current_state_ids is None:
# AFAIK, this can never happen
raise Exception("current_state_ids cannot be None")
state_group = self.database_engine.get_next_state_group_id(txn)
self._simple_insert_txn(
txn,
table="state_groups",
values={
"id": state_group,
"room_id": room_id,
"event_id": event_id,
},
)
# We persist as a delta if we can, while also ensuring the chain
# of deltas isn't tooo long, as otherwise read performance degrades.
if prev_group:
is_in_db = self._simple_select_one_onecol_txn(
txn,
table="state_groups",
keyvalues={"id": prev_group},
retcol="id",
allow_none=True,
)
if not is_in_db:
raise Exception(
"Trying to persist state with unpersisted prev_group: %r"
% (prev_group,)
)
potential_hops = self._count_state_group_hops_txn(
txn, prev_group
)
if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
self._simple_insert_txn(
txn,
table="state_group_edges",
values={
"state_group": state_group,
"prev_state_group": prev_group,
},
)
self._simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
{
"state_group": state_group,
"room_id": room_id,
"type": key[0],
"state_key": key[1],
"event_id": state_id,
}
for key, state_id in delta_ids.iteritems()
],
)
else:
self._simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
{
"state_group": state_group,
"room_id": room_id,
"type": key[0],
"state_key": key[1],
"event_id": state_id,
}
for key, state_id in current_state_ids.iteritems()
],
)
# Prefill the state group cache with this group.
# It's fine to use the sequence like this as the state group map
# is immutable. (If the map wasn't immutable then this prefill could
# race with another update)
txn.call_after(
self._state_group_cache.update,
self._state_group_cache.sequence,
key=state_group,
value=dict(current_state_ids),
full=True,
)
return state_group
return self.runInteraction("store_state_group", _store_state_group_txn)
class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
""" Keeps track of the state at a given event. """ Keeps track of the state at a given event.
This is done by the concept of `state groups`. Every event is a assigned This is done by the concept of `state groups`. Every event is a assigned
@ -591,27 +697,12 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
where_clause="type='m.room.member'", where_clause="type='m.room.member'",
) )
def _have_persisted_state_group_txn(self, txn, state_group): def _store_event_state_mappings_txn(self, txn, events_and_contexts):
txn.execute(
"SELECT count(*) FROM state_groups WHERE id = ?",
(state_group,)
)
row = txn.fetchone()
return row and row[0]
def _store_mult_state_groups_txn(self, txn, events_and_contexts):
state_groups = {} state_groups = {}
for event, context in events_and_contexts: for event, context in events_and_contexts:
if event.internal_metadata.is_outlier(): if event.internal_metadata.is_outlier():
continue continue
if context.current_state_ids is None:
# AFAIK, this can never happen
logger.error(
"Non-outlier event %s had current_state_ids==None",
event.event_id)
continue
# if the event was rejected, just give it the same state as its # if the event was rejected, just give it the same state as its
# predecessor. # predecessor.
if context.rejected: if context.rejected:
@ -620,90 +711,6 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
state_groups[event.event_id] = context.state_group state_groups[event.event_id] = context.state_group
if self._have_persisted_state_group_txn(txn, context.state_group):
continue
self._simple_insert_txn(
txn,
table="state_groups",
values={
"id": context.state_group,
"room_id": event.room_id,
"event_id": event.event_id,
},
)
# We persist as a delta if we can, while also ensuring the chain
# of deltas isn't tooo long, as otherwise read performance degrades.
if context.prev_group:
is_in_db = self._simple_select_one_onecol_txn(
txn,
table="state_groups",
keyvalues={"id": context.prev_group},
retcol="id",
allow_none=True,
)
if not is_in_db:
raise Exception(
"Trying to persist state with unpersisted prev_group: %r"
% (context.prev_group,)
)
potential_hops = self._count_state_group_hops_txn(
txn, context.prev_group
)
if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
self._simple_insert_txn(
txn,
table="state_group_edges",
values={
"state_group": context.state_group,
"prev_state_group": context.prev_group,
},
)
self._simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
{
"state_group": context.state_group,
"room_id": event.room_id,
"type": key[0],
"state_key": key[1],
"event_id": state_id,
}
for key, state_id in context.delta_ids.iteritems()
],
)
else:
self._simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
{
"state_group": context.state_group,
"room_id": event.room_id,
"type": key[0],
"state_key": key[1],
"event_id": state_id,
}
for key, state_id in context.current_state_ids.iteritems()
],
)
# Prefill the state group cache with this group.
# It's fine to use the sequence like this as the state group map
# is immutable. (If the map wasn't immutable then this prefill could
# race with another update)
txn.call_after(
self._state_group_cache.update,
self._state_group_cache.sequence,
key=context.state_group,
value=dict(context.current_state_ids),
full=True,
)
self._simple_insert_many_txn( self._simple_insert_many_txn(
txn, txn,
table="event_to_state_groups", table="event_to_state_groups",
@ -763,9 +770,6 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
return count return count
def get_next_state_group(self):
return self._state_groups_id_gen.get_next()
@defer.inlineCallbacks @defer.inlineCallbacks
def _background_deduplicate_state(self, progress, batch_size): def _background_deduplicate_state(self, progress, batch_size):
"""This background update will slowly deduplicate state by reencoding """This background update will slowly deduplicate state by reencoding

View file

@ -75,6 +75,7 @@ class Cache(object):
self.cache = LruCache( self.cache = LruCache(
max_size=max_entries, keylen=keylen, cache_type=cache_type, max_size=max_entries, keylen=keylen, cache_type=cache_type,
size_callback=(lambda d: len(d)) if iterable else None, size_callback=(lambda d: len(d)) if iterable else None,
evicted_callback=self._on_evicted,
) )
self.name = name self.name = name
@ -83,6 +84,9 @@ class Cache(object):
self.thread = None self.thread = None
self.metrics = register_cache(name, self.cache) self.metrics = register_cache(name, self.cache)
def _on_evicted(self, evicted_count):
self.metrics.inc_evictions(evicted_count)
def check_thread(self): def check_thread(self):
expected_thread = self.thread expected_thread = self.thread
if expected_thread is None: if expected_thread is None:

View file

@ -79,7 +79,11 @@ class ExpiringCache(object):
while self._max_len and len(self) > self._max_len: while self._max_len and len(self) > self._max_len:
_key, value = self._cache.popitem(last=False) _key, value = self._cache.popitem(last=False)
if self.iterable: if self.iterable:
self._size_estimate -= len(value.value) removed_len = len(value.value)
self.metrics.inc_evictions(removed_len)
self._size_estimate -= removed_len
else:
self.metrics.inc_evictions()
def __getitem__(self, key): def __getitem__(self, key):
try: try:

View file

@ -49,7 +49,24 @@ class LruCache(object):
Can also set callbacks on objects when getting/setting which are fired Can also set callbacks on objects when getting/setting which are fired
when that key gets invalidated/evicted. when that key gets invalidated/evicted.
""" """
def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None): def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None,
evicted_callback=None):
"""
Args:
max_size (int):
keylen (int):
cache_type (type):
type of underlying cache to be used. Typically one of dict
or TreeCache.
size_callback (func(V) -> int | None):
evicted_callback (func(int)|None):
if not None, called on eviction with the size of the evicted
entry
"""
cache = cache_type() cache = cache_type()
self.cache = cache # Used for introspection. self.cache = cache # Used for introspection.
list_root = _Node(None, None, None, None) list_root = _Node(None, None, None, None)
@ -61,8 +78,10 @@ class LruCache(object):
def evict(): def evict():
while cache_len() > max_size: while cache_len() > max_size:
todelete = list_root.prev_node todelete = list_root.prev_node
delete_node(todelete) evicted_len = delete_node(todelete)
cache.pop(todelete.key, None) cache.pop(todelete.key, None)
if evicted_callback:
evicted_callback(evicted_len)
def synchronized(f): def synchronized(f):
@wraps(f) @wraps(f)
@ -111,12 +130,15 @@ class LruCache(object):
prev_node.next_node = next_node prev_node.next_node = next_node
next_node.prev_node = prev_node next_node.prev_node = prev_node
deleted_len = 1
if size_callback: if size_callback:
cached_cache_len[0] -= size_callback(node.value) deleted_len = size_callback(node.value)
cached_cache_len[0] -= deleted_len
for cb in node.callbacks: for cb in node.callbacks:
cb() cb()
node.callbacks.clear() node.callbacks.clear()
return deleted_len
@synchronized @synchronized
def cache_get(key, default=None, callbacks=[]): def cache_get(key, default=None, callbacks=[]):

View file

@ -141,6 +141,7 @@ class CacheMetricTestCase(unittest.TestCase):
'cache:hits{name="cache_name"} 0', 'cache:hits{name="cache_name"} 0',
'cache:total{name="cache_name"} 0', 'cache:total{name="cache_name"} 0',
'cache:size{name="cache_name"} 0', 'cache:size{name="cache_name"} 0',
'cache:evicted_size{name="cache_name"} 0',
]) ])
metric.inc_misses() metric.inc_misses()
@ -150,6 +151,7 @@ class CacheMetricTestCase(unittest.TestCase):
'cache:hits{name="cache_name"} 0', 'cache:hits{name="cache_name"} 0',
'cache:total{name="cache_name"} 1', 'cache:total{name="cache_name"} 1',
'cache:size{name="cache_name"} 1', 'cache:size{name="cache_name"} 1',
'cache:evicted_size{name="cache_name"} 0',
]) ])
metric.inc_hits() metric.inc_hits()
@ -158,4 +160,14 @@ class CacheMetricTestCase(unittest.TestCase):
'cache:hits{name="cache_name"} 1', 'cache:hits{name="cache_name"} 1',
'cache:total{name="cache_name"} 2', 'cache:total{name="cache_name"} 2',
'cache:size{name="cache_name"} 1', 'cache:size{name="cache_name"} 1',
'cache:evicted_size{name="cache_name"} 0',
])
metric.inc_evictions(2)
self.assertEquals(metric.render(), [
'cache:hits{name="cache_name"} 1',
'cache:total{name="cache_name"} 2',
'cache:size{name="cache_name"} 1',
'cache:evicted_size{name="cache_name"} 2',
]) ])

View file

@ -226,11 +226,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
context = EventContext() context = EventContext()
context.current_state_ids = state_ids context.current_state_ids = state_ids
context.prev_state_ids = state_ids context.prev_state_ids = state_ids
elif not backfill: else:
state_handler = self.hs.get_state_handler() state_handler = self.hs.get_state_handler()
context = yield state_handler.compute_event_context(event) context = yield state_handler.compute_event_context(event)
else:
context = EventContext()
context.push_actions = push_actions context.push_actions = push_actions

View file

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,86 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.rest.media.v1.filepath import MediaFilePaths
from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
from mock import Mock
from tests import unittest
import os
import shutil
import tempfile
class MediaStorageTests(unittest.TestCase):
def setUp(self):
self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-")
self.primary_base_path = os.path.join(self.test_dir, "primary")
self.secondary_base_path = os.path.join(self.test_dir, "secondary")
hs = Mock()
hs.config.media_store_path = self.primary_base_path
storage_providers = [FileStorageProviderBackend(
hs, self.secondary_base_path
)]
self.filepaths = MediaFilePaths(self.primary_base_path)
self.media_storage = MediaStorage(
self.primary_base_path, self.filepaths, storage_providers,
)
def tearDown(self):
shutil.rmtree(self.test_dir)
@defer.inlineCallbacks
def test_ensure_media_is_in_local_cache(self):
media_id = "some_media_id"
test_body = "Test\n"
# First we create a file that is in a storage provider but not in the
# local primary media store
rel_path = self.filepaths.local_media_filepath_rel(media_id)
secondary_path = os.path.join(self.secondary_base_path, rel_path)
os.makedirs(os.path.dirname(secondary_path))
with open(secondary_path, "w") as f:
f.write(test_body)
# Now we run ensure_media_is_in_local_cache, which should copy the file
# to the local cache.
file_info = FileInfo(None, media_id)
local_path = yield self.media_storage.ensure_media_is_in_local_cache(file_info)
self.assertTrue(os.path.exists(local_path))
# Asserts the file is under the expected local cache directory
self.assertEquals(
os.path.commonprefix([self.primary_base_path, local_path]),
self.primary_base_path,
)
with open(local_path) as f:
body = f.read()
self.assertEqual(test_body, body)

View file

@ -36,8 +36,7 @@ class RedactionTestCase(unittest.TestCase):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.event_builder_factory = hs.get_event_builder_factory() self.event_builder_factory = hs.get_event_builder_factory()
self.handlers = hs.get_handlers() self.event_creation_handler = hs.get_event_creation_handler()
self.message_handler = self.handlers.message_handler
self.u_alice = UserID.from_string("@alice:test") self.u_alice = UserID.from_string("@alice:test")
self.u_bob = UserID.from_string("@bob:test") self.u_bob = UserID.from_string("@bob:test")
@ -59,7 +58,7 @@ class RedactionTestCase(unittest.TestCase):
"content": content, "content": content,
}) })
event, context = yield self.message_handler._create_new_client_event( event, context = yield self.event_creation_handler.create_new_client_event(
builder builder
) )
@ -79,7 +78,7 @@ class RedactionTestCase(unittest.TestCase):
"content": {"body": body, "msgtype": u"message"}, "content": {"body": body, "msgtype": u"message"},
}) })
event, context = yield self.message_handler._create_new_client_event( event, context = yield self.event_creation_handler.create_new_client_event(
builder builder
) )
@ -98,7 +97,7 @@ class RedactionTestCase(unittest.TestCase):
"redacts": event_id, "redacts": event_id,
}) })
event, context = yield self.message_handler._create_new_client_event( event, context = yield self.event_creation_handler.create_new_client_event(
builder builder
) )

View file

@ -37,8 +37,7 @@ class RoomMemberStoreTestCase(unittest.TestCase):
# storage logic # storage logic
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.event_builder_factory = hs.get_event_builder_factory() self.event_builder_factory = hs.get_event_builder_factory()
self.handlers = hs.get_handlers() self.event_creation_handler = hs.get_event_creation_handler()
self.message_handler = self.handlers.message_handler
self.u_alice = UserID.from_string("@alice:test") self.u_alice = UserID.from_string("@alice:test")
self.u_bob = UserID.from_string("@bob:test") self.u_bob = UserID.from_string("@bob:test")
@ -58,7 +57,7 @@ class RoomMemberStoreTestCase(unittest.TestCase):
"content": {"membership": membership}, "content": {"membership": membership},
}) })
event, context = yield self.message_handler._create_new_client_event( event, context = yield self.event_creation_handler.create_new_client_event(
builder builder
) )

View file

@ -19,7 +19,7 @@ from twisted.internet import defer
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.api.auth import Auth from synapse.api.auth import Auth
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.state import StateHandler from synapse.state import StateHandler, StateResolutionHandler
from .utils import MockClock from .utils import MockClock
@ -80,14 +80,14 @@ class StateGroupStore(object):
return defer.succeed(groups) return defer.succeed(groups)
def store_state_groups(self, event, context): def store_state_group(self, event_id, room_id, prev_group, delta_ids,
if context.current_state_ids is None: current_state_ids):
return state_group = self._next_group
self._next_group += 1
state_events = dict(context.current_state_ids) self._group_to_state[state_group] = dict(current_state_ids)
self._group_to_state[context.state_group] = state_events return state_group
self._event_to_state_group[event.event_id] = context.state_group
def get_events(self, event_ids, **kwargs): def get_events(self, event_ids, **kwargs):
return { return {
@ -95,10 +95,19 @@ class StateGroupStore(object):
if e_id in self._event_id_to_event if e_id in self._event_id_to_event
} }
def get_state_group_delta(self, name):
return (None, None)
def register_events(self, events): def register_events(self, events):
for e in events: for e in events:
self._event_id_to_event[e.event_id] = e self._event_id_to_event[e.event_id] = e
def register_event_context(self, event, context):
self._event_to_state_group[event.event_id] = context.state_group
def register_event_id_state_group(self, event_id, state_group):
self._event_to_state_group[event_id] = state_group
class DictObj(dict): class DictObj(dict):
def __init__(self, **kwargs): def __init__(self, **kwargs):
@ -137,25 +146,16 @@ class Graph(object):
class StateTestCase(unittest.TestCase): class StateTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.store = Mock( self.store = StateGroupStore()
spec_set=[
"get_state_groups_ids",
"add_event_hashes",
"get_events",
"get_next_state_group",
"get_state_group_delta",
]
)
hs = Mock(spec_set=[ hs = Mock(spec_set=[
"get_datastore", "get_auth", "get_state_handler", "get_clock", "get_datastore", "get_auth", "get_state_handler", "get_clock",
"get_state_resolution_handler",
]) ])
hs.get_datastore.return_value = self.store hs.get_datastore.return_value = self.store
hs.get_state_handler.return_value = None hs.get_state_handler.return_value = None
hs.get_clock.return_value = MockClock() hs.get_clock.return_value = MockClock()
hs.get_auth.return_value = Auth(hs) hs.get_auth.return_value = Auth(hs)
hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
self.store.get_next_state_group.side_effect = Mock
self.store.get_state_group_delta.return_value = (None, None)
self.state = StateHandler(hs) self.state = StateHandler(hs)
self.event_id = 0 self.event_id = 0
@ -195,14 +195,13 @@ class StateTestCase(unittest.TestCase):
} }
) )
store = StateGroupStore() self.store.register_events(graph.walk())
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
context_store = {} context_store = {}
for event in graph.walk(): for event in graph.walk():
context = yield self.state.compute_event_context(event) context = yield self.state.compute_event_context(event)
store.store_state_groups(event, context) self.store.register_event_context(event, context)
context_store[event.event_id] = context context_store[event.event_id] = context
self.assertEqual(2, len(context_store["D"].prev_state_ids)) self.assertEqual(2, len(context_store["D"].prev_state_ids))
@ -247,16 +246,13 @@ class StateTestCase(unittest.TestCase):
} }
) )
store = StateGroupStore() self.store.register_events(graph.walk())
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
self.store.get_events = store.get_events
store.register_events(graph.walk())
context_store = {} context_store = {}
for event in graph.walk(): for event in graph.walk():
context = yield self.state.compute_event_context(event) context = yield self.state.compute_event_context(event)
store.store_state_groups(event, context) self.store.register_event_context(event, context)
context_store[event.event_id] = context context_store[event.event_id] = context
self.assertSetEqual( self.assertSetEqual(
@ -313,16 +309,13 @@ class StateTestCase(unittest.TestCase):
} }
) )
store = StateGroupStore() self.store.register_events(graph.walk())
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
self.store.get_events = store.get_events
store.register_events(graph.walk())
context_store = {} context_store = {}
for event in graph.walk(): for event in graph.walk():
context = yield self.state.compute_event_context(event) context = yield self.state.compute_event_context(event)
store.store_state_groups(event, context) self.store.register_event_context(event, context)
context_store[event.event_id] = context context_store[event.event_id] = context
self.assertSetEqual( self.assertSetEqual(
@ -396,16 +389,13 @@ class StateTestCase(unittest.TestCase):
self._add_depths(nodes, edges) self._add_depths(nodes, edges)
graph = Graph(nodes, edges) graph = Graph(nodes, edges)
store = StateGroupStore() self.store.register_events(graph.walk())
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
self.store.get_events = store.get_events
store.register_events(graph.walk())
context_store = {} context_store = {}
for event in graph.walk(): for event in graph.walk():
context = yield self.state.compute_event_context(event) context = yield self.state.compute_event_context(event)
store.store_state_groups(event, context) self.store.register_event_context(event, context)
context_store[event.event_id] = context context_store[event.event_id] = context
self.assertSetEqual( self.assertSetEqual(
@ -465,7 +455,11 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_trivial_annotate_message(self): def test_trivial_annotate_message(self):
event = create_event(type="test_message", name="event") prev_event_id = "prev_event_id"
event = create_event(
type="test_message", name="event2",
prev_events=[(prev_event_id, {})],
)
old_state = [ old_state = [
create_event(type="test1", state_key="1"), create_event(type="test1", state_key="1"),
@ -473,11 +467,11 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""), create_event(type="test2", state_key=""),
] ]
group_name = "group_name_1" group_name = self.store.store_state_group(
prev_event_id, event.room_id, None, None,
self.store.get_state_groups_ids.return_value = { {(e.type, e.state_key): e.event_id for e in old_state},
group_name: {(e.type, e.state_key): e.event_id for e in old_state}, )
} self.store.register_event_id_state_group(prev_event_id, group_name)
context = yield self.state.compute_event_context(event) context = yield self.state.compute_event_context(event)
@ -490,7 +484,11 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_trivial_annotate_state(self): def test_trivial_annotate_state(self):
event = create_event(type="state", state_key="", name="event") prev_event_id = "prev_event_id"
event = create_event(
type="state", state_key="", name="event2",
prev_events=[(prev_event_id, {})],
)
old_state = [ old_state = [
create_event(type="test1", state_key="1"), create_event(type="test1", state_key="1"),
@ -498,11 +496,11 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""), create_event(type="test2", state_key=""),
] ]
group_name = "group_name_1" group_name = self.store.store_state_group(
prev_event_id, event.room_id, None, None,
self.store.get_state_groups_ids.return_value = { {(e.type, e.state_key): e.event_id for e in old_state},
group_name: {(e.type, e.state_key): e.event_id for e in old_state}, )
} self.store.register_event_id_state_group(prev_event_id, group_name)
context = yield self.state.compute_event_context(event) context = yield self.state.compute_event_context(event)
@ -515,7 +513,12 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_resolve_message_conflict(self): def test_resolve_message_conflict(self):
event = create_event(type="test_message", name="event") prev_event_id1 = "event_id1"
prev_event_id2 = "event_id2"
event = create_event(
type="test_message", name="event3",
prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
)
creation = create_event( creation = create_event(
type=EventTypes.Create, state_key="" type=EventTypes.Create, state_key=""
@ -535,12 +538,12 @@ class StateTestCase(unittest.TestCase):
create_event(type="test4", state_key=""), create_event(type="test4", state_key=""),
] ]
store = StateGroupStore() self.store.register_events(old_state_1)
store.register_events(old_state_1) self.store.register_events(old_state_2)
store.register_events(old_state_2)
self.store.get_events = store.get_events
context = yield self._get_context(event, old_state_1, old_state_2) context = yield self._get_context(
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
)
self.assertEqual(len(context.current_state_ids), 6) self.assertEqual(len(context.current_state_ids), 6)
@ -548,7 +551,12 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_resolve_state_conflict(self): def test_resolve_state_conflict(self):
event = create_event(type="test4", state_key="", name="event") prev_event_id1 = "event_id1"
prev_event_id2 = "event_id2"
event = create_event(
type="test4", state_key="", name="event",
prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
)
creation = create_event( creation = create_event(
type=EventTypes.Create, state_key="" type=EventTypes.Create, state_key=""
@ -573,7 +581,9 @@ class StateTestCase(unittest.TestCase):
store.register_events(old_state_2) store.register_events(old_state_2)
self.store.get_events = store.get_events self.store.get_events = store.get_events
context = yield self._get_context(event, old_state_1, old_state_2) context = yield self._get_context(
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
)
self.assertEqual(len(context.current_state_ids), 6) self.assertEqual(len(context.current_state_ids), 6)
@ -581,7 +591,12 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_standard_depth_conflict(self): def test_standard_depth_conflict(self):
event = create_event(type="test4", name="event") prev_event_id1 = "event_id1"
prev_event_id2 = "event_id2"
event = create_event(
type="test4", name="event",
prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
)
member_event = create_event( member_event = create_event(
type=EventTypes.Member, type=EventTypes.Member,
@ -613,7 +628,9 @@ class StateTestCase(unittest.TestCase):
store.register_events(old_state_2) store.register_events(old_state_2)
self.store.get_events = store.get_events self.store.get_events = store.get_events
context = yield self._get_context(event, old_state_1, old_state_2) context = yield self._get_context(
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
)
self.assertEqual( self.assertEqual(
old_state_2[2].event_id, context.current_state_ids[("test1", "1")] old_state_2[2].event_id, context.current_state_ids[("test1", "1")]
@ -637,19 +654,26 @@ class StateTestCase(unittest.TestCase):
store.register_events(old_state_1) store.register_events(old_state_1)
store.register_events(old_state_2) store.register_events(old_state_2)
context = yield self._get_context(event, old_state_1, old_state_2) context = yield self._get_context(
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
)
self.assertEqual( self.assertEqual(
old_state_1[2].event_id, context.current_state_ids[("test1", "1")] old_state_1[2].event_id, context.current_state_ids[("test1", "1")]
) )
def _get_context(self, event, old_state_1, old_state_2): def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2,
group_name_1 = "group_name_1" old_state_2):
group_name_2 = "group_name_2" sg1 = self.store.store_state_group(
prev_event_id_1, event.room_id, None, None,
{(e.type, e.state_key): e.event_id for e in old_state_1},
)
self.store.register_event_id_state_group(prev_event_id_1, sg1)
self.store.get_state_groups_ids.return_value = { sg2 = self.store.store_state_group(
group_name_1: {(e.type, e.state_key): e.event_id for e in old_state_1}, prev_event_id_2, event.room_id, None, None,
group_name_2: {(e.type, e.state_key): e.event_id for e in old_state_2}, {(e.type, e.state_key): e.event_id for e in old_state_2},
} )
self.store.register_event_id_state_group(prev_event_id_2, sg2)
return self.state.compute_event_context(event) return self.state.compute_event_context(event)