0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-15 11:33:51 +01:00

Merge branch 'develop' of github.com:matrix-org/synapse into erikj/remove_auth

This commit is contained in:
Erik Johnston 2016-10-17 11:08:23 +01:00
commit 816988baaa
80 changed files with 2948 additions and 1378 deletions

8
.gitignore vendored
View file

@ -24,10 +24,10 @@ homeserver*.yaml
.coverage .coverage
htmlcov htmlcov
demo/*.db demo/*/*.db
demo/*.log demo/*/*.log
demo/*.log.* demo/*/*.log.*
demo/*.pid demo/*/*.pid
demo/media_store.* demo/media_store.*
demo/etc demo/etc

View file

@ -1,3 +1,86 @@
Changes in synapse v0.18.1 (2016-10-0)
======================================
No changes since v0.18.1-rc1
Changes in synapse v0.18.1-rc1 (2016-09-30)
===========================================
Features:
* Add total_room_count_estimate to ``/publicRooms`` (PR #1133)
Changes:
* Time out typing over federation (PR #1140)
* Restructure LDAP authentication (PR #1153)
Bug fixes:
* Fix 3pid invites when server is already in the room (PR #1136)
* Fix upgrading with SQLite taking lots of CPU for a few days
after upgrade (PR #1144)
* Fix upgrading from very old database versions (PR #1145)
* Fix port script to work with recently added tables (PR #1146)
Changes in synapse v0.18.0 (2016-09-19)
=======================================
The release includes major changes to the state storage database schemas, which
significantly reduce database size. Synapse will attempt to upgrade the current
data in the background. Servers with large SQLite database may experience
degradation of performance while this upgrade is in progress, therefore you may
want to consider migrating to using Postgres before upgrading very large SQLite
databases
Changes:
* Make public room search case insensitive (PR #1127)
Bug fixes:
* Fix and clean up publicRooms pagination (PR #1129)
Changes in synapse v0.18.0-rc1 (2016-09-16)
===========================================
Features:
* Add ``only=highlight`` on ``/notifications`` (PR #1081)
* Add server param to /publicRooms (PR #1082)
* Allow clients to ask for the whole of a single state event (PR #1094)
* Add is_direct param to /createRoom (PR #1108)
* Add pagination support to publicRooms (PR #1121)
* Add very basic filter API to /publicRooms (PR #1126)
* Add basic direct to device messaging support for E2E (PR #1074, #1084, #1104,
#1111)
Changes:
* Move to storing state_groups_state as deltas, greatly reducing DB size (PR
#1065)
* Reduce amount of state pulled out of the DB during common requests (PR #1069)
* Allow PDF to be rendered from media repo (PR #1071)
* Reindex state_groups_state after pruning (PR #1085)
* Clobber EDUs in send queue (PR #1095)
* Conform better to the CAS protocol specification (PR #1100)
* Limit how often we ask for keys from dead servers (PR #1114)
Bug fixes:
* Fix /notifications API when used with ``from`` param (PR #1080)
* Fix backfill when cannot find an event. (PR #1107)
Changes in synapse v0.17.3 (2016-09-09) Changes in synapse v0.17.3 (2016-09-09)
======================================= =======================================

View file

@ -42,6 +42,7 @@ The current available worker applications are:
* synapse.app.appservice - handles output traffic to Application Services * synapse.app.appservice - handles output traffic to Application Services
* synapse.app.federation_reader - handles receiving federation traffic (including public_rooms API) * synapse.app.federation_reader - handles receiving federation traffic (including public_rooms API)
* synapse.app.media_repository - handles the media repository. * synapse.app.media_repository - handles the media repository.
* synapse.app.client_reader - handles client API endpoints like /publicRooms
Each worker configuration file inherits the configuration of the main homeserver Each worker configuration file inherits the configuration of the main homeserver
configuration file. You can then override configuration specific to that worker, configuration file. You can then override configuration specific to that worker,

View file

@ -20,3 +20,5 @@ export SYNAPSE_CACHE_FACTOR=1
--pusher \ --pusher \
--synchrotron \ --synchrotron \
--federation-reader \ --federation-reader \
--client-reader \
--appservice \

View file

@ -18,7 +18,9 @@
<div class="summarytext">{{ summary_text }}</div> <div class="summarytext">{{ summary_text }}</div>
</td> </td>
<td class="logo"> <td class="logo">
{% if app_name == "Vector" %} {% if app_name == "Riot" %}
<img src="http://matrix.org/img/riot-logo-email.png" width="83" height="83" alt="[Riot]"/>
{% elif app_name == "Vector" %}
<img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/> <img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/>
{% else %} {% else %}
<img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/> <img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/>

View file

@ -39,6 +39,7 @@ BOOLEAN_COLUMNS = {
"event_edges": ["is_state"], "event_edges": ["is_state"],
"presence_list": ["accepted"], "presence_list": ["accepted"],
"presence_stream": ["currently_active"], "presence_stream": ["currently_active"],
"public_room_list_stream": ["visibility"],
} }
@ -71,6 +72,14 @@ APPEND_ONLY_TABLES = [
"event_to_state_groups", "event_to_state_groups",
"rejections", "rejections",
"event_search", "event_search",
"presence_stream",
"push_rules_stream",
"current_state_resets",
"ex_outlier_stream",
"cache_invalidation_stream",
"public_room_list_stream",
"state_group_edges",
"stream_ordering_to_exterm",
] ]

View file

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.17.3" __version__ = "0.18.1"

View file

@ -72,7 +72,7 @@ class Auth(object):
auth_events = { auth_events = {
(e.type, e.state_key): e for e in auth_events.values() (e.type, e.state_key): e for e in auth_events.values()
} }
self.check(event, auth_events=auth_events, do_sig_check=False) self.check(event, auth_events=auth_events, do_sig_check=do_sig_check)
def check(self, event, auth_events, do_sig_check=True): def check(self, event, auth_events, do_sig_check=True):
""" Checks if this event is correctly authed. """ Checks if this event is correctly authed.
@ -91,11 +91,28 @@ class Auth(object):
if not hasattr(event, "room_id"): if not hasattr(event, "room_id"):
raise AuthError(500, "Event has no room_id: %s" % event) raise AuthError(500, "Event has no room_id: %s" % event)
sender_domain = get_domain_from_id(event.sender) if do_sig_check:
sender_domain = get_domain_from_id(event.sender)
event_id_domain = get_domain_from_id(event.event_id)
# Check the sender's domain has signed the event is_invite_via_3pid = (
if do_sig_check and not event.signatures.get(sender_domain): event.type == EventTypes.Member
raise AuthError(403, "Event not signed by sending server") and event.membership == Membership.INVITE
and "third_party_invite" in event.content
)
# Check the sender's domain has signed the event
if not event.signatures.get(sender_domain):
# We allow invites via 3pid to have a sender from a different
# HS, as the sender must match the sender of the original
# 3pid invite. This is checked further down with the
# other dedicated membership checks.
if not is_invite_via_3pid:
raise AuthError(403, "Event not signed by sender's server")
# Check the event_id's domain has signed the event
if not event.signatures.get(event_id_domain):
raise AuthError(403, "Event not signed by sending server")
if auth_events is None: if auth_events is None:
# Oh, we don't know what the state of the room was, so we # Oh, we don't know what the state of the room was, so we
@ -491,6 +508,9 @@ class Auth(object):
if not invite_event: if not invite_event:
return False return False
if invite_event.sender != event.sender:
return False
if event.user_id != invite_event.user_id: if event.user_id != invite_event.user_id:
return False return False
@ -633,7 +653,7 @@ class Auth(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_appservice_user_id(self, request): def _get_appservice_user_id(self, request):
app_service = yield self.store.get_app_service_by_token( app_service = self.store.get_app_service_by_token(
get_access_token_from_request( get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS request, self.TOKEN_NOT_FOUND_HTTP_STATUS
) )
@ -835,13 +855,12 @@ class Auth(object):
} }
defer.returnValue(user_info) defer.returnValue(user_info)
@defer.inlineCallbacks
def get_appservice_by_req(self, request): def get_appservice_by_req(self, request):
try: try:
token = get_access_token_from_request( token = get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS request, self.TOKEN_NOT_FOUND_HTTP_STATUS
) )
service = yield self.store.get_app_service_by_token(token) service = self.store.get_app_service_by_token(token)
if not service: if not service:
logger.warn("Unrecognised appservice access token: %s" % (token,)) logger.warn("Unrecognised appservice access token: %s" % (token,))
raise AuthError( raise AuthError(
@ -850,7 +869,7 @@ class Auth(object):
errcode=Codes.UNKNOWN_TOKEN errcode=Codes.UNKNOWN_TOKEN
) )
request.authenticated_entity = service.sender request.authenticated_entity = service.sender
defer.returnValue(service) return defer.succeed(service)
except KeyError: except KeyError:
raise AuthError( raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token." self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token."
@ -982,16 +1001,6 @@ class Auth(object):
403, 403,
"You are not allowed to set others state" "You are not allowed to set others state"
) )
else:
sender_domain = UserID.from_string(
event.user_id
).domain
if sender_domain != event.state_key:
raise AuthError(
403,
"You are not allowed to set others state"
)
return True return True

View file

@ -187,6 +187,7 @@ def start(config_options):
def start(): def start():
ps.replicate() ps.replicate()
ps.get_datastore().start_profiling() ps.get_datastore().start_profiling()
ps.get_state_handler().start_caching()
reactor.callWhenRunning(start) reactor.callWhenRunning(start)

View file

@ -0,0 +1,216 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.http.site import SynapseSite
from synapse.http.server import JsonResource
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.keys import SlavedKeyStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.rest.client.v1.room import PublicRoomListRestServlet
from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore
from synapse.storage.engines import create_engine
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from synapse.crypto import context_factory
from twisted.internet import reactor, defer
from twisted.web.resource import Resource
from daemonize import Daemonize
import sys
import logging
import gc
logger = logging.getLogger("synapse.app.client_reader")
class ClientReaderSlavedStore(
SlavedEventStore,
SlavedKeyStore,
RoomStore,
DirectoryStore,
SlavedApplicationServiceStore,
SlavedRegistrationStore,
BaseSlavedStore,
ClientIpStore, # After BaseSlavedStore because the constructor is different
):
pass
class ClientReaderServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = ClientReaderSlavedStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
elif name == "client":
resource = JsonResource(self, canonical_json=False)
PublicRoomListRestServlet(self).register(resource)
resources.update({
"/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource,
"/_matrix/client/v2_alpha": resource,
"/_matrix/client/api/v1": resource,
})
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse client reader now listening on port %d", port)
def start_listening(self, listeners):
for listener in listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
while True:
try:
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
yield store.process_replication(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(5)
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse client reader", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.client_reader"
setup_logging(config.worker_log_config, config.worker_log_file)
database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config)
ss = ClientReaderServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ss.setup()
ss.get_handlers()
ss.start_listening(config.worker_listeners)
def run():
with LoggingContext("run"):
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
gc.set_threshold(*config.gc_thresholds)
reactor.run()
def start():
ss.get_state_handler().start_caching()
ss.get_datastore().start_profiling()
ss.replicate()
reactor.callWhenRunning(start)
if config.worker_daemonize:
daemon = Daemonize(
app="synapse-client-reader",
pid=config.worker_pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
run()
if __name__ == '__main__':
with LoggingContext("main"):
start(sys.argv[1:])

View file

@ -182,6 +182,7 @@ def start(config_options):
reactor.run() reactor.run()
def start(): def start():
ss.get_state_handler().start_caching()
ss.get_datastore().start_profiling() ss.get_datastore().start_profiling()
ss.replicate() ss.replicate()

View file

@ -188,6 +188,7 @@ def start(config_options):
reactor.run() reactor.run()
def start(): def start():
ss.get_state_handler().start_caching()
ss.get_datastore().start_profiling() ss.get_datastore().start_profiling()
ss.replicate() ss.replicate()

View file

@ -197,7 +197,7 @@ class PusherServer(HomeServer):
yield start_pusher(user_id, app_id, pushkey) yield start_pusher(user_id, app_id, pushkey)
stream = results.get("events") stream = results.get("events")
if stream: if stream and stream["rows"]:
min_stream_id = stream["rows"][0][0] min_stream_id = stream["rows"][0][0]
max_stream_id = stream["position"] max_stream_id = stream["position"]
preserve_fn(pusher_pool.on_new_notifications)( preserve_fn(pusher_pool.on_new_notifications)(
@ -205,7 +205,7 @@ class PusherServer(HomeServer):
) )
stream = results.get("receipts") stream = results.get("receipts")
if stream: if stream and stream["rows"]:
rows = stream["rows"] rows = stream["rows"]
affected_room_ids = set(row[1] for row in rows) affected_room_ids = set(row[1] for row in rows)
min_stream_id = rows[0][0] min_stream_id = rows[0][0]
@ -276,6 +276,7 @@ def start(config_options):
ps.replicate() ps.replicate()
ps.get_pusherpool().start() ps.get_pusherpool().start()
ps.get_datastore().start_profiling() ps.get_datastore().start_profiling()
ps.get_state_handler().start_caching()
reactor.callWhenRunning(start) reactor.callWhenRunning(start)

View file

@ -27,6 +27,8 @@ from synapse.http.server import JsonResource
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.rest.client.v2_alpha import sync from synapse.rest.client.v2_alpha import sync
from synapse.rest.client.v1 import events from synapse.rest.client.v1 import events
from synapse.rest.client.v1.room import RoomInitialSyncRestServlet
from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
@ -37,6 +39,7 @@ from synapse.replication.slave.storage.filtering import SlavedFilteringStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.presence import SlavedPresenceStore from synapse.replication.slave.storage.presence import SlavedPresenceStore
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore from synapse.storage.client_ips import ClientIpStore
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
@ -74,6 +77,7 @@ class SynchrotronSlavedStore(
SlavedFilteringStore, SlavedFilteringStore,
SlavedPresenceStore, SlavedPresenceStore,
SlavedDeviceInboxStore, SlavedDeviceInboxStore,
RoomStore,
BaseSlavedStore, BaseSlavedStore,
ClientIpStore, # After BaseSlavedStore because the constructor is different ClientIpStore, # After BaseSlavedStore because the constructor is different
): ):
@ -296,6 +300,8 @@ class SynchrotronServer(HomeServer):
resource = JsonResource(self, canonical_json=False) resource = JsonResource(self, canonical_json=False)
sync.register_servlets(self, resource) sync.register_servlets(self, resource)
events.register_servlets(self, resource) events.register_servlets(self, resource)
InitialSyncRestServlet(self).register(resource)
RoomInitialSyncRestServlet(self).register(resource)
resources.update({ resources.update({
"/_matrix/client/r0": resource, "/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource, "/_matrix/client/unstable": resource,
@ -465,6 +471,7 @@ def start(config_options):
def start(): def start():
ss.get_datastore().start_profiling() ss.get_datastore().start_profiling()
ss.replicate() ss.replicate()
ss.get_state_handler().start_caching()
reactor.callWhenRunning(start) reactor.callWhenRunning(start)

View file

@ -24,7 +24,7 @@ import subprocess
import sys import sys
import yaml import yaml
SYNAPSE = ["python", "-B", "-m", "synapse.app.homeserver"] SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"]
GREEN = "\x1b[1;32m" GREEN = "\x1b[1;32m"
RED = "\x1b[1;31m" RED = "\x1b[1;31m"

View file

@ -30,7 +30,7 @@ from .saml2 import SAML2Config
from .cas import CasConfig from .cas import CasConfig
from .password import PasswordConfig from .password import PasswordConfig
from .jwt import JWTConfig from .jwt import JWTConfig
from .ldap import LDAPConfig from .password_auth_providers import PasswordAuthProviderConfig
from .emailconfig import EmailConfig from .emailconfig import EmailConfig
from .workers import WorkerConfig from .workers import WorkerConfig
@ -39,8 +39,8 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig, VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
AppServiceConfig, KeyConfig, SAML2Config, CasConfig, AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
JWTConfig, LDAPConfig, PasswordConfig, EmailConfig, JWTConfig, PasswordConfig, EmailConfig,
WorkerConfig,): WorkerConfig, PasswordAuthProviderConfig,):
pass pass

View file

@ -1,100 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 Niklas Riekenbrauck
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config, ConfigError
MISSING_LDAP3 = (
"Missing ldap3 library. This is required for LDAP Authentication."
)
class LDAPMode(object):
SIMPLE = "simple",
SEARCH = "search",
LIST = (SIMPLE, SEARCH)
class LDAPConfig(Config):
def read_config(self, config):
ldap_config = config.get("ldap_config", {})
self.ldap_enabled = ldap_config.get("enabled", False)
if self.ldap_enabled:
# verify dependencies are available
try:
import ldap3
ldap3 # to stop unused lint
except ImportError:
raise ConfigError(MISSING_LDAP3)
self.ldap_mode = LDAPMode.SIMPLE
# verify config sanity
self.require_keys(ldap_config, [
"uri",
"base",
"attributes",
])
self.ldap_uri = ldap_config["uri"]
self.ldap_start_tls = ldap_config.get("start_tls", False)
self.ldap_base = ldap_config["base"]
self.ldap_attributes = ldap_config["attributes"]
if "bind_dn" in ldap_config:
self.ldap_mode = LDAPMode.SEARCH
self.require_keys(ldap_config, [
"bind_dn",
"bind_password",
])
self.ldap_bind_dn = ldap_config["bind_dn"]
self.ldap_bind_password = ldap_config["bind_password"]
self.ldap_filter = ldap_config.get("filter", None)
# verify attribute lookup
self.require_keys(ldap_config['attributes'], [
"uid",
"name",
"mail",
])
def require_keys(self, config, required):
missing = [key for key in required if key not in config]
if missing:
raise ConfigError(
"LDAP enabled but missing required config values: {}".format(
", ".join(missing)
)
)
def default_config(self, **kwargs):
return """\
# ldap_config:
# enabled: true
# uri: "ldap://ldap.example.com:389"
# start_tls: true
# base: "ou=users,dc=example,dc=com"
# attributes:
# uid: "cn"
# mail: "email"
# name: "givenName"
# #bind_dn:
# #bind_password:
# #filter: "(objectClass=posixAccount)"
"""

View file

@ -0,0 +1,61 @@
# -*- coding: utf-8 -*-
# Copyright 2016 Openmarket
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
import importlib
class PasswordAuthProviderConfig(Config):
def read_config(self, config):
self.password_providers = []
# We want to be backwards compatible with the old `ldap_config`
# param.
ldap_config = config.get("ldap_config", {})
self.ldap_enabled = ldap_config.get("enabled", False)
if self.ldap_enabled:
from synapse.util.ldap_auth_provider import LdapAuthProvider
parsed_config = LdapAuthProvider.parse_config(ldap_config)
self.password_providers.append((LdapAuthProvider, parsed_config))
providers = config.get("password_providers", [])
for provider in providers:
# We need to import the module, and then pick the class out of
# that, so we split based on the last dot.
module, clz = provider['module'].rsplit(".", 1)
module = importlib.import_module(module)
provider_class = getattr(module, clz)
provider_config = provider_class.parse_config(provider["config"])
self.password_providers.append((provider_class, provider_config))
def default_config(self, **kwargs):
return """\
# password_providers:
# - module: "synapse.util.ldap_auth_provider.LdapAuthProvider"
# config:
# enabled: true
# uri: "ldap://ldap.example.com:389"
# start_tls: true
# base: "ou=users,dc=example,dc=com"
# attributes:
# uid: "cn"
# mail: "email"
# name: "givenName"
# #bind_dn:
# #bind_password:
# #filter: "(objectClass=posixAccount)"
"""

View file

@ -29,7 +29,6 @@ class ServerConfig(Config):
self.user_agent_suffix = config.get("user_agent_suffix") self.user_agent_suffix = config.get("user_agent_suffix")
self.use_frozen_dicts = config.get("use_frozen_dicts", False) self.use_frozen_dicts = config.get("use_frozen_dicts", False)
self.public_baseurl = config.get("public_baseurl") self.public_baseurl = config.get("public_baseurl")
self.secondary_directory_servers = config.get("secondary_directory_servers", [])
if self.public_baseurl is not None: if self.public_baseurl is not None:
if self.public_baseurl[-1] != '/': if self.public_baseurl[-1] != '/':
@ -142,14 +141,6 @@ class ServerConfig(Config):
# The GC threshold parameters to pass to `gc.set_threshold`, if defined # The GC threshold parameters to pass to `gc.set_threshold`, if defined
# gc_thresholds: [700, 10, 10] # gc_thresholds: [700, 10, 10]
# A list of other Home Servers to fetch the public room directory from
# and include in the public room directory of this home server
# This is a temporary stopgap solution to populate new server with a
# list of rooms until there exists a good solution of a decentralized
# room directory.
# secondary_directory_servers:
# - matrix.org
# List of ports that Synapse should listen on, their purpose and their # List of ports that Synapse should listen on, their purpose and their
# configuration. # configuration.
listeners: listeners:

View file

@ -19,6 +19,9 @@ from OpenSSL import crypto
import subprocess import subprocess
import os import os
from hashlib import sha256
from unpaddedbase64 import encode_base64
GENERATE_DH_PARAMS = False GENERATE_DH_PARAMS = False
@ -42,6 +45,19 @@ class TlsConfig(Config):
config.get("tls_dh_params_path"), "tls_dh_params" config.get("tls_dh_params_path"), "tls_dh_params"
) )
self.tls_fingerprints = config["tls_fingerprints"]
# Check that our own certificate is included in the list of fingerprints
# and include it if it is not.
x509_certificate_bytes = crypto.dump_certificate(
crypto.FILETYPE_ASN1,
self.tls_certificate
)
sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest())
sha256_fingerprints = set(f["sha256"] for f in self.tls_fingerprints)
if sha256_fingerprint not in sha256_fingerprints:
self.tls_fingerprints.append({u"sha256": sha256_fingerprint})
# This config option applies to non-federation HTTP clients # This config option applies to non-federation HTTP clients
# (e.g. for talking to recaptcha, identity servers, and such) # (e.g. for talking to recaptcha, identity servers, and such)
# It should never be used in production, and is intended for # It should never be used in production, and is intended for
@ -73,6 +89,28 @@ class TlsConfig(Config):
# Don't bind to the https port # Don't bind to the https port
no_tls: False no_tls: False
# List of allowed TLS fingerprints for this server to publish along
# with the signing keys for this server. Other matrix servers that
# make HTTPS requests to this server will check that the TLS
# certificates returned by this server match one of the fingerprints.
#
# Synapse automatically adds its the fingerprint of its own certificate
# to the list. So if federation traffic is handle directly by synapse
# then no modification to the list is required.
#
# If synapse is run behind a load balancer that handles the TLS then it
# will be necessary to add the fingerprints of the certificates used by
# the loadbalancers to this list if they are different to the one
# synapse is using.
#
# Homeservers are permitted to cache the list of TLS fingerprints
# returned in the key responses up to the "valid_until_ts" returned in
# key. It may be necessary to publish the fingerprints of a new
# certificate and wait until the "valid_until_ts" of the previous key
# responses have passed before deploying it.
tls_fingerprints: []
# tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}]
""" % locals() """ % locals()
def read_tls_certificate(self, cert_path): def read_tls_certificate(self, cert_path):

View file

@ -24,7 +24,6 @@ from synapse.api.errors import (
CodeMessageException, HttpResponseException, SynapseError, CodeMessageException, HttpResponseException, SynapseError,
) )
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
@ -137,9 +136,7 @@ class FederationClient(FederationBase):
sent_edus_counter.inc() sent_edus_counter.inc()
# TODO, add errback, etc.
self._transaction_queue.enqueue_edu(edu, key=key) self._transaction_queue.enqueue_edu(edu, key=key)
return defer.succeed(None)
@log_function @log_function
def send_device_messages(self, destination): def send_device_messages(self, destination):
@ -719,24 +716,14 @@ class FederationClient(FederationBase):
raise RuntimeError("Failed to send to any server.") raise RuntimeError("Failed to send to any server.")
@defer.inlineCallbacks def get_public_rooms(self, destination, limit=None, since_token=None,
def get_public_rooms(self, destinations): search_filter=None):
results_by_server = {} if destination == self.server_name:
return
@defer.inlineCallbacks return self.transport_layer.get_public_rooms(
def _get_result(s): destination, limit, since_token, search_filter
if s == self.server_name: )
defer.returnValue()
try:
result = yield self.transport_layer.get_public_rooms(s)
results_by_server[s] = result
except:
logger.exception("Error getting room list from server %r", s)
yield concurrently_execute(_get_result, destinations, 3)
defer.returnValue(results_by_server)
@defer.inlineCallbacks @defer.inlineCallbacks
def query_auth(self, destination, room_id, event_id, local_auth): def query_auth(self, destination, room_id, event_id, local_auth):

View file

@ -248,12 +248,22 @@ class TransportLayerClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_public_rooms(self, remote_server): def get_public_rooms(self, remote_server, limit, since_token,
search_filter=None):
path = PREFIX + "/publicRooms" path = PREFIX + "/publicRooms"
args = {}
if limit:
args["limit"] = [str(limit)]
if since_token:
args["since"] = [since_token]
# TODO(erikj): Actually send the search_filter across federation.
response = yield self.client.get_json( response = yield self.client.get_json(
destination=remote_server, destination=remote_server,
path=path, path=path,
args=args,
) )
defer.returnValue(response) defer.returnValue(response)

View file

@ -18,7 +18,9 @@ from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import (
parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
)
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
@ -554,7 +556,11 @@ class PublicRoomList(BaseFederationServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, origin, content, query): def on_GET(self, origin, content, query):
data = yield self.room_list_handler.get_local_public_room_list() limit = parse_integer_from_args(query, "limit", 0)
since_token = parse_string_from_args(query, "since", None)
data = yield self.room_list_handler.get_local_public_room_list(
limit, since_token
)
defer.returnValue((200, data)) defer.returnValue((200, data))

View file

@ -55,8 +55,14 @@ class BaseHandler(object):
def ratelimit(self, requester): def ratelimit(self, requester):
time_now = self.clock.time() time_now = self.clock.time()
user_id = requester.user.to_string()
app_service = self.store.get_app_service_by_user_id(user_id)
if app_service is not None:
return # do not ratelimit app service senders
allowed, time_allowed = self.ratelimiter.send_message( allowed, time_allowed = self.ratelimiter.send_message(
requester.user.to_string(), time_now, user_id, time_now,
msg_rate_hz=self.hs.config.rc_messages_per_second, msg_rate_hz=self.hs.config.rc_messages_per_second,
burst_count=self.hs.config.rc_message_burst_count, burst_count=self.hs.config.rc_message_burst_count,
) )

View file

@ -59,7 +59,7 @@ class ApplicationServicesHandler(object):
Args: Args:
current_id(int): The current maximum ID. current_id(int): The current maximum ID.
""" """
services = yield self.store.get_app_services() services = self.store.get_app_services()
if not services or not self.notify_appservices: if not services or not self.notify_appservices:
return return
@ -142,7 +142,7 @@ class ApplicationServicesHandler(object):
association can be found. association can be found.
""" """
room_alias_str = room_alias.to_string() room_alias_str = room_alias.to_string()
services = yield self.store.get_app_services() services = self.store.get_app_services()
alias_query_services = [ alias_query_services = [
s for s in services if ( s for s in services if (
s.is_interested_in_alias(room_alias_str) s.is_interested_in_alias(room_alias_str)
@ -177,7 +177,7 @@ class ApplicationServicesHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_3pe_protocols(self, only_protocol=None): def get_3pe_protocols(self, only_protocol=None):
services = yield self.store.get_app_services() services = self.store.get_app_services()
protocols = {} protocols = {}
# Collect up all the individual protocol responses out of the ASes # Collect up all the individual protocol responses out of the ASes
@ -224,7 +224,7 @@ class ApplicationServicesHandler(object):
list<ApplicationService>: A list of services interested in this list<ApplicationService>: A list of services interested in this
event based on the service regex. event based on the service regex.
""" """
services = yield self.store.get_app_services() services = self.store.get_app_services()
interested_list = [ interested_list = [
s for s in services if ( s for s in services if (
yield s.is_interested(event, self.store) yield s.is_interested(event, self.store)
@ -232,23 +232,21 @@ class ApplicationServicesHandler(object):
] ]
defer.returnValue(interested_list) defer.returnValue(interested_list)
@defer.inlineCallbacks
def _get_services_for_user(self, user_id): def _get_services_for_user(self, user_id):
services = yield self.store.get_app_services() services = self.store.get_app_services()
interested_list = [ interested_list = [
s for s in services if ( s for s in services if (
s.is_interested_in_user(user_id) s.is_interested_in_user(user_id)
) )
] ]
defer.returnValue(interested_list) return defer.succeed(interested_list)
@defer.inlineCallbacks
def _get_services_for_3pn(self, protocol): def _get_services_for_3pn(self, protocol):
services = yield self.store.get_app_services() services = self.store.get_app_services()
interested_list = [ interested_list = [
s for s in services if s.is_interested_in_protocol(protocol) s for s in services if s.is_interested_in_protocol(protocol)
] ]
defer.returnValue(interested_list) return defer.succeed(interested_list)
@defer.inlineCallbacks @defer.inlineCallbacks
def _is_unknown_user(self, user_id): def _is_unknown_user(self, user_id):
@ -264,7 +262,7 @@ class ApplicationServicesHandler(object):
return return
# user not found; could be the AS though, so check. # user not found; could be the AS though, so check.
services = yield self.store.get_app_services() services = self.store.get_app_services()
service_list = [s for s in services if s.sender == user_id] service_list = [s for s in services if s.sender == user_id]
defer.returnValue(len(service_list) == 0) defer.returnValue(len(service_list) == 0)

View file

@ -20,7 +20,6 @@ from synapse.api.constants import LoginType
from synapse.types import UserID from synapse.types import UserID
from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.config.ldap import LDAPMode
from twisted.web.client import PartialDownloadError from twisted.web.client import PartialDownloadError
@ -29,12 +28,6 @@ import bcrypt
import pymacaroons import pymacaroons
import simplejson import simplejson
try:
import ldap3
except ImportError:
ldap3 = None
pass
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
@ -58,23 +51,15 @@ class AuthHandler(BaseHandler):
} }
self.bcrypt_rounds = hs.config.bcrypt_rounds self.bcrypt_rounds = hs.config.bcrypt_rounds
self.sessions = {} self.sessions = {}
self.INVALID_TOKEN_HTTP_STATUS = 401
self.ldap_enabled = hs.config.ldap_enabled account_handler = _AccountHandler(
if self.ldap_enabled: hs, check_user_exists=self.check_user_exists
if not ldap3: )
raise RuntimeError(
'Missing ldap3 library. This is required for LDAP Authentication.' self.password_providers = [
) module(config=config, account_handler=account_handler)
self.ldap_mode = hs.config.ldap_mode for module, config in hs.config.password_providers
self.ldap_uri = hs.config.ldap_uri ]
self.ldap_start_tls = hs.config.ldap_start_tls
self.ldap_base = hs.config.ldap_base
self.ldap_attributes = hs.config.ldap_attributes
if self.ldap_mode == LDAPMode.SEARCH:
self.ldap_bind_dn = hs.config.ldap_bind_dn
self.ldap_bind_password = hs.config.ldap_bind_password
self.ldap_filter = hs.config.ldap_filter
self.hs = hs # FIXME better possibility to access registrationHandler later? self.hs = hs # FIXME better possibility to access registrationHandler later?
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
@ -148,13 +133,30 @@ class AuthHandler(BaseHandler):
creds = session['creds'] creds = session['creds']
# check auth type currently being presented # check auth type currently being presented
errordict = {}
if 'type' in authdict: if 'type' in authdict:
if authdict['type'] not in self.checkers: login_type = authdict['type']
if login_type not in self.checkers:
raise LoginError(400, "", Codes.UNRECOGNIZED) raise LoginError(400, "", Codes.UNRECOGNIZED)
result = yield self.checkers[authdict['type']](authdict, clientip) try:
if result: result = yield self.checkers[login_type](authdict, clientip)
creds[authdict['type']] = result if result:
self._save_session(session) creds[login_type] = result
self._save_session(session)
except LoginError, e:
if login_type == LoginType.EMAIL_IDENTITY:
# riot used to have a bug where it would request a new
# validation token (thus sending a new email) each time it
# got a 401 with a 'flows' field.
# (https://github.com/vector-im/vector-web/issues/2447).
#
# Grandfather in the old behaviour for now to avoid
# breaking old riot deployments.
raise e
# this step failed. Merge the error dict into the response
# so that the client can have another go.
errordict = e.error_dict()
for f in flows: for f in flows:
if len(set(f) - set(creds.keys())) == 0: if len(set(f) - set(creds.keys())) == 0:
@ -163,6 +165,7 @@ class AuthHandler(BaseHandler):
ret = self._auth_dict_for_flows(flows, session) ret = self._auth_dict_for_flows(flows, session)
ret['completed'] = creds.keys() ret['completed'] = creds.keys()
ret.update(errordict)
defer.returnValue((False, ret, clientdict, session['id'])) defer.returnValue((False, ret, clientdict, session['id']))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -430,37 +433,40 @@ class AuthHandler(BaseHandler):
defer.Deferred: (str) canonical_user_id, or None if zero or defer.Deferred: (str) canonical_user_id, or None if zero or
multiple matches multiple matches
""" """
try: res = yield self._find_user_id_and_pwd_hash(user_id)
res = yield self._find_user_id_and_pwd_hash(user_id) if res is not None:
defer.returnValue(res[0]) defer.returnValue(res[0])
except LoginError: defer.returnValue(None)
defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def _find_user_id_and_pwd_hash(self, user_id): def _find_user_id_and_pwd_hash(self, user_id):
"""Checks to see if a user with the given id exists. Will check case """Checks to see if a user with the given id exists. Will check case
insensitively, but will throw if there are multiple inexact matches. insensitively, but will return None if there are multiple inexact
matches.
Returns: Returns:
tuple: A 2-tuple of `(canonical_user_id, password_hash)` tuple: A 2-tuple of `(canonical_user_id, password_hash)`
None: if there is not exactly one match
""" """
user_infos = yield self.store.get_users_by_id_case_insensitive(user_id) user_infos = yield self.store.get_users_by_id_case_insensitive(user_id)
result = None
if not user_infos: if not user_infos:
logger.warn("Attempted to login as %s but they do not exist", user_id) logger.warn("Attempted to login as %s but they do not exist", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN) elif len(user_infos) == 1:
# a single match (possibly not exact)
if len(user_infos) > 1: result = user_infos.popitem()
if user_id not in user_infos: elif user_id in user_infos:
logger.warn( # multiple matches, but one is exact
"Attempted to login as %s but it matches more than one user " result = (user_id, user_infos[user_id])
"inexactly: %r",
user_id, user_infos.keys()
)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
defer.returnValue((user_id, user_infos[user_id]))
else: else:
defer.returnValue(user_infos.popitem()) # multiple matches, none of them exact
logger.warn(
"Attempted to login as %s but it matches more than one user "
"inexactly: %r",
user_id, user_infos.keys()
)
defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_password(self, user_id, password): def _check_password(self, user_id, password):
@ -474,201 +480,48 @@ class AuthHandler(BaseHandler):
Returns: Returns:
(str) the canonical_user_id (str) the canonical_user_id
Raises: Raises:
LoginError if the password was incorrect LoginError if login fails
""" """
valid_ldap = yield self._check_ldap_password(user_id, password) for provider in self.password_providers:
if valid_ldap: is_valid = yield provider.check_password(user_id, password)
defer.returnValue(user_id) if is_valid:
defer.returnValue(user_id)
result = yield self._check_local_password(user_id, password) canonical_user_id = yield self._check_local_password(user_id, password)
defer.returnValue(result)
if canonical_user_id:
defer.returnValue(canonical_user_id)
# unknown username or invalid password. We raise a 403 here, but note
# that if we're doing user-interactive login, it turns all LoginErrors
# into a 401 anyway.
raise LoginError(
403, "Invalid password",
errcode=Codes.FORBIDDEN
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_local_password(self, user_id, password): def _check_local_password(self, user_id, password):
"""Authenticate a user against the local password database. """Authenticate a user against the local password database.
user_id is checked case insensitively, but will throw if there are user_id is checked case insensitively, but will return None if there are
multiple inexact matches. multiple inexact matches.
Args: Args:
user_id (str): complete @user:id user_id (str): complete @user:id
Returns: Returns:
(str) the canonical_user_id (str) the canonical_user_id, or None if unknown user / bad password
Raises:
LoginError if the password was incorrect
""" """
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) lookupres = yield self._find_user_id_and_pwd_hash(user_id)
if not lookupres:
defer.returnValue(None)
(user_id, password_hash) = lookupres
result = self.validate_hash(password, password_hash) result = self.validate_hash(password, password_hash)
if not result: if not result:
logger.warn("Failed password login for user %s", user_id) logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN) defer.returnValue(None)
defer.returnValue(user_id) defer.returnValue(user_id)
@defer.inlineCallbacks
def _check_ldap_password(self, user_id, password):
""" Attempt to authenticate a user against an LDAP Server
and register an account if none exists.
Returns:
True if authentication against LDAP was successful
"""
if not ldap3 or not self.ldap_enabled:
defer.returnValue(False)
if self.ldap_mode not in LDAPMode.LIST:
raise RuntimeError(
'Invalid ldap mode specified: {mode}'.format(
mode=self.ldap_mode
)
)
try:
server = ldap3.Server(self.ldap_uri)
logger.debug(
"Attempting ldap connection with %s",
self.ldap_uri
)
localpart = UserID.from_string(user_id).localpart
if self.ldap_mode == LDAPMode.SIMPLE:
# bind with the the local users ldap credentials
bind_dn = "{prop}={value},{base}".format(
prop=self.ldap_attributes['uid'],
value=localpart,
base=self.ldap_base
)
conn = ldap3.Connection(server, bind_dn, password)
logger.debug(
"Established ldap connection in simple mode: %s",
conn
)
if self.ldap_start_tls:
conn.start_tls()
logger.debug(
"Upgraded ldap connection in simple mode through StartTLS: %s",
conn
)
conn.bind()
elif self.ldap_mode == LDAPMode.SEARCH:
# connect with preconfigured credentials and search for local user
conn = ldap3.Connection(
server,
self.ldap_bind_dn,
self.ldap_bind_password
)
logger.debug(
"Established ldap connection in search mode: %s",
conn
)
if self.ldap_start_tls:
conn.start_tls()
logger.debug(
"Upgraded ldap connection in search mode through StartTLS: %s",
conn
)
conn.bind()
# find matching dn
query = "({prop}={value})".format(
prop=self.ldap_attributes['uid'],
value=localpart
)
if self.ldap_filter:
query = "(&{query}{filter})".format(
query=query,
filter=self.ldap_filter
)
logger.debug("ldap search filter: %s", query)
result = conn.search(self.ldap_base, query)
if result and len(conn.response) == 1:
# found exactly one result
user_dn = conn.response[0]['dn']
logger.debug('ldap search found dn: %s', user_dn)
# unbind and reconnect, rebind with found dn
conn.unbind()
conn = ldap3.Connection(
server,
user_dn,
password,
auto_bind=True
)
else:
# found 0 or > 1 results, abort!
logger.warn(
"ldap search returned unexpected (%d!=1) amount of results",
len(conn.response)
)
defer.returnValue(False)
logger.info(
"User authenticated against ldap server: %s",
conn
)
# check for existing account, if none exists, create one
if not (yield self.check_user_exists(user_id)):
# query user metadata for account creation
query = "({prop}={value})".format(
prop=self.ldap_attributes['uid'],
value=localpart
)
if self.ldap_mode == LDAPMode.SEARCH and self.ldap_filter:
query = "(&{filter}{user_filter})".format(
filter=query,
user_filter=self.ldap_filter
)
logger.debug("ldap registration filter: %s", query)
result = conn.search(
search_base=self.ldap_base,
search_filter=query,
attributes=[
self.ldap_attributes['name'],
self.ldap_attributes['mail']
]
)
if len(conn.response) == 1:
attrs = conn.response[0]['attributes']
mail = attrs[self.ldap_attributes['mail']][0]
name = attrs[self.ldap_attributes['name']][0]
# create account
registration_handler = self.hs.get_handlers().registration_handler
user_id, access_token = (
yield registration_handler.register(localpart=localpart)
)
# TODO: bind email, set displayname with data from ldap directory
logger.info(
"ldap registration successful: %d: %s (%s, %)",
user_id,
localpart,
name,
mail
)
else:
logger.warn(
"ldap registration failed: unexpected (%d!=1) amount of results",
len(conn.response)
)
defer.returnValue(False)
defer.returnValue(True)
except ldap3.core.exceptions.LDAPException as e:
logger.warn("Error during ldap authentication: %s", e)
defer.returnValue(False)
@defer.inlineCallbacks @defer.inlineCallbacks
def issue_access_token(self, user_id, device_id=None): def issue_access_token(self, user_id, device_id=None):
access_token = self.generate_access_token(user_id) access_token = self.generate_access_token(user_id)
@ -806,3 +659,30 @@ class AuthHandler(BaseHandler):
stored_hash.encode('utf-8')) == stored_hash stored_hash.encode('utf-8')) == stored_hash
else: else:
return False return False
class _AccountHandler(object):
"""A proxy object that gets passed to password auth providers so they
can register new users etc if necessary.
"""
def __init__(self, hs, check_user_exists):
self.hs = hs
self._check_user_exists = check_user_exists
def check_user_exists(self, user_id):
"""Check if user exissts.
Returns:
Deferred(bool)
"""
return self._check_user_exists(user_id)
def register(self, localpart):
"""Registers a new user with given localpart
Returns:
Deferred: a 2-tuple of (user_id, access_token)
"""
reg = self.hs.get_handlers().registration_handler
return reg.register(localpart=localpart)

View file

@ -58,7 +58,7 @@ class DeviceHandler(BaseHandler):
attempts = 0 attempts = 0
while attempts < 5: while attempts < 5:
try: try:
device_id = stringutils.random_string_with_symbols(16) device_id = stringutils.random_string(10).upper()
yield self.store.store_device( yield self.store.store_device(
user_id=user_id, user_id=user_id,
device_id=device_id, device_id=device_id,

View file

@ -288,13 +288,12 @@ class DirectoryHandler(BaseHandler):
result = yield as_handler.query_room_alias_exists(room_alias) result = yield as_handler.query_room_alias_exists(room_alias)
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks
def can_modify_alias(self, alias, user_id=None): def can_modify_alias(self, alias, user_id=None):
# Any application service "interested" in an alias they are regexing on # Any application service "interested" in an alias they are regexing on
# can modify the alias. # can modify the alias.
# Users can only modify the alias if ALL the interested services have # Users can only modify the alias if ALL the interested services have
# non-exclusive locks on the alias (or there are no interested services) # non-exclusive locks on the alias (or there are no interested services)
services = yield self.store.get_app_services() services = self.store.get_app_services()
interested_services = [ interested_services = [
s for s in services if s.is_interested_in_alias(alias.to_string()) s for s in services if s.is_interested_in_alias(alias.to_string())
] ]
@ -302,14 +301,12 @@ class DirectoryHandler(BaseHandler):
for service in interested_services: for service in interested_services:
if user_id == service.sender: if user_id == service.sender:
# this user IS the app service so they can do whatever they like # this user IS the app service so they can do whatever they like
defer.returnValue(True) return defer.succeed(True)
return
elif service.is_exclusive_alias(alias.to_string()): elif service.is_exclusive_alias(alias.to_string()):
# another service has an exclusive lock on this alias. # another service has an exclusive lock on this alias.
defer.returnValue(False) return defer.succeed(False)
return
# either no interested services, or no service with an exclusive lock # either no interested services, or no service with an exclusive lock
defer.returnValue(True) return defer.succeed(True)
@defer.inlineCallbacks @defer.inlineCallbacks
def _user_can_delete_alias(self, alias, user_id): def _user_can_delete_alias(self, alias, user_id):

View file

@ -13,14 +13,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json import ujson as json
import logging import logging
from canonicaljson import encode_canonical_json
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError, CodeMessageException from synapse.api.errors import SynapseError, CodeMessageException
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -29,7 +31,9 @@ class E2eKeysHandler(object):
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.federation = hs.get_replication_layer() self.federation = hs.get_replication_layer()
self.device_handler = hs.get_device_handler()
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self.clock = hs.get_clock()
# doesn't really work as part of the generic query API, because the # doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the # query request requires an object POST, but we abuse the
@ -85,27 +89,37 @@ class E2eKeysHandler(object):
def do_remote_query(destination): def do_remote_query(destination):
destination_query = remote_queries[destination] destination_query = remote_queries[destination]
try: try:
remote_result = yield self.federation.query_client_keys( limiter = yield get_retry_limiter(
destination, destination, self.clock, self.store
{"device_keys": destination_query},
timeout=timeout
) )
with limiter:
remote_result = yield self.federation.query_client_keys(
destination,
{"device_keys": destination_query},
timeout=timeout
)
for user_id, keys in remote_result["device_keys"].items(): for user_id, keys in remote_result["device_keys"].items():
if user_id in destination_query: if user_id in destination_query:
results[user_id] = keys results[user_id] = keys
except CodeMessageException as e: except CodeMessageException as e:
failures[destination] = { failures[destination] = {
"status": e.code, "message": e.message "status": e.code, "message": e.message
} }
except NotRetryingDestination as e:
failures[destination] = {
"status": 503, "message": "Not ready for retry",
}
yield preserve_context_over_deferred(defer.gatherResults([ yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(do_remote_query)(destination) preserve_fn(do_remote_query)(destination)
for destination in remote_queries for destination in remote_queries
])) ]))
defer.returnValue((200, { defer.returnValue({
"device_keys": results, "failures": failures, "device_keys": results, "failures": failures,
})) })
@defer.inlineCallbacks @defer.inlineCallbacks
def query_local_devices(self, query): def query_local_devices(self, query):
@ -159,3 +173,107 @@ class E2eKeysHandler(object):
device_keys_query = query_body.get("device_keys", {}) device_keys_query = query_body.get("device_keys", {})
res = yield self.query_local_devices(device_keys_query) res = yield self.query_local_devices(device_keys_query)
defer.returnValue({"device_keys": res}) defer.returnValue({"device_keys": res})
@defer.inlineCallbacks
def claim_one_time_keys(self, query, timeout):
local_query = []
remote_queries = {}
for user_id, device_keys in query.get("one_time_keys", {}).items():
if self.is_mine_id(user_id):
for device_id, algorithm in device_keys.items():
local_query.append((user_id, device_id, algorithm))
else:
domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = device_keys
results = yield self.store.claim_e2e_one_time_keys(local_query)
json_result = {}
failures = {}
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
for key_id, json_bytes in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
key_id: json.loads(json_bytes)
}
@defer.inlineCallbacks
def claim_client_keys(destination):
device_keys = remote_queries[destination]
try:
limiter = yield get_retry_limiter(
destination, self.clock, self.store
)
with limiter:
remote_result = yield self.federation.claim_client_keys(
destination,
{"one_time_keys": device_keys},
timeout=timeout
)
for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys:
json_result[user_id] = keys
except CodeMessageException as e:
failures[destination] = {
"status": e.code, "message": e.message
}
except NotRetryingDestination as e:
failures[destination] = {
"status": 503, "message": "Not ready for retry",
}
yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(claim_client_keys)(destination)
for destination in remote_queries
]))
defer.returnValue({
"one_time_keys": json_result,
"failures": failures
})
@defer.inlineCallbacks
def upload_keys_for_user(self, user_id, device_id, keys):
time_now = self.clock.time_msec()
# TODO: Validate the JSON to make sure it has the right keys.
device_keys = keys.get("device_keys", None)
if device_keys:
logger.info(
"Updating device_keys for device %r for user %s at %d",
device_id, user_id, time_now
)
# TODO: Sign the JSON with the server key
yield self.store.set_e2e_device_keys(
user_id, device_id, time_now,
encode_canonical_json(device_keys)
)
one_time_keys = keys.get("one_time_keys", None)
if one_time_keys:
logger.info(
"Adding %d one_time_keys for device %r for user %r at %d",
len(one_time_keys), device_id, user_id, time_now
)
key_list = []
for key_id, key_json in one_time_keys.items():
algorithm, key_id = key_id.split(":")
key_list.append((
algorithm, key_id, encode_canonical_json(key_json)
))
yield self.store.add_e2e_one_time_keys(
user_id, device_id, time_now, key_list
)
# the device should have been registered already, but it may have been
# deleted due to a race with a DELETE request. Or we may be using an
# old access_token without an associated device_id. Either way, we
# need to double-check the device is registered to avoid ending up with
# keys without a corresponding device.
self.device_handler.check_device_registered(user_id, device_id)
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
defer.returnValue({"one_time_key_counts": result})

View file

@ -1585,10 +1585,12 @@ class FederationHandler(BaseHandler):
current_state = set(e.event_id for e in auth_events.values()) current_state = set(e.event_id for e in auth_events.values())
different_auth = event_auth_events - current_state different_auth = event_auth_events - current_state
context.current_state_ids = dict(context.current_state_ids)
context.current_state_ids.update({ context.current_state_ids.update({
k: a.event_id for k, a in auth_events.items() k: a.event_id for k, a in auth_events.items()
if k != event_key if k != event_key
}) })
context.prev_state_ids = dict(context.prev_state_ids)
context.prev_state_ids.update({ context.prev_state_ids.update({
k: a.event_id for k, a in auth_events.items() k: a.event_id for k, a in auth_events.items()
}) })
@ -1670,10 +1672,12 @@ class FederationHandler(BaseHandler):
# 4. Look at rejects and their proofs. # 4. Look at rejects and their proofs.
# TODO. # TODO.
context.current_state_ids = dict(context.current_state_ids)
context.current_state_ids.update({ context.current_state_ids.update({
k: a.event_id for k, a in auth_events.items() k: a.event_id for k, a in auth_events.items()
if k != event_key if k != event_key
}) })
context.prev_state_ids = dict(context.prev_state_ids)
context.prev_state_ids.update({ context.prev_state_ids.update({
k: a.event_id for k, a in auth_events.items() k: a.event_id for k, a in auth_events.items()
}) })
@ -1918,15 +1922,18 @@ class FederationHandler(BaseHandler):
original_invite = yield self.store.get_event( original_invite = yield self.store.get_event(
original_invite_id, allow_none=True original_invite_id, allow_none=True
) )
if not original_invite: if original_invite:
display_name = original_invite.content["display_name"]
event_dict["content"]["third_party_invite"]["display_name"] = display_name
else:
logger.info( logger.info(
"Could not find invite event for third_party_invite - " "Could not find invite event for third_party_invite: %r",
"discarding: %s" % (event_dict,) event_dict
) )
return # We don't discard here as this is not the appropriate place to do
# auth checks. If we need the invite and don't have it then the
# auth check code will explode appropriately.
display_name = original_invite.content["display_name"]
event_dict["content"]["third_party_invite"]["display_name"] = display_name
builder = self.event_builder_factory.new(event_dict) builder = self.event_builder_factory.new(event_dict)
EventValidator().validate_new(builder) EventValidator().validate_new(builder)
message_handler = self.hs.get_handlers().message_handler message_handler = self.hs.get_handlers().message_handler

View file

@ -0,0 +1,443 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
from synapse.streams.config import PaginationConfig
from synapse.types import (
UserID, StreamToken,
)
from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute
from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
import logging
logger = logging.getLogger(__name__)
class InitialSyncHandler(BaseHandler):
def __init__(self, hs):
super(InitialSyncHandler, self).__init__(hs)
self.hs = hs
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
self.validator = EventValidator()
self.snapshot_cache = SnapshotCache()
def snapshot_all_rooms(self, user_id=None, pagin_config=None,
as_client_event=True, include_archived=False):
"""Retrieve a snapshot of all rooms the user is invited or has joined.
This snapshot may include messages for all rooms where the user is
joined, depending on the pagination config.
Args:
user_id (str): The ID of the user making the request.
pagin_config (synapse.api.streams.PaginationConfig): The pagination
config used to determine how many messages *PER ROOM* to return.
as_client_event (bool): True to get events in client-server format.
include_archived (bool): True to get rooms that the user has left
Returns:
A list of dicts with "room_id" and "membership" keys for all rooms
the user is currently invited or joined in on. Rooms where the user
is joined on, may return a "messages" key with messages, depending
on the specified PaginationConfig.
"""
key = (
user_id,
pagin_config.from_token,
pagin_config.to_token,
pagin_config.direction,
pagin_config.limit,
as_client_event,
include_archived,
)
now_ms = self.clock.time_msec()
result = self.snapshot_cache.get(now_ms, key)
if result is not None:
return result
return self.snapshot_cache.set(now_ms, key, self._snapshot_all_rooms(
user_id, pagin_config, as_client_event, include_archived
))
@defer.inlineCallbacks
def _snapshot_all_rooms(self, user_id=None, pagin_config=None,
as_client_event=True, include_archived=False):
memberships = [Membership.INVITE, Membership.JOIN]
if include_archived:
memberships.append(Membership.LEAVE)
room_list = yield self.store.get_rooms_for_user_where_membership_is(
user_id=user_id, membership_list=memberships
)
user = UserID.from_string(user_id)
rooms_ret = []
now_token = yield self.hs.get_event_sources().get_current_token()
presence_stream = self.hs.get_event_sources().sources["presence"]
pagination_config = PaginationConfig(from_token=now_token)
presence, _ = yield presence_stream.get_pagination_rows(
user, pagination_config.get_source_config("presence"), None
)
receipt_stream = self.hs.get_event_sources().sources["receipt"]
receipt, _ = yield receipt_stream.get_pagination_rows(
user, pagination_config.get_source_config("receipt"), None
)
tags_by_room = yield self.store.get_tags_for_user(user_id)
account_data, account_data_by_room = (
yield self.store.get_account_data_for_user(user_id)
)
public_room_ids = yield self.store.get_public_room_ids()
limit = pagin_config.limit
if limit is None:
limit = 10
@defer.inlineCallbacks
def handle_room(event):
d = {
"room_id": event.room_id,
"membership": event.membership,
"visibility": (
"public" if event.room_id in public_room_ids
else "private"
),
}
if event.membership == Membership.INVITE:
time_now = self.clock.time_msec()
d["inviter"] = event.sender
invite_event = yield self.store.get_event(event.event_id)
d["invite"] = serialize_event(invite_event, time_now, as_client_event)
rooms_ret.append(d)
if event.membership not in (Membership.JOIN, Membership.LEAVE):
return
try:
if event.membership == Membership.JOIN:
room_end_token = now_token.room_key
deferred_room_state = self.state_handler.get_current_state(
event.room_id
)
elif event.membership == Membership.LEAVE:
room_end_token = "s%d" % (event.stream_ordering,)
deferred_room_state = self.store.get_state_for_events(
[event.event_id], None
)
deferred_room_state.addCallback(
lambda states: states[event.event_id]
)
(messages, token), current_state = yield preserve_context_over_deferred(
defer.gatherResults(
[
preserve_fn(self.store.get_recent_events_for_room)(
event.room_id,
limit=limit,
end_token=room_end_token,
),
deferred_room_state,
]
)
).addErrback(unwrapFirstError)
messages = yield filter_events_for_client(
self.store, user_id, messages
)
start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec()
d["messages"] = {
"chunk": [
serialize_event(m, time_now, as_client_event)
for m in messages
],
"start": start_token.to_string(),
"end": end_token.to_string(),
}
d["state"] = [
serialize_event(c, time_now, as_client_event)
for c in current_state.values()
]
account_data_events = []
tags = tags_by_room.get(event.room_id)
if tags:
account_data_events.append({
"type": "m.tag",
"content": {"tags": tags},
})
account_data = account_data_by_room.get(event.room_id, {})
for account_data_type, content in account_data.items():
account_data_events.append({
"type": account_data_type,
"content": content,
})
d["account_data"] = account_data_events
except:
logger.exception("Failed to get snapshot")
yield concurrently_execute(handle_room, room_list, 10)
account_data_events = []
for account_data_type, content in account_data.items():
account_data_events.append({
"type": account_data_type,
"content": content,
})
ret = {
"rooms": rooms_ret,
"presence": presence,
"account_data": account_data_events,
"receipts": receipt,
"end": now_token.to_string(),
}
defer.returnValue(ret)
@defer.inlineCallbacks
def room_initial_sync(self, requester, room_id, pagin_config=None):
"""Capture the a snapshot of a room. If user is currently a member of
the room this will be what is currently in the room. If the user left
the room this will be what was in the room when they left.
Args:
requester(Requester): The user to get a snapshot for.
room_id(str): The room to get a snapshot of.
pagin_config(synapse.streams.config.PaginationConfig):
The pagination config used to determine how many messages to
return.
Raises:
AuthError if the user wasn't in the room.
Returns:
A JSON serialisable dict with the snapshot of the room.
"""
user_id = requester.user.to_string()
membership, member_event_id = yield self._check_in_room_or_world_readable(
room_id, user_id,
)
is_peeking = member_event_id is None
if membership == Membership.JOIN:
result = yield self._room_initial_sync_joined(
user_id, room_id, pagin_config, membership, is_peeking
)
elif membership == Membership.LEAVE:
result = yield self._room_initial_sync_parted(
user_id, room_id, pagin_config, membership, member_event_id, is_peeking
)
account_data_events = []
tags = yield self.store.get_tags_for_room(user_id, room_id)
if tags:
account_data_events.append({
"type": "m.tag",
"content": {"tags": tags},
})
account_data = yield self.store.get_account_data_for_room(user_id, room_id)
for account_data_type, content in account_data.items():
account_data_events.append({
"type": account_data_type,
"content": content,
})
result["account_data"] = account_data_events
defer.returnValue(result)
@defer.inlineCallbacks
def _room_initial_sync_parted(self, user_id, room_id, pagin_config,
membership, member_event_id, is_peeking):
room_state = yield self.store.get_state_for_events(
[member_event_id], None
)
room_state = room_state[member_event_id]
limit = pagin_config.limit if pagin_config else None
if limit is None:
limit = 10
stream_token = yield self.store.get_stream_token_for_event(
member_event_id
)
messages, token = yield self.store.get_recent_events_for_room(
room_id,
limit=limit,
end_token=stream_token
)
messages = yield filter_events_for_client(
self.store, user_id, messages, is_peeking=is_peeking
)
start_token = StreamToken.START.copy_and_replace("room_key", token[0])
end_token = StreamToken.START.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec()
defer.returnValue({
"membership": membership,
"room_id": room_id,
"messages": {
"chunk": [serialize_event(m, time_now) for m in messages],
"start": start_token.to_string(),
"end": end_token.to_string(),
},
"state": [serialize_event(s, time_now) for s in room_state.values()],
"presence": [],
"receipts": [],
})
@defer.inlineCallbacks
def _room_initial_sync_joined(self, user_id, room_id, pagin_config,
membership, is_peeking):
current_state = yield self.state.get_current_state(
room_id=room_id,
)
# TODO: These concurrently
time_now = self.clock.time_msec()
state = [
serialize_event(x, time_now)
for x in current_state.values()
]
now_token = yield self.hs.get_event_sources().get_current_token()
limit = pagin_config.limit if pagin_config else None
if limit is None:
limit = 10
room_members = [
m for m in current_state.values()
if m.type == EventTypes.Member
and m.content["membership"] == Membership.JOIN
]
presence_handler = self.hs.get_presence_handler()
@defer.inlineCallbacks
def get_presence():
states = yield presence_handler.get_states(
[m.user_id for m in room_members],
as_event=True,
)
defer.returnValue(states)
@defer.inlineCallbacks
def get_receipts():
receipts_handler = self.hs.get_handlers().receipts_handler
receipts = yield receipts_handler.get_receipts_for_room(
room_id,
now_token.receipt_key
)
defer.returnValue(receipts)
presence, receipts, (messages, token) = yield defer.gatherResults(
[
preserve_fn(get_presence)(),
preserve_fn(get_receipts)(),
preserve_fn(self.store.get_recent_events_for_room)(
room_id,
limit=limit,
end_token=now_token.room_key,
)
],
consumeErrors=True,
).addErrback(unwrapFirstError)
messages = yield filter_events_for_client(
self.store, user_id, messages, is_peeking=is_peeking,
)
start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec()
ret = {
"room_id": room_id,
"messages": {
"chunk": [serialize_event(m, time_now) for m in messages],
"start": start_token.to_string(),
"end": end_token.to_string(),
},
"state": state,
"presence": presence,
"receipts": receipts,
}
if not is_peeking:
ret["membership"] = membership
defer.returnValue(ret)
@defer.inlineCallbacks
def _check_in_room_or_world_readable(self, room_id, user_id):
try:
# check_user_was_in_room will return the most recent membership
# event for the user if:
# * The user is a non-guest user, and was ever in the room
# * The user is a guest user, and has joined the room
# else it will throw.
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
defer.returnValue((member_event.membership, member_event.event_id))
return
except AuthError:
visibility = yield self.state_handler.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, ""
)
if (
visibility and
visibility.content["history_visibility"] == "world_readable"
):
defer.returnValue((Membership.JOIN, None))
return
raise AuthError(
403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
)

View file

@ -21,14 +21,11 @@ from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.push.action_generator import ActionGenerator from synapse.push.action_generator import ActionGenerator
from synapse.streams.config import PaginationConfig
from synapse.types import ( from synapse.types import (
UserID, RoomAlias, RoomStreamToken, StreamToken, get_domain_from_id UserID, RoomAlias, RoomStreamToken, get_domain_from_id
) )
from synapse.util import unwrapFirstError from synapse.util.async import run_on_reactor, ReadWriteLock
from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLock from synapse.util.logcontext import preserve_fn
from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
@ -49,7 +46,6 @@ class MessageHandler(BaseHandler):
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.validator = EventValidator() self.validator = EventValidator()
self.snapshot_cache = SnapshotCache()
self.pagination_lock = ReadWriteLock() self.pagination_lock = ReadWriteLock()
@ -392,377 +388,6 @@ class MessageHandler(BaseHandler):
[serialize_event(c, now) for c in room_state.values()] [serialize_event(c, now) for c in room_state.values()]
) )
def snapshot_all_rooms(self, user_id=None, pagin_config=None,
as_client_event=True, include_archived=False):
"""Retrieve a snapshot of all rooms the user is invited or has joined.
This snapshot may include messages for all rooms where the user is
joined, depending on the pagination config.
Args:
user_id (str): The ID of the user making the request.
pagin_config (synapse.api.streams.PaginationConfig): The pagination
config used to determine how many messages *PER ROOM* to return.
as_client_event (bool): True to get events in client-server format.
include_archived (bool): True to get rooms that the user has left
Returns:
A list of dicts with "room_id" and "membership" keys for all rooms
the user is currently invited or joined in on. Rooms where the user
is joined on, may return a "messages" key with messages, depending
on the specified PaginationConfig.
"""
key = (
user_id,
pagin_config.from_token,
pagin_config.to_token,
pagin_config.direction,
pagin_config.limit,
as_client_event,
include_archived,
)
now_ms = self.clock.time_msec()
result = self.snapshot_cache.get(now_ms, key)
if result is not None:
return result
return self.snapshot_cache.set(now_ms, key, self._snapshot_all_rooms(
user_id, pagin_config, as_client_event, include_archived
))
@defer.inlineCallbacks
def _snapshot_all_rooms(self, user_id=None, pagin_config=None,
as_client_event=True, include_archived=False):
memberships = [Membership.INVITE, Membership.JOIN]
if include_archived:
memberships.append(Membership.LEAVE)
room_list = yield self.store.get_rooms_for_user_where_membership_is(
user_id=user_id, membership_list=memberships
)
user = UserID.from_string(user_id)
rooms_ret = []
now_token = yield self.hs.get_event_sources().get_current_token()
presence_stream = self.hs.get_event_sources().sources["presence"]
pagination_config = PaginationConfig(from_token=now_token)
presence, _ = yield presence_stream.get_pagination_rows(
user, pagination_config.get_source_config("presence"), None
)
receipt_stream = self.hs.get_event_sources().sources["receipt"]
receipt, _ = yield receipt_stream.get_pagination_rows(
user, pagination_config.get_source_config("receipt"), None
)
tags_by_room = yield self.store.get_tags_for_user(user_id)
account_data, account_data_by_room = (
yield self.store.get_account_data_for_user(user_id)
)
public_room_ids = yield self.store.get_public_room_ids()
limit = pagin_config.limit
if limit is None:
limit = 10
@defer.inlineCallbacks
def handle_room(event):
d = {
"room_id": event.room_id,
"membership": event.membership,
"visibility": (
"public" if event.room_id in public_room_ids
else "private"
),
}
if event.membership == Membership.INVITE:
time_now = self.clock.time_msec()
d["inviter"] = event.sender
invite_event = yield self.store.get_event(event.event_id)
d["invite"] = serialize_event(invite_event, time_now, as_client_event)
rooms_ret.append(d)
if event.membership not in (Membership.JOIN, Membership.LEAVE):
return
try:
if event.membership == Membership.JOIN:
room_end_token = now_token.room_key
deferred_room_state = self.state_handler.get_current_state(
event.room_id
)
elif event.membership == Membership.LEAVE:
room_end_token = "s%d" % (event.stream_ordering,)
deferred_room_state = self.store.get_state_for_events(
[event.event_id], None
)
deferred_room_state.addCallback(
lambda states: states[event.event_id]
)
(messages, token), current_state = yield preserve_context_over_deferred(
defer.gatherResults(
[
preserve_fn(self.store.get_recent_events_for_room)(
event.room_id,
limit=limit,
end_token=room_end_token,
),
deferred_room_state,
]
)
).addErrback(unwrapFirstError)
messages = yield filter_events_for_client(
self.store, user_id, messages
)
start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec()
d["messages"] = {
"chunk": [
serialize_event(m, time_now, as_client_event)
for m in messages
],
"start": start_token.to_string(),
"end": end_token.to_string(),
}
d["state"] = [
serialize_event(c, time_now, as_client_event)
for c in current_state.values()
]
account_data_events = []
tags = tags_by_room.get(event.room_id)
if tags:
account_data_events.append({
"type": "m.tag",
"content": {"tags": tags},
})
account_data = account_data_by_room.get(event.room_id, {})
for account_data_type, content in account_data.items():
account_data_events.append({
"type": account_data_type,
"content": content,
})
d["account_data"] = account_data_events
except:
logger.exception("Failed to get snapshot")
yield concurrently_execute(handle_room, room_list, 10)
account_data_events = []
for account_data_type, content in account_data.items():
account_data_events.append({
"type": account_data_type,
"content": content,
})
ret = {
"rooms": rooms_ret,
"presence": presence,
"account_data": account_data_events,
"receipts": receipt,
"end": now_token.to_string(),
}
defer.returnValue(ret)
@defer.inlineCallbacks
def room_initial_sync(self, requester, room_id, pagin_config=None):
"""Capture the a snapshot of a room. If user is currently a member of
the room this will be what is currently in the room. If the user left
the room this will be what was in the room when they left.
Args:
requester(Requester): The user to get a snapshot for.
room_id(str): The room to get a snapshot of.
pagin_config(synapse.streams.config.PaginationConfig):
The pagination config used to determine how many messages to
return.
Raises:
AuthError if the user wasn't in the room.
Returns:
A JSON serialisable dict with the snapshot of the room.
"""
user_id = requester.user.to_string()
membership, member_event_id = yield self._check_in_room_or_world_readable(
room_id, user_id,
)
is_peeking = member_event_id is None
if membership == Membership.JOIN:
result = yield self._room_initial_sync_joined(
user_id, room_id, pagin_config, membership, is_peeking
)
elif membership == Membership.LEAVE:
result = yield self._room_initial_sync_parted(
user_id, room_id, pagin_config, membership, member_event_id, is_peeking
)
account_data_events = []
tags = yield self.store.get_tags_for_room(user_id, room_id)
if tags:
account_data_events.append({
"type": "m.tag",
"content": {"tags": tags},
})
account_data = yield self.store.get_account_data_for_room(user_id, room_id)
for account_data_type, content in account_data.items():
account_data_events.append({
"type": account_data_type,
"content": content,
})
result["account_data"] = account_data_events
defer.returnValue(result)
@defer.inlineCallbacks
def _room_initial_sync_parted(self, user_id, room_id, pagin_config,
membership, member_event_id, is_peeking):
room_state = yield self.store.get_state_for_events(
[member_event_id], None
)
room_state = room_state[member_event_id]
limit = pagin_config.limit if pagin_config else None
if limit is None:
limit = 10
stream_token = yield self.store.get_stream_token_for_event(
member_event_id
)
messages, token = yield self.store.get_recent_events_for_room(
room_id,
limit=limit,
end_token=stream_token
)
messages = yield filter_events_for_client(
self.store, user_id, messages, is_peeking=is_peeking
)
start_token = StreamToken.START.copy_and_replace("room_key", token[0])
end_token = StreamToken.START.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec()
defer.returnValue({
"membership": membership,
"room_id": room_id,
"messages": {
"chunk": [serialize_event(m, time_now) for m in messages],
"start": start_token.to_string(),
"end": end_token.to_string(),
},
"state": [serialize_event(s, time_now) for s in room_state.values()],
"presence": [],
"receipts": [],
})
@defer.inlineCallbacks
def _room_initial_sync_joined(self, user_id, room_id, pagin_config,
membership, is_peeking):
current_state = yield self.state.get_current_state(
room_id=room_id,
)
# TODO: These concurrently
time_now = self.clock.time_msec()
state = [
serialize_event(x, time_now)
for x in current_state.values()
]
now_token = yield self.hs.get_event_sources().get_current_token()
limit = pagin_config.limit if pagin_config else None
if limit is None:
limit = 10
room_members = [
m for m in current_state.values()
if m.type == EventTypes.Member
and m.content["membership"] == Membership.JOIN
]
presence_handler = self.hs.get_presence_handler()
@defer.inlineCallbacks
def get_presence():
states = yield presence_handler.get_states(
[m.user_id for m in room_members],
as_event=True,
)
defer.returnValue(states)
@defer.inlineCallbacks
def get_receipts():
receipts_handler = self.hs.get_handlers().receipts_handler
receipts = yield receipts_handler.get_receipts_for_room(
room_id,
now_token.receipt_key
)
defer.returnValue(receipts)
presence, receipts, (messages, token) = yield defer.gatherResults(
[
preserve_fn(get_presence)(),
preserve_fn(get_receipts)(),
preserve_fn(self.store.get_recent_events_for_room)(
room_id,
limit=limit,
end_token=now_token.room_key,
)
],
consumeErrors=True,
).addErrback(unwrapFirstError)
messages = yield filter_events_for_client(
self.store, user_id, messages, is_peeking=is_peeking,
)
start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec()
ret = {
"room_id": room_id,
"messages": {
"chunk": [serialize_event(m, time_now) for m in messages],
"start": start_token.to_string(),
"end": end_token.to_string(),
},
"state": state,
"presence": presence,
"receipts": receipts,
}
if not is_peeking:
ret["membership"] = membership
defer.returnValue(ret)
@measure_func("_create_new_client_event") @measure_func("_create_new_client_event")
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_new_client_event(self, builder, prev_event_ids=None): def _create_new_client_event(self, builder, prev_event_ids=None):

View file

@ -217,7 +217,7 @@ class PresenceHandler(object):
is some spurious presence changes that will self-correct. is some spurious presence changes that will self-correct.
""" """
logger.info( logger.info(
"Performing _on_shutdown. Persiting %d unpersisted changes", "Performing _on_shutdown. Persisting %d unpersisted changes",
len(self.user_to_current_state) len(self.user_to_current_state)
) )
@ -234,7 +234,7 @@ class PresenceHandler(object):
may stack up and slow down shutdown times. may stack up and slow down shutdown times.
""" """
logger.info( logger.info(
"Performing _persist_unpersisted_changes. Persiting %d unpersisted changes", "Performing _persist_unpersisted_changes. Persisting %d unpersisted changes",
len(self.unpersisted_users_changes) len(self.unpersisted_users_changes)
) )

View file

@ -65,13 +65,13 @@ class ProfileHandler(BaseHandler):
defer.returnValue(result["displayname"]) defer.returnValue(result["displayname"])
@defer.inlineCallbacks @defer.inlineCallbacks
def set_displayname(self, target_user, requester, new_displayname): def set_displayname(self, target_user, requester, new_displayname, by_admin=False):
"""target_user is the user whose displayname is to be changed; """target_user is the user whose displayname is to be changed;
auth_user is the user attempting to make this change.""" auth_user is the user attempting to make this change."""
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
if target_user != requester.user: if not by_admin and target_user != requester.user:
raise AuthError(400, "Cannot set another user's displayname") raise AuthError(400, "Cannot set another user's displayname")
if new_displayname == '': if new_displayname == '':
@ -111,13 +111,13 @@ class ProfileHandler(BaseHandler):
defer.returnValue(result["avatar_url"]) defer.returnValue(result["avatar_url"])
@defer.inlineCallbacks @defer.inlineCallbacks
def set_avatar_url(self, target_user, requester, new_avatar_url): def set_avatar_url(self, target_user, requester, new_avatar_url, by_admin=False):
"""target_user is the user whose avatar_url is to be changed; """target_user is the user whose avatar_url is to be changed;
auth_user is the user attempting to make this change.""" auth_user is the user attempting to make this change."""
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
if target_user != requester.user: if not by_admin and target_user != requester.user:
raise AuthError(400, "Cannot set another user's avatar_url") raise AuthError(400, "Cannot set another user's avatar_url")
yield self.store.set_profile_avatar_url( yield self.store.set_profile_avatar_url(

View file

@ -19,7 +19,6 @@ import urllib
from twisted.internet import defer from twisted.internet import defer
import synapse.types
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
) )
@ -194,7 +193,7 @@ class RegistrationHandler(BaseHandler):
def appservice_register(self, user_localpart, as_token): def appservice_register(self, user_localpart, as_token):
user = UserID(user_localpart, self.hs.hostname) user = UserID(user_localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
service = yield self.store.get_app_service_by_token(as_token) service = self.store.get_app_service_by_token(as_token)
if not service: if not service:
raise AuthError(403, "Invalid application service token.") raise AuthError(403, "Invalid application service token.")
if not service.is_interested_in_user(user_id): if not service.is_interested_in_user(user_id):
@ -305,11 +304,10 @@ class RegistrationHandler(BaseHandler):
# XXX: This should be a deferred list, shouldn't it? # XXX: This should be a deferred list, shouldn't it?
yield identity_handler.bind_threepid(c, user_id) yield identity_handler.bind_threepid(c, user_id)
@defer.inlineCallbacks
def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None): def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None):
# valid user IDs must not clash with any user ID namespaces claimed by # valid user IDs must not clash with any user ID namespaces claimed by
# application services. # application services.
services = yield self.store.get_app_services() services = self.store.get_app_services()
interested_services = [ interested_services = [
s for s in services s for s in services
if s.is_interested_in_user(user_id) if s.is_interested_in_user(user_id)
@ -371,7 +369,7 @@ class RegistrationHandler(BaseHandler):
defer.returnValue(data) defer.returnValue(data)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_or_create_user(self, localpart, displayname, duration_in_ms, def get_or_create_user(self, requester, localpart, displayname, duration_in_ms,
password_hash=None): password_hash=None):
"""Creates a new user if the user does not exist, """Creates a new user if the user does not exist,
else revokes all previous access tokens and generates a new one. else revokes all previous access tokens and generates a new one.
@ -418,9 +416,8 @@ class RegistrationHandler(BaseHandler):
if displayname is not None: if displayname is not None:
logger.info("setting user display name: %s -> %s", user_id, displayname) logger.info("setting user display name: %s -> %s", user_id, displayname)
profile_handler = self.hs.get_handlers().profile_handler profile_handler = self.hs.get_handlers().profile_handler
requester = synapse.types.create_requester(user)
yield profile_handler.set_displayname( yield profile_handler.set_displayname(
user, requester, displayname user, requester, displayname, by_admin=True,
) )
defer.returnValue((user_id, token)) defer.returnValue((user_id, token))

View file

@ -20,12 +20,10 @@ from ._base import BaseHandler
from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken
from synapse.api.constants import ( from synapse.api.constants import (
EventTypes, JoinRules, RoomCreationPreset, Membership, EventTypes, JoinRules, RoomCreationPreset
) )
from synapse.api.errors import AuthError, StoreError, SynapseError from synapse.api.errors import AuthError, StoreError, SynapseError
from synapse.util import stringutils from synapse.util import stringutils
from synapse.util.async import concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
from collections import OrderedDict from collections import OrderedDict
@ -36,8 +34,6 @@ import string
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
id_server_scheme = "https://" id_server_scheme = "https://"
@ -348,159 +344,6 @@ class RoomCreationHandler(BaseHandler):
) )
class RoomListHandler(BaseHandler):
def __init__(self, hs):
super(RoomListHandler, self).__init__(hs)
self.response_cache = ResponseCache(hs)
self.remote_list_request_cache = ResponseCache(hs)
self.remote_list_cache = {}
self.fetch_looping_call = hs.get_clock().looping_call(
self.fetch_all_remote_lists, REMOTE_ROOM_LIST_POLL_INTERVAL
)
self.fetch_all_remote_lists()
def get_local_public_room_list(self):
result = self.response_cache.get(())
if not result:
result = self.response_cache.set((), self._get_public_room_list())
return result
@defer.inlineCallbacks
def _get_public_room_list(self):
room_ids = yield self.store.get_public_room_ids()
results = []
@defer.inlineCallbacks
def handle_room(room_id):
current_state = yield self.state_handler.get_current_state(room_id)
# Double check that this is actually a public room.
join_rules_event = current_state.get((EventTypes.JoinRules, ""))
if join_rules_event:
join_rule = join_rules_event.content.get("join_rule", None)
if join_rule and join_rule != JoinRules.PUBLIC:
defer.returnValue(None)
result = {"room_id": room_id}
num_joined_users = len([
1 for _, event in current_state.items()
if event.type == EventTypes.Member
and event.membership == Membership.JOIN
])
if num_joined_users == 0:
return
result["num_joined_members"] = num_joined_users
aliases = yield self.store.get_aliases_for_room(room_id)
if aliases:
result["aliases"] = aliases
name_event = yield current_state.get((EventTypes.Name, ""))
if name_event:
name = name_event.content.get("name", None)
if name:
result["name"] = name
topic_event = current_state.get((EventTypes.Topic, ""))
if topic_event:
topic = topic_event.content.get("topic", None)
if topic:
result["topic"] = topic
canonical_event = current_state.get((EventTypes.CanonicalAlias, ""))
if canonical_event:
canonical_alias = canonical_event.content.get("alias", None)
if canonical_alias:
result["canonical_alias"] = canonical_alias
visibility_event = current_state.get((EventTypes.RoomHistoryVisibility, ""))
visibility = None
if visibility_event:
visibility = visibility_event.content.get("history_visibility", None)
result["world_readable"] = visibility == "world_readable"
guest_event = current_state.get((EventTypes.GuestAccess, ""))
guest = None
if guest_event:
guest = guest_event.content.get("guest_access", None)
result["guest_can_join"] = guest == "can_join"
avatar_event = current_state.get(("m.room.avatar", ""))
if avatar_event:
avatar_url = avatar_event.content.get("url", None)
if avatar_url:
result["avatar_url"] = avatar_url
results.append(result)
yield concurrently_execute(handle_room, room_ids, 10)
# FIXME (erikj): START is no longer a valid value
defer.returnValue({"start": "START", "end": "END", "chunk": results})
@defer.inlineCallbacks
def fetch_all_remote_lists(self):
deferred = self.hs.get_replication_layer().get_public_rooms(
self.hs.config.secondary_directory_servers
)
self.remote_list_request_cache.set((), deferred)
self.remote_list_cache = yield deferred
@defer.inlineCallbacks
def get_remote_public_room_list(self, server_name):
res = yield self.hs.get_replication_layer().get_public_rooms(
[server_name]
)
if server_name not in res:
raise SynapseError(404, "Server not found")
defer.returnValue(res[server_name])
@defer.inlineCallbacks
def get_aggregated_public_room_list(self):
"""
Get the public room list from this server and the servers
specified in the secondary_directory_servers config option.
XXX: Pagination...
"""
# We return the results from out cache which is updated by a looping call,
# unless we're missing a cache entry, in which case wait for the result
# of the fetch if there's one in progress. If not, omit that server.
wait = False
for s in self.hs.config.secondary_directory_servers:
if s not in self.remote_list_cache:
logger.warn("No cached room list from %s: waiting for fetch", s)
wait = True
break
if wait and self.remote_list_request_cache.get(()):
yield self.remote_list_request_cache.get(())
public_rooms = yield self.get_local_public_room_list()
# keep track of which room IDs we've seen so we can de-dup
room_ids = set()
# tag all the ones in our list with our server name.
# Also add the them to the de-deping set
for room in public_rooms['chunk']:
room["server_name"] = self.hs.hostname
room_ids.add(room["room_id"])
# Now add the results from federation
for server_name, server_result in self.remote_list_cache.items():
for room in server_result["chunk"]:
if room["room_id"] not in room_ids:
room["server_name"] = server_name
public_rooms["chunk"].append(room)
room_ids.add(room["room_id"])
defer.returnValue(public_rooms)
class RoomContextHandler(BaseHandler): class RoomContextHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_event_context(self, user, room_id, event_id, limit, is_guest): def get_event_context(self, user, room_id, event_id, limit, is_guest):
@ -594,7 +437,7 @@ class RoomEventSource(object):
logger.warn("Stream has topological part!!!! %r", from_key) logger.warn("Stream has topological part!!!! %r", from_key)
from_key = "s%s" % (from_token.stream,) from_key = "s%s" % (from_token.stream,)
app_service = yield self.store.get_app_service_by_user_id( app_service = self.store.get_app_service_by_user_id(
user.to_string() user.to_string()
) )
if app_service: if app_service:

View file

@ -0,0 +1,403 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.constants import (
EventTypes, JoinRules,
)
from synapse.util.async import concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
from collections import namedtuple
from unpaddedbase64 import encode_base64, decode_base64
import logging
import msgpack
logger = logging.getLogger(__name__)
REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
class RoomListHandler(BaseHandler):
def __init__(self, hs):
super(RoomListHandler, self).__init__(hs)
self.response_cache = ResponseCache(hs)
self.remote_response_cache = ResponseCache(hs, timeout_ms=30 * 1000)
def get_local_public_room_list(self, limit=None, since_token=None,
search_filter=None):
if search_filter:
# We explicitly don't bother caching searches.
return self._get_public_room_list(limit, since_token, search_filter)
result = self.response_cache.get((limit, since_token))
if not result:
result = self.response_cache.set(
(limit, since_token),
self._get_public_room_list(limit, since_token)
)
return result
@defer.inlineCallbacks
def _get_public_room_list(self, limit=None, since_token=None,
search_filter=None):
if since_token and since_token != "END":
since_token = RoomListNextBatch.from_token(since_token)
else:
since_token = None
rooms_to_order_value = {}
rooms_to_num_joined = {}
rooms_to_latest_event_ids = {}
newly_visible = []
newly_unpublished = []
if since_token:
stream_token = since_token.stream_ordering
current_public_id = yield self.store.get_current_public_room_stream_id()
public_room_stream_id = since_token.public_room_stream_id
newly_visible, newly_unpublished = yield self.store.get_public_room_changes(
public_room_stream_id, current_public_id
)
else:
stream_token = yield self.store.get_room_max_stream_ordering()
public_room_stream_id = yield self.store.get_current_public_room_stream_id()
room_ids = yield self.store.get_public_room_ids_at_stream_id(
public_room_stream_id
)
# We want to return rooms in a particular order: the number of joined
# users. We then arbitrarily use the room_id as a tie breaker.
@defer.inlineCallbacks
def get_order_for_room(room_id):
latest_event_ids = rooms_to_latest_event_ids.get(room_id, None)
if not latest_event_ids:
latest_event_ids = yield self.store.get_forward_extremeties_for_room(
room_id, stream_token
)
rooms_to_latest_event_ids[room_id] = latest_event_ids
if not latest_event_ids:
return
joined_users = yield self.state_handler.get_current_user_in_room(
room_id, latest_event_ids,
)
num_joined_users = len(joined_users)
rooms_to_num_joined[room_id] = num_joined_users
if num_joined_users == 0:
return
# We want larger rooms to be first, hence negating num_joined_users
rooms_to_order_value[room_id] = (-num_joined_users, room_id)
yield concurrently_execute(get_order_for_room, room_ids, 10)
sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1])
sorted_rooms = [room_id for room_id, _ in sorted_entries]
# `sorted_rooms` should now be a list of all public room ids that is
# stable across pagination. Therefore, we can use indices into this
# list as our pagination tokens.
# Filter out rooms that we don't want to return
rooms_to_scan = [
r for r in sorted_rooms
if r not in newly_unpublished and rooms_to_num_joined[room_id] > 0
]
total_room_count = len(rooms_to_scan)
if since_token:
# Filter out rooms we've already returned previously
# `since_token.current_limit` is the index of the last room we
# sent down, so we exclude it and everything before/after it.
if since_token.direction_is_forward:
rooms_to_scan = rooms_to_scan[since_token.current_limit + 1:]
else:
rooms_to_scan = rooms_to_scan[:since_token.current_limit]
rooms_to_scan.reverse()
# Actually generate the entries. _generate_room_entry will append to
# chunk but will stop if len(chunk) > limit
chunk = []
if limit and not search_filter:
step = limit + 1
for i in xrange(0, len(rooms_to_scan), step):
# We iterate here because the vast majority of cases we'll stop
# at first iteration, but occaisonally _generate_room_entry
# won't append to the chunk and so we need to loop again.
# We don't want to scan over the entire range either as that
# would potentially waste a lot of work.
yield concurrently_execute(
lambda r: self._generate_room_entry(
r, rooms_to_num_joined[r],
chunk, limit, search_filter
),
rooms_to_scan[i:i + step], 10
)
if len(chunk) >= limit + 1:
break
else:
yield concurrently_execute(
lambda r: self._generate_room_entry(
r, rooms_to_num_joined[r],
chunk, limit, search_filter
),
rooms_to_scan, 5
)
chunk.sort(key=lambda e: (-e["num_joined_members"], e["room_id"]))
# Work out the new limit of the batch for pagination, or None if we
# know there are no more results that would be returned.
# i.e., [since_token.current_limit..new_limit] is the batch of rooms
# we've returned (or the reverse if we paginated backwards)
# We tried to pull out limit + 1 rooms above, so if we have <= limit
# then we know there are no more results to return
new_limit = None
if chunk and (not limit or len(chunk) > limit):
if not since_token or since_token.direction_is_forward:
if limit:
chunk = chunk[:limit]
last_room_id = chunk[-1]["room_id"]
else:
if limit:
chunk = chunk[-limit:]
last_room_id = chunk[0]["room_id"]
new_limit = sorted_rooms.index(last_room_id)
results = {
"chunk": chunk,
"total_room_count_estimate": total_room_count,
}
if since_token:
results["new_rooms"] = bool(newly_visible)
if not since_token or since_token.direction_is_forward:
if new_limit is not None:
results["next_batch"] = RoomListNextBatch(
stream_ordering=stream_token,
public_room_stream_id=public_room_stream_id,
current_limit=new_limit,
direction_is_forward=True,
).to_token()
if since_token:
results["prev_batch"] = since_token.copy_and_replace(
direction_is_forward=False,
current_limit=since_token.current_limit + 1,
).to_token()
else:
if new_limit is not None:
results["prev_batch"] = RoomListNextBatch(
stream_ordering=stream_token,
public_room_stream_id=public_room_stream_id,
current_limit=new_limit,
direction_is_forward=False,
).to_token()
if since_token:
results["next_batch"] = since_token.copy_and_replace(
direction_is_forward=True,
current_limit=since_token.current_limit - 1,
).to_token()
defer.returnValue(results)
@defer.inlineCallbacks
def _generate_room_entry(self, room_id, num_joined_users, chunk, limit,
search_filter):
if limit and len(chunk) > limit + 1:
# We've already got enough, so lets just drop it.
return
result = {
"room_id": room_id,
"num_joined_members": num_joined_users,
}
current_state_ids = yield self.state_handler.get_current_state_ids(room_id)
event_map = yield self.store.get_events([
event_id for key, event_id in current_state_ids.items()
if key[0] in (
EventTypes.JoinRules,
EventTypes.Name,
EventTypes.Topic,
EventTypes.CanonicalAlias,
EventTypes.RoomHistoryVisibility,
EventTypes.GuestAccess,
"m.room.avatar",
)
])
current_state = {
(ev.type, ev.state_key): ev
for ev in event_map.values()
}
# Double check that this is actually a public room.
join_rules_event = current_state.get((EventTypes.JoinRules, ""))
if join_rules_event:
join_rule = join_rules_event.content.get("join_rule", None)
if join_rule and join_rule != JoinRules.PUBLIC:
defer.returnValue(None)
aliases = yield self.store.get_aliases_for_room(room_id)
if aliases:
result["aliases"] = aliases
name_event = yield current_state.get((EventTypes.Name, ""))
if name_event:
name = name_event.content.get("name", None)
if name:
result["name"] = name
topic_event = current_state.get((EventTypes.Topic, ""))
if topic_event:
topic = topic_event.content.get("topic", None)
if topic:
result["topic"] = topic
canonical_event = current_state.get((EventTypes.CanonicalAlias, ""))
if canonical_event:
canonical_alias = canonical_event.content.get("alias", None)
if canonical_alias:
result["canonical_alias"] = canonical_alias
visibility_event = current_state.get((EventTypes.RoomHistoryVisibility, ""))
visibility = None
if visibility_event:
visibility = visibility_event.content.get("history_visibility", None)
result["world_readable"] = visibility == "world_readable"
guest_event = current_state.get((EventTypes.GuestAccess, ""))
guest = None
if guest_event:
guest = guest_event.content.get("guest_access", None)
result["guest_can_join"] = guest == "can_join"
avatar_event = current_state.get(("m.room.avatar", ""))
if avatar_event:
avatar_url = avatar_event.content.get("url", None)
if avatar_url:
result["avatar_url"] = avatar_url
if _matches_room_entry(result, search_filter):
chunk.append(result)
@defer.inlineCallbacks
def get_remote_public_room_list(self, server_name, limit=None, since_token=None,
search_filter=None):
if search_filter:
# We currently don't support searching across federation, so we have
# to do it manually without pagination
limit = None
since_token = None
res = yield self._get_remote_list_cached(
server_name, limit=limit, since_token=since_token,
)
if search_filter:
res = {"chunk": [
entry
for entry in list(res.get("chunk", []))
if _matches_room_entry(entry, search_filter)
]}
defer.returnValue(res)
def _get_remote_list_cached(self, server_name, limit=None, since_token=None,
search_filter=None):
repl_layer = self.hs.get_replication_layer()
if search_filter:
# We can't cache when asking for search
return repl_layer.get_public_rooms(
server_name, limit=limit, since_token=since_token,
search_filter=search_filter,
)
result = self.remote_response_cache.get((server_name, limit, since_token))
if not result:
result = self.remote_response_cache.set(
(server_name, limit, since_token),
repl_layer.get_public_rooms(
server_name, limit=limit, since_token=since_token,
search_filter=search_filter,
)
)
return result
class RoomListNextBatch(namedtuple("RoomListNextBatch", (
"stream_ordering", # stream_ordering of the first public room list
"public_room_stream_id", # public room stream id for first public room list
"current_limit", # The number of previous rooms returned
"direction_is_forward", # Bool if this is a next_batch, false if prev_batch
))):
KEY_DICT = {
"stream_ordering": "s",
"public_room_stream_id": "p",
"current_limit": "n",
"direction_is_forward": "d",
}
REVERSE_KEY_DICT = {v: k for k, v in KEY_DICT.items()}
@classmethod
def from_token(cls, token):
return RoomListNextBatch(**{
cls.REVERSE_KEY_DICT[key]: val
for key, val in msgpack.loads(decode_base64(token)).items()
})
def to_token(self):
return encode_base64(msgpack.dumps({
self.KEY_DICT[key]: val
for key, val in self._asdict().items()
}))
def copy_and_replace(self, **kwds):
return self._replace(
**kwds
)
def _matches_room_entry(room_entry, search_filter):
if search_filter and search_filter.get("generic_search_term", None):
generic_search_term = search_filter["generic_search_term"].upper()
if generic_search_term in room_entry.get("name", "").upper():
return True
elif generic_search_term in room_entry.get("topic", "").upper():
return True
elif generic_search_term in room_entry.get("canonical_alias", "").upper():
return True
else:
return True
return False

View file

@ -788,7 +788,7 @@ class SyncHandler(object):
assert since_token assert since_token
app_service = yield self.store.get_app_service_by_user_id(user_id) app_service = self.store.get_app_service_by_user_id(user_id)
if app_service: if app_service:
rooms = yield self.store.get_app_service_rooms(app_service) rooms = yield self.store.get_app_service_rooms(app_service)
joined_room_ids = set(r.room_id for r in rooms) joined_room_ids = set(r.room_id for r in rooms)

View file

@ -16,10 +16,9 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError from synapse.api.errors import SynapseError, AuthError
from synapse.util.logcontext import ( from synapse.util.logcontext import preserve_fn
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
)
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
from synapse.types import UserID, get_domain_from_id from synapse.types import UserID, get_domain_from_id
import logging import logging
@ -35,6 +34,13 @@ logger = logging.getLogger(__name__)
RoomMember = namedtuple("RoomMember", ("room_id", "user_id")) RoomMember = namedtuple("RoomMember", ("room_id", "user_id"))
# How often we expect remote servers to resend us presence.
FEDERATION_TIMEOUT = 60 * 1000
# How often to resend typing across federation.
FEDERATION_PING_INTERVAL = 40 * 1000
class TypingHandler(object): class TypingHandler(object):
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -44,7 +50,10 @@ class TypingHandler(object):
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.hs = hs
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.wheel_timer = WheelTimer(bucket_size=5000)
self.federation = hs.get_replication_layer() self.federation = hs.get_replication_layer()
@ -53,7 +62,7 @@ class TypingHandler(object):
hs.get_distributor().observe("user_left_room", self.user_left_room) hs.get_distributor().observe("user_left_room", self.user_left_room)
self._member_typing_until = {} # clock time we expect to stop self._member_typing_until = {} # clock time we expect to stop
self._member_typing_timer = {} # deferreds to manage theabove self._member_last_federation_poke = {}
# map room IDs to serial numbers # map room IDs to serial numbers
self._room_serials = {} self._room_serials = {}
@ -61,12 +70,41 @@ class TypingHandler(object):
# map room IDs to sets of users currently typing # map room IDs to sets of users currently typing
self._room_typing = {} self._room_typing = {}
def tearDown(self): self.clock.looping_call(
"""Cancels all the pending timers. self._handle_timeouts,
Normally this shouldn't be needed, but it's required from unit tests 5000,
to avoid a "Reactor was unclean" warning.""" )
for t in self._member_typing_timer.values():
self.clock.cancel_call_later(t) def _handle_timeouts(self):
logger.info("Checking for typing timeouts")
now = self.clock.time_msec()
members = set(self.wheel_timer.fetch(now))
for member in members:
if not self.is_typing(member):
# Nothing to do if they're no longer typing
continue
until = self._member_typing_until.get(member, None)
if not until or until < now:
logger.info("Timing out typing for: %s", member.user_id)
preserve_fn(self._stopped_typing)(member)
continue
# Check if we need to resend a keep alive over federation for this
# user.
if self.hs.is_mine_id(member.user_id):
last_fed_poke = self._member_last_federation_poke.get(member, None)
if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL < now:
preserve_fn(self._push_remote)(
member=member,
typing=True
)
def is_typing(self, member):
return member.user_id in self._room_typing.get(member.room_id, [])
@defer.inlineCallbacks @defer.inlineCallbacks
def started_typing(self, target_user, auth_user, room_id, timeout): def started_typing(self, target_user, auth_user, room_id, timeout):
@ -85,23 +123,17 @@ class TypingHandler(object):
"%s has started typing in %s", target_user_id, room_id "%s has started typing in %s", target_user_id, room_id
) )
until = self.clock.time_msec() + timeout
member = RoomMember(room_id=room_id, user_id=target_user_id) member = RoomMember(room_id=room_id, user_id=target_user_id)
was_present = member in self._member_typing_until was_present = member.user_id in self._room_typing.get(room_id, set())
if member in self._member_typing_timer: now = self.clock.time_msec()
self.clock.cancel_call_later(self._member_typing_timer[member]) self._member_typing_until[member] = now + timeout
def _cb(): self.wheel_timer.insert(
logger.debug( now=now,
"%s has timed out in %s", target_user.to_string(), room_id obj=member,
) then=now + timeout,
self._stopped_typing(member)
self._member_typing_until[member] = until
self._member_typing_timer[member] = self.clock.call_later(
timeout / 1000.0, _cb
) )
if was_present: if was_present:
@ -109,8 +141,7 @@ class TypingHandler(object):
defer.returnValue(None) defer.returnValue(None)
yield self._push_update( yield self._push_update(
room_id=room_id, member=member,
user_id=target_user_id,
typing=True, typing=True,
) )
@ -133,10 +164,6 @@ class TypingHandler(object):
member = RoomMember(room_id=room_id, user_id=target_user_id) member = RoomMember(room_id=room_id, user_id=target_user_id)
if member in self._member_typing_timer:
self.clock.cancel_call_later(self._member_typing_timer[member])
del self._member_typing_timer[member]
yield self._stopped_typing(member) yield self._stopped_typing(member)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -148,57 +175,61 @@ class TypingHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _stopped_typing(self, member): def _stopped_typing(self, member):
if member not in self._member_typing_until: if member.user_id not in self._room_typing.get(member.room_id, set()):
# No point # No point
defer.returnValue(None) defer.returnValue(None)
self._member_typing_until.pop(member, None)
self._member_last_federation_poke.pop(member, None)
yield self._push_update( yield self._push_update(
room_id=member.room_id, member=member,
user_id=member.user_id,
typing=False, typing=False,
) )
del self._member_typing_until[member] @defer.inlineCallbacks
def _push_update(self, member, typing):
if self.hs.is_mine_id(member.user_id):
# Only send updates for changes to our own users.
yield self._push_remote(member, typing)
if member in self._member_typing_timer: self._push_update_local(
# Don't cancel it - either it already expired, or the real member=member,
# stopped_typing() will cancel it typing=typing
del self._member_typing_timer[member] )
@defer.inlineCallbacks @defer.inlineCallbacks
def _push_update(self, room_id, user_id, typing): def _push_remote(self, member, typing):
users = yield self.state.get_current_user_in_room(room_id) users = yield self.state.get_current_user_in_room(member.room_id)
domains = set(get_domain_from_id(u) for u in users) self._member_last_federation_poke[member] = self.clock.time_msec()
deferreds = [] now = self.clock.time_msec()
for domain in domains: self.wheel_timer.insert(
if domain == self.server_name: now=now,
preserve_fn(self._push_update_local)( obj=member,
room_id=room_id, then=now + FEDERATION_PING_INTERVAL,
user_id=user_id, )
typing=typing
) for domain in set(get_domain_from_id(u) for u in users):
else: if domain != self.server_name:
deferreds.append(preserve_fn(self.federation.send_edu)( self.federation.send_edu(
destination=domain, destination=domain,
edu_type="m.typing", edu_type="m.typing",
content={ content={
"room_id": room_id, "room_id": member.room_id,
"user_id": user_id, "user_id": member.user_id,
"typing": typing, "typing": typing,
}, },
key=(room_id, user_id), key=member,
)) )
yield preserve_context_over_deferred(
defer.DeferredList(deferreds, consumeErrors=True)
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _recv_edu(self, origin, content): def _recv_edu(self, origin, content):
room_id = content["room_id"] room_id = content["room_id"]
user_id = content["user_id"] user_id = content["user_id"]
member = RoomMember(user_id=user_id, room_id=room_id)
# Check that the string is a valid user id # Check that the string is a valid user id
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
@ -213,26 +244,32 @@ class TypingHandler(object):
domains = set(get_domain_from_id(u) for u in users) domains = set(get_domain_from_id(u) for u in users)
if self.server_name in domains: if self.server_name in domains:
logger.info("Got typing update from %s: %r", user_id, content)
now = self.clock.time_msec()
self._member_typing_until[member] = now + FEDERATION_TIMEOUT
self.wheel_timer.insert(
now=now,
obj=member,
then=now + FEDERATION_TIMEOUT,
)
self._push_update_local( self._push_update_local(
room_id=room_id, member=member,
user_id=user_id,
typing=content["typing"] typing=content["typing"]
) )
def _push_update_local(self, room_id, user_id, typing): def _push_update_local(self, member, typing):
room_set = self._room_typing.setdefault(room_id, set()) room_set = self._room_typing.setdefault(member.room_id, set())
if typing: if typing:
room_set.add(user_id) room_set.add(member.user_id)
else: else:
room_set.discard(user_id) room_set.discard(member.user_id)
self._latest_room_serial += 1 self._latest_room_serial += 1
self._room_serials[room_id] = self._latest_room_serial self._room_serials[member.room_id] = self._latest_room_serial
with PreserveLoggingContext(): self.notifier.on_new_event(
self.notifier.on_new_event( "typing_key", self._latest_room_serial, rooms=[member.room_id]
"typing_key", self._latest_room_serial, rooms=[room_id] )
)
def get_all_typing_updates(self, last_id, current_id): def get_all_typing_updates(self, last_id, current_id):
# TODO: Work out a way to do this without scanning the entire state. # TODO: Work out a way to do this without scanning the entire state.

View file

@ -41,9 +41,13 @@ def parse_integer(request, name, default=None, required=False):
SynapseError: if the parameter is absent and required, or if the SynapseError: if the parameter is absent and required, or if the
parameter is present and not an integer. parameter is present and not an integer.
""" """
if name in request.args: return parse_integer_from_args(request.args, name, default, required)
def parse_integer_from_args(args, name, default=None, required=False):
if name in args:
try: try:
return int(request.args[name][0]) return int(args[name][0])
except: except:
message = "Query parameter %r must be an integer" % (name,) message = "Query parameter %r must be an integer" % (name,)
raise SynapseError(400, message) raise SynapseError(400, message)
@ -116,9 +120,15 @@ def parse_string(request, name, default=None, required=False,
parameter is present, must be one of a list of allowed values and parameter is present, must be one of a list of allowed values and
is not one of those allowed values. is not one of those allowed values.
""" """
return parse_string_from_args(
request.args, name, default, required, allowed_values, param_type,
)
if name in request.args:
value = request.args[name][0] def parse_string_from_args(args, name, default=None, required=False,
allowed_values=None, param_type="string"):
if name in args:
value = args[name][0]
if allowed_values is not None and value not in allowed_values: if allowed_values is not None and value not in allowed_values:
message = "Query parameter %r must be one of [%s]" % ( message = "Query parameter %r must be one of [%s]" % (
name, ", ".join(repr(v) for v in allowed_values) name, ", ".join(repr(v) for v in allowed_values)

View file

@ -263,6 +263,8 @@ BASE_APPEND_UNDERRIDE_RULES = [
} }
] ]
}, },
# XXX: once m.direct is standardised everywhere, we should use it to detect
# a DM from the user's perspective rather than this heuristic.
{ {
'rule_id': 'global/underride/.m.rule.room_one_to_one', 'rule_id': 'global/underride/.m.rule.room_one_to_one',
'conditions': [ 'conditions': [
@ -289,6 +291,34 @@ BASE_APPEND_UNDERRIDE_RULES = [
} }
] ]
}, },
# XXX: this is going to fire for events which aren't m.room.messages
# but are encrypted (e.g. m.call.*)...
{
'rule_id': 'global/underride/.m.rule.encrypted_room_one_to_one',
'conditions': [
{
'kind': 'room_member_count',
'is': '2',
'_id': 'member_count',
},
{
'kind': 'event_match',
'key': 'type',
'pattern': 'm.room.encrypted',
'_id': '_encrypted',
}
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default'
}, {
'set_tweak': 'highlight',
'value': False
}
]
},
{ {
'rule_id': 'global/underride/.m.rule.message', 'rule_id': 'global/underride/.m.rule.message',
'conditions': [ 'conditions': [
@ -305,6 +335,25 @@ BASE_APPEND_UNDERRIDE_RULES = [
'value': False 'value': False
} }
] ]
},
# XXX: this is going to fire for events which aren't m.room.messages
# but are encrypted (e.g. m.call.*)...
{
'rule_id': 'global/underride/.m.rule.encrypted',
'conditions': [
{
'kind': 'event_match',
'key': 'type',
'pattern': 'm.room.encrypted',
'_id': '_encrypted',
}
],
'actions': [
'notify', {
'set_tweak': 'highlight',
'value': False
}
]
} }
] ]

