0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 19:53:51 +01:00

Merge branch 'release-v0.18.1' of github.com:matrix-org/synapse

This commit is contained in:
Erik Johnston 2016-10-05 14:41:07 +01:00
commit e779ee0ee2
24 changed files with 947 additions and 616 deletions

View file

@ -1,3 +1,32 @@
Changes in synapse v0.18.1 (2016-10-0)
======================================
No changes since v0.18.1-rc1
Changes in synapse v0.18.1-rc1 (2016-09-30)
===========================================
Features:
* Add total_room_count_estimate to ``/publicRooms`` (PR #1133)
Changes:
* Time out typing over federation (PR #1140)
* Restructure LDAP authentication (PR #1153)
Bug fixes:
* Fix 3pid invites when server is already in the room (PR #1136)
* Fix upgrading with SQLite taking lots of CPU for a few days
after upgrade (PR #1144)
* Fix upgrading from very old database versions (PR #1145)
* Fix port script to work with recently added tables (PR #1146)
Changes in synapse v0.18.0 (2016-09-19)
=======================================
@ -6,7 +35,7 @@ significantly reduce database size. Synapse will attempt to upgrade the current
data in the background. Servers with large SQLite database may experience
degradation of performance while this upgrade is in progress, therefore you may
want to consider migrating to using Postgres before upgrading very large SQLite
daabases
databases
Changes:

View file

@ -39,6 +39,7 @@ BOOLEAN_COLUMNS = {
"event_edges": ["is_state"],
"presence_list": ["accepted"],
"presence_stream": ["currently_active"],
"public_room_list_stream": ["visibility"],
}
@ -71,6 +72,14 @@ APPEND_ONLY_TABLES = [
"event_to_state_groups",
"rejections",
"event_search",
"presence_stream",
"push_rules_stream",
"current_state_resets",
"ex_outlier_stream",
"cache_invalidation_stream",
"public_room_list_stream",
"state_group_edges",
"stream_ordering_to_exterm",
]

View file

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server.
"""
__version__ = "0.18.0"
__version__ = "0.18.1"

View file

@ -72,7 +72,7 @@ class Auth(object):
auth_events = {
(e.type, e.state_key): e for e in auth_events.values()
}
self.check(event, auth_events=auth_events, do_sig_check=False)
self.check(event, auth_events=auth_events, do_sig_check=do_sig_check)
def check(self, event, auth_events, do_sig_check=True):
""" Checks if this event is correctly authed.
@ -91,11 +91,28 @@ class Auth(object):
if not hasattr(event, "room_id"):
raise AuthError(500, "Event has no room_id: %s" % event)
sender_domain = get_domain_from_id(event.sender)
if do_sig_check:
sender_domain = get_domain_from_id(event.sender)
event_id_domain = get_domain_from_id(event.event_id)
# Check the sender's domain has signed the event
if do_sig_check and not event.signatures.get(sender_domain):
raise AuthError(403, "Event not signed by sending server")
is_invite_via_3pid = (
event.type == EventTypes.Member
and event.membership == Membership.INVITE
and "third_party_invite" in event.content
)
# Check the sender's domain has signed the event
if not event.signatures.get(sender_domain):
# We allow invites via 3pid to have a sender from a different
# HS, as the sender must match the sender of the original
# 3pid invite. This is checked further down with the
# other dedicated membership checks.
if not is_invite_via_3pid:
raise AuthError(403, "Event not signed by sender's server")
# Check the event_id's domain has signed the event
if not event.signatures.get(event_id_domain):
raise AuthError(403, "Event not signed by sending server")
if auth_events is None:
# Oh, we don't know what the state of the room was, so we
@ -491,6 +508,9 @@ class Auth(object):
if not invite_event:
return False
if invite_event.sender != event.sender:
return False
if event.user_id != invite_event.user_id:
return False

View file

@ -27,6 +27,8 @@ from synapse.http.server import JsonResource
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.rest.client.v2_alpha import sync
from synapse.rest.client.v1 import events
from synapse.rest.client.v1.room import RoomInitialSyncRestServlet
from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
@ -37,6 +39,7 @@ from synapse.replication.slave.storage.filtering import SlavedFilteringStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.presence import SlavedPresenceStore
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore
from synapse.storage.engines import create_engine
@ -74,6 +77,7 @@ class SynchrotronSlavedStore(
SlavedFilteringStore,
SlavedPresenceStore,
SlavedDeviceInboxStore,
RoomStore,
BaseSlavedStore,
ClientIpStore, # After BaseSlavedStore because the constructor is different
):
@ -296,6 +300,8 @@ class SynchrotronServer(HomeServer):
resource = JsonResource(self, canonical_json=False)
sync.register_servlets(self, resource)
events.register_servlets(self, resource)
InitialSyncRestServlet(self).register(resource)
RoomInitialSyncRestServlet(self).register(resource)
resources.update({
"/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource,

View file

@ -136,9 +136,7 @@ class FederationClient(FederationBase):
sent_edus_counter.inc()
# TODO, add errback, etc.
self._transaction_queue.enqueue_edu(edu, key=key)
return defer.succeed(None)
@log_function
def send_device_messages(self, destination):

View file

@ -31,6 +31,7 @@ import simplejson
try:
import ldap3
import ldap3.core.exceptions
except ImportError:
ldap3 = None
pass
@ -504,6 +505,144 @@ class AuthHandler(BaseHandler):
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
defer.returnValue(user_id)
def _ldap_simple_bind(self, server, localpart, password):
""" Attempt a simple bind with the credentials
given by the user against the LDAP server.
Returns True, LDAP3Connection
if the bind was successful
Returns False, None
if an error occured
"""
try:
# bind with the the local users ldap credentials
bind_dn = "{prop}={value},{base}".format(
prop=self.ldap_attributes['uid'],
value=localpart,
base=self.ldap_base
)
conn = ldap3.Connection(server, bind_dn, password)
logger.debug(
"Established LDAP connection in simple bind mode: %s",
conn
)
if self.ldap_start_tls:
conn.start_tls()
logger.debug(
"Upgraded LDAP connection in simple bind mode through StartTLS: %s",
conn
)
if conn.bind():
# GOOD: bind okay
logger.debug("LDAP Bind successful in simple bind mode.")
return True, conn
# BAD: bind failed
logger.info(
"Binding against LDAP failed for '%s' failed: %s",
localpart, conn.result['description']
)
conn.unbind()
return False, None
except ldap3.core.exceptions.LDAPException as e:
logger.warn("Error during LDAP authentication: %s", e)
return False, None
def _ldap_authenticated_search(self, server, localpart, password):
""" Attempt to login with the preconfigured bind_dn
and then continue searching and filtering within
the base_dn
Returns (True, LDAP3Connection)
if a single matching DN within the base was found
that matched the filter expression, and with which
a successful bind was achieved
The LDAP3Connection returned is the instance that was used to
verify the password not the one using the configured bind_dn.
Returns (False, None)
if an error occured
"""
try:
conn = ldap3.Connection(
server,
self.ldap_bind_dn,
self.ldap_bind_password
)
logger.debug(
"Established LDAP connection in search mode: %s",
conn
)
if self.ldap_start_tls:
conn.start_tls()
logger.debug(
"Upgraded LDAP connection in search mode through StartTLS: %s",
conn
)
if not conn.bind():
logger.warn(
"Binding against LDAP with `bind_dn` failed: %s",
conn.result['description']
)
conn.unbind()
return False, None
# construct search_filter like (uid=localpart)
query = "({prop}={value})".format(
prop=self.ldap_attributes['uid'],
value=localpart
)
if self.ldap_filter:
# combine with the AND expression
query = "(&{query}{filter})".format(
query=query,
filter=self.ldap_filter
)
logger.debug(
"LDAP search filter: %s",
query
)
conn.search(
search_base=self.ldap_base,
search_filter=query
)
if len(conn.response) == 1:
# GOOD: found exactly one result
user_dn = conn.response[0]['dn']
logger.debug('LDAP search found dn: %s', user_dn)
# unbind and simple bind with user_dn to verify the password
# Note: do not use rebind(), for some reason it did not verify
# the password for me!
conn.unbind()
return self._ldap_simple_bind(server, localpart, password)
else:
# BAD: found 0 or > 1 results, abort!
if len(conn.response) == 0:
logger.info(
"LDAP search returned no results for '%s'",
localpart
)
else:
logger.info(
"LDAP search returned too many (%s) results for '%s'",
len(conn.response), localpart
)
conn.unbind()
return False, None
except ldap3.core.exceptions.LDAPException as e:
logger.warn("Error during LDAP authentication: %s", e)
return False, None
@defer.inlineCallbacks
def _check_ldap_password(self, user_id, password):
""" Attempt to authenticate a user against an LDAP Server
@ -516,106 +655,62 @@ class AuthHandler(BaseHandler):
if not ldap3 or not self.ldap_enabled:
defer.returnValue(False)
if self.ldap_mode not in LDAPMode.LIST:
raise RuntimeError(
'Invalid ldap mode specified: {mode}'.format(
mode=self.ldap_mode
)
)
localpart = UserID.from_string(user_id).localpart
try:
server = ldap3.Server(self.ldap_uri)
logger.debug(
"Attempting ldap connection with %s",
"Attempting LDAP connection with %s",
self.ldap_uri
)
localpart = UserID.from_string(user_id).localpart
if self.ldap_mode == LDAPMode.SIMPLE:
# bind with the the local users ldap credentials
bind_dn = "{prop}={value},{base}".format(
prop=self.ldap_attributes['uid'],
value=localpart,
base=self.ldap_base
)
conn = ldap3.Connection(server, bind_dn, password)
logger.debug(
"Established ldap connection in simple mode: %s",
conn
)
if self.ldap_start_tls:
conn.start_tls()
logger.debug(
"Upgraded ldap connection in simple mode through StartTLS: %s",
conn
)
conn.bind()
elif self.ldap_mode == LDAPMode.SEARCH:
# connect with preconfigured credentials and search for local user
conn = ldap3.Connection(
server,
self.ldap_bind_dn,
self.ldap_bind_password
result, conn = self._ldap_simple_bind(
server=server, localpart=localpart, password=password
)
logger.debug(
"Established ldap connection in search mode: %s",
'LDAP authentication method simple bind returned: %s (conn: %s)',
result,
conn
)
if self.ldap_start_tls:
conn.start_tls()
logger.debug(
"Upgraded ldap connection in search mode through StartTLS: %s",
conn
)
conn.bind()
# find matching dn
query = "({prop}={value})".format(
prop=self.ldap_attributes['uid'],
value=localpart
)
if self.ldap_filter:
query = "(&{query}{filter})".format(
query=query,
filter=self.ldap_filter
)
logger.debug("ldap search filter: %s", query)
result = conn.search(self.ldap_base, query)
if result and len(conn.response) == 1:
# found exactly one result
user_dn = conn.response[0]['dn']
logger.debug('ldap search found dn: %s', user_dn)
# unbind and reconnect, rebind with found dn
conn.unbind()
conn = ldap3.Connection(
server,
user_dn,
password,
auto_bind=True
)
else:
# found 0 or > 1 results, abort!
logger.warn(
"ldap search returned unexpected (%d!=1) amount of results",
len(conn.response)
)
if not result:
defer.returnValue(False)
elif self.ldap_mode == LDAPMode.SEARCH:
result, conn = self._ldap_authenticated_search(
server=server, localpart=localpart, password=password
)
logger.debug(
'LDAP auth method authenticated search returned: %s (conn: %s)',
result,
conn
)
if not result:
defer.returnValue(False)
else:
raise RuntimeError(
'Invalid LDAP mode specified: {mode}'.format(
mode=self.ldap_mode
)
)
logger.info(
"User authenticated against ldap server: %s",
conn
)
try:
logger.info(
"User authenticated against LDAP server: %s",
conn
)
except NameError:
logger.warn("Authentication method yielded no LDAP connection, aborting!")
defer.returnValue(False)
# check for existing account, if none exists, create one
if not (yield self.check_user_exists(user_id)):
# query user metadata for account creation
# check if user with user_id exists
if (yield self.check_user_exists(user_id)):
# exists, authentication complete
conn.unbind()
defer.returnValue(True)
else:
# does not exist, fetch metadata for account creation from
# existing ldap connection
query = "({prop}={value})".format(
prop=self.ldap_attributes['uid'],
value=localpart
@ -626,9 +721,12 @@ class AuthHandler(BaseHandler):
filter=query,
user_filter=self.ldap_filter
)
logger.debug("ldap registration filter: %s", query)
logger.debug(
"ldap registration filter: %s",
query
)
result = conn.search(
conn.search(
search_base=self.ldap_base,
search_filter=query,
attributes=[
@ -651,20 +749,27 @@ class AuthHandler(BaseHandler):
# TODO: bind email, set displayname with data from ldap directory
logger.info(
"ldap registration successful: %d: %s (%s, %)",
"Registration based on LDAP data was successful: %d: %s (%s, %)",
user_id,
localpart,
name,
mail
)
defer.returnValue(True)
else:
logger.warn(
"ldap registration failed: unexpected (%d!=1) amount of results",
len(conn.response)
)
if len(conn.response) == 0:
logger.warn("LDAP registration failed, no result.")
else:
logger.warn(
"LDAP registration failed, too many results (%s)",
len(conn.response)
)
defer.returnValue(False)
defer.returnValue(True)
defer.returnValue(False)
except ldap3.core.exceptions.LDAPException as e:
logger.warn("Error during ldap authentication: %s", e)
defer.returnValue(False)

View file

@ -1922,15 +1922,18 @@ class FederationHandler(BaseHandler):
original_invite = yield self.store.get_event(
original_invite_id, allow_none=True
)
if not original_invite:
if original_invite:
display_name = original_invite.content["display_name"]
event_dict["content"]["third_party_invite"]["display_name"] = display_name
else:
logger.info(
"Could not find invite event for third_party_invite - "
"discarding: %s" % (event_dict,)
"Could not find invite event for third_party_invite: %r",
event_dict
)
return
# We don't discard here as this is not the appropriate place to do
# auth checks. If we need the invite and don't have it then the
# auth check code will explode appropriately.
display_name = original_invite.content["display_name"]
event_dict["content"]["third_party_invite"]["display_name"] = display_name
builder = self.event_builder_factory.new(event_dict)
EventValidator().validate_new(builder)
message_handler = self.hs.get_handlers().message_handler

View file

@ -0,0 +1,443 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
from synapse.streams.config import PaginationConfig
from synapse.types import (
UserID, StreamToken,
)
from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute
from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
import logging
logger = logging.getLogger(__name__)
class InitialSyncHandler(BaseHandler):
def __init__(self, hs):
super(InitialSyncHandler, self).__init__(hs)
self.hs = hs
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
self.validator = EventValidator()
self.snapshot_cache = SnapshotCache()
def snapshot_all_rooms(self, user_id=None, pagin_config=None,
as_client_event=True, include_archived=False):
"""Retrieve a snapshot of all rooms the user is invited or has joined.
This snapshot may include messages for all rooms where the user is
joined, depending on the pagination config.
Args:
user_id (str): The ID of the user making the request.
pagin_config (synapse.api.streams.PaginationConfig): The pagination
config used to determine how many messages *PER ROOM* to return.
as_client_event (bool): True to get events in client-server format.
include_archived (bool): True to get rooms that the user has left
Returns:
A list of dicts with "room_id" and "membership" keys for all rooms
the user is currently invited or joined in on. Rooms where the user
is joined on, may return a "messages" key with messages, depending
on the specified PaginationConfig.
"""
key = (
user_id,
pagin_config.from_token,
pagin_config.to_token,
pagin_config.direction,
pagin_config.limit,
as_client_event,
include_archived,
)
now_ms = self.clock.time_msec()
result = self.snapshot_cache.get(now_ms, key)
if result is not None:
return result
return self.snapshot_cache.set(now_ms, key, self._snapshot_all_rooms(
user_id, pagin_config, as_client_event, include_archived
))
@defer.inlineCallbacks
def _snapshot_all_rooms(self, user_id=None, pagin_config=None,
as_client_event=True, include_archived=False):
memberships = [Membership.INVITE, Membership.JOIN]
if include_archived:
memberships.append(Membership.LEAVE)
room_list = yield self.store.get_rooms_for_user_where_membership_is(
user_id=user_id, membership_list=memberships
)
user = UserID.from_string(user_id)
rooms_ret = []
now_token = yield self.hs.get_event_sources().get_current_token()
presence_stream = self.hs.get_event_sources().sources["presence"]
pagination_config = PaginationConfig(from_token=now_token)
presence, _ = yield presence_stream.get_pagination_rows(
user, pagination_config.get_source_config("presence"), None
)
receipt_stream = self.hs.get_event_sources().sources["receipt"]
receipt, _ = yield receipt_stream.get_pagination_rows(
user, pagination_config.get_source_config("receipt"), None
)
tags_by_room = yield self.store.get_tags_for_user(user_id)
account_data, account_data_by_room = (
yield self.store.get_account_data_for_user(user_id)
)
public_room_ids = yield self.store.get_public_room_ids()
limit = pagin_config.limit
if limit is None:
limit = 10
@defer.inlineCallbacks
def handle_room(event):
d = {
"room_id": event.room_id,
"membership": event.membership,
"visibility": (
"public" if event.room_id in public_room_ids
else "private"
),
}
if event.membership == Membership.INVITE:
time_now = self.clock.time_msec()
d["inviter"] = event.sender
invite_event = yield self.store.get_event(event.event_id)
d["invite"] = serialize_event(invite_event, time_now, as_client_event)
rooms_ret.append(d)
if event.membership not in (Membership.JOIN, Membership.LEAVE):
return
try:
if event.membership == Membership.JOIN:
room_end_token = now_token.room_key
deferred_room_state = self.state_handler.get_current_state(
event.room_id
)
elif event.membership == Membership.LEAVE:
room_end_token = "s%d" % (event.stream_ordering,)
deferred_room_state = self.store.get_state_for_events(
[event.event_id], None
)
deferred_room_state.addCallback(
lambda states: states[event.event_id]
)
(messages, token), current_state = yield preserve_context_over_deferred(
defer.gatherResults(
[
preserve_fn(self.store.get_recent_events_for_room)(
event.room_id,
limit=limit,
end_token=room_end_token,
),
deferred_room_state,
]
)
).addErrback(unwrapFirstError)
messages = yield filter_events_for_client(
self.store, user_id, messages
)
start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec()
d["messages"] = {
"chunk": [
serialize_event(m, time_now, as_client_event)
for m in messages
],
"start": start_token.to_string(),
"end": end_token.to_string(),
}
d["state"] = [
serialize_event(c, time_now, as_client_event)
for c in current_state.values()
]
account_data_events = []
tags = tags_by_room.get(event.room_id)
if tags:
account_data_events.append({
"type": "m.tag",
"content": {"tags": tags},
})
account_data = account_data_by_room.get(event.room_id, {})
for account_data_type, content in account_data.items():
account_data_events.append({
"type": account_data_type,
"content": content,
})
d["account_data"] = account_data_events
except:
logger.exception("Failed to get snapshot")
yield concurrently_execute(handle_room, room_list, 10)
account_data_events = []
for account_data_type, content in account_data.items():
account_data_events.append({
"type": account_data_type,
"content": content,
})
ret = {
"rooms": rooms_ret,
"presence": presence,
"account_data": account_data_events,
"receipts": receipt,
"end": now_token.to_string(),
}
defer.returnValue(ret)
@defer.inlineCallbacks
def room_initial_sync(self, requester, room_id, pagin_config=None):
"""Capture the a snapshot of a room. If user is currently a member of
the room this will be what is currently in the room. If the user left
the room this will be what was in the room when they left.
Args:
requester(Requester): The user to get a snapshot for.
room_id(str): The room to get a snapshot of.
pagin_config(synapse.streams.config.PaginationConfig):
The pagination config used to determine how many messages to
return.
Raises:
AuthError if the user wasn't in the room.
Returns:
A JSON serialisable dict with the snapshot of the room.
"""
user_id = requester.user.to_string()
membership, member_event_id = yield self._check_in_room_or_world_readable(
room_id, user_id,
)
is_peeking = member_event_id is None
if membership == Membership.JOIN:
result = yield self._room_initial_sync_joined(
user_id, room_id, pagin_config, membership, is_peeking
)
elif membership == Membership.LEAVE:
result = yield self._room_initial_sync_parted(
user_id, room_id, pagin_config, membership, member_event_id, is_peeking
)
account_data_events = []
tags = yield self.store.get_tags_for_room(user_id, room_id)
if tags:
account_data_events.append({
"type": "m.tag",
"content": {"tags": tags},
})
account_data = yield self.store.get_account_data_for_room(user_id, room_id)
for account_data_type, content in account_data.items():
account_data_events.append({
"type": account_data_type,
"content": content,
})
result["account_data"] = account_data_events
defer.returnValue(result)
@defer.inlineCallbacks
def _room_initial_sync_parted(self, user_id, room_id, pagin_config,
membership, member_event_id, is_peeking):
room_state = yield self.store.get_state_for_events(
[member_event_id], None
)
room_state = room_state[member_event_id]
limit = pagin_config.limit if pagin_config else None
if limit is None:
limit = 10
stream_token = yield self.store.get_stream_token_for_event(
member_event_id
)
messages, token = yield self.store.get_recent_events_for_room(
room_id,
limit=limit,
end_token=stream_token
)
messages = yield filter_events_for_client(
self.store, user_id, messages, is_peeking=is_peeking
)
start_token = StreamToken.START.copy_and_replace("room_key", token[0])
end_token = StreamToken.START.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec()
defer.returnValue({
"membership": membership,
"room_id": room_id,
"messages": {
"chunk": [serialize_event(m, time_now) for m in messages],
"start": start_token.to_string(),
"end": end_token.to_string(),
},
"state": [serialize_event(s, time_now) for s in room_state.values()],
"presence": [],
"receipts": [],
})
@defer.inlineCallbacks
def _room_initial_sync_joined(self, user_id, room_id, pagin_config,
membership, is_peeking):
current_state = yield self.state.get_current_state(
room_id=room_id,
)
# TODO: These concurrently
time_now = self.clock.time_msec()
state = [
serialize_event(x, time_now)
for x in current_state.values()
]
now_token = yield self.hs.get_event_sources().get_current_token()
limit = pagin_config.limit if pagin_config else None
if limit is None:
limit = 10
room_members = [
m for m in current_state.values()
if m.type == EventTypes.Member
and m.content["membership"] == Membership.JOIN
]
presence_handler = self.hs.get_presence_handler()
@defer.inlineCallbacks
def get_presence():
states = yield presence_handler.get_states(
[m.user_id for m in room_members],
as_event=True,
)
defer.returnValue(states)
@defer.inlineCallbacks
def get_receipts():
receipts_handler = self.hs.get_handlers().receipts_handler
receipts = yield receipts_handler.get_receipts_for_room(
room_id,
now_token.receipt_key
)
defer.returnValue(receipts)
presence, receipts, (messages, token) = yield defer.gatherResults(
[
preserve_fn(get_presence)(),
preserve_fn(get_receipts)(),
preserve_fn(self.store.get_recent_events_for_room)(
room_id,
limit=limit,
end_token=now_token.room_key,
)
],
consumeErrors=True,
).addErrback(unwrapFirstError)
messages = yield filter_events_for_client(
self.store, user_id, messages, is_peeking=is_peeking,
)
start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec()
ret = {
"room_id": room_id,
"messages": {
"chunk": [serialize_event(m, time_now) for m in messages],
"start": start_token.to_string(),
"end": end_token.to_string(),
},
"state": state,
"presence": presence,
"receipts": receipts,
}
if not is_peeking:
ret["membership"] = membership
defer.returnValue(ret)
@defer.inlineCallbacks
def _check_in_room_or_world_readable(self, room_id, user_id):
try:
# check_user_was_in_room will return the most recent membership
# event for the user if:
# * The user is a non-guest user, and was ever in the room
# * The user is a guest user, and has joined the room
# else it will throw.
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
defer.returnValue((member_event.membership, member_event.event_id))
return
except AuthError:
visibility = yield self.state_handler.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, ""
)
if (
visibility and
visibility.content["history_visibility"] == "world_readable"
):
defer.returnValue((Membership.JOIN, None))
return
raise AuthError(
403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
)

View file

@ -21,14 +21,11 @@ from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
from synapse.push.action_generator import ActionGenerator
from synapse.streams.config import PaginationConfig
from synapse.types import (
UserID, RoomAlias, RoomStreamToken, StreamToken, get_domain_from_id
UserID, RoomAlias, RoomStreamToken, get_domain_from_id
)
from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLock
from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.util.async import run_on_reactor, ReadWriteLock
from synapse.util.logcontext import preserve_fn
from synapse.util.metrics import measure_func
from synapse.visibility import filter_events_for_client
@ -49,7 +46,6 @@ class MessageHandler(BaseHandler):
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
self.validator = EventValidator()
self.snapshot_cache = SnapshotCache()
self.pagination_lock = ReadWriteLock()
@ -392,377 +388,6 @@ class MessageHandler(BaseHandler):
[serialize_event(c, now) for c in room_state.values()]
)
def snapshot_all_rooms(self, user_id=None, pagin_config=None,
as_client_event=True, include_archived=False):
"""Retrieve a snapshot of all rooms the user is invited or has joined.
This snapshot may include messages for all rooms where the user is
joined, depending on the pagination config.
Args:
user_id (str): The ID of the user making the request.
pagin_config (synapse.api.streams.PaginationConfig): The pagination
config used to determine how many messages *PER ROOM* to return.
as_client_event (bool): True to get events in client-server format.
include_archived (bool): True to get rooms that the user has left
Returns:
A list of dicts with "room_id" and "membership" keys for all rooms
the user is currently invited or joined in on. Rooms where the user
is joined on, may return a "messages" key with messages, depending
on the specified PaginationConfig.
"""
key = (
user_id,
pagin_config.from_token,
pagin_config.to_token,
pagin_config.direction,
pagin_config.limit,
as_client_event,
include_archived,
)
now_ms = self.clock.time_msec()
result = self.snapshot_cache.get(now_ms, key)
if result is not None:
return result
return self.snapshot_cache.set(now_ms, key, self._snapshot_all_rooms(
user_id, pagin_config, as_client_event, include_archived
))
@defer.inlineCallbacks
def _snapshot_all_rooms(self, user_id=None, pagin_config=None,
as_client_event=True, include_archived=False):
memberships = [Membership.INVITE, Membership.JOIN]
if include_archived:
memberships.append(Membership.LEAVE)
room_list = yield self.store.get_rooms_for_user_where_membership_is(
user_id=user_id, membership_list=memberships
)
user = UserID.from_string(user_id)
rooms_ret = []
now_token = yield self.hs.get_event_sources().get_current_token()
presence_stream = self.hs.get_event_sources().sources["presence"]
pagination_config = PaginationConfig(from_token=now_token)
presence, _ = yield presence_stream.get_pagination_rows(
user, pagination_config.get_source_config("presence"), None
)
receipt_stream = self.hs.get_event_sources().sources["receipt"]
receipt, _ = yield receipt_stream.get_pagination_rows(
user, pagination_config.get_source_config("receipt"), None
)
tags_by_room = yield self.store.get_tags_for_user(user_id)
account_data, account_data_by_room = (
yield self.store.get_account_data_for_user(user_id)
)
public_room_ids = yield self.store.get_public_room_ids()
limit = pagin_config.limit
if limit is None:
limit = 10
@defer.inlineCallbacks
def handle_room(event):
d = {
"room_id": event.room_id,
"membership": event.membership,
"visibility": (
"public" if event.room_id in public_room_ids
else "private"
),
}
if event.membership == Membership.INVITE:
time_now = self.clock.time_msec()
d["inviter"] = event.sender
invite_event = yield self.store.get_event(event.event_id)
d["invite"] = serialize_event(invite_event, time_now, as_client_event)
rooms_ret.append(d)
if event.membership not in (Membership.JOIN, Membership.LEAVE):
return
try:
if event.membership == Membership.JOIN:
room_end_token = now_token.room_key
deferred_room_state = self.state_handler.get_current_state(
event.room_id
)
elif event.membership == Membership.LEAVE:
room_end_token = "s%d" % (event.stream_ordering,)
deferred_room_state = self.store.get_state_for_events(
[event.event_id], None
)
deferred_room_state.addCallback(
lambda states: states[event.event_id]
)
(messages, token), current_state = yield preserve_context_over_deferred(
defer.gatherResults(
[
preserve_fn(self.store.get_recent_events_for_room)(
event.room_id,
limit=limit,
end_token=room_end_token,
),
deferred_room_state,
]
)
).addErrback(unwrapFirstError)
messages = yield filter_events_for_client(
self.store, user_id, messages
)
start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec()
d["messages"] = {
"chunk": [
serialize_event(m, time_now, as_client_event)
for m in messages
],
"start": start_token.to_string(),
"end": end_token.to_string(),
}
d["state"] = [
serialize_event(c, time_now, as_client_event)
for c in current_state.values()
]
account_data_events = []
tags = tags_by_room.get(event.room_id)
if tags:
account_data_events.append({
"type": "m.tag",
"content": {"tags": tags},
})
account_data = account_data_by_room.get(event.room_id, {})
for account_data_type, content in account_data.items():
account_data_events.append({
"type": account_data_type,
"content": content,
})
d["account_data"] = account_data_events
except:
logger.exception("Failed to get snapshot")
yield concurrently_execute(handle_room, room_list, 10)
account_data_events = []
for account_data_type, content in account_data.items():
account_data_events.append({
"type": account_data_type,
"content": content,
})
ret = {
"rooms": rooms_ret,
"presence": presence,
"account_data": account_data_events,
"receipts": receipt,
"end": now_token.to_string(),
}
defer.returnValue(ret)
@defer.inlineCallbacks
def room_initial_sync(self, requester, room_id, pagin_config=None):
"""Capture the a snapshot of a room. If user is currently a member of
the room this will be what is currently in the room. If the user left
the room this will be what was in the room when they left.
Args:
requester(Requester): The user to get a snapshot for.
room_id(str): The room to get a snapshot of.
pagin_config(synapse.streams.config.PaginationConfig):
The pagination config used to determine how many messages to
return.
Raises:
AuthError if the user wasn't in the room.
Returns:
A JSON serialisable dict with the snapshot of the room.
"""
user_id = requester.user.to_string()
membership, member_event_id = yield self._check_in_room_or_world_readable(
room_id, user_id,
)
is_peeking = member_event_id is None
if membership == Membership.JOIN:
result = yield self._room_initial_sync_joined(
user_id, room_id, pagin_config, membership, is_peeking
)
elif membership == Membership.LEAVE:
result = yield self._room_initial_sync_parted(
user_id, room_id, pagin_config, membership, member_event_id, is_peeking
)
account_data_events = []
tags = yield self.store.get_tags_for_room(user_id, room_id)
if tags:
account_data_events.append({
"type": "m.tag",
"content": {"tags": tags},
})
account_data = yield self.store.get_account_data_for_room(user_id, room_id)
for account_data_type, content in account_data.items():
account_data_events.append({
"type": account_data_type,
"content": content,
})
result["account_data"] = account_data_events
defer.returnValue(result)
@defer.inlineCallbacks
def _room_initial_sync_parted(self, user_id, room_id, pagin_config,
membership, member_event_id, is_peeking):
room_state = yield self.store.get_state_for_events(
[member_event_id], None
)
room_state = room_state[member_event_id]
limit = pagin_config.limit if pagin_config else None
if limit is None:
limit = 10
stream_token = yield self.store.get_stream_token_for_event(
member_event_id
)
messages, token = yield self.store.get_recent_events_for_room(
room_id,
limit=limit,
end_token=stream_token
)
messages = yield filter_events_for_client(
self.store, user_id, messages, is_peeking=is_peeking
)
start_token = StreamToken.START.copy_and_replace("room_key", token[0])
end_token = StreamToken.START.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec()
defer.returnValue({
"membership": membership,
"room_id": room_id,
"messages": {
"chunk": [serialize_event(m, time_now) for m in messages],
"start": start_token.to_string(),
"end": end_token.to_string(),
},
"state": [serialize_event(s, time_now) for s in room_state.values()],
"presence": [],
"receipts": [],
})
@defer.inlineCallbacks
def _room_initial_sync_joined(self, user_id, room_id, pagin_config,
membership, is_peeking):
current_state = yield self.state.get_current_state(
room_id=room_id,
)
# TODO: These concurrently
time_now = self.clock.time_msec()
state = [
serialize_event(x, time_now)
for x in current_state.values()
]
now_token = yield self.hs.get_event_sources().get_current_token()
limit = pagin_config.limit if pagin_config else None
if limit is None:
limit = 10
room_members = [
m for m in current_state.values()
if m.type == EventTypes.Member
and m.content["membership"] == Membership.JOIN
]
presence_handler = self.hs.get_presence_handler()
@defer.inlineCallbacks
def get_presence():
states = yield presence_handler.get_states(
[m.user_id for m in room_members],
as_event=True,
)
defer.returnValue(states)
@defer.inlineCallbacks
def get_receipts():
receipts_handler = self.hs.get_handlers().receipts_handler
receipts = yield receipts_handler.get_receipts_for_room(
room_id,
now_token.receipt_key
)
defer.returnValue(receipts)
presence, receipts, (messages, token) = yield defer.gatherResults(
[
preserve_fn(get_presence)(),
preserve_fn(get_receipts)(),
preserve_fn(self.store.get_recent_events_for_room)(
room_id,
limit=limit,
end_token=now_token.room_key,
)
],
consumeErrors=True,
).addErrback(unwrapFirstError)
messages = yield filter_events_for_client(
self.store, user_id, messages, is_peeking=is_peeking,
)
start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec()
ret = {
"room_id": room_id,
"messages": {
"chunk": [serialize_event(m, time_now) for m in messages],
"start": start_token.to_string(),
"end": end_token.to_string(),
},
"state": state,
"presence": presence,
"receipts": receipts,
}
if not is_peeking:
ret["membership"] = membership
defer.returnValue(ret)
@measure_func("_create_new_client_event")
@defer.inlineCallbacks
def _create_new_client_event(self, builder, prev_event_ids=None):

View file

@ -125,6 +125,8 @@ class RoomListHandler(BaseHandler):
if r not in newly_unpublished and rooms_to_num_joined[room_id] > 0
]
total_room_count = len(rooms_to_scan)
if since_token:
# Filter out rooms we've already returned previously
# `since_token.current_limit` is the index of the last room we
@ -188,6 +190,7 @@ class RoomListHandler(BaseHandler):
results = {
"chunk": chunk,
"total_room_count_estimate": total_room_count,
}
if since_token:

View file

@ -16,10 +16,9 @@
from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError
from synapse.util.logcontext import (
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
)
from synapse.util.logcontext import preserve_fn
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
from synapse.types import UserID, get_domain_from_id
import logging
@ -35,6 +34,13 @@ logger = logging.getLogger(__name__)
RoomMember = namedtuple("RoomMember", ("room_id", "user_id"))
# How often we expect remote servers to resend us presence.
FEDERATION_TIMEOUT = 60 * 1000
# How often to resend typing across federation.
FEDERATION_PING_INTERVAL = 40 * 1000
class TypingHandler(object):
def __init__(self, hs):
self.store = hs.get_datastore()
@ -44,7 +50,10 @@ class TypingHandler(object):
self.notifier = hs.get_notifier()
self.state = hs.get_state_handler()
self.hs = hs
self.clock = hs.get_clock()
self.wheel_timer = WheelTimer(bucket_size=5000)
self.federation = hs.get_replication_layer()
@ -53,7 +62,7 @@ class TypingHandler(object):
hs.get_distributor().observe("user_left_room", self.user_left_room)
self._member_typing_until = {} # clock time we expect to stop
self._member_typing_timer = {} # deferreds to manage theabove
self._member_last_federation_poke = {}
# map room IDs to serial numbers
self._room_serials = {}
@ -61,12 +70,41 @@ class TypingHandler(object):
# map room IDs to sets of users currently typing
self._room_typing = {}
def tearDown(self):
"""Cancels all the pending timers.
Normally this shouldn't be needed, but it's required from unit tests
to avoid a "Reactor was unclean" warning."""
for t in self._member_typing_timer.values():
self.clock.cancel_call_later(t)
self.clock.looping_call(
self._handle_timeouts,
5000,
)
def _handle_timeouts(self):
logger.info("Checking for typing timeouts")
now = self.clock.time_msec()
members = set(self.wheel_timer.fetch(now))
for member in members:
if not self.is_typing(member):
# Nothing to do if they're no longer typing
continue
until = self._member_typing_until.get(member, None)
if not until or until < now:
logger.info("Timing out typing for: %s", member.user_id)
preserve_fn(self._stopped_typing)(member)
continue
# Check if we need to resend a keep alive over federation for this
# user.
if self.hs.is_mine_id(member.user_id):
last_fed_poke = self._member_last_federation_poke.get(member, None)
if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL < now:
preserve_fn(self._push_remote)(
member=member,
typing=True
)
def is_typing(self, member):
return member.user_id in self._room_typing.get(member.room_id, [])
@defer.inlineCallbacks
def started_typing(self, target_user, auth_user, room_id, timeout):
@ -85,23 +123,17 @@ class TypingHandler(object):
"%s has started typing in %s", target_user_id, room_id
)
until = self.clock.time_msec() + timeout
member = RoomMember(room_id=room_id, user_id=target_user_id)
was_present = member in self._member_typing_until
was_present = member.user_id in self._room_typing.get(room_id, set())
if member in self._member_typing_timer:
self.clock.cancel_call_later(self._member_typing_timer[member])
now = self.clock.time_msec()
self._member_typing_until[member] = now + timeout
def _cb():
logger.debug(
"%s has timed out in %s", target_user.to_string(), room_id
)
self._stopped_typing(member)
self._member_typing_until[member] = until
self._member_typing_timer[member] = self.clock.call_later(
timeout / 1000.0, _cb
self.wheel_timer.insert(
now=now,
obj=member,
then=now + timeout,
)
if was_present:
@ -109,8 +141,7 @@ class TypingHandler(object):
defer.returnValue(None)
yield self._push_update(
room_id=room_id,
user_id=target_user_id,
member=member,
typing=True,
)
@ -133,10 +164,6 @@ class TypingHandler(object):
member = RoomMember(room_id=room_id, user_id=target_user_id)
if member in self._member_typing_timer:
self.clock.cancel_call_later(self._member_typing_timer[member])
del self._member_typing_timer[member]
yield self._stopped_typing(member)
@defer.inlineCallbacks
@ -148,57 +175,61 @@ class TypingHandler(object):
@defer.inlineCallbacks
def _stopped_typing(self, member):
if member not in self._member_typing_until:
if member.user_id not in self._room_typing.get(member.room_id, set()):
# No point
defer.returnValue(None)
self._member_typing_until.pop(member, None)
self._member_last_federation_poke.pop(member, None)
yield self._push_update(
room_id=member.room_id,
user_id=member.user_id,
member=member,
typing=False,
)
del self._member_typing_until[member]
@defer.inlineCallbacks
def _push_update(self, member, typing):
if self.hs.is_mine_id(member.user_id):
# Only send updates for changes to our own users.
yield self._push_remote(member, typing)
if member in self._member_typing_timer:
# Don't cancel it - either it already expired, or the real
# stopped_typing() will cancel it
del self._member_typing_timer[member]
self._push_update_local(
member=member,
typing=typing
)
@defer.inlineCallbacks
def _push_update(self, room_id, user_id, typing):
users = yield self.state.get_current_user_in_room(room_id)
domains = set(get_domain_from_id(u) for u in users)
def _push_remote(self, member, typing):
users = yield self.state.get_current_user_in_room(member.room_id)
self._member_last_federation_poke[member] = self.clock.time_msec()
deferreds = []
for domain in domains:
if domain == self.server_name:
preserve_fn(self._push_update_local)(
room_id=room_id,
user_id=user_id,
typing=typing
)
else:
deferreds.append(preserve_fn(self.federation.send_edu)(
now = self.clock.time_msec()
self.wheel_timer.insert(
now=now,
obj=member,
then=now + FEDERATION_PING_INTERVAL,
)
for domain in set(get_domain_from_id(u) for u in users):
if domain != self.server_name:
self.federation.send_edu(
destination=domain,
edu_type="m.typing",
content={
"room_id": room_id,
"user_id": user_id,
"room_id": member.room_id,
"user_id": member.user_id,
"typing": typing,
},
key=(room_id, user_id),
))
yield preserve_context_over_deferred(
defer.DeferredList(deferreds, consumeErrors=True)
)
key=member,
)
@defer.inlineCallbacks
def _recv_edu(self, origin, content):
room_id = content["room_id"]
user_id = content["user_id"]
member = RoomMember(user_id=user_id, room_id=room_id)
# Check that the string is a valid user id
user = UserID.from_string(user_id)
@ -213,26 +244,32 @@ class TypingHandler(object):
domains = set(get_domain_from_id(u) for u in users)
if self.server_name in domains:
logger.info("Got typing update from %s: %r", user_id, content)
now = self.clock.time_msec()
self._member_typing_until[member] = now + FEDERATION_TIMEOUT
self.wheel_timer.insert(
now=now,
obj=member,
then=now + FEDERATION_TIMEOUT,
)
self._push_update_local(
room_id=room_id,
user_id=user_id,
member=member,
typing=content["typing"]
)
def _push_update_local(self, room_id, user_id, typing):
room_set = self._room_typing.setdefault(room_id, set())
def _push_update_local(self, member, typing):
room_set = self._room_typing.setdefault(member.room_id, set())
if typing:
room_set.add(user_id)
room_set.add(member.user_id)
else:
room_set.discard(user_id)
room_set.discard(member.user_id)
self._latest_room_serial += 1
self._room_serials[room_id] = self._latest_room_serial
self._room_serials[member.room_id] = self._latest_room_serial
with PreserveLoggingContext():
self.notifier.on_new_event(
"typing_key", self._latest_room_serial, rooms=[room_id]
)
self.notifier.on_new_event(
"typing_key", self._latest_room_serial, rooms=[member.room_id]
)
def get_all_typing_updates(self, last_id, current_id):
# TODO: Work out a way to do this without scanning the entire state.

View file

@ -25,16 +25,15 @@ class InitialSyncRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(InitialSyncRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
self.initial_sync_handler = hs.get_initial_sync_handler()
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request)
as_client_event = "raw" not in request.args
pagination_config = PaginationConfig.from_request(request)
handler = self.handlers.message_handler
include_archived = request.args.get("archived", None) == ["true"]
content = yield handler.snapshot_all_rooms(
content = yield self.initial_sync_handler.snapshot_all_rooms(
user_id=requester.user.to_string(),
pagin_config=pagination_config,
as_client_event=as_client_event,

View file

@ -456,13 +456,13 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomInitialSyncRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
self.initial_sync_handler = hs.get_initial_sync_handler()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = PaginationConfig.from_request(request)
content = yield self.handlers.message_handler.room_initial_sync(
content = yield self.initial_sync_handler.room_initial_sync(
room_id=room_id,
requester=requester,
pagin_config=pagination_config,
@ -705,12 +705,15 @@ class RoomTypingRestServlet(ClientV1RestServlet):
yield self.presence_handler.bump_presence_active_time(requester.user)
# Limit timeout to stop people from setting silly typing timeouts.
timeout = min(content.get("timeout", 30000), 120000)
if content["typing"]:
yield self.typing_handler.started_typing(
target_user=target_user,
auth_user=requester.user,
room_id=room_id,
timeout=content.get("timeout", 30000),
timeout=timeout,
)
else:
yield self.typing_handler.stopped_typing(

View file

@ -43,6 +43,7 @@ from synapse.handlers.room_list import RoomListHandler
from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import TypingHandler
from synapse.handlers.events import EventHandler, EventStreamHandler
from synapse.handlers.initial_sync import InitialSyncHandler
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.notifier import Notifier
@ -98,6 +99,7 @@ class HomeServer(object):
'e2e_keys_handler',
'event_handler',
'event_stream_handler',
'initial_sync_handler',
'application_service_api',
'application_service_scheduler',
'application_service_handler',
@ -228,6 +230,9 @@ class HomeServer(object):
def build_event_stream_handler(self):
return EventStreamHandler(self)
def build_initial_sync_handler(self):
return InitialSyncHandler(self)
def build_event_sources(self):
return EventSources(self)

View file

@ -398,12 +398,11 @@ class EventFederationStore(SQLBaseStore):
sql = ("""
DELETE FROM stream_ordering_to_exterm
WHERE
(
SELECT max(stream_ordering) AS stream_ordering
room_id IN (
SELECT room_id
FROM stream_ordering_to_exterm
WHERE room_id = stream_ordering_to_exterm.room_id
) > ?
AND stream_ordering < ?
WHERE stream_ordering > ?
) AND stream_ordering < ?
""")
txn.execute(
sql,

View file

@ -1355,39 +1355,53 @@ class EventsStore(SQLBaseStore):
min_stream_id = rows[-1][0]
event_ids = [row[1] for row in rows]
events = self._get_events_txn(txn, event_ids)
rows_to_update = []
rows = []
for event in events:
try:
event_id = event.event_id
origin_server_ts = event.origin_server_ts
except (KeyError, AttributeError):
# If the event is missing a necessary field then
# skip over it.
continue
chunks = [
event_ids[i:i + 100]
for i in xrange(0, len(event_ids), 100)
]
for chunk in chunks:
ev_rows = self._simple_select_many_txn(
txn,
table="event_json",
column="event_id",
iterable=chunk,
retcols=["event_id", "json"],
keyvalues={},
)
rows.append((origin_server_ts, event_id))
for row in ev_rows:
event_id = row["event_id"]
event_json = json.loads(row["json"])
try:
origin_server_ts = event_json["origin_server_ts"]
except (KeyError, AttributeError):
# If the event is missing a necessary field then
# skip over it.
continue
rows_to_update.append((origin_server_ts, event_id))
sql = (
"UPDATE events SET origin_server_ts = ? WHERE event_id = ?"
)
for index in range(0, len(rows), INSERT_CLUMP_SIZE):
clump = rows[index:index + INSERT_CLUMP_SIZE]
for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE):
clump = rows_to_update[index:index + INSERT_CLUMP_SIZE]
txn.executemany(sql, clump)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
"max_stream_id_exclusive": min_stream_id,
"rows_inserted": rows_inserted + len(rows)
"rows_inserted": rows_inserted + len(rows_to_update)
}
self._background_update_progress_txn(
txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress
)
return len(rows)
return len(rows_to_update)
result = yield self.runInteraction(
self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn

View file

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 35
SCHEMA_VERSION = 36
dir_path = os.path.abspath(os.path.dirname(__file__))

View file

@ -0,0 +1,26 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Re-add some entries to stream_ordering_to_exterm that were incorrectly deleted
INSERT INTO stream_ordering_to_exterm (stream_ordering, room_id, event_id)
SELECT
(SELECT stream_ordering FROM events where event_id = e.event_id) AS stream_ordering,
room_id,
event_id
FROM event_forward_extremities AS e
WHERE NOT EXISTS (
SELECT room_id FROM stream_ordering_to_exterm AS s
WHERE s.room_id = e.room_id
);

View file

@ -307,6 +307,9 @@ class StateStore(SQLBaseStore):
def _get_state_groups_from_groups_txn(self, txn, groups, types=None):
results = {group: {} for group in groups}
if types is not None:
types = list(set(types)) # deduplicate types list
if isinstance(self.database_engine, PostgresEngine):
# Temporarily disable sequential scans in this transaction. This is
# a temporary hack until we can add the right indices in
@ -375,10 +378,35 @@ class StateStore(SQLBaseStore):
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
for group in groups:
group_tree = [group]
next_group = group
while next_group:
# We did this before by getting the list of group ids, and
# then passing that list to sqlite to get latest event for
# each (type, state_key). However, that was terribly slow
# without the right indicies (which we can't add until
# after we finish deduping state, which requires this func)
args = [next_group]
if types:
args.extend(i for typ in types for i in typ)
txn.execute(
"SELECT type, state_key, event_id FROM state_groups_state"
" WHERE state_group = ? %s" % (where_clause,),
args
)
rows = txn.fetchall()
results[group].update({
(typ, state_key): event_id
for typ, state_key, event_id in rows
if (typ, state_key) not in results[group]
})
# If the lengths match then we must have all the types,
# so no need to go walk further down the tree.
if types is not None and len(results[group]) == len(types):
break
next_group = self._simple_select_one_onecol_txn(
txn,
table="state_group_edges",
@ -386,28 +414,6 @@ class StateStore(SQLBaseStore):
retcol="prev_state_group",
allow_none=True,
)
if next_group:
group_tree.append(next_group)
sql = ("""
SELECT type, state_key, event_id FROM state_groups_state
INNER JOIN (
SELECT type, state_key, max(state_group) as state_group
FROM state_groups_state
WHERE state_group IN (%s) %s
GROUP BY type, state_key
) USING (type, state_key, state_group);
""") % (",".join("?" for _ in group_tree), where_clause,)
args = list(group_tree)
if types is not None:
args.extend([i for typ in types for i in typ])
txn.execute(sql, args)
rows = self.cursor_to_dict(txn)
for row in rows:
key = (row["type"], row["state_key"])
results[group][key] = row["event_id"]
return results

View file

@ -56,7 +56,7 @@ def get_domain_from_id(string):
try:
return string.split(":", 1)[1]
except IndexError:
raise SynapseError(400, "Invalid ID: %r", string)
raise SynapseError(400, "Invalid ID: %r" % (string,))
class DomainSpecificString(

View file

@ -267,10 +267,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
from synapse.handlers.typing import RoomMember
member = RoomMember(self.room_id, self.u_apple.to_string())
self.handler._member_typing_until[member] = 1002000
self.handler._member_typing_timer[member] = (
self.clock.call_later(1002, lambda: 0)
)
self.handler._room_typing[self.room_id] = set((self.u_apple.to_string(),))
self.handler._room_typing[self.room_id] = set([self.u_apple.to_string()])
self.assertEquals(self.event_source.get_current_key(), 0)
@ -330,7 +327,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
},
}])
self.clock.advance_time(11)
self.clock.advance_time(16)
self.on_new_event.assert_has_calls([
call('typing_key', 2, rooms=[self.room_id]),

View file

@ -105,9 +105,6 @@ class RoomTypingTestCase(RestTestCase):
# Need another user to make notifications actually work
yield self.join(self.room_id, user="@jim:red")
def tearDown(self):
self.hs.get_typing_handler().tearDown()
@defer.inlineCallbacks
def test_set_typing(self):
(code, _) = yield self.mock_resource.trigger(
@ -147,7 +144,7 @@ class RoomTypingTestCase(RestTestCase):
self.assertEquals(self.event_source.get_current_key(), 1)
self.clock.advance_time(31)
self.clock.advance_time(36)
self.assertEquals(self.event_source.get_current_key(), 2)

View file

@ -220,6 +220,7 @@ class MockClock(object):
# list of lists of [absolute_time, callback, expired] in no particular
# order
self.timers = []
self.loopers = []
def time(self):
return self.now
@ -240,7 +241,7 @@ class MockClock(object):
return t
def looping_call(self, function, interval):
pass
self.loopers.append([function, interval / 1000., self.now])
def cancel_call_later(self, timer, ignore_errs=False):
if timer[2]:
@ -269,6 +270,12 @@ class MockClock(object):
else:
self.timers.append(t)
for looped in self.loopers:
func, interval, last = looped
if last + interval < self.now:
func()
looped[2] = self.now
def advance_time_msec(self, ms):
self.advance_time(ms / 1000.)