Merge branch 'develop' into markjh/direct_to_device_synchrotron

This commit is contained in:
Mark Haines 2016-09-02 10:59:24 +01:00
commit 965168a842
28 changed files with 431 additions and 235 deletions

View file

@ -134,6 +134,12 @@ Installing prerequisites on Raspbian::
sudo pip install --upgrade ndg-httpsclient
sudo pip install --upgrade virtualenv
Installing prerequisites on openSUSE::
sudo zypper in -t pattern devel_basis
sudo zypper in python-pip python-setuptools sqlite3 python-virtualenv \
python-devel libffi-devel libopenssl-devel libjpeg62-devel
To install the synapse homeserver run::
virtualenv -p python2.7 ~/.synapse
@ -230,9 +236,6 @@ The advantages of Postgres include:
pointing at the same DB master, as well as enabling DB replication in
synapse itself.
The only disadvantage is that the code is relatively new as of April 2015 and
may have a few regressions relative to SQLite.
For information on how to install and use PostgreSQL, please see
`docs/postgres.rst <docs/postgres.rst>`_.

View file

@ -66,7 +66,7 @@ class Auth(object):
@defer.inlineCallbacks
def check_from_context(self, event, context, do_sig_check=True):
auth_events_ids = yield self.compute_auth_events(
event, context.current_state_ids, for_verification=True,
event, context.prev_state_ids, for_verification=True,
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {
@ -281,11 +281,13 @@ class Auth(object):
with Measure(self.clock, "check_host_in_room"):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
group, curr_state_ids = yield self.state.resolve_state_groups(
entry = yield self.state.resolve_state_groups(
room_id, latest_event_ids
)
ret = yield self.store.is_host_joined(room_id, host, group, curr_state_ids)
ret = yield self.store.is_host_joined(
room_id, host, entry.state_group, entry.state
)
defer.returnValue(ret)
def check_event_sender_in_room(self, event, auth_events):
@ -852,7 +854,7 @@ class Auth(object):
@defer.inlineCallbacks
def add_auth_events(self, builder, context):
auth_ids = yield self.compute_auth_events(builder, context.current_state_ids)
auth_ids = yield self.compute_auth_events(builder, context.prev_state_ids)
auth_events_entries = yield self.store.add_event_hashes(
auth_ids

View file

@ -67,6 +67,8 @@ class ApplicationServiceApi(SimpleHttpClient):
@defer.inlineCallbacks
def query_user(self, service, user_id):
if service.url is None:
defer.returnValue(False)
uri = service.url + ("/users/%s" % urllib.quote(user_id))
response = None
try:
@ -86,6 +88,8 @@ class ApplicationServiceApi(SimpleHttpClient):
@defer.inlineCallbacks
def query_alias(self, service, alias):
if service.url is None:
defer.returnValue(False)
uri = service.url + ("/rooms/%s" % urllib.quote(alias))
response = None
try:
@ -113,6 +117,8 @@ class ApplicationServiceApi(SimpleHttpClient):
raise ValueError(
"Unrecognised 'kind' argument %r to query_3pe()", kind
)
if service.url is None:
defer.returnValue([])
uri = "%s%s/thirdparty/%s/%s" % (
service.url,
@ -145,6 +151,9 @@ class ApplicationServiceApi(SimpleHttpClient):
defer.returnValue([])
def get_3pe_protocol(self, service, protocol):
if service.url is None:
defer.returnValue({})
@defer.inlineCallbacks
def _get():
uri = "%s%s/thirdparty/protocol/%s" % (
@ -166,6 +175,9 @@ class ApplicationServiceApi(SimpleHttpClient):
@defer.inlineCallbacks
def push_bulk(self, service, events, txn_id=None):
if service.url is None:
defer.returnValue(True)
events = self._serialize(events)
if txn_id is None:

View file

@ -86,7 +86,7 @@ def load_appservices(hostname, config_files):
def _load_appservice(hostname, as_info, config_filename):
required_string_fields = [
"id", "url", "as_token", "hs_token", "sender_localpart"
"id", "as_token", "hs_token", "sender_localpart"
]
for field in required_string_fields:
if not isinstance(as_info.get(field), basestring):
@ -94,6 +94,14 @@ def _load_appservice(hostname, as_info, config_filename):
field, config_filename,
))
# 'url' must either be a string or explicitly null, not missing
# to avoid accidentally turning off push for ASes.
if (not isinstance(as_info.get("url"), basestring) and
as_info.get("url", "") is not None):
raise KeyError(
"Required string field or explicit null: 'url' (%s)" % (config_filename,)
)
localpart = as_info["sender_localpart"]
if urllib.quote(localpart) != localpart:
raise ValueError(
@ -132,6 +140,13 @@ def _load_appservice(hostname, as_info, config_filename):
for p in protocols:
if not isinstance(p, str):
raise KeyError("Bad value for 'protocols' item")
if as_info["url"] is None:
logger.info(
"(%s) Explicitly empty 'url' provided. This application service"
" will not receive events or queries.",
config_filename,
)
return ApplicationService(
token=as_info["as_token"],
url=as_info["url"],

View file

@ -15,8 +15,9 @@
class EventContext(object):
def __init__(self, current_state_ids=None):
self.current_state_ids = current_state_ids
def __init__(self):
self.current_state_ids = None
self.prev_state_ids = None
self.state_group = None
self.rejected = False
self.push_actions = []

View file

@ -269,7 +269,7 @@ class FederationClient(FederationBase):
pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
pdu = None
signed_pdu = None
for destination in destinations:
now = self._clock.time_msec()
last_attempt = pdu_attempts.get(destination, 0)
@ -299,7 +299,7 @@ class FederationClient(FederationBase):
pdu = pdu_list[0]
# Check signatures are correct.
pdu = yield self._check_sigs_and_hashes([pdu])[0]
signed_pdu = yield self._check_sigs_and_hashes([pdu])[0]
break
@ -322,10 +322,10 @@ class FederationClient(FederationBase):
)
continue
if self._get_pdu_cache is not None and pdu:
self._get_pdu_cache[event_id] = pdu
if self._get_pdu_cache is not None and signed_pdu:
self._get_pdu_cache[event_id] = signed_pdu
defer.returnValue(pdu)
defer.returnValue(signed_pdu)
@defer.inlineCallbacks
@log_function

View file

@ -222,7 +222,7 @@ class FederationHandler(BaseHandler):
# joined the room. Don't bother if the user is just
# changing their profile info.
newly_joined = True
prev_state_id = context.current_state_ids.get(
prev_state_id = context.prev_state_ids.get(
(event.type, event.state_key)
)
if prev_state_id:
@ -835,12 +835,12 @@ class FederationHandler(BaseHandler):
self.replication_layer.send_pdu(new_pdu, destinations)
state_ids = context.current_state_ids.values()
state_ids = context.prev_state_ids.values()
auth_chain = yield self.store.get_auth_chain(set(
[event.event_id] + state_ids
))
state = yield self.store.get_events(context.current_state_ids.values())
state = yield self.store.get_events(context.prev_state_ids.values())
defer.returnValue({
"state": state.values(),
@ -1333,7 +1333,7 @@ class FederationHandler(BaseHandler):
if not auth_events:
auth_events_ids = yield self.auth.compute_auth_events(
event, context.current_state_ids, for_verification=True,
event, context.prev_state_ids, for_verification=True,
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {
@ -1432,6 +1432,11 @@ class FederationHandler(BaseHandler):
current_state = set(e.event_id for e in auth_events.values())
event_auth_events = set(e_id for e_id, _ in event.auth_events)
if event.is_state():
event_key = (event.type, event.state_key)
else:
event_key = None
if event_auth_events - current_state:
have_events = yield self.store.have_events(
event_auth_events - current_state
@ -1537,8 +1542,12 @@ class FederationHandler(BaseHandler):
context.current_state_ids.update({
k: a.event_id for k, a in auth_events.items()
if k != event_key
})
context.state_group = None
context.prev_state_ids.update({
k: a.event_id for k, a in auth_events.items()
})
context.state_group = self.store.get_next_state_group()
if different_auth and not event.internal_metadata.is_outlier():
logger.info("Different auth after resolution: %s", different_auth)
@ -1560,7 +1569,7 @@ class FederationHandler(BaseHandler):
if do_resolution:
# 1. Get what we think is the auth chain.
auth_ids = yield self.auth.compute_auth_events(
event, context.current_state_ids
event, context.prev_state_ids
)
local_auth_chain = yield self.store.get_auth_chain(auth_ids)
@ -1618,8 +1627,12 @@ class FederationHandler(BaseHandler):
context.current_state_ids.update({
k: a.event_id for k, a in auth_events.items()
if k != event_key
})
context.state_group = None
context.prev_state_ids.update({
k: a.event_id for k, a in auth_events.items()
})
context.state_group = self.store.get_next_state_group()
try:
self.auth.check(event, auth_events=auth_events)
@ -1855,7 +1868,7 @@ class FederationHandler(BaseHandler):
event.content["third_party_invite"]["signed"]["token"]
)
original_invite = None
original_invite_id = context.current_state_ids.get(key)
original_invite_id = context.prev_state_ids.get(key)
if original_invite_id:
original_invite = yield self.store.get_event(
original_invite_id, allow_none=True
@ -1893,7 +1906,7 @@ class FederationHandler(BaseHandler):
signed = event.content["third_party_invite"]["signed"]
token = signed["token"]
invite_event_id = context.current_state_ids.get(
invite_event_id = context.prev_state_ids.get(
(EventTypes.ThirdPartyInvite, token,)
)

View file

@ -272,7 +272,7 @@ class MessageHandler(BaseHandler):
If so, returns the version of the event in context.
Otherwise, returns None.
"""
prev_event_id = context.current_state_ids.get((event.type, event.state_key))
prev_event_id = context.prev_state_ids.get((event.type, event.state_key))
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
if not prev_event:
return
@ -808,8 +808,8 @@ class MessageHandler(BaseHandler):
event = builder.build()
logger.debug(
"Created event %s with current state: %s",
event.event_id, context.current_state_ids,
"Created event %s with state: %s",
event.event_id, context.prev_state_ids,
)
defer.returnValue(
@ -904,7 +904,7 @@ class MessageHandler(BaseHandler):
if event.type == EventTypes.Redaction:
auth_events_ids = yield self.auth.compute_auth_events(
event, context.current_state_ids, for_verification=True,
event, context.prev_state_ids, for_verification=True,
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {
@ -924,7 +924,7 @@ class MessageHandler(BaseHandler):
"You don't have permission to redact events"
)
if event.type == EventTypes.Create and context.current_state_ids:
if event.type == EventTypes.Create and context.prev_state_ids:
raise AuthError(
403,
"Changing the room create event is forbidden",

View file

@ -191,6 +191,13 @@ class PresenceHandler(object):
5000,
)
self.clock.call_later(
60,
self.clock.looping_call,
self._persist_unpersisted_changes,
60 * 1000,
)
metrics.register_callback("wheel_timer_size", lambda: len(self.wheel_timer))
@defer.inlineCallbacks
@ -216,6 +223,27 @@ class PresenceHandler(object):
])
logger.info("Finished _on_shutdown")
@defer.inlineCallbacks
def _persist_unpersisted_changes(self):
"""We periodically persist the unpersisted changes, as otherwise they
may stack up and slow down shutdown times.
"""
logger.info(
"Performing _persist_unpersisted_changes. Persiting %d unpersisted changes",
len(self.unpersisted_users_changes)
)
unpersisted = self.unpersisted_users_changes
self.unpersisted_users_changes = set()
if unpersisted:
yield self.store.update_presence([
self.user_to_current_state[user_id]
for user_id in unpersisted
])
logger.info("Finished _persist_unpersisted_changes")
@defer.inlineCallbacks
def _update_states(self, new_states):
"""Updates presence of users. Sets the appropriate timeouts. Pokes
@ -922,7 +950,12 @@ def should_notify(old_state, new_state):
if new_state.currently_active != old_state.currently_active:
return True
if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY:
if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY:
# Only notify about last active bumps if we're not currently acive
if not (old_state.currently_active and new_state.currently_active):
return True
elif new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY:
# Always notify for a transition where last active gets bumped.
return True

View file

@ -93,7 +93,7 @@ class RoomMemberHandler(BaseHandler):
ratelimit=ratelimit,
)
prev_member_event_id = context.current_state_ids.get(
prev_member_event_id = context.prev_state_ids.get(
(EventTypes.Member, target.to_string()),
None
)
@ -341,7 +341,7 @@ class RoomMemberHandler(BaseHandler):
if event.membership == Membership.JOIN:
if requester.is_guest:
guest_can_join = yield self._can_guest_join(context.current_state_ids)
guest_can_join = yield self._can_guest_join(context.prev_state_ids)
if not guest_can_join:
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
@ -355,7 +355,7 @@ class RoomMemberHandler(BaseHandler):
ratelimit=ratelimit,
)
prev_member_event_id = context.current_state_ids.get(
prev_member_event_id = context.prev_state_ids.get(
(EventTypes.Member, event.state_key),
None
)

View file

@ -565,21 +565,26 @@ class SyncHandler(object):
if sync_result_builder.since_token is not None:
since_stream_id = int(sync_result_builder.since_token.to_device_key)
if since_stream_id:
if since_stream_id != int(now_token.to_device_key):
# We only delete messages when a new message comes in, but that's
# fine so long as we delete them at some point.
logger.debug("Deleting messages up to %d", since_stream_id)
yield self.store.delete_messages_for_device(
user_id, device_id, since_stream_id
)
logger.debug("Getting messages up to %d", now_token.to_device_key)
messages, stream_id = yield self.store.get_new_messages_for_device(
user_id, device_id, now_token.to_device_key
)
logger.debug("Got messages up to %d: %r", stream_id, messages)
sync_result_builder.now_token = now_token.copy_and_replace(
"to_device_key", stream_id
)
sync_result_builder.to_device = messages
logger.debug("Getting messages up to %d", now_token.to_device_key)
messages, stream_id = yield self.store.get_new_messages_for_device(
user_id, device_id, now_token.to_device_key
)
logger.debug("Got messages up to %d: %r", stream_id, messages)
sync_result_builder.now_token = now_token.copy_and_replace(
"to_device_key", stream_id
)
sync_result_builder.to_device = messages
else:
sync_result_builder.to_device = []
@defer.inlineCallbacks
def _generate_sync_entry_for_account_data(self, sync_result_builder):

View file

@ -87,7 +87,7 @@ class BulkPushRuleEvaluator:
)
room_members = yield self.store.get_joined_users_from_context(
event.room_id, context.state_group, context.current_state_ids
event, context
)
evaluator = PushRuleEvaluatorForEvent(event, len(room_members))

View file

@ -338,7 +338,7 @@ class Mailer(object):
# want the generated-from-names one here otherwise we'll
# end up with, "new message from Bob in the Bob room"
room_name = yield calculate_room_name(
state_by_room[room_id], user_id, fallback_to_members=False
self.store, state_by_room[room_id], user_id, fallback_to_members=False
)
my_member_event = state_by_room[room_id][("m.room.member", user_id)]

View file

@ -74,7 +74,7 @@ def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True
alias_event = yield store.get_event(
alias_id, allow_none=True
)
if alias_event and alias_event.content and alias_event.get("aliases"):
if alias_event and alias_event.content.get("aliases"):
the_aliases = alias_event.content["aliases"]
if len(the_aliases) > 0 and _looks_like_an_alias(the_aliases[0]):
defer.returnValue(the_aliases[0])

View file

@ -40,7 +40,6 @@ STREAM_NAMES = (
("backfill",),
("push_rules",),
("pushers",),
("state",),
("caches",),
("to_device",),
)
@ -131,7 +130,6 @@ class ReplicationResource(Resource):
backfill_token = yield self.store.get_current_backfill_token()
push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
pushers_token = self.store.get_pushers_stream_token()
state_token = self.store.get_state_stream_token()
caches_token = self.store.get_cache_stream_token()
defer.returnValue(_ReplicationToken(
@ -143,7 +141,7 @@ class ReplicationResource(Resource):
backfill_token,
push_rules_token,
pushers_token,
state_token,
0, # State stream is no longer a thing
caches_token,
int(stream_token.to_device_key),
))
@ -193,7 +191,6 @@ class ReplicationResource(Resource):
yield self.receipts(writer, current_token, limit, request_streams)
yield self.push_rules(writer, current_token, limit, request_streams)
yield self.pushers(writer, current_token, limit, request_streams)
yield self.state(writer, current_token, limit, request_streams)
yield self.caches(writer, current_token, limit, request_streams)
yield self.to_device(writer, current_token, limit, request_streams)
self.streams(writer, current_token, request_streams)
@ -368,25 +365,6 @@ class ReplicationResource(Resource):
"position", "user_id", "app_id", "pushkey"
))
@defer.inlineCallbacks
def state(self, writer, current_token, limit, request_streams):
current_position = current_token.state
state = request_streams.get("state")
if state is not None:
state_groups, state_group_state = (
yield self.store.get_all_new_state_groups(
state, current_position, limit
)
)
writer.write_header_and_rows("state_groups", state_groups, (
"position", "room_id", "event_id"
))
writer.write_header_and_rows("state_group_state", state_group_state, (
"position", "type", "state_key", "event_id"
))
@defer.inlineCallbacks
def caches(self, writer, current_token, limit, request_streams):
current_position = current_token.caches

View file

@ -123,6 +123,7 @@ class SlavedEventStore(BaseSlavedStore):
get_state_groups_ids = DataStore.get_state_groups_ids.__func__
get_state_ids_for_event = DataStore.get_state_ids_for_event.__func__
get_state_ids_for_events = DataStore.get_state_ids_for_events.__func__
get_joined_users_from_state = DataStore.get_joined_users_from_state.__func__
get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__
_get_joined_users_from_context = (
RoomMemberStore.__dict__["_get_joined_users_from_context"]

View file

@ -23,6 +23,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.api.auth import AuthEventTypes
from synapse.events.snapshot import EventContext
from synapse.util.async import Linearizer
from collections import namedtuple
@ -43,11 +44,35 @@ SIZE_OF_CACHE = int(1000 * CACHE_SIZE_FACTOR)
EVICTION_TIMEOUT_SECONDS = 60 * 60
_NEXT_STATE_ID = 1
def _gen_state_id():
global _NEXT_STATE_ID
s = "X%d" % (_NEXT_STATE_ID,)
_NEXT_STATE_ID += 1
return s
class _StateCacheEntry(object):
def __init__(self, state, state_group, ts):
__slots__ = ["state", "state_group", "state_id"]
def __init__(self, state, state_group):
self.state = state
self.state_group = state_group
# The `state_id` is a unique ID we generate that can be used as ID for
# this collection of state. Usually this would be the same as the
# state group, but on worker instances we can't generate a new state
# group each time we resolve state, so we generate a separate one that
# isn't persisted and is used solely for caches.
# `state_id` is either a state_group (and so an int) or a string. This
# ensures we don't accidentally persist a state_id as a stateg_group
if state_group:
self.state_id = state_group
else:
self.state_id = _gen_state_id()
class StateHandler(object):
""" Responsible for doing state conflict resolution.
@ -60,6 +85,7 @@ class StateHandler(object):
# dict of set of event_ids -> _StateCacheEntry.
self._state_cache = None
self.resolve_linearizer = Linearizer()
def start_caching(self):
logger.debug("start_caching")
@ -93,7 +119,8 @@ class StateHandler(object):
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
_, state = yield self.resolve_state_groups(room_id, latest_event_ids)
ret = yield self.resolve_state_groups(room_id, latest_event_ids)
state = ret.state
if event_type:
event_id = state.get((event_type, state_key))
@ -116,7 +143,8 @@ class StateHandler(object):
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
_, state = yield self.resolve_state_groups(room_id, latest_event_ids)
ret = yield self.resolve_state_groups(room_id, latest_event_ids)
state = ret.state
if event_type:
defer.returnValue(state.get((event_type, state_key)))
@ -127,9 +155,9 @@ class StateHandler(object):
@defer.inlineCallbacks
def get_current_user_in_room(self, room_id):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
group, state_ids = yield self.resolve_state_groups(room_id, latest_event_ids)
joined_users = yield self.store.get_joined_users_from_context(
room_id, group, state_ids
entry = yield self.resolve_state_groups(room_id, latest_event_ids)
joined_users = yield self.store.get_joined_users_from_state(
room_id, entry.state_id, entry.state
)
defer.returnValue(joined_users)
@ -154,52 +182,73 @@ class StateHandler(object):
# state. Certainly store.get_current_state won't return any, and
# persisting the event won't store the state group.
if old_state:
context.current_state_ids = {
context.prev_state_ids = {
(s.type, s.state_key): s.event_id for s in old_state
}
if event.is_state():
context.current_state_events = dict(context.prev_state_ids)
key = (event.type, event.state_key)
context.current_state_events[key] = event.event_id
else:
context.current_state_events = context.prev_state_ids
else:
context.current_state_ids = {}
context.prev_state_ids = {}
context.prev_state_events = []
context.state_group = None
context.state_group = self.store.get_next_state_group()
defer.returnValue(context)
if old_state:
context.current_state_ids = {
context.prev_state_ids = {
(s.type, s.state_key): s.event_id for s in old_state
}
context.state_group = None
context.state_group = self.store.get_next_state_group()
if event.is_state():
key = (event.type, event.state_key)
if key in context.current_state_ids:
replaces = context.current_state_ids[key]
if key in context.prev_state_ids:
replaces = context.prev_state_ids[key]
if replaces != event.event_id: # Paranoia check
event.unsigned["replaces_state"] = replaces
context.current_state_ids = dict(context.prev_state_ids)
context.current_state_ids[key] = event.event_id
else:
context.current_state_ids = context.prev_state_ids
context.prev_state_events = []
defer.returnValue(context)
if event.is_state():
ret = yield self.resolve_state_groups(
entry = yield self.resolve_state_groups(
event.room_id, [e for e, _ in event.prev_events],
event_type=event.type,
state_key=event.state_key,
)
else:
ret = yield self.resolve_state_groups(
entry = yield self.resolve_state_groups(
event.room_id, [e for e, _ in event.prev_events],
)
group, curr_state = ret
curr_state = entry.state
context.current_state_ids = curr_state
context.state_group = group if not event.is_state() else None
context.prev_state_ids = curr_state
if event.is_state():
context.state_group = self.store.get_next_state_group()
else:
if entry.state_group is None:
entry.state_group = self.store.get_next_state_group()
entry.state_id = entry.state_group
context.state_group = entry.state_group
if event.is_state():
key = (event.type, event.state_key)
if key in context.current_state_ids:
replaces = context.current_state_ids[key]
if key in context.prev_state_ids:
replaces = context.prev_state_ids[key]
event.unsigned["replaces_state"] = replaces
context.current_state_ids = dict(context.prev_state_ids)
context.current_state_ids[key] = event.event_id
else:
context.current_state_ids = context.prev_state_ids
context.prev_state_events = []
defer.returnValue(context)
@ -231,70 +280,75 @@ class StateHandler(object):
if len(group_names) == 1:
name, state_list = state_groups_ids.items().pop()
defer.returnValue((name, state_list,))
defer.returnValue(_StateCacheEntry(
state=state_list,
state_group=name,
))
if self._state_cache is not None:
cache = self._state_cache.get(group_names, None)
if cache:
cache.ts = self.clock.time_msec()
with (yield self.resolve_linearizer.queue(group_names)):
if self._state_cache is not None:
cache = self._state_cache.get(group_names, None)
if cache:
defer.returnValue(cache)
defer.returnValue(
(cache.state_group, cache.state,)
logger.info(
"Resolving state for %s with %d groups", room_id, len(state_groups_ids)
)
state = {}
for st in state_groups_ids.values():
for key, e_id in st.items():
state.setdefault(key, set()).add(e_id)
conflicted_state = {
k: list(v)
for k, v in state.items()
if len(v) > 1
}
if conflicted_state:
logger.info("Resolving conflicted state for %r", room_id)
state_map = yield self.store.get_events(
[e_id for st in state_groups_ids.values() for e_id in st.values()],
get_prev_content=False
)
state_sets = [
[state_map[e_id] for key, e_id in st.items() if e_id in state_map]
for st in state_groups_ids.values()
]
new_state, _ = self._resolve_events(
state_sets, event_type, state_key
)
new_state = {
key: e.event_id for key, e in new_state.items()
}
else:
new_state = {
key: e_ids.pop() for key, e_ids in state.items()
}
logger.info(
"Resolving state for %s with %d groups", room_id, len(state_groups_ids)
)
state_group = None
new_state_event_ids = frozenset(new_state.values())
for sg, events in state_groups_ids.items():
if new_state_event_ids == frozenset(e_id for e_id in events):
state_group = sg
break
if state_group is None:
# Worker instances don't have access to this method, but we want
# to set the state_group on the main instance to increase cache
# hits.
if hasattr(self.store, "get_next_state_group"):
state_group = self.store.get_next_state_group()
state = {}
for st in state_groups_ids.values():
for key, e_id in st.items():
state.setdefault(key, set()).add(e_id)
conflicted_state = {
k: list(v)
for k, v in state.items()
if len(v) > 1
}
if conflicted_state:
logger.info("Resolving conflicted state for %r", room_id)
state_map = yield self.store.get_events(
[e_id for st in state_groups_ids.values() for e_id in st.values()],
get_prev_content=False
)
state_sets = [
[state_map[e_id] for key, e_id in st.items() if e_id in state_map]
for st in state_groups_ids.values()
]
new_state, _ = self._resolve_events(
state_sets, event_type, state_key
)
new_state = {
key: e.event_id for key, e in new_state.items()
}
else:
new_state = {
key: e_ids.pop() for key, e_ids in state.items()
}
state_group = None
new_state_event_ids = frozenset(new_state.values())
for sg, events in state_groups_ids.items():
if new_state_event_ids == frozenset(e_id for e_id in events):
state_group = sg
break
if self._state_cache is not None:
cache = _StateCacheEntry(
state=new_state,
state_group=state_group,
ts=self.clock.time_msec()
)
self._state_cache[group_names] = cache
if self._state_cache is not None:
self._state_cache[group_names] = cache
defer.returnValue((state_group, new_state,))
defer.returnValue(cache)
def resolve_events(self, state_sets, event):
logger.info(

View file

@ -115,7 +115,7 @@ class DataStore(RoomMemberStore, RoomStore,
)
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
self._state_groups_id_gen = StreamIdGenerator(db_conn, "state_groups", "id")
self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")

View file

@ -271,39 +271,28 @@ class EventsStore(SQLBaseStore):
len(events_and_contexts)
)
state_group_id_manager = self._state_groups_id_gen.get_next_mult(
len(events_and_contexts)
)
with stream_ordering_manager as stream_orderings:
with state_group_id_manager as state_group_ids:
for (event, context), stream, state_group_id in zip(
events_and_contexts, stream_orderings, state_group_ids
):
event.internal_metadata.stream_ordering = stream
# Assign a state group_id in case a new id is needed for
# this context. In theory we only need to assign this
# for contexts that have current_state and aren't outliers
# but that make the code more complicated. Assigning an ID
# per event only causes the state_group_ids to grow as fast
# as the stream_ordering so in practise shouldn't be a problem.
context.new_state_group_id = state_group_id
for (event, context), stream, in zip(
events_and_contexts, stream_orderings
):
event.internal_metadata.stream_ordering = stream
chunks = [
events_and_contexts[x:x + 100]
for x in xrange(0, len(events_and_contexts), 100)
]
chunks = [
events_and_contexts[x:x + 100]
for x in xrange(0, len(events_and_contexts), 100)
]
for chunk in chunks:
# We can't easily parallelize these since different chunks
# might contain the same event. :(
yield self.runInteraction(
"persist_events",
self._persist_events_txn,
events_and_contexts=chunk,
backfilled=backfilled,
delete_existing=delete_existing,
)
persist_event_counter.inc_by(len(chunk))
for chunk in chunks:
# We can't easily parallelize these since different chunks
# might contain the same event. :(
yield self.runInteraction(
"persist_events",
self._persist_events_txn,
events_and_contexts=chunk,
backfilled=backfilled,
delete_existing=delete_existing,
)
persist_event_counter.inc_by(len(chunk))
@_retry_on_integrity_error
@defer.inlineCallbacks
@ -312,19 +301,17 @@ class EventsStore(SQLBaseStore):
delete_existing=False):
try:
with self._stream_id_gen.get_next() as stream_ordering:
with self._state_groups_id_gen.get_next() as state_group_id:
event.internal_metadata.stream_ordering = stream_ordering
context.new_state_group_id = state_group_id
yield self.runInteraction(
"persist_event",
self._persist_event_txn,
event=event,
context=context,
current_state=current_state,
backfilled=backfilled,
delete_existing=delete_existing,
)
persist_event_counter.inc()
event.internal_metadata.stream_ordering = stream_ordering
yield self.runInteraction(
"persist_event",
self._persist_event_txn,
event=event,
context=context,
current_state=current_state,
backfilled=backfilled,
delete_existing=delete_existing,
)
persist_event_counter.inc()
except _RollbackButIsFineException:
pass
@ -528,7 +515,7 @@ class EventsStore(SQLBaseStore):
# Add an entry to the ex_outlier_stream table to replicate the
# change in outlier status to our workers.
stream_order = event.internal_metadata.stream_ordering
state_group_id = context.state_group or context.new_state_group_id
state_group_id = context.state_group
self._simple_insert_txn(
txn,
table="ex_outlier_stream",

View file

@ -145,7 +145,7 @@ class ReceiptsStore(SQLBaseStore):
defer.returnValue([ev for res in results.values() for ev in res])
@cachedInlineCallbacks(num_args=3, max_entries=5000, tree=True)
@cachedInlineCallbacks(num_args=3, tree=True)
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
"""Get receipts for a single room for sending to clients.

View file

@ -354,7 +354,8 @@ class RoomMemberStore(SQLBaseStore):
desc="who_forgot"
)
def get_joined_users_from_context(self, room_id, state_group, state_ids):
def get_joined_users_from_context(self, event, context):
state_group = context.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
@ -363,12 +364,24 @@ class RoomMemberStore(SQLBaseStore):
state_group = object()
return self._get_joined_users_from_context(
room_id, state_group, state_ids
event.room_id, state_group, context.current_state_ids, event=event,
)
def get_joined_users_from_state(self, room_id, state_group, state_ids):
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
# of None don't hit previous cached calls with a None state_group.
# To do this we set the state_group to a new object as object() != object()
state_group = object()
return self._get_joined_users_from_context(
room_id, state_group, state_ids,
)
@cachedInlineCallbacks(num_args=2, cache_context=True)
def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
cache_context):
cache_context, event=None):
# We don't use `state_group`, its there so that we can cache based
# on it. However, its important that its never None, since two current_state's
# with a state_group of None are likely to be different.
@ -393,7 +406,13 @@ class RoomMemberStore(SQLBaseStore):
desc="_get_joined_users_from_context",
)
defer.returnValue(set(row["user_id"] for row in rows))
users_in_room = set(row["user_id"] for row in rows)
if event is not None and event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
if event.event_id in member_event_ids:
users_in_room.add(event.state_key)
defer.returnValue(users_in_room)
def is_host_joined(self, room_id, host, state_group, state_ids):
if not state_group:

View file

@ -0,0 +1,32 @@
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage.engines import PostgresEngine
import logging
logger = logging.getLogger(__name__)
def run_create(cur, database_engine, *args, **kwargs):
if isinstance(database_engine, PostgresEngine):
cur.execute("TRUNCATE sent_transactions")
else:
cur.execute("DELETE FROM sent_transactions")
cur.execute("CREATE INDEX sent_transactions_ts ON sent_transactions(ts)")
def run_upgrade(cur, database_engine, *args, **kwargs):
pass

View file

@ -83,6 +83,14 @@ class StateStore(SQLBaseStore):
for group, event_id_map in group_to_ids.items()
})
def _have_persisted_state_group_txn(self, txn, state_group):
txn.execute(
"SELECT count(*) FROM state_groups WHERE id = ?",
(state_group,)
)
row = txn.fetchone()
return row and row[0]
def _store_mult_state_groups_txn(self, txn, events_and_contexts):
state_groups = {}
for event, context in events_and_contexts:
@ -92,22 +100,19 @@ class StateStore(SQLBaseStore):
if context.current_state_ids is None:
continue
if context.state_group is not None:
state_groups[event.event_id] = context.state_group
state_groups[event.event_id] = context.state_group
if self._have_persisted_state_group_txn(txn, context.state_group):
logger.info("Already persisted state_group: %r", context.state_group)
continue
state_event_ids = dict(context.current_state_ids)
if event.is_state():
state_event_ids[(event.type, event.state_key)] = event.event_id
state_group = context.new_state_group_id
self._simple_insert_txn(
txn,
table="state_groups",
values={
"id": state_group,
"id": context.state_group,
"room_id": event.room_id,
"event_id": event.event_id,
},
@ -118,7 +123,7 @@ class StateStore(SQLBaseStore):
table="state_groups_state",
values=[
{
"state_group": state_group,
"state_group": context.state_group,
"room_id": event.room_id,
"type": key[0],
"state_key": key[1],
@ -127,7 +132,6 @@ class StateStore(SQLBaseStore):
for key, state_id in state_event_ids.items()
],
)
state_groups[event.event_id] = state_group
self._simple_insert_many_txn(
txn,
@ -527,5 +531,5 @@ class StateStore(SQLBaseStore):
"get_all_new_state_groups", get_all_new_state_groups_txn
)
def get_state_stream_token(self):
return self._state_groups_id_gen.get_current_token()
def get_next_state_group(self):
return self._state_groups_id_gen.get_next()

View file

@ -387,8 +387,10 @@ class TransactionStore(SQLBaseStore):
def _cleanup_transactions(self):
now = self._clock.time_msec()
month_ago = now - 30 * 24 * 60 * 60 * 1000
six_hours_ago = now - 6 * 60 * 60 * 1000
def _cleanup_transactions_txn(txn):
txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
txn.execute("DELETE FROM sent_transactions WHERE ts < ?", (six_hours_ago,))
return self.runInteraction("_persist_in_mem_txns", _cleanup_transactions_txn)

View file

@ -115,6 +115,53 @@ class PresenceUpdateTestCase(unittest.TestCase):
),
], any_order=True)
def test_online_to_online_last_active_noop(self):
wheel_timer = Mock()
user_id = "@foo:bar"
now = 5000000
prev_state = UserPresenceState.default(user_id)
prev_state = prev_state.copy_and_replace(
state=PresenceState.ONLINE,
last_active_ts=now - LAST_ACTIVE_GRANULARITY - 10,
currently_active=True,
)
new_state = prev_state.copy_and_replace(
state=PresenceState.ONLINE,
last_active_ts=now,
)
state, persist_and_notify, federation_ping = handle_update(
prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now
)
self.assertFalse(persist_and_notify)
self.assertTrue(federation_ping)
self.assertTrue(state.currently_active)
self.assertEquals(new_state.state, state.state)
self.assertEquals(new_state.status_msg, state.status_msg)
self.assertEquals(state.last_federation_update_ts, now)
self.assertEquals(wheel_timer.insert.call_count, 3)
wheel_timer.insert.assert_has_calls([
call(
now=now,
obj=user_id,
then=new_state.last_active_ts + IDLE_TIMER
),
call(
now=now,
obj=user_id,
then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT
),
call(
now=now,
obj=user_id,
then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY
),
], any_order=True)
def test_online_to_online_last_active(self):
wheel_timer = Mock()
user_id = "@foo:bar"

View file

@ -312,7 +312,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
else:
state_ids = None
context = EventContext(current_state_ids=state_ids)
context = EventContext()
context.current_state_ids = state_ids
context.prev_state_ids = state_ids
context.push_actions = push_actions
ordering = None

View file

@ -60,8 +60,8 @@ class ReplicationResourceCase(unittest.TestCase):
self.assertEquals(body, {})
@defer.inlineCallbacks
def test_events_and_state(self):
get = self.get(events="-1", state="-1", timeout="0")
def test_events(self):
get = self.get(events="-1", timeout="0")
yield self.hs.get_handlers().room_creation_handler.create_room(
synapse.types.create_requester(self.user), {}
)
@ -70,12 +70,6 @@ class ReplicationResourceCase(unittest.TestCase):
self.assertEquals(body["events"]["field_names"], [
"position", "internal", "json", "state_group"
])
self.assertEquals(body["state_groups"]["field_names"], [
"position", "room_id", "event_id"
])
self.assertEquals(body["state_group_state"]["field_names"], [
"position", "type", "state_key", "event_id"
])
@defer.inlineCallbacks
def test_presence(self):

View file

@ -86,17 +86,8 @@ class StateGroupStore(object):
state_events = dict(context.current_state_ids)
if event.is_state():
state_events[(event.type, event.state_key)] = event.event_id
state_group = context.state_group
if not state_group:
state_group = self._next_group
self._next_group += 1
self._group_to_state[state_group] = state_events
self._event_to_state_group[event.event_id] = state_group
self._group_to_state[context.state_group] = state_events
self._event_to_state_group[event.event_id] = context.state_group
def get_events(self, event_ids, **kwargs):
return {
@ -151,6 +142,7 @@ class StateTestCase(unittest.TestCase):
"get_state_groups_ids",
"add_event_hashes",
"get_events",
"get_next_state_group",
]
)
hs = Mock(spec_set=[
@ -161,6 +153,8 @@ class StateTestCase(unittest.TestCase):
hs.get_clock.return_value = MockClock()
hs.get_auth.return_value = Auth(hs)
self.store.get_next_state_group.side_effect = Mock
self.state = StateHandler(hs)
self.event_id = 0
@ -209,7 +203,7 @@ class StateTestCase(unittest.TestCase):
store.store_state_groups(event, context)
context_store[event.event_id] = context
self.assertEqual(2, len(context_store["D"].current_state_ids))
self.assertEqual(2, len(context_store["D"].prev_state_ids))
@defer.inlineCallbacks
def test_branch_basic_conflict(self):
@ -265,7 +259,7 @@ class StateTestCase(unittest.TestCase):
self.assertSetEqual(
{"START", "A", "C"},
{e_id for e_id in context_store["D"].current_state_ids.values()}
{e_id for e_id in context_store["D"].prev_state_ids.values()}
)
@defer.inlineCallbacks
@ -331,7 +325,7 @@ class StateTestCase(unittest.TestCase):
self.assertSetEqual(
{"START", "A", "B", "C"},
{e for e in context_store["E"].current_state_ids.values()}
{e for e in context_store["E"].prev_state_ids.values()}
)
@defer.inlineCallbacks
@ -414,7 +408,7 @@ class StateTestCase(unittest.TestCase):
self.assertSetEqual(
{"A1", "A2", "A3", "A5", "B"},
{e for e in context_store["D"].current_state_ids.values()}
{e for e in context_store["D"].prev_state_ids.values()}
)
def _add_depths(self, nodes, edges):
@ -447,7 +441,7 @@ class StateTestCase(unittest.TestCase):
set(e.event_id for e in old_state), set(context.current_state_ids.values())
)
self.assertIsNone(context.state_group)
self.assertIsNotNone(context.state_group)
@defer.inlineCallbacks
def test_annotate_with_old_state(self):
@ -464,11 +458,9 @@ class StateTestCase(unittest.TestCase):
)
self.assertEqual(
set(e.event_id for e in old_state), set(context.current_state_ids.values())
set(e.event_id for e in old_state), set(context.prev_state_ids.values())
)
self.assertIsNone(context.state_group)
@defer.inlineCallbacks
def test_trivial_annotate_message(self):
event = create_event(type="test_message", name="event")
@ -514,10 +506,10 @@ class StateTestCase(unittest.TestCase):
self.assertEqual(
set([e.event_id for e in old_state]),
set(context.current_state_ids.values())
set(context.prev_state_ids.values())
)
self.assertIsNone(context.state_group)
self.assertIsNotNone(context.state_group)
@defer.inlineCallbacks
def test_resolve_message_conflict(self):
@ -550,7 +542,7 @@ class StateTestCase(unittest.TestCase):
self.assertEqual(len(context.current_state_ids), 6)
self.assertIsNone(context.state_group)
self.assertIsNotNone(context.state_group)
@defer.inlineCallbacks
def test_resolve_state_conflict(self):
@ -583,7 +575,7 @@ class StateTestCase(unittest.TestCase):
self.assertEqual(len(context.current_state_ids), 6)
self.assertIsNone(context.state_group)
self.assertIsNotNone(context.state_group)
@defer.inlineCallbacks
def test_standard_depth_conflict(self):