View file

@ -26,15 +26,6 @@ from synapse.visibility import filter_events_for_clients_context
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@defer.inlineCallbacks
def _get_rules(room_id, user_ids, store):
rules_by_user = yield store.bulk_get_push_rules(user_ids)
rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
defer.returnValue(rules_by_user)
@defer.inlineCallbacks @defer.inlineCallbacks
def evaluator_for_event(event, hs, store, context): def evaluator_for_event(event, hs, store, context):
rules_by_user = yield store.bulk_get_push_rules_for_room( rules_by_user = yield store.bulk_get_push_rules_for_room(
@ -48,6 +39,7 @@ def evaluator_for_event(event, hs, store, context):
if invited_user and hs.is_mine_id(invited_user): if invited_user and hs.is_mine_id(invited_user):
has_pusher = yield store.user_has_pusher(invited_user) has_pusher = yield store.user_has_pusher(invited_user)
if has_pusher: if has_pusher:
rules_by_user = dict(rules_by_user)
rules_by_user[invited_user] = yield store.get_push_rules_for_user( rules_by_user[invited_user] = yield store.get_push_rules_for_user(
invited_user invited_user
) )

View file

@ -328,7 +328,7 @@ class Mailer(object):
return messagevars return messagevars
@defer.inlineCallbacks @defer.inlineCallbacks
def make_summary_text(self, notifs_by_room, state_by_room, def make_summary_text(self, notifs_by_room, room_state_ids,
notif_events, user_id, reason): notif_events, user_id, reason):
if len(notifs_by_room) == 1: if len(notifs_by_room) == 1:
# Only one room has new stuff # Only one room has new stuff
@ -338,14 +338,18 @@ class Mailer(object):
# want the generated-from-names one here otherwise we'll # want the generated-from-names one here otherwise we'll
# end up with, "new message from Bob in the Bob room" # end up with, "new message from Bob in the Bob room"
room_name = yield calculate_room_name( room_name = yield calculate_room_name(
self.store, state_by_room[room_id], user_id, fallback_to_members=False self.store, room_state_ids[room_id], user_id, fallback_to_members=False
) )
my_member_event = state_by_room[room_id][("m.room.member", user_id)] my_member_event_id = room_state_ids[room_id][("m.room.member", user_id)]
my_member_event = yield self.store.get_event(my_member_event_id)
if my_member_event.content["membership"] == "invite": if my_member_event.content["membership"] == "invite":
inviter_member_event = state_by_room[room_id][ inviter_member_event_id = room_state_ids[room_id][
("m.room.member", my_member_event.sender) ("m.room.member", my_member_event.sender)
] ]
inviter_member_event = yield self.store.get_event(
inviter_member_event_id
)
inviter_name = name_from_member_event(inviter_member_event) inviter_name = name_from_member_event(inviter_member_event)
if room_name is None: if room_name is None:
@ -364,8 +368,11 @@ class Mailer(object):
if len(notifs_by_room[room_id]) == 1: if len(notifs_by_room[room_id]) == 1:
# There is just the one notification, so give some detail # There is just the one notification, so give some detail
event = notif_events[notifs_by_room[room_id][0]["event_id"]] event = notif_events[notifs_by_room[room_id][0]["event_id"]]
if ("m.room.member", event.sender) in state_by_room[room_id]: if ("m.room.member", event.sender) in room_state_ids[room_id]:
state_event = state_by_room[room_id][("m.room.member", event.sender)] state_event_id = room_state_ids[room_id][
("m.room.member", event.sender)
]
state_event = yield self.get_event(state_event_id)
sender_name = name_from_member_event(state_event) sender_name = name_from_member_event(state_event)
if sender_name is not None and room_name is not None: if sender_name is not None and room_name is not None:
@ -395,11 +402,13 @@ class Mailer(object):
for n in notifs_by_room[room_id] for n in notifs_by_room[room_id]
])) ]))
member_events = yield self.store.get_events([
room_state_ids[room_id][("m.room.member", s)]
for s in sender_ids
])
defer.returnValue(MESSAGES_FROM_PERSON % { defer.returnValue(MESSAGES_FROM_PERSON % {
"person": descriptor_from_member_events([ "person": descriptor_from_member_events(member_events.values()),
state_by_room[room_id][("m.room.member", s)]
for s in sender_ids
]),
"app": self.app_name, "app": self.app_name,
}) })
else: else:
@ -419,11 +428,13 @@ class Mailer(object):
for n in notifs_by_room[reason['room_id']] for n in notifs_by_room[reason['room_id']]
])) ]))
member_events = yield self.store.get_events([
room_state_ids[room_id][("m.room.member", s)]
for s in sender_ids
])
defer.returnValue(MESSAGES_FROM_PERSON_AND_OTHERS % { defer.returnValue(MESSAGES_FROM_PERSON_AND_OTHERS % {
"person": descriptor_from_member_events([ "person": descriptor_from_member_events(member_events.values()),
state_by_room[reason['room_id']][("m.room.member", s)]
for s in sender_ids
]),
"app": self.app_name, "app": self.app_name,
}) })

View file

@ -36,6 +36,7 @@ REQUIREMENTS = {
"blist": ["blist"], "blist": ["blist"],
"pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"], "pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"],
"pymacaroons-pynacl": ["pymacaroons"], "pymacaroons-pynacl": ["pymacaroons"],
"msgpack-python>=0.3.0": ["msgpack"],
} }
CONDITIONAL_REQUIREMENTS = { CONDITIONAL_REQUIREMENTS = {
"web_client": { "web_client": {

View file

@ -17,6 +17,7 @@ from synapse.http.servlet import parse_integer, parse_string
from synapse.http.server import request_handler, finish_request from synapse.http.server import request_handler, finish_request
from synapse.replication.pusher_resource import PusherResource from synapse.replication.pusher_resource import PusherResource
from synapse.replication.presence_resource import PresenceResource from synapse.replication.presence_resource import PresenceResource
from synapse.api.errors import SynapseError
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
@ -42,6 +43,7 @@ STREAM_NAMES = (
("pushers",), ("pushers",),
("caches",), ("caches",),
("to_device",), ("to_device",),
("public_rooms",),
) )
@ -131,6 +133,7 @@ class ReplicationResource(Resource):
push_rules_token, room_stream_token = self.store.get_push_rules_stream_token() push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
pushers_token = self.store.get_pushers_stream_token() pushers_token = self.store.get_pushers_stream_token()
caches_token = self.store.get_cache_stream_token() caches_token = self.store.get_cache_stream_token()
public_rooms_token = self.store.get_current_public_room_stream_id()
defer.returnValue(_ReplicationToken( defer.returnValue(_ReplicationToken(
room_stream_token, room_stream_token,
@ -144,6 +147,7 @@ class ReplicationResource(Resource):
0, # State stream is no longer a thing 0, # State stream is no longer a thing
caches_token, caches_token,
int(stream_token.to_device_key), int(stream_token.to_device_key),
int(public_rooms_token),
)) ))
@request_handler() @request_handler()
@ -163,7 +167,8 @@ class ReplicationResource(Resource):
def replicate(): def replicate():
return self.replicate(request_streams, limit) return self.replicate(request_streams, limit)
result = yield self.notifier.wait_for_replication(replicate, timeout) writer = yield self.notifier.wait_for_replication(replicate, timeout)
result = writer.finish()
for stream_name, stream_content in result.items(): for stream_name, stream_content in result.items():
logger.info( logger.info(
@ -183,6 +188,9 @@ class ReplicationResource(Resource):
current_token = yield self.current_replication_token() current_token = yield self.current_replication_token()
logger.debug("Replicating up to %r", current_token) logger.debug("Replicating up to %r", current_token)
if limit == 0:
raise SynapseError(400, "Limit cannot be 0")
yield self.account_data(writer, current_token, limit, request_streams) yield self.account_data(writer, current_token, limit, request_streams)
yield self.events(writer, current_token, limit, request_streams) yield self.events(writer, current_token, limit, request_streams)
# TODO: implement limit # TODO: implement limit
@ -193,10 +201,11 @@ class ReplicationResource(Resource):
yield self.pushers(writer, current_token, limit, request_streams) yield self.pushers(writer, current_token, limit, request_streams)
yield self.caches(writer, current_token, limit, request_streams) yield self.caches(writer, current_token, limit, request_streams)
yield self.to_device(writer, current_token, limit, request_streams) yield self.to_device(writer, current_token, limit, request_streams)
yield self.public_rooms(writer, current_token, limit, request_streams)
self.streams(writer, current_token, request_streams) self.streams(writer, current_token, request_streams)
logger.debug("Replicated %d rows", writer.total) logger.debug("Replicated %d rows", writer.total)
defer.returnValue(writer.finish()) defer.returnValue(writer)
def streams(self, writer, current_token, request_streams): def streams(self, writer, current_token, request_streams):
request_token = request_streams.get("streams") request_token = request_streams.get("streams")
@ -233,27 +242,48 @@ class ReplicationResource(Resource):
request_events = current_token.events request_events = current_token.events
if request_backfill is None: if request_backfill is None:
request_backfill = current_token.backfill request_backfill = current_token.backfill
no_new_tokens = (
request_events == current_token.events
and request_backfill == current_token.backfill
)
if no_new_tokens:
return
res = yield self.store.get_all_new_events( res = yield self.store.get_all_new_events(
request_backfill, request_events, request_backfill, request_events,
current_token.backfill, current_token.events, current_token.backfill, current_token.events,
limit limit
) )
writer.write_header_and_rows("events", res.new_forward_events, (
"position", "internal", "json", "state_group" upto_events_token = _position_from_rows(
)) res.new_forward_events, current_token.events
writer.write_header_and_rows("backfill", res.new_backfill_events, ( )
"position", "internal", "json", "state_group"
)) upto_backfill_token = _position_from_rows(
res.new_backfill_events, current_token.backfill
)
if request_events != upto_events_token:
writer.write_header_and_rows("events", res.new_forward_events, (
"position", "internal", "json", "state_group"
), position=upto_events_token)
if request_backfill != upto_backfill_token:
writer.write_header_and_rows("backfill", res.new_backfill_events, (
"position", "internal", "json", "state_group",
), position=upto_backfill_token)
writer.write_header_and_rows( writer.write_header_and_rows(
"forward_ex_outliers", res.forward_ex_outliers, "forward_ex_outliers", res.forward_ex_outliers,
("position", "event_id", "state_group") ("position", "event_id", "state_group"),
) )
writer.write_header_and_rows( writer.write_header_and_rows(
"backward_ex_outliers", res.backward_ex_outliers, "backward_ex_outliers", res.backward_ex_outliers,
("position", "event_id", "state_group") ("position", "event_id", "state_group"),
) )
writer.write_header_and_rows( writer.write_header_and_rows(
"state_resets", res.state_resets, ("position",) "state_resets", res.state_resets, ("position",),
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -262,15 +292,16 @@ class ReplicationResource(Resource):
request_presence = request_streams.get("presence") request_presence = request_streams.get("presence")
if request_presence is not None: if request_presence is not None and request_presence != current_position:
presence_rows = yield self.presence_handler.get_all_presence_updates( presence_rows = yield self.presence_handler.get_all_presence_updates(
request_presence, current_position request_presence, current_position
) )
upto_token = _position_from_rows(presence_rows, current_position)
writer.write_header_and_rows("presence", presence_rows, ( writer.write_header_and_rows("presence", presence_rows, (
"position", "user_id", "state", "last_active_ts", "position", "user_id", "state", "last_active_ts",
"last_federation_update_ts", "last_user_sync_ts", "last_federation_update_ts", "last_user_sync_ts",
"status_msg", "currently_active", "status_msg", "currently_active",
)) ), position=upto_token)
@defer.inlineCallbacks @defer.inlineCallbacks
def typing(self, writer, current_token, request_streams): def typing(self, writer, current_token, request_streams):
@ -278,7 +309,7 @@ class ReplicationResource(Resource):
request_typing = request_streams.get("typing") request_typing = request_streams.get("typing")
if request_typing is not None: if request_typing is not None and request_typing != current_position:
# If they have a higher token than current max, we can assume that # If they have a higher token than current max, we can assume that
# they had been talking to a previous instance of the master. Since # they had been talking to a previous instance of the master. Since
# we reset the token on restart, the best (but hacky) thing we can # we reset the token on restart, the best (but hacky) thing we can
@ -289,9 +320,10 @@ class ReplicationResource(Resource):
typing_rows = yield self.typing_handler.get_all_typing_updates( typing_rows = yield self.typing_handler.get_all_typing_updates(
request_typing, current_position request_typing, current_position
) )
upto_token = _position_from_rows(typing_rows, current_position)
writer.write_header_and_rows("typing", typing_rows, ( writer.write_header_and_rows("typing", typing_rows, (
"position", "room_id", "typing" "position", "room_id", "typing"
)) ), position=upto_token)
@defer.inlineCallbacks @defer.inlineCallbacks
def receipts(self, writer, current_token, limit, request_streams): def receipts(self, writer, current_token, limit, request_streams):
@ -299,13 +331,14 @@ class ReplicationResource(Resource):
request_receipts = request_streams.get("receipts") request_receipts = request_streams.get("receipts")
if request_receipts is not None: if request_receipts is not None and request_receipts != current_position:
receipts_rows = yield self.store.get_all_updated_receipts( receipts_rows = yield self.store.get_all_updated_receipts(
request_receipts, current_position, limit request_receipts, current_position, limit
) )
upto_token = _position_from_rows(receipts_rows, current_position)
writer.write_header_and_rows("receipts", receipts_rows, ( writer.write_header_and_rows("receipts", receipts_rows, (
"position", "room_id", "receipt_type", "user_id", "event_id", "data" "position", "room_id", "receipt_type", "user_id", "event_id", "data"
)) ), position=upto_token)
@defer.inlineCallbacks @defer.inlineCallbacks
def account_data(self, writer, current_token, limit, request_streams): def account_data(self, writer, current_token, limit, request_streams):
@ -320,23 +353,36 @@ class ReplicationResource(Resource):
user_account_data = current_position user_account_data = current_position
if room_account_data is None: if room_account_data is None:
room_account_data = current_position room_account_data = current_position
no_new_tokens = (
user_account_data == current_position
and room_account_data == current_position
)
if no_new_tokens:
return
user_rows, room_rows = yield self.store.get_all_updated_account_data( user_rows, room_rows = yield self.store.get_all_updated_account_data(
user_account_data, room_account_data, current_position, limit user_account_data, room_account_data, current_position, limit
) )
upto_users_token = _position_from_rows(user_rows, current_position)
upto_rooms_token = _position_from_rows(room_rows, current_position)
writer.write_header_and_rows("user_account_data", user_rows, ( writer.write_header_and_rows("user_account_data", user_rows, (
"position", "user_id", "type", "content" "position", "user_id", "type", "content"
)) ), position=upto_users_token)
writer.write_header_and_rows("room_account_data", room_rows, ( writer.write_header_and_rows("room_account_data", room_rows, (
"position", "user_id", "room_id", "type", "content" "position", "user_id", "room_id", "type", "content"
)) ), position=upto_rooms_token)
if tag_account_data is not None: if tag_account_data is not None:
tag_rows = yield self.store.get_all_updated_tags( tag_rows = yield self.store.get_all_updated_tags(
tag_account_data, current_position, limit tag_account_data, current_position, limit
) )
upto_tag_token = _position_from_rows(tag_rows, current_position)
writer.write_header_and_rows("tag_account_data", tag_rows, ( writer.write_header_and_rows("tag_account_data", tag_rows, (
"position", "user_id", "room_id", "tags" "position", "user_id", "room_id", "tags"
)) ), position=upto_tag_token)
@defer.inlineCallbacks @defer.inlineCallbacks
def push_rules(self, writer, current_token, limit, request_streams): def push_rules(self, writer, current_token, limit, request_streams):
@ -344,14 +390,15 @@ class ReplicationResource(Resource):
push_rules = request_streams.get("push_rules") push_rules = request_streams.get("push_rules")
if push_rules is not None: if push_rules is not None and push_rules != current_position:
rows = yield self.store.get_all_push_rule_updates( rows = yield self.store.get_all_push_rule_updates(
push_rules, current_position, limit push_rules, current_position, limit
) )
upto_token = _position_from_rows(rows, current_position)
writer.write_header_and_rows("push_rules", rows, ( writer.write_header_and_rows("push_rules", rows, (
"position", "event_stream_ordering", "user_id", "rule_id", "op", "position", "event_stream_ordering", "user_id", "rule_id", "op",
"priority_class", "priority", "conditions", "actions" "priority_class", "priority", "conditions", "actions"
)) ), position=upto_token)
@defer.inlineCallbacks @defer.inlineCallbacks
def pushers(self, writer, current_token, limit, request_streams): def pushers(self, writer, current_token, limit, request_streams):
@ -359,18 +406,19 @@ class ReplicationResource(Resource):
pushers = request_streams.get("pushers") pushers = request_streams.get("pushers")
if pushers is not None: if pushers is not None and pushers != current_position:
updated, deleted = yield self.store.get_all_updated_pushers( updated, deleted = yield self.store.get_all_updated_pushers(
pushers, current_position, limit pushers, current_position, limit
) )
upto_token = _position_from_rows(updated, current_position)
writer.write_header_and_rows("pushers", updated, ( writer.write_header_and_rows("pushers", updated, (
"position", "user_id", "access_token", "profile_tag", "kind", "position", "user_id", "access_token", "profile_tag", "kind",
"app_id", "app_display_name", "device_display_name", "pushkey", "app_id", "app_display_name", "device_display_name", "pushkey",
"ts", "lang", "data" "ts", "lang", "data"
)) ), position=upto_token)
writer.write_header_and_rows("deleted_pushers", deleted, ( writer.write_header_and_rows("deleted_pushers", deleted, (
"position", "user_id", "app_id", "pushkey" "position", "user_id", "app_id", "pushkey"
)) ), position=upto_token)
@defer.inlineCallbacks @defer.inlineCallbacks
def caches(self, writer, current_token, limit, request_streams): def caches(self, writer, current_token, limit, request_streams):
@ -378,13 +426,14 @@ class ReplicationResource(Resource):
caches = request_streams.get("caches") caches = request_streams.get("caches")
if caches is not None: if caches is not None and caches != current_position:
updated_caches = yield self.store.get_all_updated_caches( updated_caches = yield self.store.get_all_updated_caches(
caches, current_position, limit caches, current_position, limit
) )
upto_token = _position_from_rows(updated_caches, current_position)
writer.write_header_and_rows("caches", updated_caches, ( writer.write_header_and_rows("caches", updated_caches, (
"position", "cache_func", "keys", "invalidation_ts" "position", "cache_func", "keys", "invalidation_ts"
)) ), position=upto_token)
@defer.inlineCallbacks @defer.inlineCallbacks
def to_device(self, writer, current_token, limit, request_streams): def to_device(self, writer, current_token, limit, request_streams):
@ -392,13 +441,29 @@ class ReplicationResource(Resource):
to_device = request_streams.get("to_device") to_device = request_streams.get("to_device")
if to_device is not None: if to_device is not None and to_device != current_position:
to_device_rows = yield self.store.get_all_new_device_messages( to_device_rows = yield self.store.get_all_new_device_messages(
to_device, current_position, limit to_device, current_position, limit
) )
upto_token = _position_from_rows(to_device_rows, current_position)
writer.write_header_and_rows("to_device", to_device_rows, ( writer.write_header_and_rows("to_device", to_device_rows, (
"position", "user_id", "device_id", "message_json" "position", "user_id", "device_id", "message_json"
)) ), position=upto_token)
@defer.inlineCallbacks
def public_rooms(self, writer, current_token, limit, request_streams):
current_position = current_token.public_rooms
public_rooms = request_streams.get("public_rooms")
if public_rooms is not None and public_rooms != current_position:
public_rooms_rows = yield self.store.get_all_new_public_rooms(
public_rooms, current_position, limit
)
upto_token = _position_from_rows(public_rooms_rows, current_position)
writer.write_header_and_rows("public_rooms", public_rooms_rows, (
"position", "room_id", "visibility"
), position=upto_token)
class _Writer(object): class _Writer(object):
@ -408,11 +473,11 @@ class _Writer(object):
self.total = 0 self.total = 0
def write_header_and_rows(self, name, rows, fields, position=None): def write_header_and_rows(self, name, rows, fields, position=None):
if not rows:
return
if position is None: if position is None:
position = rows[-1][0] if rows:
position = rows[-1][0]
else:
return
self.streams[name] = { self.streams[name] = {
"position": position if type(position) is int else str(position), "position": position if type(position) is int else str(position),
@ -422,13 +487,16 @@ class _Writer(object):
self.total += len(rows) self.total += len(rows)
def __nonzero__(self):
return bool(self.total)
def finish(self): def finish(self):
return self.streams return self.streams
class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
"events", "presence", "typing", "receipts", "account_data", "backfill", "events", "presence", "typing", "receipts", "account_data", "backfill",
"push_rules", "pushers", "state", "caches", "to_device", "push_rules", "pushers", "state", "caches", "to_device", "public_rooms",
))): ))):
__slots__ = [] __slots__ = []
@ -443,3 +511,20 @@ class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
def __str__(self): def __str__(self):
return "_".join(str(value) for value in self) return "_".join(str(value) for value in self)
def _position_from_rows(rows, current_position):
"""Calculates a position to return for a stream. Ideally we want to return the
position of the last row, as that will be the most correct. However, if there
are no rows we fall back to using the current position to stop us from
repeatedly hitting the storage layer unncessarily thinking there are updates.
(Not all advances of the token correspond to an actual update)
We can't just always return the current position, as we often limit the
number of rows we replicate, and so the stream may lag. The assumption is
that if the storage layer returns no new rows then we are not lagging and
we are at the `current_position`.
"""
if rows:
return rows[-1][0]
return current_position

View file

@ -61,6 +61,9 @@ class SlavedEventStore(BaseSlavedStore):
"MembershipStreamChangeCache", events_max, "MembershipStreamChangeCache", events_max,
) )
self.stream_ordering_month_ago = 0
self._stream_order_on_start = self.get_room_max_stream_ordering()
# Cached functions can't be accessed through a class instance so we need # Cached functions can't be accessed through a class instance so we need
# to reach inside the __dict__ to extract them. # to reach inside the __dict__ to extract them.
get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"] get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"]
@ -168,6 +171,15 @@ class SlavedEventStore(BaseSlavedStore):
get_auth_chain_ids = DataStore.get_auth_chain_ids.__func__ get_auth_chain_ids = DataStore.get_auth_chain_ids.__func__
_get_auth_chain_ids_txn = DataStore._get_auth_chain_ids_txn.__func__ _get_auth_chain_ids_txn = DataStore._get_auth_chain_ids_txn.__func__
get_room_max_stream_ordering = DataStore.get_room_max_stream_ordering.__func__
get_forward_extremeties_for_room = (
DataStore.get_forward_extremeties_for_room.__func__
)
_get_forward_extremeties_for_room = (
EventFederationStore.__dict__["_get_forward_extremeties_for_room"]
)
def stream_positions(self): def stream_positions(self):
result = super(SlavedEventStore, self).stream_positions() result = super(SlavedEventStore, self).stream_positions()
result["events"] = self._stream_id_gen.get_current_token() result["events"] = self._stream_id_gen.get_current_token()

View file

@ -15,7 +15,39 @@
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from synapse.storage import DataStore from synapse.storage import DataStore
from ._slaved_id_tracker import SlavedIdTracker
class RoomStore(BaseSlavedStore): class RoomStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(RoomStore, self).__init__(db_conn, hs)
self._public_room_id_gen = SlavedIdTracker(
db_conn, "public_room_list_stream", "stream_id"
)
get_public_room_ids = DataStore.get_public_room_ids.__func__ get_public_room_ids = DataStore.get_public_room_ids.__func__
get_current_public_room_stream_id = (
DataStore.get_current_public_room_stream_id.__func__
)
get_public_room_ids_at_stream_id = (
DataStore.get_public_room_ids_at_stream_id.__func__
)
get_public_room_ids_at_stream_id_txn = (
DataStore.get_public_room_ids_at_stream_id_txn.__func__
)
get_published_at_stream_id_txn = (
DataStore.get_published_at_stream_id_txn.__func__
)
get_public_room_changes = DataStore.get_public_room_changes.__func__
def stream_positions(self):
result = super(RoomStore, self).stream_positions()
result["public_rooms"] = self._public_room_id_gen.get_current_token()
return result
def process_replication(self, result):
stream = result.get("public_rooms")
if stream:
self._public_room_id_gen.advance(int(stream["position"]))
return super(RoomStore, self).process_replication(result)

View file

@ -25,16 +25,15 @@ class InitialSyncRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(InitialSyncRestServlet, self).__init__(hs) super(InitialSyncRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.initial_sync_handler = hs.get_initial_sync_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
as_client_event = "raw" not in request.args as_client_event = "raw" not in request.args
pagination_config = PaginationConfig.from_request(request) pagination_config = PaginationConfig.from_request(request)
handler = self.handlers.message_handler
include_archived = request.args.get("archived", None) == ["true"] include_archived = request.args.get("archived", None) == ["true"]
content = yield handler.snapshot_all_rooms( content = yield self.initial_sync_handler.snapshot_all_rooms(
user_id=requester.user.to_string(), user_id=requester.user.to_string(),
pagin_config=pagination_config, pagin_config=pagination_config,
as_client_event=as_client_event, as_client_event=as_client_event,

View file

@ -22,6 +22,7 @@ from synapse.api.auth import get_access_token_from_request
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.types import create_requester
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
@ -391,15 +392,16 @@ class CreateUserRestServlet(ClientV1RestServlet):
user_json = parse_json_object_from_request(request) user_json = parse_json_object_from_request(request)
access_token = get_access_token_from_request(request) access_token = get_access_token_from_request(request)
app_service = yield self.store.get_app_service_by_token( app_service = self.store.get_app_service_by_token(
access_token access_token
) )
if not app_service: if not app_service:
raise SynapseError(403, "Invalid application service token.") raise SynapseError(403, "Invalid application service token.")
logger.debug("creating user: %s", user_json) requester = create_requester(app_service.sender)
response = yield self._do_create(user_json) logger.debug("creating user: %s", user_json)
response = yield self._do_create(requester, user_json)
defer.returnValue((200, response)) defer.returnValue((200, response))
@ -407,7 +409,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
return 403, {} return 403, {}
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_create(self, user_json): def _do_create(self, requester, user_json):
yield run_on_reactor() yield run_on_reactor()
if "localpart" not in user_json: if "localpart" not in user_json:
@ -433,6 +435,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
handler = self.handlers.registration_handler handler = self.handlers.registration_handler
user_id, token = yield handler.get_or_create_user( user_id, token = yield handler.get_or_create_user(
requester=requester,
localpart=localpart, localpart=localpart,
displayname=displayname, displayname=displayname,
duration_in_ms=(duration_seconds * 1000), duration_in_ms=(duration_seconds * 1000),

View file

@ -23,7 +23,9 @@ from synapse.api.constants import EventTypes, Membership
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.types import UserID, RoomID, RoomAlias from synapse.types import UserID, RoomID, RoomAlias
from synapse.events.utils import serialize_event, format_event_for_client_v2 from synapse.events.utils import serialize_event, format_event_for_client_v2
from synapse.http.servlet import parse_json_object_from_request, parse_string from synapse.http.servlet import (
parse_json_object_from_request, parse_string, parse_integer
)
import logging import logging
import urllib import urllib
@ -305,7 +307,7 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
server = parse_string(request, "server", default=None) server = parse_string(request, "server", default=None)
try: try:
yield self.auth.get_user_by_req(request) yield self.auth.get_user_by_req(request, allow_guest=True)
except AuthError as e: except AuthError as e:
# We allow people to not be authed if they're just looking at our # We allow people to not be authed if they're just looking at our
# room list, but require auth when we proxy the request. # room list, but require auth when we proxy the request.
@ -317,11 +319,49 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
else: else:
pass pass
limit = parse_integer(request, "limit", 0)
since_token = parse_string(request, "since", None)
handler = self.hs.get_room_list_handler() handler = self.hs.get_room_list_handler()
if server: if server:
data = yield handler.get_remote_public_room_list(server) data = yield handler.get_remote_public_room_list(
server,
limit=limit,
since_token=since_token,
)
else: else:
data = yield handler.get_aggregated_public_room_list() data = yield handler.get_local_public_room_list(
limit=limit,
since_token=since_token,
)
defer.returnValue((200, data))
@defer.inlineCallbacks
def on_POST(self, request):
yield self.auth.get_user_by_req(request, allow_guest=True)
server = parse_string(request, "server", default=None)
content = parse_json_object_from_request(request)
limit = int(content.get("limit", 100))
since_token = content.get("since", None)
search_filter = content.get("filter", None)
handler = self.hs.get_room_list_handler()
if server:
data = yield handler.get_remote_public_room_list(
server,
limit=limit,
since_token=since_token,
search_filter=search_filter,
)
else:
data = yield handler.get_local_public_room_list(
limit=limit,
since_token=since_token,
search_filter=search_filter,
)
defer.returnValue((200, data)) defer.returnValue((200, data))
@ -416,13 +456,13 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomInitialSyncRestServlet, self).__init__(hs) super(RoomInitialSyncRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.initial_sync_handler = hs.get_initial_sync_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = PaginationConfig.from_request(request) pagination_config = PaginationConfig.from_request(request)
content = yield self.handlers.message_handler.room_initial_sync( content = yield self.initial_sync_handler.room_initial_sync(
room_id=room_id, room_id=room_id,
requester=requester, requester=requester,
pagin_config=pagination_config, pagin_config=pagination_config,
@ -665,12 +705,15 @@ class RoomTypingRestServlet(ClientV1RestServlet):
yield self.presence_handler.bump_presence_active_time(requester.user) yield self.presence_handler.bump_presence_active_time(requester.user)
# Limit timeout to stop people from setting silly typing timeouts.
timeout = min(content.get("timeout", 30000), 120000)
if content["typing"]: if content["typing"]:
yield self.typing_handler.started_typing( yield self.typing_handler.started_typing(
target_user=target_user, target_user=target_user,
auth_user=requester.user, auth_user=requester.user,
room_id=room_id, room_id=room_id,
timeout=content.get("timeout", 30000), timeout=timeout,
) )
else: else:
yield self.typing_handler.stopped_typing( yield self.typing_handler.stopped_typing(

View file

@ -77,8 +77,10 @@ SUCCESS_TEMPLATE = """
user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'> user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
<link rel="stylesheet" href="/_matrix/static/client/register/style.css"> <link rel="stylesheet" href="/_matrix/static/client/register/style.css">
<script> <script>
if (window.onAuthDone != undefined) { if (window.onAuthDone) {
window.onAuthDone(); window.onAuthDone();
} else if (window.opener && window.opener.postMessage) {
window.opener.postMessage("authDone", "*");
} }
</script> </script>
</head> </head>

View file

@ -17,6 +17,7 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api import constants, errors
from synapse.http import servlet from synapse.http import servlet
from ._base import client_v2_patterns from ._base import client_v2_patterns
@ -58,6 +59,7 @@ class DeviceRestServlet(servlet.RestServlet):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.auth_handler = hs.get_auth_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, device_id): def on_GET(self, request, device_id):
@ -70,11 +72,24 @@ class DeviceRestServlet(servlet.RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_DELETE(self, request, device_id): def on_DELETE(self, request, device_id):
# XXX: it's not completely obvious we want to expose this endpoint. try:
# It allows the client to delete access tokens, which feels like a body = servlet.parse_json_object_from_request(request)
# thing which merits extra auth. But if we want to do the interactive-
# auth dance, we should really make it possible to delete more than one except errors.SynapseError as e:
# device at a time. if e.errcode == errors.Codes.NOT_JSON:
# deal with older clients which didn't pass a JSON dict
# the same as those that pass an empty dict
body = {}
else:
raise
authed, result, params, _ = yield self.auth_handler.check_auth([
[constants.LoginType.PASSWORD],
], body, self.hs.get_ip_from_request(request))
if not authed:
defer.returnValue((401, result))
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
yield self.device_handler.delete_device( yield self.device_handler.delete_device(
requester.user.to_string(), requester.user.to_string(),

View file

@ -15,16 +15,12 @@
import logging import logging
import simplejson as json
from canonicaljson import encode_canonical_json
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError, CodeMessageException from synapse.api.errors import SynapseError
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, parse_json_object_from_request, parse_integer RestServlet, parse_json_object_from_request, parse_integer
) )
from synapse.types import get_domain_from_id
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from ._base import client_v2_patterns from ._base import client_v2_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -64,17 +60,13 @@ class KeyUploadServlet(RestServlet):
hs (synapse.server.HomeServer): server hs (synapse.server.HomeServer): server
""" """
super(KeyUploadServlet, self).__init__() super(KeyUploadServlet, self).__init__()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler() self.e2e_keys_handler = hs.get_e2e_keys_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, device_id): def on_POST(self, request, device_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
if device_id is not None: if device_id is not None:
@ -94,47 +86,10 @@ class KeyUploadServlet(RestServlet):
"To upload keys, you must pass device_id when authenticating" "To upload keys, you must pass device_id when authenticating"
) )
time_now = self.clock.time_msec() result = yield self.e2e_keys_handler.upload_keys_for_user(
user_id, device_id, body
# TODO: Validate the JSON to make sure it has the right keys. )
device_keys = body.get("device_keys", None) defer.returnValue((200, result))
if device_keys:
logger.info(
"Updating device_keys for device %r for user %s at %d",
device_id, user_id, time_now
)
# TODO: Sign the JSON with the server key
yield self.store.set_e2e_device_keys(
user_id, device_id, time_now,
encode_canonical_json(device_keys)
)
one_time_keys = body.get("one_time_keys", None)
if one_time_keys:
logger.info(
"Adding %d one_time_keys for device %r for user %r at %d",
len(one_time_keys), device_id, user_id, time_now
)
key_list = []
for key_id, key_json in one_time_keys.items():
algorithm, key_id = key_id.split(":")
key_list.append((
algorithm, key_id, encode_canonical_json(key_json)
))
yield self.store.add_e2e_one_time_keys(
user_id, device_id, time_now, key_list
)
# the device should have been registered already, but it may have been
# deleted due to a race with a DELETE request. Or we may be using an
# old access_token without an associated device_id. Either way, we
# need to double-check the device is registered to avoid ending up with
# keys without a corresponding device.
self.device_handler.check_device_registered(user_id, device_id)
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
defer.returnValue((200, {"one_time_key_counts": result}))
class KeyQueryServlet(RestServlet): class KeyQueryServlet(RestServlet):
@ -199,7 +154,7 @@ class KeyQueryServlet(RestServlet):
timeout = parse_integer(request, "timeout", 10 * 1000) timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
result = yield self.e2e_keys_handler.query_devices(body, timeout) result = yield self.e2e_keys_handler.query_devices(body, timeout)
defer.returnValue(result) defer.returnValue((200, result))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id, device_id): def on_GET(self, request, user_id, device_id):
@ -212,7 +167,7 @@ class KeyQueryServlet(RestServlet):
{"device_keys": {user_id: device_ids}}, {"device_keys": {user_id: device_ids}},
timeout, timeout,
) )
defer.returnValue(result) defer.returnValue((200, result))
class OneTimeKeyServlet(RestServlet): class OneTimeKeyServlet(RestServlet):
@ -244,80 +199,29 @@ class OneTimeKeyServlet(RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(OneTimeKeyServlet, self).__init__() super(OneTimeKeyServlet, self).__init__()
self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.clock = hs.get_clock() self.e2e_keys_handler = hs.get_e2e_keys_handler()
self.federation = hs.get_replication_layer()
self.is_mine_id = hs.is_mine_id
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id, device_id, algorithm): def on_GET(self, request, user_id, device_id, algorithm):
yield self.auth.get_user_by_req(request) yield self.auth.get_user_by_req(request)
timeout = parse_integer(request, "timeout", 10 * 1000) timeout = parse_integer(request, "timeout", 10 * 1000)
result = yield self.handle_request( result = yield self.e2e_keys_handler.claim_one_time_keys(
{"one_time_keys": {user_id: {device_id: algorithm}}}, {"one_time_keys": {user_id: {device_id: algorithm}}},
timeout, timeout,
) )
defer.returnValue(result) defer.returnValue((200, result))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, user_id, device_id, algorithm): def on_POST(self, request, user_id, device_id, algorithm):
yield self.auth.get_user_by_req(request) yield self.auth.get_user_by_req(request)
timeout = parse_integer(request, "timeout", 10 * 1000) timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
result = yield self.handle_request(body, timeout) result = yield self.e2e_keys_handler.claim_one_time_keys(
defer.returnValue(result) body,
timeout,
@defer.inlineCallbacks )
def handle_request(self, body, timeout): defer.returnValue((200, result))
local_query = []
remote_queries = {}
for user_id, device_keys in body.get("one_time_keys", {}).items():
if self.is_mine_id(user_id):
for device_id, algorithm in device_keys.items():
local_query.append((user_id, device_id, algorithm))
else:
domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = device_keys
results = yield self.store.claim_e2e_one_time_keys(local_query)
json_result = {}
failures = {}
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
for key_id, json_bytes in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
key_id: json.loads(json_bytes)
}
@defer.inlineCallbacks
def claim_client_keys(destination):
device_keys = remote_queries[destination]
try:
remote_result = yield self.federation.claim_client_keys(
destination,
{"one_time_keys": device_keys},
timeout=timeout
)
for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys:
json_result[user_id] = keys
except CodeMessageException as e:
failures[destination] = {
"status": e.code, "message": e.message
}
yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(claim_client_keys)(destination)
for destination in remote_queries
]))
defer.returnValue((200, {
"one_time_keys": json_result,
"failures": failures
}))
def register_servlets(hs, http_server): def register_servlets(hs, http_server):

View file

@ -19,8 +19,6 @@ from synapse.http.server import respond_with_json_bytes
from signedjson.sign import sign_json from signedjson.sign import sign_json
from unpaddedbase64 import encode_base64 from unpaddedbase64 import encode_base64
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from hashlib import sha256
from OpenSSL import crypto
import logging import logging
@ -48,8 +46,12 @@ class LocalKey(Resource):
"expired_ts": # integer posix timestamp when the key expired. "expired_ts": # integer posix timestamp when the key expired.
"key": # base64 encoded NACL verification key. "key": # base64 encoded NACL verification key.
} }
} },
"tls_certificate": # base64 ASN.1 DER encoded X.509 tls cert. "tls_fingerprints": [ # Fingerprints of the TLS certs this server uses.
{
"sha256": # base64 encoded sha256 fingerprint of the X509 cert
},
],
"signatures": { "signatures": {
"this.server.example.com": { "this.server.example.com": {
"algorithm:version": # NACL signature for this server "algorithm:version": # NACL signature for this server
@ -90,21 +92,14 @@ class LocalKey(Resource):
u"expired_ts": key.expired, u"expired_ts": key.expired,
} }
x509_certificate_bytes = crypto.dump_certificate( tls_fingerprints = self.config.tls_fingerprints
crypto.FILETYPE_ASN1,
self.config.tls_certificate
)
sha256_fingerprint = sha256(x509_certificate_bytes).digest()
json_object = { json_object = {
u"valid_until_ts": self.valid_until_ts, u"valid_until_ts": self.valid_until_ts,
u"server_name": self.config.server_name, u"server_name": self.config.server_name,
u"verify_keys": verify_keys, u"verify_keys": verify_keys,
u"old_verify_keys": old_verify_keys, u"old_verify_keys": old_verify_keys,
u"tls_fingerprints": [{ u"tls_fingerprints": tls_fingerprints,
u"sha256": encode_base64(sha256_fingerprint),
}]
} }
for key in self.config.signing_key: for key in self.config.signing_key:
json_object = sign_json( json_object = sign_json(

View file

@ -39,10 +39,11 @@ from synapse.handlers.devicemessage import DeviceMessageHandler
from synapse.handlers.device import DeviceHandler from synapse.handlers.device import DeviceHandler
from synapse.handlers.e2e_keys import E2eKeysHandler from synapse.handlers.e2e_keys import E2eKeysHandler
from synapse.handlers.presence import PresenceHandler from synapse.handlers.presence import PresenceHandler
from synapse.handlers.room import RoomListHandler from synapse.handlers.room_list import RoomListHandler
from synapse.handlers.sync import SyncHandler from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import TypingHandler from synapse.handlers.typing import TypingHandler
from synapse.handlers.events import EventHandler, EventStreamHandler from synapse.handlers.events import EventHandler, EventStreamHandler
from synapse.handlers.initial_sync import InitialSyncHandler
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.notifier import Notifier from synapse.notifier import Notifier
@ -98,6 +99,7 @@ class HomeServer(object):
'e2e_keys_handler', 'e2e_keys_handler',
'event_handler', 'event_handler',
'event_stream_handler', 'event_stream_handler',
'initial_sync_handler',
'application_service_api', 'application_service_api',
'application_service_scheduler', 'application_service_scheduler',
'application_service_handler', 'application_service_handler',
@ -228,6 +230,9 @@ class HomeServer(object):
def build_event_stream_handler(self): def build_event_stream_handler(self):
return EventStreamHandler(self) return EventStreamHandler(self)
def build_initial_sync_handler(self):
return InitialSyncHandler(self)
def build_event_sources(self): def build_event_sources(self):
return EventSources(self) return EventSources(self)

View file

@ -26,6 +26,7 @@ from synapse.events.snapshot import EventContext
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
from collections import namedtuple from collections import namedtuple
from frozendict import frozendict
import logging import logging
import hashlib import hashlib
@ -58,11 +59,11 @@ class _StateCacheEntry(object):
__slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"] __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
def __init__(self, state, state_group, prev_group=None, delta_ids=None): def __init__(self, state, state_group, prev_group=None, delta_ids=None):
self.state = state self.state = frozendict(state)
self.state_group = state_group self.state_group = state_group
self.prev_group = prev_group self.prev_group = prev_group
self.delta_ids = delta_ids self.delta_ids = frozendict(delta_ids) if delta_ids is not None else None
# The `state_id` is a unique ID we generate that can be used as ID for # The `state_id` is a unique ID we generate that can be used as ID for
# this collection of state. Usually this would be the same as the # this collection of state. Usually this would be the same as the
@ -156,8 +157,9 @@ class StateHandler(object):
defer.returnValue(state) defer.returnValue(state)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_current_user_in_room(self, room_id): def get_current_user_in_room(self, room_id, latest_event_ids=None):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
entry = yield self.resolve_state_groups(room_id, latest_event_ids) entry = yield self.resolve_state_groups(room_id, latest_event_ids)
joined_users = yield self.store.get_joined_users_from_state( joined_users = yield self.store.get_joined_users_from_state(
room_id, entry.state_id, entry.state room_id, entry.state_id, entry.state
@ -237,13 +239,7 @@ class StateHandler(object):
context.prev_state_ids = curr_state context.prev_state_ids = curr_state
if event.is_state(): if event.is_state():
context.state_group = self.store.get_next_state_group() context.state_group = self.store.get_next_state_group()
else:
if entry.state_group is None:
entry.state_group = self.store.get_next_state_group()
entry.state_id = entry.state_group
context.state_group = entry.state_group
if event.is_state():
key = (event.type, event.state_key) key = (event.type, event.state_key)
if key in context.prev_state_ids: if key in context.prev_state_ids:
replaces = context.prev_state_ids[key] replaces = context.prev_state_ids[key]
@ -255,10 +251,15 @@ class StateHandler(object):
context.prev_group = entry.prev_group context.prev_group = entry.prev_group
context.delta_ids = entry.delta_ids context.delta_ids = entry.delta_ids
if context.delta_ids is not None: if context.delta_ids is not None:
context.delta_ids = dict(context.delta_ids)
context.delta_ids[key] = event.event_id context.delta_ids[key] = event.event_id
else: else:
context.current_state_ids = context.prev_state_ids if entry.state_group is None:
entry.state_group = self.store.get_next_state_group()
entry.state_id = entry.state_group
context.state_group = entry.state_group
context.current_state_ids = context.prev_state_ids
context.prev_group = entry.prev_group context.prev_group = entry.prev_group
context.delta_ids = entry.delta_ids context.delta_ids = entry.delta_ids

View file

@ -113,6 +113,9 @@ class DataStore(RoomMemberStore, RoomStore,
self._device_inbox_id_gen = StreamIdGenerator( self._device_inbox_id_gen = StreamIdGenerator(
db_conn, "device_max_stream_id", "stream_id" db_conn, "device_max_stream_id", "stream_id"
) )
self._public_room_id_gen = StreamIdGenerator(
db_conn, "public_room_list_stream", "stream_id"
)
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id") self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
@ -219,6 +222,8 @@ class DataStore(RoomMemberStore, RoomStore,
self._find_stream_orderings_for_times, 60 * 60 * 1000 self._find_stream_orderings_for_times, 60 * 60 * 1000
) )
self._stream_order_on_start = self.get_room_max_stream_ordering()
super(DataStore, self).__init__(hs) super(DataStore, self).__init__(hs)
def take_presence_startup_info(self): def take_presence_startup_info(self):

View file

@ -37,7 +37,7 @@ class ApplicationServiceStore(SQLBaseStore):
) )
def get_app_services(self): def get_app_services(self):
return defer.succeed(self.services_cache) return self.services_cache
def get_app_service_by_user_id(self, user_id): def get_app_service_by_user_id(self, user_id):
"""Retrieve an application service from their user ID. """Retrieve an application service from their user ID.
@ -54,8 +54,8 @@ class ApplicationServiceStore(SQLBaseStore):
""" """
for service in self.services_cache: for service in self.services_cache:
if service.sender == user_id: if service.sender == user_id:
return defer.succeed(service) return service
return defer.succeed(None) return None
def get_app_service_by_token(self, token): def get_app_service_by_token(self, token):
"""Get the application service with the given appservice token. """Get the application service with the given appservice token.
@ -67,8 +67,8 @@ class ApplicationServiceStore(SQLBaseStore):
""" """
for service in self.services_cache: for service in self.services_cache:
if service.token == token: if service.token == token:
return defer.succeed(service) return service
return defer.succeed(None) return None
def get_app_service_rooms(self, service): def get_app_service_rooms(self, service):
"""Get a list of RoomsForUser for this application service. """Get a list of RoomsForUser for this application service.
@ -163,7 +163,7 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
["as_id"] ["as_id"]
) )
# NB: This assumes this class is linked with ApplicationServiceStore # NB: This assumes this class is linked with ApplicationServiceStore
as_list = yield self.get_app_services() as_list = self.get_app_services()
services = [] services = []
for res in results: for res in results:

View file

@ -16,6 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.api.errors import StoreError
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from unpaddedbase64 import encode_base64 from unpaddedbase64 import encode_base64
@ -36,6 +37,13 @@ class EventFederationStore(SQLBaseStore):
and backfilling from another server respectively. and backfilling from another server respectively.
""" """
def __init__(self, hs):
super(EventFederationStore, self).__init__(hs)
hs.get_clock().looping_call(
self._delete_old_forward_extrem_cache, 60 * 60 * 1000
)
def get_auth_chain(self, event_ids): def get_auth_chain(self, event_ids):
return self.get_auth_chain_ids(event_ids).addCallback(self._get_events) return self.get_auth_chain_ids(event_ids).addCallback(self._get_events)
@ -270,6 +278,37 @@ class EventFederationStore(SQLBaseStore):
] ]
) )
# We now insert into stream_ordering_to_exterm a mapping from room_id,
# new stream_ordering to new forward extremeties in the room.
# This allows us to later efficiently look up the forward extremeties
# for a room before a given stream_ordering
max_stream_ord = max(
ev.internal_metadata.stream_ordering for ev in events
)
new_extrem = {}
for room_id in events_by_room:
event_ids = self._simple_select_onecol_txn(
txn,
table="event_forward_extremities",
keyvalues={"room_id": room_id},
retcol="event_id",
)
new_extrem[room_id] = event_ids
self._simple_insert_many_txn(
txn,
table="stream_ordering_to_exterm",
values=[
{
"room_id": room_id,
"event_id": event_id,
"stream_ordering": max_stream_ord,
}
for room_id, extrem_evs in new_extrem.items()
for event_id in extrem_evs
]
)
query = ( query = (
"INSERT INTO event_backward_extremities (event_id, room_id)" "INSERT INTO event_backward_extremities (event_id, room_id)"
" SELECT ?, ? WHERE NOT EXISTS (" " SELECT ?, ? WHERE NOT EXISTS ("
@ -305,6 +344,75 @@ class EventFederationStore(SQLBaseStore):
self.get_latest_event_ids_in_room.invalidate, (room_id,) self.get_latest_event_ids_in_room.invalidate, (room_id,)
) )
def get_forward_extremeties_for_room(self, room_id, stream_ordering):
# We want to make the cache more effective, so we clamp to the last
# change before the given ordering.
last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id)
# We don't always have a full stream_to_exterm_id table, e.g. after
# the upgrade that introduced it, so we make sure we never ask for a
# try and pin to a stream_ordering from before a restart
last_change = max(self._stream_order_on_start, last_change)
if last_change > self.stream_ordering_month_ago:
stream_ordering = min(last_change, stream_ordering)
return self._get_forward_extremeties_for_room(room_id, stream_ordering)
@cached(max_entries=5000, num_args=2)
def _get_forward_extremeties_for_room(self, room_id, stream_ordering):
"""For a given room_id and stream_ordering, return the forward
extremeties of the room at that point in "time".
Throws a StoreError if we have since purged the index for
stream_orderings from that point.
"""
if stream_ordering <= self.stream_ordering_month_ago:
raise StoreError(400, "stream_ordering too old")
sql = ("""
SELECT event_id FROM stream_ordering_to_exterm
INNER JOIN (
SELECT room_id, MAX(stream_ordering) AS stream_ordering
FROM stream_ordering_to_exterm
WHERE stream_ordering <= ? GROUP BY room_id
) AS rms USING (room_id, stream_ordering)
WHERE room_id = ?
""")
def get_forward_extremeties_for_room_txn(txn):
txn.execute(sql, (stream_ordering, room_id))
rows = txn.fetchall()
return [event_id for event_id, in rows]
return self.runInteraction(
"get_forward_extremeties_for_room",
get_forward_extremeties_for_room_txn
)
def _delete_old_forward_extrem_cache(self):
def _delete_old_forward_extrem_cache_txn(txn):
# Delete entries older than a month, while making sure we don't delete
# the only entries for a room.
sql = ("""
DELETE FROM stream_ordering_to_exterm
WHERE
room_id IN (
SELECT room_id
FROM stream_ordering_to_exterm
WHERE stream_ordering > ?
) AND stream_ordering < ?
""")
txn.execute(
sql,
(self.stream_ordering_month_ago, self.stream_ordering_month_ago,)
)
return self.runInteraction(
"_delete_old_forward_extrem_cache",
_delete_old_forward_extrem_cache_txn
)
def get_backfill_events(self, room_id, event_list, limit): def get_backfill_events(self, room_id, event_list, limit):
"""Get a list of Events for a given topic that occurred before (and """Get a list of Events for a given topic that occurred before (and
including) the events in event_list. Return a list of max size `limit` including) the events in event_list. Return a list of max size `limit`

View file

@ -1354,39 +1354,53 @@ class EventsStore(SQLBaseStore):
min_stream_id = rows[-1][0] min_stream_id = rows[-1][0]
event_ids = [row[1] for row in rows] event_ids = [row[1] for row in rows]
events = self._get_events_txn(txn, event_ids) rows_to_update = []
rows = [] chunks = [
for event in events: event_ids[i:i + 100]
try: for i in xrange(0, len(event_ids), 100)
event_id = event.event_id ]
origin_server_ts = event.origin_server_ts for chunk in chunks:
except (KeyError, AttributeError): ev_rows = self._simple_select_many_txn(
# If the event is missing a necessary field then txn,
# skip over it. table="event_json",
continue column="event_id",
iterable=chunk,
retcols=["event_id", "json"],
keyvalues={},
)
rows.append((origin_server_ts, event_id)) for row in ev_rows:
event_id = row["event_id"]
event_json = json.loads(row["json"])
try:
origin_server_ts = event_json["origin_server_ts"]
except (KeyError, AttributeError):
# If the event is missing a necessary field then
# skip over it.
continue
rows_to_update.append((origin_server_ts, event_id))
sql = ( sql = (
"UPDATE events SET origin_server_ts = ? WHERE event_id = ?" "UPDATE events SET origin_server_ts = ? WHERE event_id = ?"
) )
for index in range(0, len(rows), INSERT_CLUMP_SIZE): for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE):
clump = rows[index:index + INSERT_CLUMP_SIZE] clump = rows_to_update[index:index + INSERT_CLUMP_SIZE]
txn.executemany(sql, clump) txn.executemany(sql, clump)
progress = { progress = {
"target_min_stream_id_inclusive": target_min_stream_id, "target_min_stream_id_inclusive": target_min_stream_id,
"max_stream_id_exclusive": min_stream_id, "max_stream_id_exclusive": min_stream_id,
"rows_inserted": rows_inserted + len(rows) "rows_inserted": rows_inserted + len(rows_to_update)
} }
self._background_update_progress_txn( self._background_update_progress_txn(
txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress
) )
return len(rows) return len(rows_to_update)
result = yield self.runInteraction( result = yield self.runInteraction(
self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn

View file

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 35 SCHEMA_VERSION = 36
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))

View file

@ -48,15 +48,31 @@ class RoomStore(SQLBaseStore):
StoreError if the room could not be stored. StoreError if the room could not be stored.
""" """
try: try:
yield self._simple_insert( def store_room_txn(txn, next_id):
"rooms", self._simple_insert_txn(
{ txn,
"room_id": room_id, "rooms",
"creator": room_creator_user_id, {
"is_public": is_public, "room_id": room_id,
}, "creator": room_creator_user_id,
desc="store_room", "is_public": is_public,
) },
)
if is_public:
self._simple_insert_txn(
txn,
table="public_room_list_stream",
values={
"stream_id": next_id,
"room_id": room_id,
"visibility": is_public,
}
)
with self._public_room_id_gen.get_next() as next_id:
yield self.runInteraction(
"store_room_txn",
store_room_txn, next_id,
)
except Exception as e: except Exception as e:
logger.error("store_room with room_id=%s failed: %s", room_id, e) logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.") raise StoreError(500, "Problem creating room.")
@ -77,13 +93,45 @@ class RoomStore(SQLBaseStore):
allow_none=True, allow_none=True,
) )
@defer.inlineCallbacks
def set_room_is_public(self, room_id, is_public): def set_room_is_public(self, room_id, is_public):
return self._simple_update_one( def set_room_is_public_txn(txn, next_id):
table="rooms", self._simple_update_one_txn(
keyvalues={"room_id": room_id}, txn,
updatevalues={"is_public": is_public}, table="rooms",
desc="set_room_is_public", keyvalues={"room_id": room_id},
) updatevalues={"is_public": is_public},
)
entries = self._simple_select_list_txn(
txn,
table="public_room_list_stream",
keyvalues={"room_id": room_id},
retcols=("stream_id", "visibility"),
)
entries.sort(key=lambda r: r["stream_id"])
add_to_stream = True
if entries:
add_to_stream = bool(entries[-1]["visibility"]) != is_public
if add_to_stream:
self._simple_insert_txn(
txn,
table="public_room_list_stream",
values={
"stream_id": next_id,
"room_id": room_id,
"visibility": is_public,
}
)
with self._public_room_id_gen.get_next() as next_id:
yield self.runInteraction(
"set_room_is_public",
set_room_is_public_txn, next_id,
)
def get_public_room_ids(self): def get_public_room_ids(self):
return self._simple_select_onecol( return self._simple_select_onecol(
@ -207,3 +255,74 @@ class RoomStore(SQLBaseStore):
}, },
desc="add_event_report" desc="add_event_report"
) )
def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token()
def get_public_room_ids_at_stream_id(self, stream_id):
return self.runInteraction(
"get_public_room_ids_at_stream_id",
self.get_public_room_ids_at_stream_id_txn, stream_id
)
def get_public_room_ids_at_stream_id_txn(self, txn, stream_id):
return {
rm
for rm, vis in self.get_published_at_stream_id_txn(txn, stream_id).items()
if vis
}
def get_published_at_stream_id_txn(self, txn, stream_id):
sql = ("""
SELECT room_id, visibility FROM public_room_list_stream
INNER JOIN (
SELECT room_id, max(stream_id) AS stream_id
FROM public_room_list_stream
WHERE stream_id <= ?
GROUP BY room_id
) grouped USING (room_id, stream_id)
""")
txn.execute(sql, (stream_id,))
return dict(txn.fetchall())
def get_public_room_changes(self, prev_stream_id, new_stream_id):
def get_public_room_changes_txn(txn):
then_rooms = self.get_public_room_ids_at_stream_id_txn(txn, prev_stream_id)
now_rooms_dict = self.get_published_at_stream_id_txn(txn, new_stream_id)
now_rooms_visible = set(
rm for rm, vis in now_rooms_dict.items() if vis
)
now_rooms_not_visible = set(
rm for rm, vis in now_rooms_dict.items() if not vis
)
newly_visible = now_rooms_visible - then_rooms
newly_unpublished = now_rooms_not_visible & then_rooms
return newly_visible, newly_unpublished
return self.runInteraction(
"get_public_room_changes", get_public_room_changes_txn
)
def get_all_new_public_rooms(self, prev_id, current_id, limit):
def get_all_new_public_rooms(txn):
sql = ("""
SELECT stream_id, room_id, visibility FROM public_room_list_stream
WHERE stream_id > ? AND stream_id <= ?
ORDER BY stream_id ASC
LIMIT ?
""")
txn.execute(sql, (prev_id, current_id, limit,))
return txn.fetchall()
if prev_id == current_id:
return defer.succeed([])
return self.runInteraction(
"get_all_new_public_rooms", get_all_new_public_rooms
)

View file

@ -0,0 +1,33 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE public_room_list_stream (
stream_id BIGINT NOT NULL,
room_id TEXT NOT NULL,
visibility BOOLEAN NOT NULL
);
INSERT INTO public_room_list_stream (stream_id, room_id, visibility)
SELECT 1, room_id, is_public FROM rooms
WHERE is_public = CAST(1 AS BOOLEAN);
CREATE INDEX public_room_list_stream_idx on public_room_list_stream(
stream_id
);
CREATE INDEX public_room_list_stream_rm_idx on public_room_list_stream(
room_id, stream_id
);

View file

@ -0,0 +1,37 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE stream_ordering_to_exterm (
stream_ordering BIGINT NOT NULL,
room_id TEXT NOT NULL,
event_id TEXT NOT NULL
);
INSERT INTO stream_ordering_to_exterm (stream_ordering, room_id, event_id)
SELECT stream_ordering, room_id, event_id FROM event_forward_extremities
INNER JOIN (
SELECT room_id, max(stream_ordering) as stream_ordering FROM events
INNER JOIN event_forward_extremities USING (room_id, event_id)
GROUP BY room_id
) AS rms USING (room_id);
CREATE INDEX stream_ordering_to_exterm_idx on stream_ordering_to_exterm(
stream_ordering
);
CREATE INDEX stream_ordering_to_exterm_rm_idx on stream_ordering_to_exterm(
room_id, stream_ordering
);

View file

@ -0,0 +1,26 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Re-add some entries to stream_ordering_to_exterm that were incorrectly deleted
INSERT INTO stream_ordering_to_exterm (stream_ordering, room_id, event_id)
SELECT
(SELECT stream_ordering FROM events where event_id = e.event_id) AS stream_ordering,
room_id,
event_id
FROM event_forward_extremities AS e
WHERE NOT EXISTS (
SELECT room_id FROM stream_ordering_to_exterm AS s
WHERE s.room_id = e.room_id
);

View file

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from synapse.storage.prepare_database import get_statements from synapse.storage.prepare_database import get_statements
from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.engines import PostgresEngine
import logging import logging

View file

@ -307,6 +307,9 @@ class StateStore(SQLBaseStore):
def _get_state_groups_from_groups_txn(self, txn, groups, types=None): def _get_state_groups_from_groups_txn(self, txn, groups, types=None):
results = {group: {} for group in groups} results = {group: {} for group in groups}
if types is not None:
types = list(set(types)) # deduplicate types list
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
# Temporarily disable sequential scans in this transaction. This is # Temporarily disable sequential scans in this transaction. This is
# a temporary hack until we can add the right indices in # a temporary hack until we can add the right indices in
@ -375,10 +378,35 @@ class StateStore(SQLBaseStore):
# We don't use WITH RECURSIVE on sqlite3 as there are distributions # We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy) # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
for group in groups: for group in groups:
group_tree = [group]
next_group = group next_group = group
while next_group: while next_group:
# We did this before by getting the list of group ids, and
# then passing that list to sqlite to get latest event for
# each (type, state_key). However, that was terribly slow
# without the right indicies (which we can't add until
# after we finish deduping state, which requires this func)
args = [next_group]
if types:
args.extend(i for typ in types for i in typ)
txn.execute(
"SELECT type, state_key, event_id FROM state_groups_state"
" WHERE state_group = ? %s" % (where_clause,),
args
)
rows = txn.fetchall()
results[group].update({
(typ, state_key): event_id
for typ, state_key, event_id in rows
if (typ, state_key) not in results[group]
})
# If the lengths match then we must have all the types,
# so no need to go walk further down the tree.
if types is not None and len(results[group]) == len(types):
break
next_group = self._simple_select_one_onecol_txn( next_group = self._simple_select_one_onecol_txn(
txn, txn,
table="state_group_edges", table="state_group_edges",
@ -386,28 +414,6 @@ class StateStore(SQLBaseStore):
retcol="prev_state_group", retcol="prev_state_group",
allow_none=True, allow_none=True,
) )
if next_group:
group_tree.append(next_group)
sql = ("""
SELECT type, state_key, event_id FROM state_groups_state
INNER JOIN (
SELECT type, state_key, max(state_group) as state_group
FROM state_groups_state
WHERE state_group IN (%s) %s
GROUP BY type, state_key
) USING (type, state_key, state_group);
""") % (",".join("?" for _ in group_tree), where_clause,)
args = list(group_tree)
if types is not None:
args.extend([i for typ in types for i in typ])
txn.execute(sql, args)
rows = self.cursor_to_dict(txn)
for row in rows:
key = (row["type"], row["state_key"])
results[group][key] = row["event_id"]
return results return results
@ -817,16 +823,24 @@ class StateStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def _background_index_state(self, progress, batch_size): def _background_index_state(self, progress, batch_size):
def reindex_txn(txn): def reindex_txn(conn):
conn.rollback()
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
txn.execute( # postgres insists on autocommit for the index
"CREATE INDEX CONCURRENTLY state_groups_state_type_idx" conn.set_session(autocommit=True)
" ON state_groups_state(state_group, type, state_key)" try:
) txn = conn.cursor()
txn.execute( txn.execute(
"DROP INDEX IF EXISTS state_groups_state_id" "CREATE INDEX CONCURRENTLY state_groups_state_type_idx"
) " ON state_groups_state(state_group, type, state_key)"
)
txn.execute(
"DROP INDEX IF EXISTS state_groups_state_id"
)
finally:
conn.set_session(autocommit=False)
else: else:
txn = conn.cursor()
txn.execute( txn.execute(
"CREATE INDEX state_groups_state_type_idx" "CREATE INDEX state_groups_state_type_idx"
" ON state_groups_state(state_group, type, state_key)" " ON state_groups_state(state_group, type, state_key)"
@ -835,9 +849,7 @@ class StateStore(SQLBaseStore):
"DROP INDEX IF EXISTS state_groups_state_id" "DROP INDEX IF EXISTS state_groups_state_id"
) )
yield self.runInteraction( yield self.runWithConnection(reindex_txn)
self.STATE_GROUP_INDEX_UPDATE_NAME, reindex_txn
)
yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)

View file

@ -531,6 +531,9 @@ class StreamStore(SQLBaseStore):
) )
defer.returnValue("t%d-%d" % (topo, token)) defer.returnValue("t%d-%d" % (topo, token))
def get_room_max_stream_ordering(self):
return self._stream_id_gen.get_current_token()
def get_stream_token_for_event(self, event_id): def get_stream_token_for_event(self, event_id):
"""The stream token for an event """The stream token for an event
Args: Args:

View file

@ -56,7 +56,7 @@ def get_domain_from_id(string):
try: try:
return string.split(":", 1)[1] return string.split(":", 1)[1]
except IndexError: except IndexError:
raise SynapseError(400, "Invalid ID: %r", string) raise SynapseError(400, "Invalid ID: %r" % (string,))
class DomainSpecificString( class DomainSpecificString(

View file

@ -121,3 +121,9 @@ class StreamChangeCache(object):
k, r = self._cache.popitem() k, r = self._cache.popitem()
self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos) self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos)
self._entity_to_key.pop(r, None) self._entity_to_key.pop(r, None)
def get_max_pos_of_last_change(self, entity):
"""Returns an upper bound of the stream id of the last change to an
entity.
"""
return self._entity_to_key.get(entity, self._earliest_known_stream_pos)

View file

@ -0,0 +1,368 @@
from twisted.internet import defer
from synapse.config._base import ConfigError
from synapse.types import UserID
import ldap3
import ldap3.core.exceptions
import logging
try:
import ldap3
import ldap3.core.exceptions
except ImportError:
ldap3 = None
pass
logger = logging.getLogger(__name__)
class LDAPMode(object):
SIMPLE = "simple",
SEARCH = "search",
LIST = (SIMPLE, SEARCH)
class LdapAuthProvider(object):
__version__ = "0.1"
def __init__(self, config, account_handler):
self.account_handler = account_handler
if not ldap3:
raise RuntimeError(
'Missing ldap3 library. This is required for LDAP Authentication.'
)
self.ldap_mode = config.mode
self.ldap_uri = config.uri
self.ldap_start_tls = config.start_tls
self.ldap_base = config.base
self.ldap_attributes = config.attributes
if self.ldap_mode == LDAPMode.SEARCH:
self.ldap_bind_dn = config.bind_dn
self.ldap_bind_password = config.bind_password
self.ldap_filter = config.filter
@defer.inlineCallbacks
def check_password(self, user_id, password):
""" Attempt to authenticate a user against an LDAP Server
and register an account if none exists.
Returns:
True if authentication against LDAP was successful
"""
localpart = UserID.from_string(user_id).localpart
try:
server = ldap3.Server(self.ldap_uri)
logger.debug(
"Attempting LDAP connection with %s",
self.ldap_uri
)
if self.ldap_mode == LDAPMode.SIMPLE:
result, conn = self._ldap_simple_bind(
server=server, localpart=localpart, password=password
)
logger.debug(
'LDAP authentication method simple bind returned: %s (conn: %s)',
result,
conn
)
if not result:
defer.returnValue(False)
elif self.ldap_mode == LDAPMode.SEARCH:
result, conn = self._ldap_authenticated_search(
server=server, localpart=localpart, password=password
)
logger.debug(
'LDAP auth method authenticated search returned: %s (conn: %s)',
result,
conn
)
if not result:
defer.returnValue(False)
else:
raise RuntimeError(
'Invalid LDAP mode specified: {mode}'.format(
mode=self.ldap_mode
)
)
try:
logger.info(
"User authenticated against LDAP server: %s",
conn
)
except NameError:
logger.warn(
"Authentication method yielded no LDAP connection, aborting!"
)
defer.returnValue(False)
# check if user with user_id exists
if (yield self.account_handler.check_user_exists(user_id)):
# exists, authentication complete
conn.unbind()
defer.returnValue(True)
else:
# does not exist, fetch metadata for account creation from
# existing ldap connection
query = "({prop}={value})".format(
prop=self.ldap_attributes['uid'],
value=localpart
)
if self.ldap_mode == LDAPMode.SEARCH and self.ldap_filter:
query = "(&{filter}{user_filter})".format(
filter=query,
user_filter=self.ldap_filter
)
logger.debug(
"ldap registration filter: %s",
query
)
conn.search(
search_base=self.ldap_base,
search_filter=query,
attributes=[
self.ldap_attributes['name'],
self.ldap_attributes['mail']
]
)
if len(conn.response) == 1:
attrs = conn.response[0]['attributes']
mail = attrs[self.ldap_attributes['mail']][0]
name = attrs[self.ldap_attributes['name']][0]
# create account
user_id, access_token = (
yield self.account_handler.register(localpart=localpart)
)
# TODO: bind email, set displayname with data from ldap directory
logger.info(
"Registration based on LDAP data was successful: %d: %s (%s, %)",
user_id,
localpart,
name,
mail
)
defer.returnValue(True)
else:
if len(conn.response) == 0:
logger.warn("LDAP registration failed, no result.")
else:
logger.warn(
"LDAP registration failed, too many results (%s)",
len(conn.response)
)
defer.returnValue(False)
defer.returnValue(False)
except ldap3.core.exceptions.LDAPException as e:
logger.warn("Error during ldap authentication: %s", e)
defer.returnValue(False)
@staticmethod
def parse_config(config):
class _LdapConfig(object):
pass
ldap_config = _LdapConfig()
ldap_config.enabled = config.get("enabled", False)
ldap_config.mode = LDAPMode.SIMPLE
# verify config sanity
_require_keys(config, [
"uri",
"base",
"attributes",
])
ldap_config.uri = config["uri"]
ldap_config.start_tls = config.get("start_tls", False)
ldap_config.base = config["base"]
ldap_config.attributes = config["attributes"]
if "bind_dn" in config:
ldap_config.mode = LDAPMode.SEARCH
_require_keys(config, [
"bind_dn",
"bind_password",
])
ldap_config.bind_dn = config["bind_dn"]
ldap_config.bind_password = config["bind_password"]
ldap_config.filter = config.get("filter", None)
# verify attribute lookup
_require_keys(config['attributes'], [
"uid",
"name",
"mail",
])
return ldap_config
def _ldap_simple_bind(self, server, localpart, password):
""" Attempt a simple bind with the credentials
given by the user against the LDAP server.
Returns True, LDAP3Connection
if the bind was successful
Returns False, None
if an error occured
"""
try:
# bind with the the local users ldap credentials
bind_dn = "{prop}={value},{base}".format(
prop=self.ldap_attributes['uid'],
value=localpart,
base=self.ldap_base
)
conn = ldap3.Connection(server, bind_dn, password)
logger.debug(
"Established LDAP connection in simple bind mode: %s",
conn
)
if self.ldap_start_tls:
conn.start_tls()
logger.debug(
"Upgraded LDAP connection in simple bind mode through StartTLS: %s",
conn
)
if conn.bind():
# GOOD: bind okay
logger.debug("LDAP Bind successful in simple bind mode.")
return True, conn
# BAD: bind failed
logger.info(
"Binding against LDAP failed for '%s' failed: %s",
localpart, conn.result['description']
)
conn.unbind()
return False, None
except ldap3.core.exceptions.LDAPException as e:
logger.warn("Error during LDAP authentication: %s", e)
return False, None
def _ldap_authenticated_search(self, server, localpart, password):
""" Attempt to login with the preconfigured bind_dn
and then continue searching and filtering within
the base_dn
Returns (True, LDAP3Connection)
if a single matching DN within the base was found
that matched the filter expression, and with which
a successful bind was achieved
The LDAP3Connection returned is the instance that was used to
verify the password not the one using the configured bind_dn.
Returns (False, None)
if an error occured
"""
try:
conn = ldap3.Connection(
server,
self.ldap_bind_dn,
self.ldap_bind_password
)
logger.debug(
"Established LDAP connection in search mode: %s",
conn
)
if self.ldap_start_tls:
conn.start_tls()
logger.debug(
"Upgraded LDAP connection in search mode through StartTLS: %s",
conn
)
if not conn.bind():
logger.warn(
"Binding against LDAP with `bind_dn` failed: %s",
conn.result['description']
)
conn.unbind()
return False, None
# construct search_filter like (uid=localpart)
query = "({prop}={value})".format(
prop=self.ldap_attributes['uid'],
value=localpart
)
if self.ldap_filter:
# combine with the AND expression
query = "(&{query}{filter})".format(
query=query,
filter=self.ldap_filter
)
logger.debug(
"LDAP search filter: %s",
query
)
conn.search(
search_base=self.ldap_base,
search_filter=query
)
if len(conn.response) == 1:
# GOOD: found exactly one result
user_dn = conn.response[0]['dn']
logger.debug('LDAP search found dn: %s', user_dn)
# unbind and simple bind with user_dn to verify the password
# Note: do not use rebind(), for some reason it did not verify
# the password for me!
conn.unbind()
return self._ldap_simple_bind(server, localpart, password)
else:
# BAD: found 0 or > 1 results, abort!
if len(conn.response) == 0:
logger.info(
"LDAP search returned no results for '%s'",
localpart
)
else:
logger.info(
"LDAP search returned too many (%s) results for '%s'",
len(conn.response), localpart
)
conn.unbind()
return False, None
except ldap3.core.exceptions.LDAPException as e:
logger.warn("Error during LDAP authentication: %s", e)
return False, None
def _require_keys(config, required):
missing = [key for key in required if key not in config]
if missing:
raise ConfigError(
"LDAP enabled but missing required config values: {}".format(
", ".join(missing)
)
)

View file

@ -17,7 +17,7 @@ from twisted.internet import defer
from .. import unittest from .. import unittest
from synapse.handlers.register import RegistrationHandler from synapse.handlers.register import RegistrationHandler
from synapse.types import UserID from synapse.types import UserID, create_requester
from tests.utils import setup_test_homeserver from tests.utils import setup_test_homeserver
@ -57,8 +57,9 @@ class RegistrationTestCase(unittest.TestCase):
local_part = "someone" local_part = "someone"
display_name = "someone" display_name = "someone"
user_id = "@someone:test" user_id = "@someone:test"
requester = create_requester("@as:test")
result_user_id, result_token = yield self.handler.get_or_create_user( result_user_id, result_token = yield self.handler.get_or_create_user(
local_part, display_name, duration_ms) requester, local_part, display_name, duration_ms)
self.assertEquals(result_user_id, user_id) self.assertEquals(result_user_id, user_id)
self.assertEquals(result_token, 'secret') self.assertEquals(result_token, 'secret')
@ -74,7 +75,8 @@ class RegistrationTestCase(unittest.TestCase):
local_part = "frank" local_part = "frank"
display_name = "Frank" display_name = "Frank"
user_id = "@frank:test" user_id = "@frank:test"
requester = create_requester("@as:test")
result_user_id, result_token = yield self.handler.get_or_create_user( result_user_id, result_token = yield self.handler.get_or_create_user(
local_part, display_name, duration_ms) requester, local_part, display_name, duration_ms)
self.assertEquals(result_user_id, user_id) self.assertEquals(result_user_id, user_id)
self.assertEquals(result_token, 'secret') self.assertEquals(result_token, 'secret')

View file

@ -267,10 +267,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
from synapse.handlers.typing import RoomMember from synapse.handlers.typing import RoomMember
member = RoomMember(self.room_id, self.u_apple.to_string()) member = RoomMember(self.room_id, self.u_apple.to_string())
self.handler._member_typing_until[member] = 1002000 self.handler._member_typing_until[member] = 1002000
self.handler._member_typing_timer[member] = ( self.handler._room_typing[self.room_id] = set([self.u_apple.to_string()])
self.clock.call_later(1002, lambda: 0)
)
self.handler._room_typing[self.room_id] = set((self.u_apple.to_string(),))
self.assertEquals(self.event_source.get_current_key(), 0) self.assertEquals(self.event_source.get_current_key(), 0)
@ -330,7 +327,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
}, },
}]) }])
self.clock.advance_time(11) self.clock.advance_time(16)
self.on_new_event.assert_has_calls([ self.on_new_event.assert_has_calls([
call('typing_key', 2, rooms=[self.room_id]), call('typing_key', 2, rooms=[self.room_id]),

View file

@ -42,7 +42,8 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def replicate(self): def replicate(self):
streams = self.slaved_store.stream_positions() streams = self.slaved_store.stream_positions()
result = yield self.replication.replicate(streams, 100) writer = yield self.replication.replicate(streams, 100)
result = writer.finish()
yield self.slaved_store.process_replication(result) yield self.slaved_store.process_replication(result)
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -120,7 +120,7 @@ class ReplicationResourceCase(unittest.TestCase):
self.hs.clock.advance_time_msec(1) self.hs.clock.advance_time_msec(1)
code, body = yield get code, body = yield get
self.assertEquals(code, 200) self.assertEquals(code, 200)
self.assertEquals(body, {}) self.assertEquals(body.get("rows", []), [])
test_timeout.__name__ = "test_timeout_%s" % (stream) test_timeout.__name__ = "test_timeout_%s" % (stream)
return test_timeout return test_timeout
@ -195,7 +195,6 @@ class ReplicationResourceCase(unittest.TestCase):
self.assertIn("field_names", stream) self.assertIn("field_names", stream)
field_names = stream["field_names"] field_names = stream["field_names"]
self.assertIn("rows", stream) self.assertIn("rows", stream)
self.assertTrue(stream["rows"])
for row in stream["rows"]: for row in stream["rows"]:
self.assertEquals( self.assertEquals(
len(row), len(field_names), len(row), len(field_names),

View file

@ -31,33 +31,21 @@ class CreateUserServletTestCase(unittest.TestCase):
) )
self.request.args = {} self.request.args = {}
self.appservice = None
self.auth = Mock(get_appservice_by_req=Mock(
side_effect=lambda x: defer.succeed(self.appservice))
)
self.auth_result = (False, None, None, None)
self.auth_handler = Mock(
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
get_session_data=Mock(return_value=None)
)
self.registration_handler = Mock() self.registration_handler = Mock()
self.identity_handler = Mock()
self.login_handler = Mock()
# do the dance to hook it up to the hs global self.appservice = Mock(sender="@as:test")
self.handlers = Mock( self.datastore = Mock(
auth_handler=self.auth_handler, get_app_service_by_token=Mock(return_value=self.appservice)
)
# do the dance to hook things up to the hs global
handlers = Mock(
registration_handler=self.registration_handler, registration_handler=self.registration_handler,
identity_handler=self.identity_handler,
login_handler=self.login_handler
) )
self.hs = Mock() self.hs = Mock()
self.hs.hostname = "supergbig~testing~thing.com" self.hs.hostname = "superbig~testing~thing.com"
self.hs.get_auth = Mock(return_value=self.auth) self.hs.get_datastore = Mock(return_value=self.datastore)
self.hs.get_handlers = Mock(return_value=self.handlers) self.hs.get_handlers = Mock(return_value=handlers)
self.hs.config.enable_registration = True
# init the thing we're testing
self.servlet = CreateUserRestServlet(self.hs) self.servlet = CreateUserRestServlet(self.hs)
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -105,9 +105,6 @@ class RoomTypingTestCase(RestTestCase):
# Need another user to make notifications actually work # Need another user to make notifications actually work
yield self.join(self.room_id, user="@jim:red") yield self.join(self.room_id, user="@jim:red")
def tearDown(self):
self.hs.get_typing_handler().tearDown()
@defer.inlineCallbacks @defer.inlineCallbacks
def test_set_typing(self): def test_set_typing(self):
(code, _) = yield self.mock_resource.trigger( (code, _) = yield self.mock_resource.trigger(
@ -147,7 +144,7 @@ class RoomTypingTestCase(RestTestCase):
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
self.clock.advance_time(31) self.clock.advance_time(36)
self.assertEquals(self.event_source.get_current_key(), 2) self.assertEquals(self.event_source.get_current_key(), 2)

View file

@ -19,7 +19,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.appservice = None self.appservice = None
self.auth = Mock(get_appservice_by_req=Mock( self.auth = Mock(get_appservice_by_req=Mock(
side_effect=lambda x: defer.succeed(self.appservice)) side_effect=lambda x: self.appservice)
) )
self.auth_result = (False, None, None, None) self.auth_result = (False, None, None, None)

View file

@ -37,6 +37,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
config = Mock( config = Mock(
app_service_config_files=self.as_yaml_files, app_service_config_files=self.as_yaml_files,
event_cache_size=1, event_cache_size=1,
password_providers=[],
) )
hs = yield setup_test_homeserver(config=config) hs = yield setup_test_homeserver(config=config)
@ -71,14 +72,12 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
outfile.write(yaml.dump(as_yaml)) outfile.write(yaml.dump(as_yaml))
self.as_yaml_files.append(as_token) self.as_yaml_files.append(as_token)
@defer.inlineCallbacks
def test_retrieve_unknown_service_token(self): def test_retrieve_unknown_service_token(self):
service = yield self.store.get_app_service_by_token("invalid_token") service = self.store.get_app_service_by_token("invalid_token")
self.assertEquals(service, None) self.assertEquals(service, None)
@defer.inlineCallbacks
def test_retrieval_of_service(self): def test_retrieval_of_service(self):
stored_service = yield self.store.get_app_service_by_token( stored_service = self.store.get_app_service_by_token(
self.as_token self.as_token
) )
self.assertEquals(stored_service.token, self.as_token) self.assertEquals(stored_service.token, self.as_token)
@ -97,9 +96,8 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
[] []
) )
@defer.inlineCallbacks
def test_retrieval_of_all_services(self): def test_retrieval_of_all_services(self):
services = yield self.store.get_app_services() services = self.store.get_app_services()
self.assertEquals(len(services), 3) self.assertEquals(len(services), 3)
@ -112,6 +110,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
config = Mock( config = Mock(
app_service_config_files=self.as_yaml_files, app_service_config_files=self.as_yaml_files,
event_cache_size=1, event_cache_size=1,
password_providers=[],
) )
hs = yield setup_test_homeserver(config=config) hs = yield setup_test_homeserver(config=config)
self.db_pool = hs.get_db_pool() self.db_pool = hs.get_db_pool()
@ -440,7 +439,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
f1 = self._write_config(suffix="1") f1 = self._write_config(suffix="1")
f2 = self._write_config(suffix="2") f2 = self._write_config(suffix="2")
config = Mock(app_service_config_files=[f1, f2], event_cache_size=1) config = Mock(
app_service_config_files=[f1, f2], event_cache_size=1,
password_providers=[]
)
hs = yield setup_test_homeserver(config=config, datastore=Mock()) hs = yield setup_test_homeserver(config=config, datastore=Mock())
ApplicationServiceStore(hs) ApplicationServiceStore(hs)
@ -450,7 +452,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
f1 = self._write_config(id="id", suffix="1") f1 = self._write_config(id="id", suffix="1")
f2 = self._write_config(id="id", suffix="2") f2 = self._write_config(id="id", suffix="2")
config = Mock(app_service_config_files=[f1, f2], event_cache_size=1) config = Mock(
app_service_config_files=[f1, f2], event_cache_size=1,
password_providers=[]
)
hs = yield setup_test_homeserver(config=config, datastore=Mock()) hs = yield setup_test_homeserver(config=config, datastore=Mock())
with self.assertRaises(ConfigError) as cm: with self.assertRaises(ConfigError) as cm:
@ -466,7 +471,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
f1 = self._write_config(as_token="as_token", suffix="1") f1 = self._write_config(as_token="as_token", suffix="1")
f2 = self._write_config(as_token="as_token", suffix="2") f2 = self._write_config(as_token="as_token", suffix="2")
config = Mock(app_service_config_files=[f1, f2], event_cache_size=1) config = Mock(
app_service_config_files=[f1, f2], event_cache_size=1,
password_providers=[]
)
hs = yield setup_test_homeserver(config=config, datastore=Mock()) hs = yield setup_test_homeserver(config=config, datastore=Mock())
with self.assertRaises(ConfigError) as cm: with self.assertRaises(ConfigError) as cm:

View file

@ -52,6 +52,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
config.server_name = name config.server_name = name
config.trusted_third_party_id_servers = [] config.trusted_third_party_id_servers = []
config.room_invite_state_types = [] config.room_invite_state_types = []
config.password_providers = []
config.use_frozen_dicts = True config.use_frozen_dicts = True
config.database_config = {"name": "sqlite3"} config.database_config = {"name": "sqlite3"}
@ -220,6 +221,7 @@ class MockClock(object):
# list of lists of [absolute_time, callback, expired] in no particular # list of lists of [absolute_time, callback, expired] in no particular
# order # order
self.timers = [] self.timers = []
self.loopers = []
def time(self): def time(self):
return self.now return self.now
@ -240,7 +242,7 @@ class MockClock(object):
return t return t
def looping_call(self, function, interval): def looping_call(self, function, interval):
pass self.loopers.append([function, interval / 1000., self.now])
def cancel_call_later(self, timer, ignore_errs=False): def cancel_call_later(self, timer, ignore_errs=False):
if timer[2]: if timer[2]:
@ -269,6 +271,12 @@ class MockClock(object):
else: else:
self.timers.append(t) self.timers.append(t)
for looped in self.loopers:
func, interval, last = looped
if last + interval < self.now:
func()
looped[2] = self.now
def advance_time_msec(self, ms): def advance_time_msec(self, ms):
self.advance_time(ms / 1000.) self.advance_time(ms / 1000.)