Merge branch 'develop' of github.com:matrix-org/synapse into erikj/unfederatable

This commit is contained in:
Erik Johnston 2015-10-02 10:33:49 +01:00
commit d5e081c7ae
62 changed files with 1425 additions and 949 deletions

View file

@ -1,3 +1,17 @@
Changes in synapse v0.10.0-r2 (2015-09-16)
==========================================
* Fix bug where we always fetched remote server signing keys instead of using
ones in our cache.
* Fix adding threepids to an existing account.
* Fix bug with invinting over federation where remote server was already in
the room. (PR #281, SYN-392)
Changes in synapse v0.10.0-r1 (2015-09-08)
==========================================
* Fix bug with python packaging
Changes in synapse v0.10.0 (2015-09-03)
=======================================

View file

@ -25,6 +25,7 @@ for port in 8080 8081 8082; do
--generate-config \
-H "localhost:$https_port" \
--config-path "$DIR/etc/$port.config" \
--report-stats no
# Check script parameters
if [ $# -eq 1 ]; then

142
scripts-dev/definitions.py Executable file
View file

@ -0,0 +1,142 @@
#! /usr/bin/python
import ast
import yaml
class DefinitionVisitor(ast.NodeVisitor):
def __init__(self):
super(DefinitionVisitor, self).__init__()
self.functions = {}
self.classes = {}
self.names = {}
self.attrs = set()
self.definitions = {
'def': self.functions,
'class': self.classes,
'names': self.names,
'attrs': self.attrs,
}
def visit_Name(self, node):
self.names.setdefault(type(node.ctx).__name__, set()).add(node.id)
def visit_Attribute(self, node):
self.attrs.add(node.attr)
for child in ast.iter_child_nodes(node):
self.visit(child)
def visit_ClassDef(self, node):
visitor = DefinitionVisitor()
self.classes[node.name] = visitor.definitions
for child in ast.iter_child_nodes(node):
visitor.visit(child)
def visit_FunctionDef(self, node):
visitor = DefinitionVisitor()
self.functions[node.name] = visitor.definitions
for child in ast.iter_child_nodes(node):
visitor.visit(child)
def non_empty(defs):
functions = {name: non_empty(f) for name, f in defs['def'].items()}
classes = {name: non_empty(f) for name, f in defs['class'].items()}
result = {}
if functions: result['def'] = functions
if classes: result['class'] = classes
names = defs['names']
uses = []
for name in names.get('Load', ()):
if name not in names.get('Param', ()) and name not in names.get('Store', ()):
uses.append(name)
uses.extend(defs['attrs'])
if uses: result['uses'] = uses
result['names'] = names
result['attrs'] = defs['attrs']
return result
def definitions_in_code(input_code):
input_ast = ast.parse(input_code)
visitor = DefinitionVisitor()
visitor.visit(input_ast)
definitions = non_empty(visitor.definitions)
return definitions
def definitions_in_file(filepath):
with open(filepath) as f:
return definitions_in_code(f.read())
def defined_names(prefix, defs, names):
for name, funcs in defs.get('def', {}).items():
names.setdefault(name, {'defined': []})['defined'].append(prefix + name)
defined_names(prefix + name + ".", funcs, names)
for name, funcs in defs.get('class', {}).items():
names.setdefault(name, {'defined': []})['defined'].append(prefix + name)
defined_names(prefix + name + ".", funcs, names)
def used_names(prefix, defs, names):
for name, funcs in defs.get('def', {}).items():
used_names(prefix + name + ".", funcs, names)
for name, funcs in defs.get('class', {}).items():
used_names(prefix + name + ".", funcs, names)
for used in defs.get('uses', ()):
if used in names:
names[used].setdefault('used', []).append(prefix.rstrip('.'))
if __name__ == '__main__':
import sys, os, argparse, re
parser = argparse.ArgumentParser(description='Find definitions.')
parser.add_argument(
"--unused", action="store_true", help="Only list unused definitions"
)
parser.add_argument(
"--ignore", action="append", metavar="REGEXP", help="Ignore a pattern"
)
parser.add_argument(
"--pattern", action="append", metavar="REGEXP",
help="Search for a pattern"
)
parser.add_argument(
"directories", nargs='+', metavar="DIR",
help="Directories to search for definitions"
)
args = parser.parse_args()
definitions = {}
for directory in args.directories:
for root, dirs, files in os.walk(directory):
for filename in files:
if filename.endswith(".py"):
filepath = os.path.join(root, filename)
definitions[filepath] = definitions_in_file(filepath)
names = {}
for filepath, defs in definitions.items():
defined_names(filepath + ":", defs, names)
for filepath, defs in definitions.items():
used_names(filepath + ":", defs, names)
patterns = [re.compile(pattern) for pattern in args.pattern or ()]
ignore = [re.compile(pattern) for pattern in args.ignore or ()]
result = {}
for name, definition in names.items():
if patterns and not any(pattern.match(name) for pattern in patterns):
continue
if ignore and any(pattern.match(name) for pattern in ignore):
continue
if args.unused and definition.get('used'):
continue
result[name] = definition
yaml.dump(result, sys.stdout, default_flow_style=False)

View file

@ -95,8 +95,6 @@ class Store(object):
_simple_update_one = SQLBaseStore.__dict__["_simple_update_one"]
_simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"]
_execute_and_decode = SQLBaseStore.__dict__["_execute_and_decode"]
def runInteraction(self, desc, func, *args, **kwargs):
def r(conn):
try:

View file

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server.
"""
__version__ = "0.10.0"
__version__ = "0.10.0-r2"

View file

@ -23,6 +23,7 @@ from synapse.util.logutils import log_function
from synapse.types import RoomID, UserID, EventID
import logging
import pymacaroons
logger = logging.getLogger(__name__)
@ -40,6 +41,12 @@ class Auth(object):
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
self._KNOWN_CAVEAT_PREFIXES = set([
"gen = ",
"type = ",
"time < ",
"user_id = ",
])
def check(self, event, auth_events):
""" Checks if this event is correctly authed.
@ -121,6 +128,20 @@ class Auth(object):
@defer.inlineCallbacks
def check_joined_room(self, room_id, user_id, current_state=None):
"""Check if the user is currently joined in the room
Args:
room_id(str): The room to check.
user_id(str): The user to check.
current_state(dict): Optional map of the current state of the room.
If provided then that map is used to check whether they are a
member of the room. Otherwise the current membership is
loaded from the database.
Raises:
AuthError if the user is not in the room.
Returns:
A deferred membership event for the user if the user is in
the room.
"""
if current_state:
member = current_state.get(
(EventTypes.Member, user_id),
@ -136,6 +157,43 @@ class Auth(object):
self._check_joined_room(member, user_id, room_id)
defer.returnValue(member)
@defer.inlineCallbacks
def check_user_was_in_room(self, room_id, user_id, current_state=None):
"""Check if the user was in the room at some point.
Args:
room_id(str): The room to check.
user_id(str): The user to check.
current_state(dict): Optional map of the current state of the room.
If provided then that map is used to check whether they are a
member of the room. Otherwise the current membership is
loaded from the database.
Raises:
AuthError if the user was never in the room.
Returns:
A deferred membership event for the user if the user was in the
room. This will be the join event if they are currently joined to
the room. This will be the leave event if they have left the room.
"""
if current_state:
member = current_state.get(
(EventTypes.Member, user_id),
None
)
else:
member = yield self.state.get_current_state(
room_id=room_id,
event_type=EventTypes.Member,
state_key=user_id
)
membership = member.membership if member else None
if membership not in (Membership.JOIN, Membership.LEAVE):
raise AuthError(403, "User %s not in room %s" % (
user_id, room_id
))
defer.returnValue(member)
@defer.inlineCallbacks
def check_host_in_room(self, room_id, host):
curr_state = yield self.state.get_current_state(room_id)
@ -390,7 +448,7 @@ class Auth(object):
except KeyError:
pass # normal users won't have the user_id query parameter set.
user_info = yield self.get_user_by_access_token(access_token)
user_info = yield self._get_user_by_access_token(access_token)
user = user_info["user"]
token_id = user_info["token_id"]
@ -417,7 +475,7 @@ class Auth(object):
)
@defer.inlineCallbacks
def get_user_by_access_token(self, token):
def _get_user_by_access_token(self, token):
""" Get a registered user's ID.
Args:
@ -427,6 +485,86 @@ class Auth(object):
Raises:
AuthError if no user by that token exists or the token is invalid.
"""
try:
ret = yield self._get_user_from_macaroon(token)
except AuthError:
# TODO(daniel): Remove this fallback when all existing access tokens
# have been re-issued as macaroons.
ret = yield self._look_up_user_by_access_token(token)
defer.returnValue(ret)
@defer.inlineCallbacks
def _get_user_from_macaroon(self, macaroon_str):
try:
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
self._validate_macaroon(macaroon)
user_prefix = "user_id = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix):
user = UserID.from_string(caveat.caveat_id[len(user_prefix):])
# This codepath exists so that we can actually return a
# token ID, because we use token IDs in place of device
# identifiers throughout the codebase.
# TODO(daniel): Remove this fallback when device IDs are
# properly implemented.
ret = yield self._look_up_user_by_access_token(macaroon_str)
if ret["user"] != user:
logger.error(
"Macaroon user (%s) != DB user (%s)",
user,
ret["user"]
)
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS,
"User mismatch in macaroon",
errcode=Codes.UNKNOWN_TOKEN
)
defer.returnValue(ret)
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
errcode=Codes.UNKNOWN_TOKEN
)
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.",
errcode=Codes.UNKNOWN_TOKEN
)
def _validate_macaroon(self, macaroon):
v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1")
v.satisfy_exact("type = access")
v.satisfy_general(lambda c: c.startswith("user_id = "))
v.satisfy_general(self._verify_expiry)
v.verify(macaroon, self.hs.config.macaroon_secret_key)
v = pymacaroons.Verifier()
v.satisfy_general(self._verify_recognizes_caveats)
v.verify(macaroon, self.hs.config.macaroon_secret_key)
def _verify_expiry(self, caveat):
prefix = "time < "
if not caveat.startswith(prefix):
return False
# TODO(daniel): Enable expiry check when clients actually know how to
# refresh tokens. (And remember to enable the tests)
return True
expiry = int(caveat[len(prefix):])
now = self.hs.get_clock().time_msec()
return now < expiry
def _verify_recognizes_caveats(self, caveat):
first_space = caveat.find(" ")
if first_space < 0:
return False
second_space = caveat.find(" ", first_space + 1)
if second_space < 0:
return False
return caveat[:second_space + 1] in self._KNOWN_CAVEAT_PREFIXES
@defer.inlineCallbacks
def _look_up_user_by_access_token(self, token):
ret = yield self.store.get_user_by_access_token(token)
if not ret:
raise AuthError(
@ -437,7 +575,6 @@ class Auth(object):
"user": UserID.from_string(ret.get("name")),
"token_id": ret.get("token_id", None),
}
defer.returnValue(user_info)
@defer.inlineCallbacks

View file

@ -27,16 +27,6 @@ class Membership(object):
LIST = (INVITE, JOIN, KNOCK, LEAVE, BAN)
class Feedback(object):
"""Represents the types of feedback a user can send in response to a
message."""
DELIVERED = u"delivered"
READ = u"read"
LIST = (DELIVERED, READ)
class PresenceState(object):
"""Represents the presence state of a user."""
OFFLINE = u"offline"
@ -73,7 +63,6 @@ class EventTypes(object):
PowerLevels = "m.room.power_levels"
Aliases = "m.room.aliases"
Redaction = "m.room.redaction"
Feedback = "m.room.message.feedback"
RoomHistoryVisibility = "m.room.history_visibility"
CanonicalAlias = "m.room.canonical_alias"

View file

@ -77,11 +77,6 @@ class SynapseError(CodeMessageException):
)
class RoomError(SynapseError):
"""An error raised when a room event fails."""
pass
class RegistrationError(SynapseError):
"""An error raised when a registration event fails."""
pass

View file

@ -16,10 +16,23 @@
import sys
sys.dont_write_bytecode = True
from synapse.python_dependencies import check_requirements, DEPENDENCY_LINKS
from synapse.python_dependencies import (
check_requirements, DEPENDENCY_LINKS, MissingRequirementError
)
if __name__ == '__main__':
check_requirements()
try:
check_requirements()
except MissingRequirementError as e:
message = "\n".join([
"Missing Requirement: %s" % (e.message,),
"To install run:",
" pip install --upgrade --force \"%s\"" % (e.dependency,),
"",
])
sys.stderr.writelines(message)
sys.exit(1)
from synapse.storage.engines import create_engine, IncorrectDatabaseSetup
from synapse.storage import (
@ -29,7 +42,7 @@ from synapse.storage import (
from synapse.server import HomeServer
from twisted.internet import reactor
from twisted.internet import reactor, task, defer
from twisted.application import service
from twisted.enterprise import adbapi
from twisted.web.resource import Resource, EncodingResourceWrapper
@ -72,12 +85,6 @@ import time
logger = logging.getLogger("synapse.app.homeserver")
class GzipFile(File):
def getChild(self, path, request):
child = File.getChild(self, path, request)
return EncodingResourceWrapper(child, [GzipEncoderFactory()])
def gz_wrap(r):
return EncodingResourceWrapper(r, [GzipEncoderFactory()])
@ -121,6 +128,7 @@ class SynapseHomeServer(HomeServer):
# (It can stay enabled for the API resources: they call
# write() with the whole body and then finish() straight
# after and so do not trigger the bug.
# GzipFile was removed in commit 184ba09
# return GzipFile(webclient_path) # TODO configurable?
return File(webclient_path) # TODO configurable?
@ -221,7 +229,7 @@ class SynapseHomeServer(HomeServer):
listener_config,
root_resource,
),
self.tls_context_factory,
self.tls_server_context_factory,
interface=bind_address
)
else:
@ -365,7 +373,6 @@ def setup(config_options):
Args:
config_options_options: The options passed to Synapse. Usually
`sys.argv[1:]`.
should_run (bool): Whether to start the reactor.
Returns:
HomeServer
@ -388,7 +395,7 @@ def setup(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts
tls_context_factory = context_factory.ServerContextFactory(config)
tls_server_context_factory = context_factory.ServerContextFactory(config)
database_engine = create_engine(config.database_config["name"])
config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
@ -396,7 +403,7 @@ def setup(config_options):
hs = SynapseHomeServer(
config.server_name,
db_config=config.database_config,
tls_context_factory=tls_context_factory,
tls_server_context_factory=tls_server_context_factory,
config=config,
content_addr=config.content_addr,
version_string=version_string,
@ -665,6 +672,42 @@ def run(hs):
ThreadPool._worker = profile(ThreadPool._worker)
reactor.run = profile(reactor.run)
start_time = hs.get_clock().time()
@defer.inlineCallbacks
def phone_stats_home():
now = int(hs.get_clock().time())
uptime = int(now - start_time)
if uptime < 0:
uptime = 0
stats = {}
stats["homeserver"] = hs.config.server_name
stats["timestamp"] = now
stats["uptime_seconds"] = uptime
stats["total_users"] = yield hs.get_datastore().count_all_users()
all_rooms = yield hs.get_datastore().get_rooms(False)
stats["total_room_count"] = len(all_rooms)
stats["daily_active_users"] = yield hs.get_datastore().count_daily_users()
daily_messages = yield hs.get_datastore().count_daily_messages()
if daily_messages is not None:
stats["daily_messages"] = daily_messages
logger.info("Reporting stats to matrix.org: %s" % (stats,))
try:
yield hs.get_simple_http_client().put_json(
"https://matrix.org/report-usage-stats/push",
stats
)
except Exception as e:
logger.warn("Error reporting stats: %s", e)
if hs.config.report_stats:
phone_home_task = task.LoopingCall(phone_stats_home)
phone_home_task.start(60 * 60 * 24, now=False)
def in_thread():
with LoggingContext("run"):
change_resource_limit(hs.config.soft_file_limit)

View file

@ -16,57 +16,67 @@
import sys
import os
import os.path
import subprocess
import signal
import yaml
SYNAPSE = ["python", "-B", "-m", "synapse.app.homeserver"]
CONFIGFILE = "homeserver.yaml"
GREEN = "\x1b[1;32m"
RED = "\x1b[1;31m"
NORMAL = "\x1b[m"
if not os.path.exists(CONFIGFILE):
sys.stderr.write(
"No config file found\n"
"To generate a config file, run '%s -c %s --generate-config"
" --server-name=<server name>'\n" % (
" ".join(SYNAPSE), CONFIGFILE
)
)
sys.exit(1)
CONFIG = yaml.load(open(CONFIGFILE))
PIDFILE = CONFIG["pid_file"]
def start():
def start(configfile):
print "Starting ...",
args = SYNAPSE
args.extend(["--daemonize", "-c", CONFIGFILE])
subprocess.check_call(args)
print GREEN + "started" + NORMAL
args.extend(["--daemonize", "-c", configfile])
try:
subprocess.check_call(args)
print GREEN + "started" + NORMAL
except subprocess.CalledProcessError as e:
print (
RED +
"error starting (exit code: %d); see above for logs" % e.returncode +
NORMAL
)
def stop():
if os.path.exists(PIDFILE):
pid = int(open(PIDFILE).read())
def stop(pidfile):
if os.path.exists(pidfile):
pid = int(open(pidfile).read())
os.kill(pid, signal.SIGTERM)
print GREEN + "stopped" + NORMAL
def main():
configfile = sys.argv[2] if len(sys.argv) == 3 else "homeserver.yaml"
if not os.path.exists(configfile):
sys.stderr.write(
"No config file found\n"
"To generate a config file, run '%s -c %s --generate-config"
" --server-name=<server name>'\n" % (
" ".join(SYNAPSE), configfile
)
)
sys.exit(1)
config = yaml.load(open(configfile))
pidfile = config["pid_file"]
action = sys.argv[1] if sys.argv[1:] else "usage"
if action == "start":
start()
start(configfile)
elif action == "stop":
stop()
stop(pidfile)
elif action == "restart":
stop()
start()
stop(pidfile)
start(configfile)
else:
sys.stderr.write("Usage: %s [start|stop|restart]\n" % (sys.argv[0],))
sys.stderr.write("Usage: %s [start|stop|restart] [configfile]\n" % (sys.argv[0],))
sys.exit(1)

View file

@ -26,6 +26,16 @@ class ConfigError(Exception):
class Config(object):
stats_reporting_begging_spiel = (
"We would really appreciate it if you could help our project out by"
" reporting anonymized usage statistics from your homeserver. Only very"
" basic aggregate data (e.g. number of users) will be reported, but it"
" helps us to track the growth of the Matrix community, and helps us to"
" make Matrix a success, as well as to convince other networks that they"
" should peer with us."
"\nThank you."
)
@staticmethod
def parse_size(value):
if isinstance(value, int) or isinstance(value, long):
@ -111,11 +121,14 @@ class Config(object):
results.append(getattr(cls, name)(self, *args, **kargs))
return results
def generate_config(self, config_dir_path, server_name):
def generate_config(self, config_dir_path, server_name, report_stats=None):
default_config = "# vim:ft=yaml\n"
default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all(
"default_config", config_dir_path, server_name
"default_config",
config_dir_path=config_dir_path,
server_name=server_name,
report_stats=report_stats,
))
config = yaml.load(default_config)
@ -139,6 +152,12 @@ class Config(object):
action="store_true",
help="Generate a config file for the server name"
)
config_parser.add_argument(
"--report-stats",
action="store",
help="Stuff",
choices=["yes", "no"]
)
config_parser.add_argument(
"--generate-keys",
action="store_true",
@ -189,6 +208,11 @@ class Config(object):
config_files.append(config_path)
if config_args.generate_config:
if config_args.report_stats is None:
config_parser.error(
"Please specify either --report-stats=yes or --report-stats=no\n\n" +
cls.stats_reporting_begging_spiel
)
if not config_files:
config_parser.error(
"Must supply a config file.\nA config file can be automatically"
@ -211,7 +235,9 @@ class Config(object):
os.makedirs(config_dir_path)
with open(config_path, "wb") as config_file:
config_bytes, config = obj.generate_config(
config_dir_path, server_name
config_dir_path=config_dir_path,
server_name=server_name,
report_stats=(config_args.report_stats == "yes"),
)
obj.invoke_all("generate_files", config)
config_file.write(config_bytes)
@ -261,9 +287,20 @@ class Config(object):
specified_config.update(yaml_config)
server_name = specified_config["server_name"]
_, config = obj.generate_config(config_dir_path, server_name)
_, config = obj.generate_config(
config_dir_path=config_dir_path,
server_name=server_name
)
config.pop("log_config")
config.update(specified_config)
if "report_stats" not in config:
sys.stderr.write(
"Please opt in or out of reporting anonymized homeserver usage "
"statistics, by setting the report_stats key in your config file "
" ( " + config_path + " ) " +
"to either True or False.\n\n" +
Config.stats_reporting_begging_spiel + "\n")
sys.exit(1)
if generate_keys:
obj.invoke_all("generate_files", config)

View file

@ -20,7 +20,7 @@ class AppServiceConfig(Config):
def read_config(self, config):
self.app_service_config_files = config.get("app_service_config_files", [])
def default_config(cls, config_dir_path, server_name):
def default_config(cls, **kwargs):
return """\
# A list of application service config file to use
app_service_config_files: []

View file

@ -24,7 +24,7 @@ class CaptchaConfig(Config):
self.captcha_bypass_secret = config.get("captcha_bypass_secret")
self.recaptcha_siteverify_api = config["recaptcha_siteverify_api"]
def default_config(self, config_dir_path, server_name):
def default_config(self, **kwargs):
return """\
## Captcha ##

View file

@ -45,7 +45,7 @@ class DatabaseConfig(Config):
self.set_databasepath(config.get("database_path"))
def default_config(self, config, config_dir_path):
def default_config(self, **kwargs):
database_path = self.abspath("homeserver.db")
return """\
# Database configuration

View file

@ -40,7 +40,7 @@ class KeyConfig(Config):
config["perspectives"]
)
def default_config(self, config_dir_path, server_name):
def default_config(self, config_dir_path, server_name, **kwargs):
base_key_name = os.path.join(config_dir_path, server_name)
return """\
## Signing Keys ##

View file

@ -21,6 +21,7 @@ import logging.config
import yaml
from string import Template
import os
import signal
DEFAULT_LOG_CONFIG = Template("""
@ -69,7 +70,7 @@ class LoggingConfig(Config):
self.log_config = self.abspath(config.get("log_config"))
self.log_file = self.abspath(config.get("log_file"))
def default_config(self, config_dir_path, server_name):
def default_config(self, config_dir_path, server_name, **kwargs):
log_file = self.abspath("homeserver.log")
log_config = self.abspath(
os.path.join(config_dir_path, server_name + ".log.config")
@ -142,6 +143,19 @@ class LoggingConfig(Config):
handler = logging.handlers.RotatingFileHandler(
self.log_file, maxBytes=(1000 * 1000 * 100), backupCount=3
)
def sighup(signum, stack):
logger.info("Closing log file due to SIGHUP")
handler.doRollover()
logger.info("Opened new log file due to SIGHUP")
# TODO(paul): obviously this is a terrible mechanism for
# stealing SIGHUP, because it means no other part of synapse
# can use it instead. If we want to catch SIGHUP anywhere
# else as well, I'd suggest we find a nicer way to broadcast
# it around.
if getattr(signal, "SIGHUP"):
signal.signal(signal.SIGHUP, sighup)
else:
handler = logging.StreamHandler()
handler.setFormatter(formatter)

View file

@ -19,13 +19,15 @@ from ._base import Config
class MetricsConfig(Config):
def read_config(self, config):
self.enable_metrics = config["enable_metrics"]
self.report_stats = config.get("report_stats", None)
self.metrics_port = config.get("metrics_port")
self.metrics_bind_host = config.get("metrics_bind_host", "127.0.0.1")
def default_config(self, config_dir_path, server_name):
return """\
def default_config(self, report_stats=None, **kwargs):
suffix = "" if report_stats is None else "report_stats: %(report_stats)s\n"
return ("""\
## Metrics ###
# Enable collection and rendering of performance metrics
enable_metrics: False
"""
""" + suffix) % locals()

View file

@ -27,7 +27,7 @@ class RatelimitConfig(Config):
self.federation_rc_reject_limit = config["federation_rc_reject_limit"]
self.federation_rc_concurrent = config["federation_rc_concurrent"]
def default_config(self, config_dir_path, server_name):
def default_config(self, **kwargs):
return """\
## Ratelimiting ##

View file

@ -34,7 +34,7 @@ class RegistrationConfig(Config):
self.registration_shared_secret = config.get("registration_shared_secret")
self.macaroon_secret_key = config.get("macaroon_secret_key")
def default_config(self, config_dir, server_name):
def default_config(self, **kwargs):
registration_shared_secret = random_string_with_symbols(50)
macaroon_secret_key = random_string_with_symbols(50)
return """\

View file

@ -60,7 +60,7 @@ class ContentRepositoryConfig(Config):
config["thumbnail_sizes"]
)
def default_config(self, config_dir_path, server_name):
def default_config(self, **kwargs):
media_store = self.default_path("media_store")
uploads_path = self.default_path("uploads")
return """

View file

@ -41,7 +41,7 @@ class SAML2Config(Config):
self.saml2_config_path = None
self.saml2_idp_redirect_url = None
def default_config(self, config_dir_path, server_name):
def default_config(self, config_dir_path, server_name, **kwargs):
return """
# Enable SAML2 for registration and login. Uses pysaml2
# config_path: Path to the sp_conf.py configuration file

View file

@ -117,7 +117,7 @@ class ServerConfig(Config):
self.content_addr = content_addr
def default_config(self, config_dir_path, server_name):
def default_config(self, server_name, **kwargs):
if ":" in server_name:
bind_port = int(server_name.split(":")[1])
unsecure_port = bind_port - 400

View file

@ -42,7 +42,15 @@ class TlsConfig(Config):
config.get("tls_dh_params_path"), "tls_dh_params"
)
def default_config(self, config_dir_path, server_name):
# This config option applies to non-federation HTTP clients
# (e.g. for talking to recaptcha, identity servers, and such)
# It should never be used in production, and is intended for
# use only when running tests.
self.use_insecure_ssl_client_just_for_testing_do_not_use = config.get(
"use_insecure_ssl_client_just_for_testing_do_not_use"
)
def default_config(self, config_dir_path, server_name, **kwargs):
base_key_name = os.path.join(config_dir_path, server_name)
tls_certificate_path = base_key_name + ".tls.crt"

View file

@ -22,7 +22,7 @@ class VoipConfig(Config):
self.turn_shared_secret = config["turn_shared_secret"]
self.turn_user_lifetime = self.parse_duration(config["turn_user_lifetime"])
def default_config(self, config_dir_path, server_name):
def default_config(self, **kwargs):
return """\
## Turn ##

View file

@ -228,10 +228,9 @@ class Keyring(object):
def do_iterations():
merged_results = {}
missing_keys = {
group.server_name: set(group.key_ids)
for group in group_id_to_group.values()
}
missing_keys = {}
for group in group_id_to_group.values():
missing_keys.setdefault(group.server_name, set()).union(group.key_ids)
for fn in key_fetch_fns:
results = yield fn(missing_keys.items())
@ -470,7 +469,7 @@ class Keyring(object):
continue
(response, tls_certificate) = yield fetch_server_key(
server_name, self.hs.tls_context_factory,
server_name, self.hs.tls_server_context_factory,
path=(b"/_matrix/key/v2/server/%s" % (
urllib.quote(requested_key_id),
)).encode("ascii"),
@ -604,7 +603,7 @@ class Keyring(object):
# Try to fetch the key from the remote server.
(response, tls_certificate) = yield fetch_server_key(
server_name, self.hs.tls_context_factory
server_name, self.hs.tls_server_context_factory
)
# Check the response.

View file

@ -19,7 +19,6 @@ from ._base import BaseHandler
from synapse.api.constants import LoginType
from synapse.types import UserID
from synapse.api.errors import LoginError, Codes
from synapse.http.client import SimpleHttpClient
from synapse.util.async import run_on_reactor
from twisted.web.client import PartialDownloadError
@ -187,7 +186,7 @@ class AuthHandler(BaseHandler):
# TODO: get this from the homeserver rather than creating a new one for
# each request
try:
client = SimpleHttpClient(self.hs)
client = self.hs.get_simple_http_client()
resp_body = yield client.post_urlencoded_get_json(
self.hs.config.recaptcha_siteverify_api,
args={

View file

@ -125,60 +125,72 @@ class FederationHandler(BaseHandler):
)
if not is_in_room and not event.internal_metadata.is_outlier():
logger.debug("Got event for room we're not in.")
current_state = state
event_ids = set()
if state:
event_ids |= {e.event_id for e in state}
if auth_chain:
event_ids |= {e.event_id for e in auth_chain}
try:
event_stream_id, max_stream_id = yield self._persist_auth_tree(
auth_chain, state, event
)
except AuthError as e:
raise FederationError(
"ERROR",
e.code,
e.msg,
affected=event.event_id,
)
seen_ids = set(
(yield self.store.have_events(event_ids)).keys()
)
else:
event_ids = set()
if state:
event_ids |= {e.event_id for e in state}
if auth_chain:
event_ids |= {e.event_id for e in auth_chain}
if state and auth_chain is not None:
# If we have any state or auth_chain given to us by the replication
# layer, then we should handle them (if we haven't before.)
event_infos = []
for e in itertools.chain(auth_chain, state):
if e.event_id in seen_ids:
continue
e.internal_metadata.outlier = True
auth_ids = [e_id for e_id, _ in e.auth_events]
auth = {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids
}
event_infos.append({
"event": e,
"auth_events": auth,
})
seen_ids.add(e.event_id)
yield self._handle_new_events(
origin,
event_infos,
outliers=True
seen_ids = set(
(yield self.store.have_events(event_ids)).keys()
)
try:
_, event_stream_id, max_stream_id = yield self._handle_new_event(
origin,
event,
state=state,
backfilled=backfilled,
current_state=current_state,
)
except AuthError as e:
raise FederationError(
"ERROR",
e.code,
e.msg,
affected=event.event_id,
)
if state and auth_chain is not None:
# If we have any state or auth_chain given to us by the replication
# layer, then we should handle them (if we haven't before.)
event_infos = []
for e in itertools.chain(auth_chain, state):
if e.event_id in seen_ids:
continue
e.internal_metadata.outlier = True
auth_ids = [e_id for e_id, _ in e.auth_events]
auth = {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids
}
event_infos.append({
"event": e,
"auth_events": auth,
})
seen_ids.add(e.event_id)
yield self._handle_new_events(
origin,
event_infos,
outliers=True
)
try:
_, event_stream_id, max_stream_id = yield self._handle_new_event(
origin,
event,
state=state,
backfilled=backfilled,
current_state=current_state,
)
except AuthError as e:
raise FederationError(
"ERROR",
e.code,
e.msg,
affected=event.event_id,
)
# if we're receiving valid events from an origin,
# it's probably a good idea to mark it as not in retry-state
@ -649,35 +661,8 @@ class FederationHandler(BaseHandler):
# FIXME
pass
ev_infos = []
for e in itertools.chain(state, auth_chain):
if e.event_id == event.event_id:
continue
e.internal_metadata.outlier = True
auth_ids = [e_id for e_id, _ in e.auth_events]
ev_infos.append({
"event": e,
"auth_events": {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids
}
})
yield self._handle_new_events(origin, ev_infos, outliers=True)
auth_ids = [e_id for e_id, _ in event.auth_events]
auth_events = {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids
}
_, event_stream_id, max_stream_id = yield self._handle_new_event(
origin,
new_event,
state=state,
current_state=state,
auth_events=auth_events,
event_stream_id, max_stream_id = yield self._persist_auth_tree(
auth_chain, state, event
)
with PreserveLoggingContext():
@ -1026,6 +1011,76 @@ class FederationHandler(BaseHandler):
is_new_state=(not outliers and not backfilled),
)
@defer.inlineCallbacks
def _persist_auth_tree(self, auth_events, state, event):
"""Checks the auth chain is valid (and passes auth checks) for the
state and event. Then persists the auth chain and state atomically.
Persists the event seperately.
Returns:
2-tuple of (event_stream_id, max_stream_id) from the persist_event
call for `event`
"""
events_to_context = {}
for e in itertools.chain(auth_events, state):
ctx = yield self.state_handler.compute_event_context(
e, outlier=True,
)
events_to_context[e.event_id] = ctx
e.internal_metadata.outlier = True
event_map = {
e.event_id: e
for e in auth_events
}
create_event = None
for e in auth_events:
if (e.type, e.state_key) == (EventTypes.Create, ""):
create_event = e
break
for e in itertools.chain(auth_events, state, [event]):
auth_for_e = {
(event_map[e_id].type, event_map[e_id].state_key): event_map[e_id]
for e_id, _ in e.auth_events
}
if create_event:
auth_for_e[(EventTypes.Create, "")] = create_event
try:
self.auth.check(e, auth_events=auth_for_e)
except AuthError as err:
logger.warn(
"Rejecting %s because %s",
e.event_id, err.msg
)
if e == event:
raise
events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
yield self.store.persist_events(
[
(e, events_to_context[e.event_id])
for e in itertools.chain(auth_events, state)
],
is_new_state=False,
)
new_event_context = yield self.state_handler.compute_event_context(
event, old_state=state, outlier=False,
)
event_stream_id, max_stream_id = yield self.store.persist_event(
event, new_event_context,
backfilled=False,
is_new_state=True,
current_state=state,
)
defer.returnValue((event_stream_id, max_stream_id))
@defer.inlineCallbacks
def _prep_event(self, origin, event, state=None, backfilled=False,
current_state=None, auth_events=None):
@ -1456,52 +1511,3 @@ class FederationHandler(BaseHandler):
},
"missing": [e.event_id for e in missing_locals],
})
@defer.inlineCallbacks
def _handle_auth_events(self, origin, auth_events):
auth_ids_to_deferred = {}
def process_auth_ev(ev):
auth_ids = [e_id for e_id, _ in ev.auth_events]
prev_ds = [
auth_ids_to_deferred[i]
for i in auth_ids
if i in auth_ids_to_deferred
]
d = defer.Deferred()
auth_ids_to_deferred[ev.event_id] = d
@defer.inlineCallbacks
def f(*_):
ev.internal_metadata.outlier = True
try:
auth = {
(e.type, e.state_key): e for e in auth_events
if e.event_id in auth_ids
}
yield self._handle_new_event(
origin, ev, auth_events=auth
)
except:
logger.exception(
"Failed to handle auth event %s",
ev.event_id,
)
d.callback(None)
if prev_ds:
dx = defer.DeferredList(prev_ds)
dx.addBoth(f)
else:
f()
for e in auth_events:
process_auth_ev(e)
yield defer.DeferredList(auth_ids_to_deferred.values())

View file

@ -16,13 +16,13 @@
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import RoomError, SynapseError
from synapse.api.errors import SynapseError
from synapse.streams.config import PaginationConfig
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
from synapse.util import unwrapFirstError
from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import UserID, RoomStreamToken
from synapse.types import UserID, RoomStreamToken, StreamToken
from ._base import BaseHandler
@ -71,7 +71,7 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks
def get_messages(self, user_id=None, room_id=None, pagin_config=None,
feedback=False, as_client_event=True):
as_client_event=True):
"""Get messages in a room.
Args:
@ -79,26 +79,52 @@ class MessageHandler(BaseHandler):
room_id (str): The room they want messages from.
pagin_config (synapse.api.streams.PaginationConfig): The pagination
config rules to apply, if any.
feedback (bool): True to get compressed feedback with the messages
as_client_event (bool): True to get events in client-server format.
Returns:
dict: Pagination API results
"""
yield self.auth.check_joined_room(room_id, user_id)
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
data_source = self.hs.get_event_sources().sources["room"]
if not pagin_config.from_token:
if pagin_config.from_token:
room_token = pagin_config.from_token.room_key
else:
pagin_config.from_token = (
yield self.hs.get_event_sources().get_current_token(
direction='b'
)
)
room_token = pagin_config.from_token.room_key
room_token = RoomStreamToken.parse(pagin_config.from_token.room_key)
room_token = RoomStreamToken.parse(room_token)
if room_token.topological is None:
raise SynapseError(400, "Invalid token")
pagin_config.from_token = pagin_config.from_token.copy_and_replace(
"room_key", str(room_token)
)
source_config = pagin_config.get_source_config("room")
if member_event.membership == Membership.LEAVE:
# If they have left the room then clamp the token to be before
# they left the room
leave_token = yield self.store.get_topological_token_for_event(
member_event.event_id
)
leave_token = RoomStreamToken.parse(leave_token)
if leave_token.topological < room_token.topological:
source_config.from_key = str(leave_token)
if source_config.direction == "f":
if source_config.to_key is None:
source_config.to_key = str(leave_token)
else:
to_token = RoomStreamToken.parse(source_config.to_key)
if leave_token.topological < to_token.topological:
source_config.to_key = str(leave_token)
yield self.hs.get_handlers().federation_handler.maybe_backfill(
room_id, room_token.topological
)
@ -106,7 +132,7 @@ class MessageHandler(BaseHandler):
user = UserID.from_string(user_id)
events, next_key = yield data_source.get_pagination_rows(
user, pagin_config.get_source_config("room"), room_id
user, source_config, room_id
)
next_token = pagin_config.from_token.copy_and_replace(
@ -255,29 +281,26 @@ class MessageHandler(BaseHandler):
Raises:
SynapseError if something went wrong.
"""
have_joined = yield self.auth.check_joined_room(room_id, user_id)
if not have_joined:
raise RoomError(403, "User not in room.")
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
if member_event.membership == Membership.JOIN:
data = yield self.state_handler.get_current_state(
room_id, event_type, state_key
)
elif member_event.membership == Membership.LEAVE:
key = (event_type, state_key)
room_state = yield self.store.get_state_for_events(
room_id, [member_event.event_id], [key]
)
data = room_state[member_event.event_id].get(key)
data = yield self.state_handler.get_current_state(
room_id, event_type, state_key
)
defer.returnValue(data)
@defer.inlineCallbacks
def get_feedback(self, event_id):
# yield self.auth.check_joined_room(room_id, user_id)
# Pull out the feedback from the db
fb = yield self.store.get_feedback(event_id)
if fb:
defer.returnValue(fb)
defer.returnValue(None)
@defer.inlineCallbacks
def get_state_events(self, user_id, room_id):
"""Retrieve all state events for a given room.
"""Retrieve all state events for a given room. If the user is
joined to the room then return the current state. If the user has
left the room return the state events from when they left.
Args:
user_id(str): The user requesting state events.
@ -285,18 +308,23 @@ class MessageHandler(BaseHandler):
Returns:
A list of dicts representing state events. [{}, {}, {}]
"""
yield self.auth.check_joined_room(room_id, user_id)
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
if member_event.membership == Membership.JOIN:
room_state = yield self.state_handler.get_current_state(room_id)
elif member_event.membership == Membership.LEAVE:
room_state = yield self.store.get_state_for_events(
room_id, [member_event.event_id], None
)
room_state = room_state[member_event.event_id]
# TODO: This is duplicating logic from snapshot_all_rooms
current_state = yield self.state_handler.get_current_state(room_id)
now = self.clock.time_msec()
defer.returnValue(
[serialize_event(c, now) for c in current_state.values()]
[serialize_event(c, now) for c in room_state.values()]
)
@defer.inlineCallbacks
def snapshot_all_rooms(self, user_id=None, pagin_config=None,
feedback=False, as_client_event=True):
def snapshot_all_rooms(self, user_id=None, pagin_config=None, as_client_event=True):
"""Retrieve a snapshot of all rooms the user is invited or has joined.
This snapshot may include messages for all rooms where the user is
@ -306,7 +334,6 @@ class MessageHandler(BaseHandler):
user_id (str): The ID of the user making the request.
pagin_config (synapse.api.streams.PaginationConfig): The pagination
config used to determine how many messages *PER ROOM* to return.
feedback (bool): True to get feedback along with these messages.
as_client_event (bool): True to get events in client-server format.
Returns:
A list of dicts with "room_id" and "membership" keys for all rooms
@ -316,7 +343,9 @@ class MessageHandler(BaseHandler):
"""
room_list = yield self.store.get_rooms_for_user_where_membership_is(
user_id=user_id,
membership_list=[Membership.INVITE, Membership.JOIN]
membership_list=[
Membership.INVITE, Membership.JOIN, Membership.LEAVE
]
)
user = UserID.from_string(user_id)
@ -358,19 +387,32 @@ class MessageHandler(BaseHandler):
rooms_ret.append(d)
if event.membership != Membership.JOIN:
if event.membership not in (Membership.JOIN, Membership.LEAVE):
return
try:
if event.membership == Membership.JOIN:
room_end_token = now_token.room_key
deferred_room_state = self.state_handler.get_current_state(
event.room_id
)
elif event.membership == Membership.LEAVE:
room_end_token = "s%d" % (event.stream_ordering,)
deferred_room_state = self.store.get_state_for_events(
event.room_id, [event.event_id], None
)
deferred_room_state.addCallback(
lambda states: states[event.event_id]
)
(messages, token), current_state = yield defer.gatherResults(
[
self.store.get_recent_events_for_room(
event.room_id,
limit=limit,
end_token=now_token.room_key,
),
self.state_handler.get_current_state(
event.room_id
end_token=room_end_token,
),
deferred_room_state,
]
).addErrback(unwrapFirstError)
@ -417,15 +459,85 @@ class MessageHandler(BaseHandler):
defer.returnValue(ret)
@defer.inlineCallbacks
def room_initial_sync(self, user_id, room_id, pagin_config=None,
feedback=False):
current_state = yield self.state.get_current_state(
room_id=room_id,
def room_initial_sync(self, user_id, room_id, pagin_config=None):
"""Capture the a snapshot of a room. If user is currently a member of
the room this will be what is currently in the room. If the user left
the room this will be what was in the room when they left.
Args:
user_id(str): The user to get a snapshot for.
room_id(str): The room to get a snapshot of.
pagin_config(synapse.streams.config.PaginationConfig):
The pagination config used to determine how many messages to
return.
Raises:
AuthError if the user wasn't in the room.
Returns:
A JSON serialisable dict with the snapshot of the room.
"""
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
if member_event.membership == Membership.JOIN:
result = yield self._room_initial_sync_joined(
user_id, room_id, pagin_config, member_event
)
elif member_event.membership == Membership.LEAVE:
result = yield self._room_initial_sync_parted(
user_id, room_id, pagin_config, member_event
)
defer.returnValue(result)
@defer.inlineCallbacks
def _room_initial_sync_parted(self, user_id, room_id, pagin_config,
member_event):
room_state = yield self.store.get_state_for_events(
member_event.room_id, [member_event.event_id], None
)
yield self.auth.check_joined_room(
room_id, user_id,
current_state=current_state
room_state = room_state[member_event.event_id]
limit = pagin_config.limit if pagin_config else None
if limit is None:
limit = 10
stream_token = yield self.store.get_stream_token_for_event(
member_event.event_id
)
messages, token = yield self.store.get_recent_events_for_room(
room_id,
limit=limit,
end_token=stream_token
)
messages = yield self._filter_events_for_client(
user_id, room_id, messages
)
start_token = StreamToken(token[0], 0, 0, 0)
end_token = StreamToken(token[1], 0, 0, 0)
time_now = self.clock.time_msec()
defer.returnValue({
"membership": member_event.membership,
"room_id": room_id,
"messages": {
"chunk": [serialize_event(m, time_now) for m in messages],
"start": start_token.to_string(),
"end": end_token.to_string(),
},
"state": [serialize_event(s, time_now) for s in room_state.values()],
"presence": [],
"receipts": [],
})
@defer.inlineCallbacks
def _room_initial_sync_joined(self, user_id, room_id, pagin_config,
member_event):
current_state = yield self.state.get_current_state(
room_id=room_id,
)
# TODO(paul): I wish I was called with user objects not user_id
@ -439,8 +551,6 @@ class MessageHandler(BaseHandler):
for x in current_state.values()
]
member_event = current_state.get((EventTypes.Member, user_id,))
now_token = yield self.hs.get_event_sources().get_current_token()
limit = pagin_config.limit if pagin_config else None

View file

@ -25,7 +25,6 @@ from synapse.api.constants import (
from synapse.api.errors import StoreError, SynapseError
from synapse.util import stringutils, unwrapFirstError
from synapse.util.async import run_on_reactor
from synapse.events.utils import serialize_event
from collections import OrderedDict
import logging
@ -39,7 +38,7 @@ class RoomCreationHandler(BaseHandler):
PRESETS_DICT = {
RoomCreationPreset.PRIVATE_CHAT: {
"join_rules": JoinRules.INVITE,
"history_visibility": "invited",
"history_visibility": "shared",
"original_invitees_have_ops": False,
},
RoomCreationPreset.PUBLIC_CHAT: {
@ -159,6 +158,7 @@ class RoomCreationHandler(BaseHandler):
invite_list=invite_list,
initial_state=initial_state,
creation_content=creation_content,
room_alias=room_alias,
)
msg_handler = self.hs.get_handlers().message_handler
@ -206,7 +206,8 @@ class RoomCreationHandler(BaseHandler):
defer.returnValue(result)
def _create_events_for_new_room(self, creator, room_id, preset_config,
invite_list, initial_state, creation_content):
invite_list, initial_state, creation_content,
room_alias):
config = RoomCreationHandler.PRESETS_DICT[preset_config]
creator_id = creator.to_string()
@ -276,6 +277,14 @@ class RoomCreationHandler(BaseHandler):
returned_events.append(power_levels_event)
if room_alias and (EventTypes.CanonicalAlias, '') not in initial_state:
room_alias_event = create(
etype=EventTypes.CanonicalAlias,
content={"alias": room_alias.to_string()},
)
returned_events.append(room_alias_event)
if (EventTypes.JoinRules, '') not in initial_state:
join_rules_event = create(
etype=EventTypes.JoinRules,
@ -346,41 +355,6 @@ class RoomMemberHandler(BaseHandler):
if remotedomains is not None:
remotedomains.add(member.domain)
@defer.inlineCallbacks
def get_room_members_as_pagination_chunk(self, room_id=None, user_id=None,
limit=0, start_tok=None,
end_tok=None):
"""Retrieve a list of room members in the room.
Args:
room_id (str): The room to get the member list for.
user_id (str): The ID of the user making the request.
limit (int): The max number of members to return.
start_tok (str): Optional. The start token if known.
end_tok (str): Optional. The end token if known.
Returns:
dict: A Pagination streamable dict.
Raises:
SynapseError if something goes wrong.
"""
yield self.auth.check_joined_room(room_id, user_id)
member_list = yield self.store.get_room_members(room_id=room_id)
time_now = self.clock.time_msec()
event_list = [
serialize_event(entry, time_now)
for entry in member_list
]
chunk_data = {
"start": "START", # FIXME (erikj): START is no longer valid
"end": "END",
"chunk": event_list
}
# TODO honor Pagination stream params
# TODO snapshot this list to return on subsequent requests when
# paginating
defer.returnValue(chunk_data)
@defer.inlineCallbacks
def change_membership(self, event, context, do_auth=True):
""" Change the membership status of a user in a room.
@ -532,32 +506,6 @@ class RoomMemberHandler(BaseHandler):
"user_joined_room", user=user, room_id=room_id
)
@defer.inlineCallbacks
def _should_invite_join(self, room_id, prev_state, do_auth):
logger.debug("_should_invite_join: room_id: %s", room_id)
# XXX: We don't do an auth check if we are doing an invite
# join dance for now, since we're kinda implicitly checking
# that we are allowed to join when we decide whether or not we
# need to do the invite/join dance.
# Only do an invite join dance if a) we were invited,
# b) the person inviting was from a differnt HS and c) we are
# not currently in the room
room_host = None
if prev_state and prev_state.membership == Membership.INVITE:
room = yield self.store.get_room(room_id)
inviter = UserID.from_string(
prev_state.sender
)
is_remote_invite_join = not self.hs.is_mine(inviter) and not room
room_host = inviter.domain
else:
is_remote_invite_join = False
defer.returnValue((is_remote_invite_join, room_host))
@defer.inlineCallbacks
def get_joined_rooms_for_user(self, user):
"""Returns a list of roomids that the user has any of the given
@ -650,7 +598,6 @@ class RoomEventSource(object):
to_key=config.to_key,
direction=config.direction,
limit=config.limit,
with_feedback=True
)
defer.returnValue((events, next_key))

View file

@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE
from synapse.api.errors import CodeMessageException
from synapse.util.logcontext import preserve_context_over_fn
@ -19,7 +21,7 @@ import synapse.metrics
from canonicaljson import encode_canonical_json
from twisted.internet import defer, reactor
from twisted.internet import defer, reactor, ssl
from twisted.web.client import (
Agent, readBody, FileBodyProducer, PartialDownloadError,
HTTPConnectionPool,
@ -59,7 +61,12 @@ class SimpleHttpClient(object):
# 'like a browser'
pool = HTTPConnectionPool(reactor)
pool.maxPersistentPerHost = 10
self.agent = Agent(reactor, pool=pool)
self.agent = Agent(
reactor,
pool=pool,
connectTimeout=15,
contextFactory=hs.get_http_client_context_factory()
)
self.version_string = hs.version_string
def request(self, method, uri, *args, **kwargs):
@ -252,3 +259,18 @@ def _print_ex(e):
_print_ex(ex)
else:
logger.exception(e)
class InsecureInterceptableContextFactory(ssl.ContextFactory):
"""
Factory for PyOpenSSL SSL contexts which accepts any certificate for any domain.
Do not use this since it allows an attacker to intercept your communications.
"""
def __init__(self):
self._context = SSL.Context(SSL.SSLv23_METHOD)
self._context.set_verify(VERIFY_NONE, lambda *_: None)
def getContext(self, hostname, port):
return self._context

View file

@ -57,14 +57,14 @@ incoming_responses_counter = metrics.register_counter(
class MatrixFederationEndpointFactory(object):
def __init__(self, hs):
self.tls_context_factory = hs.tls_context_factory
self.tls_server_context_factory = hs.tls_server_context_factory
def endpointForURI(self, uri):
destination = uri.netloc
return matrix_federation_endpoint(
reactor, destination, timeout=10,
ssl_context_factory=self.tls_context_factory
ssl_context_factory=self.tls_server_context_factory
)

View file

@ -18,18 +18,18 @@ from distutils.version import LooseVersion
logger = logging.getLogger(__name__)
REQUIREMENTS = {
"frozendict>=0.4": ["frozendict"],
"unpaddedbase64>=1.0.1": ["unpaddedbase64>=1.0.1"],
"canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"],
"signedjson>=1.0.0": ["signedjson>=1.0.0"],
"Twisted>=15.1.0": ["twisted>=15.1.0"],
"pynacl>=0.3.0": ["nacl>=0.3.0", "nacl.bindings"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
"Twisted>=15.1.0": ["twisted>=15.1.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
"pyyaml": ["yaml"],
"pyasn1": ["pyasn1"],
"pynacl>=0.3.0": ["nacl>=0.3.0"],
"daemonize": ["daemonize"],
"py-bcrypt": ["bcrypt"],
"frozendict>=0.4": ["frozendict"],
"pillow": ["PIL"],
"pydenticon": ["pydenticon"],
"ujson": ["ujson"],
@ -60,7 +60,10 @@ DEPENDENCY_LINKS = {
class MissingRequirementError(Exception):
pass
def __init__(self, message, module_name, dependency):
super(MissingRequirementError, self).__init__(message)
self.module_name = module_name
self.dependency = dependency
def check_requirements(config=None):
@ -88,7 +91,7 @@ def check_requirements(config=None):
)
raise MissingRequirementError(
"Can't import %r which is part of %r"
% (module_name, dependency)
% (module_name, dependency), module_name, dependency
)
version = getattr(module, "__version__", None)
file_path = getattr(module, "__file__", None)
@ -101,23 +104,25 @@ def check_requirements(config=None):
if version is None:
raise MissingRequirementError(
"Version of %r isn't set as __version__ of module %r"
% (dependency, module_name)
% (dependency, module_name), module_name, dependency
)
if LooseVersion(version) < LooseVersion(required_version):
raise MissingRequirementError(
"Version of %r in %r is too old. %r < %r"
% (dependency, file_path, version, required_version)
% (dependency, file_path, version, required_version),
module_name, dependency
)
elif version_test == "==":
if version is None:
raise MissingRequirementError(
"Version of %r isn't set as __version__ of module %r"
% (dependency, module_name)
% (dependency, module_name), module_name, dependency
)
if LooseVersion(version) != LooseVersion(required_version):
raise MissingRequirementError(
"Unexpected version of %r in %r. %r != %r"
% (dependency, file_path, version, required_version)
% (dependency, file_path, version, required_version),
module_name, dependency
)

View file

@ -26,14 +26,12 @@ class InitialSyncRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
user, _ = yield self.auth.get_user_by_req(request)
with_feedback = "feedback" in request.args
as_client_event = "raw" not in request.args
pagination_config = PaginationConfig.from_request(request)
handler = self.handlers.message_handler
content = yield handler.snapshot_all_rooms(
user_id=user.to_string(),
pagin_config=pagination_config,
feedback=with_feedback,
as_client_event=as_client_event
)

View file

@ -290,12 +290,18 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens)
user, _ = yield self.auth.get_user_by_req(request)
handler = self.handlers.room_member_handler
members = yield handler.get_room_members_as_pagination_chunk(
handler = self.handlers.message_handler
events = yield handler.get_state_events(
room_id=room_id,
user_id=user.to_string())
user_id=user.to_string(),
)
for event in members["chunk"]:
chunk = []
for event in events:
if event["type"] != EventTypes.Member:
continue
chunk.append(event)
# FIXME: should probably be state_key here, not user_id
target_user = UserID.from_string(event["user_id"])
# Presence is an optional cache; don't fail if we can't fetch it
@ -308,7 +314,9 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
except:
pass
defer.returnValue((200, members))
defer.returnValue((200, {
"chunk": chunk
}))
# TODO: Needs unit testing
@ -321,14 +329,12 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
pagination_config = PaginationConfig.from_request(
request, default_limit=10,
)
with_feedback = "feedback" in request.args
as_client_event = "raw" not in request.args
handler = self.handlers.message_handler
msgs = yield handler.get_messages(
room_id=room_id,
user_id=user.to_string(),
pagin_config=pagination_config,
feedback=with_feedback,
as_client_event=as_client_event
)

View file

@ -15,6 +15,7 @@
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet
from ._base import client_v2_pattern
@ -41,6 +42,9 @@ class ReceiptRestServlet(RestServlet):
def on_POST(self, request, room_id, receipt_type, event_id):
user, _ = yield self.auth.get_user_by_req(request)
if receipt_type != "m.read":
raise SynapseError(400, "Receipt type must be 'm.read'")
yield self.receipts_handler.received_client_receipt(
room_id,
receipt_type,

View file

@ -19,7 +19,9 @@
# partial one for unit test mocking.
# Imports required for the default HomeServer() implementation
from twisted.web.client import BrowserLikePolicyForHTTPS
from synapse.federation import initialize_http_replication
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
from synapse.notifier import Notifier
from synapse.api.auth import Auth
from synapse.handlers import Handlers
@ -87,6 +89,8 @@ class BaseHomeServer(object):
'pusherpool',
'event_builder_factory',
'filtering',
'http_client_context_factory',
'simple_http_client',
]
def __init__(self, hostname, **kwargs):
@ -174,6 +178,17 @@ class HomeServer(BaseHomeServer):
def build_auth(self):
return Auth(self)
def build_http_client_context_factory(self):
config = self.get_config()
return (
InsecureInterceptableContextFactory()
if config.use_insecure_ssl_client_just_for_testing_do_not_use
else BrowserLikePolicyForHTTPS()
)
def build_simple_http_client(self):
return SimpleHttpClient(self)
def build_v1auth(self):
orf = Auth(self)
# Matrix spec makes no reference to what HTTP status code is returned,

View file

@ -17,7 +17,6 @@
from twisted.internet import defer
from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
@ -32,10 +31,6 @@ import hashlib
logger = logging.getLogger(__name__)
def _get_state_key_from_event(event):
return event.state_key
KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
@ -119,8 +114,6 @@ class StateHandler(object):
Returns:
an EventContext
"""
yield run_on_reactor()
context = EventContext()
if outlier:

View file

@ -54,7 +54,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 23
SCHEMA_VERSION = 24
dir_path = os.path.abspath(os.path.dirname(__file__))
@ -126,6 +126,27 @@ class DataStore(RoomMemberStore, RoomStore,
lock=False,
)
@defer.inlineCallbacks
def count_daily_users(self):
"""
Counts the number of users who used this homeserver in the last 24 hours.
"""
def _count_users(txn):
txn.execute(
"SELECT COUNT(DISTINCT user_id) AS users"
" FROM user_ips"
" WHERE last_seen > ?",
# This is close enough to a day for our purposes.
(int(self._clock.time_msec()) - (1000 * 60 * 60 * 24),)
)
rows = self.cursor_to_dict(txn)
if rows:
return rows[0]["users"]
return 0
ret = yield self.runInteraction("count_users", _count_users)
defer.returnValue(ret)
def get_user_ip_and_agents(self, user):
return self._simple_select_list(
table="user_ips",

View file

@ -25,8 +25,6 @@ from util.id_generators import IdGenerator, StreamIdGenerator
from twisted.internet import defer
from collections import namedtuple
import sys
import time
import threading
@ -376,9 +374,6 @@ class SQLBaseStore(object):
return self.runInteraction(desc, interaction)
def _execute_and_decode(self, desc, query, *args):
return self._execute(desc, self.cursor_to_dict, query, *args)
# "Simple" SQL API methods that operate on a single table with no JOINs,
# no complex WHERE clauses, just a dict of values for columns.
@ -691,37 +686,6 @@ class SQLBaseStore(object):
return dict(zip(retcols, row))
def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None,
retcols=None, allow_none=False,
desc="_simple_selectupdate_one"):
""" Combined SELECT then UPDATE."""
def func(txn):
ret = None
if retcols:
ret = self._simple_select_one_txn(
txn,
table=table,
keyvalues=keyvalues,
retcols=retcols,
allow_none=allow_none,
)
if updatevalues:
self._simple_update_one_txn(
txn,
table=table,
keyvalues=keyvalues,
updatevalues=updatevalues,
)
# if txn.rowcount == 0:
# raise StoreError(404, "No row found")
if txn.rowcount > 1:
raise StoreError(500, "More than one row matched")
return ret
return self.runInteraction(desc, func)
def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"):
"""Executes a DELETE query on the named table, expecting to delete a
single row.
@ -743,16 +707,6 @@ class SQLBaseStore(object):
raise StoreError(500, "more than one row matched")
return self.runInteraction(desc, func)
def _simple_delete(self, table, keyvalues, desc="_simple_delete"):
"""Executes a DELETE query on the named table.
Args:
table : string giving the table name
keyvalues : dict of column names and values to select the row with
"""
return self.runInteraction(desc, self._simple_delete_txn)
def _simple_delete_txn(self, txn, table, keyvalues):
sql = "DELETE FROM %s WHERE %s" % (
table,
@ -761,24 +715,6 @@ class SQLBaseStore(object):
return txn.execute(sql, keyvalues.values())
def _simple_max_id(self, table):
"""Executes a SELECT query on the named table, expecting to return the
max value for the column "id".
Args:
table : string giving the table name
"""
sql = "SELECT MAX(id) AS id FROM %s" % table
def func(txn):
txn.execute(sql)
max_id = self.cursor_to_dict(txn)[0]["id"]
if max_id is None:
return 0
return max_id
return self.runInteraction("_simple_max_id", func)
def get_next_stream_id(self):
with self._next_stream_id_lock:
i = self._next_stream_id
@ -791,129 +727,3 @@ class _RollbackButIsFineException(Exception):
something went wrong.
"""
pass
class Table(object):
""" A base class used to store information about a particular table.
"""
table_name = None
""" str: The name of the table """
fields = None
""" list: The field names """
EntryType = None
""" Type: A tuple type used to decode the results """
_select_where_clause = "SELECT %s FROM %s WHERE %s"
_select_clause = "SELECT %s FROM %s"
_insert_clause = "REPLACE INTO %s (%s) VALUES (%s)"
@classmethod
def select_statement(cls, where_clause=None):
"""
Args:
where_clause (str): The WHERE clause to use.
Returns:
str: An SQL statement to select rows from the table with the given
WHERE clause.
"""
if where_clause:
return cls._select_where_clause % (
", ".join(cls.fields),
cls.table_name,
where_clause
)
else:
return cls._select_clause % (
", ".join(cls.fields),
cls.table_name,
)
@classmethod
def insert_statement(cls):
return cls._insert_clause % (
cls.table_name,
", ".join(cls.fields),
", ".join(["?"] * len(cls.fields)),
)
@classmethod
def decode_single_result(cls, results):
""" Given an iterable of tuples, return a single instance of
`EntryType` or None if the iterable is empty
Args:
results (list): The results list to convert to `EntryType`
Returns:
EntryType: An instance of `EntryType`
"""
results = list(results)
if results:
return cls.EntryType(*results[0])
else:
return None
@classmethod
def decode_results(cls, results):
""" Given an iterable of tuples, return a list of `EntryType`
Args:
results (list): The results list to convert to `EntryType`
Returns:
list: A list of `EntryType`
"""
return [cls.EntryType(*row) for row in results]
@classmethod
def get_fields_string(cls, prefix=None):
if prefix:
to_join = ("%s.%s" % (prefix, f) for f in cls.fields)
else:
to_join = cls.fields
return ", ".join(to_join)
class JoinHelper(object):
""" Used to help do joins on tables by looking at the tables' fields and
creating a list of unique fields to use with SELECTs and a namedtuple
to dump the results into.
Attributes:
tables (list): List of `Table` classes
EntryType (type)
"""
def __init__(self, *tables):
self.tables = tables
res = []
for table in self.tables:
res += [f for f in table.fields if f not in res]
self.EntryType = namedtuple("JoinHelperEntry", res)
def get_fields(self, **prefixes):
"""Get a string representing a list of fields for use in SELECT
statements with the given prefixes applied to each.
For example::
JoinHelper(PdusTable, StateTable).get_fields(
PdusTable="pdus",
StateTable="state"
)
"""
res = []
for field in self.EntryType._fields:
for table in self.tables:
if field in table.fields:
res.append("%s.%s" % (prefixes[table.__name__], field))
break
return ", ".join(res)
def decode_results(self, rows):
return [self.EntryType(*row) for row in rows]

View file

@ -154,98 +154,6 @@ class EventFederationStore(SQLBaseStore):
return results
def _get_latest_state_in_room(self, txn, room_id, type, state_key):
event_ids = self._simple_select_onecol_txn(
txn,
table="state_forward_extremities",
keyvalues={
"room_id": room_id,
"type": type,
"state_key": state_key,
},
retcol="event_id",
)
results = []
for event_id in event_ids:
hashes = self._get_event_reference_hashes_txn(txn, event_id)
prev_hashes = {
k: encode_base64(v) for k, v in hashes.items()
if k == "sha256"
}
results.append((event_id, prev_hashes))
return results
def _get_prev_events(self, txn, event_id):
results = self._get_prev_events_and_state(
txn,
event_id,
is_state=0,
)
return [(e_id, h, ) for e_id, h, _ in results]
def _get_prev_state(self, txn, event_id):
results = self._get_prev_events_and_state(
txn,
event_id,
is_state=True,
)
return [(e_id, h, ) for e_id, h, _ in results]
def _get_prev_events_and_state(self, txn, event_id, is_state=None):
keyvalues = {
"event_id": event_id,
}
if is_state is not None:
keyvalues["is_state"] = bool(is_state)
res = self._simple_select_list_txn(
txn,
table="event_edges",
keyvalues=keyvalues,
retcols=["prev_event_id", "is_state"],
)
hashes = self._get_prev_event_hashes_txn(txn, event_id)
results = []
for d in res:
edge_hash = self._get_event_reference_hashes_txn(txn, d["prev_event_id"])
edge_hash.update(hashes.get(d["prev_event_id"], {}))
prev_hashes = {
k: encode_base64(v)
for k, v in edge_hash.items()
if k == "sha256"
}
results.append((d["prev_event_id"], prev_hashes, d["is_state"]))
return results
def _get_auth_events(self, txn, event_id):
auth_ids = self._simple_select_onecol_txn(
txn,
table="event_auth",
keyvalues={
"event_id": event_id,
},
retcol="auth_id",
)
results = []
for auth_id in auth_ids:
hashes = self._get_event_reference_hashes_txn(txn, auth_id)
prev_hashes = {
k: encode_base64(v) for k, v in hashes.items()
if k == "sha256"
}
results.append((auth_id, prev_hashes))
return results
def get_min_depth(self, room_id):
""" For hte given room, get the minimum depth we have seen for it.
"""
@ -303,6 +211,15 @@ class EventFederationStore(SQLBaseStore):
],
)
self._update_extremeties(txn, events)
def _update_extremeties(self, txn, events):
"""Updates the event_*_extremities tables based on the new/updated
events being persisted.
This is called for new events *and* for events that were outliers, but
are are now being persisted as non-outliers.
"""
events_by_room = {}
for ev in events:
events_by_room.setdefault(ev.room_id, []).append(ev)

View file

@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from _base import SQLBaseStore, _RollbackButIsFineException
from twisted.internet import defer, reactor
@ -28,6 +27,7 @@ from canonicaljson import encode_canonical_json
from contextlib import contextmanager
import logging
import math
import ujson as json
logger = logging.getLogger(__name__)
@ -281,6 +281,8 @@ class EventsStore(SQLBaseStore):
(False, event.event_id,)
)
self._update_extremeties(txn, [event])
events_and_contexts = filter(
lambda ec: ec[0] not in to_remove,
events_and_contexts
@ -888,18 +890,69 @@ class EventsStore(SQLBaseStore):
return ev
def _parse_events(self, rows):
return self.runInteraction(
"_parse_events", self._parse_events_txn, rows
)
def _parse_events_txn(self, txn, rows):
event_ids = [r["event_id"] for r in rows]
return self._get_events_txn(txn, event_ids)
def _has_been_redacted_txn(self, txn, event):
sql = "SELECT event_id FROM redactions WHERE redacts = ?"
txn.execute(sql, (event.event_id,))
result = txn.fetchone()
return result[0] if result else None
@defer.inlineCallbacks
def count_daily_messages(self):
"""
Returns an estimate of the number of messages sent in the last day.
If it has been significantly less or more than one day since the last
call to this function, it will return None.
"""
def _count_messages(txn):
now = self.hs.get_clock().time()
txn.execute(
"SELECT reported_stream_token, reported_time FROM stats_reporting"
)
last_reported = self.cursor_to_dict(txn)
txn.execute(
"SELECT stream_ordering"
" FROM events"
" ORDER BY stream_ordering DESC"
" LIMIT 1"
)
now_reporting = self.cursor_to_dict(txn)
if not now_reporting:
return None
now_reporting = now_reporting[0]["stream_ordering"]
txn.execute("DELETE FROM stats_reporting")
txn.execute(
"INSERT INTO stats_reporting"
" (reported_stream_token, reported_time)"
" VALUES (?, ?)",
(now_reporting, now,)
)
if not last_reported:
return None
# Close enough to correct for our purposes.
yesterday = (now - 24 * 60 * 60)
if math.fabs(yesterday - last_reported[0]["reported_time"]) > 60 * 60:
return None
txn.execute(
"SELECT COUNT(*) as messages"
" FROM events NATURAL JOIN event_json"
" WHERE json like '%m.room.message%'"
" AND stream_ordering > ?"
" AND stream_ordering <= ?",
(
last_reported[0]["reported_stream_token"],
now_reporting,
)
)
rows = self.cursor_to_dict(txn)
if not rows:
return None
return rows[0]["messages"]
ret = yield self.runInteraction("count_messages", _count_messages)
defer.returnValue(ret)

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore, Table
from ._base import SQLBaseStore
from twisted.internet import defer
from synapse.api.errors import StoreError
@ -149,5 +149,5 @@ class PusherStore(SQLBaseStore):
)
class PushersTable(Table):
class PushersTable(object):
table_name = "pushers"

View file

@ -289,3 +289,16 @@ class RegistrationStore(SQLBaseStore):
if ret:
defer.returnValue(ret['user_id'])
defer.returnValue(None)
@defer.inlineCallbacks
def count_all_users(self):
"""Counts all users registered on the homeserver."""
def _count_users(txn):
txn.execute("SELECT COUNT(*) AS users FROM users")
rows = self.cursor_to_dict(txn)
if rows:
return rows[0]["users"]
return 0
ret = yield self.runInteraction("count_users", _count_users)
defer.returnValue(ret)

View file

@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
RoomsForUser = namedtuple(
"RoomsForUser",
("room_id", "sender", "membership")
("room_id", "sender", "membership", "event_id", "stream_ordering")
)
@ -141,11 +141,13 @@ class RoomMemberStore(SQLBaseStore):
args.extend(membership_list)
sql = (
"SELECT m.room_id, m.sender, m.membership"
" FROM room_memberships as m"
" INNER JOIN current_state_events as c"
" ON m.event_id = c.event_id "
" AND m.room_id = c.room_id "
"SELECT m.room_id, m.sender, m.membership, m.event_id, e.stream_ordering"
" FROM current_state_events as c"
" INNER JOIN room_memberships as m"
" ON m.event_id = c.event_id"
" INNER JOIN events as e"
" ON e.event_id = c.event_id"
" AND m.room_id = c.room_id"
" AND m.user_id = c.state_key"
" WHERE %s"
) % (where_clause,)
@ -176,12 +178,6 @@ class RoomMemberStore(SQLBaseStore):
return joined_domains
def _get_members_query(self, where_clause, where_values):
return self.runInteraction(
"get_members_query", self._get_members_events_txn,
where_clause, where_values
).addCallbacks(self._get_events)
def _get_members_events_txn(self, txn, room_id, membership=None, user_id=None):
rows = self._get_members_rows_txn(
txn,

View file

@ -0,0 +1,22 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Should only ever contain one row
CREATE TABLE IF NOT EXISTS stats_reporting(
-- The stream ordering token which was most recently reported as stats
reported_stream_token INTEGER,
-- The time (seconds since epoch) stats were most recently reported
reported_time BIGINT
);

View file

@ -24,41 +24,6 @@ from synapse.crypto.event_signing import compute_event_reference_hash
class SignatureStore(SQLBaseStore):
"""Persistence for event signatures and hashes"""
def _get_event_content_hashes_txn(self, txn, event_id):
"""Get all the hashes for a given Event.
Args:
txn (cursor):
event_id (str): Id for the Event.
Returns:
A dict of algorithm -> hash.
"""
query = (
"SELECT algorithm, hash"
" FROM event_content_hashes"
" WHERE event_id = ?"
)
txn.execute(query, (event_id, ))
return dict(txn.fetchall())
def _store_event_content_hash_txn(self, txn, event_id, algorithm,
hash_bytes):
"""Store a hash for a Event
Args:
txn (cursor):
event_id (str): Id for the Event.
algorithm (str): Hashing algorithm.
hash_bytes (bytes): Hash function output bytes.
"""
self._simple_insert_txn(
txn,
"event_content_hashes",
{
"event_id": event_id,
"algorithm": algorithm,
"hash": buffer(hash_bytes),
},
)
def get_event_reference_hashes(self, event_ids):
def f(txn):
return [
@ -123,80 +88,3 @@ class SignatureStore(SQLBaseStore):
table="event_reference_hashes",
values=vals,
)
def _get_event_signatures_txn(self, txn, event_id):
"""Get all the signatures for a given PDU.
Args:
txn (cursor):
event_id (str): Id for the Event.
Returns:
A dict of sig name -> dict(key_id -> signature_bytes)
"""
query = (
"SELECT signature_name, key_id, signature"
" FROM event_signatures"
" WHERE event_id = ? "
)
txn.execute(query, (event_id, ))
rows = txn.fetchall()
res = {}
for name, key, sig in rows:
res.setdefault(name, {})[key] = sig
return res
def _store_event_signature_txn(self, txn, event_id, signature_name, key_id,
signature_bytes):
"""Store a signature from the origin server for a PDU.
Args:
txn (cursor):
event_id (str): Id for the Event.
origin (str): origin of the Event.
key_id (str): Id for the signing key.
signature (bytes): The signature.
"""
self._simple_insert_txn(
txn,
"event_signatures",
{
"event_id": event_id,
"signature_name": signature_name,
"key_id": key_id,
"signature": buffer(signature_bytes),
},
)
def _get_prev_event_hashes_txn(self, txn, event_id):
"""Get all the hashes for previous PDUs of a PDU
Args:
txn (cursor):
event_id (str): Id for the Event.
Returns:
dict of (pdu_id, origin) -> dict of algorithm -> hash_bytes.
"""
query = (
"SELECT prev_event_id, algorithm, hash"
" FROM event_edge_hashes"
" WHERE event_id = ?"
)
txn.execute(query, (event_id, ))
results = {}
for prev_event_id, algorithm, hash_bytes in txn.fetchall():
hashes = results.setdefault(prev_event_id, {})
hashes[algorithm] = hash_bytes
return results
def _store_prev_event_hash_txn(self, txn, event_id, prev_event_id,
algorithm, hash_bytes):
self._simple_insert_txn(
txn,
"event_edge_hashes",
{
"event_id": event_id,
"prev_event_id": prev_event_id,
"algorithm": algorithm,
"hash": buffer(hash_bytes),
},
)

View file

@ -20,8 +20,6 @@ from synapse.util.caches.descriptors import (
from twisted.internet import defer
from synapse.util.stringutils import random_string
import logging
logger = logging.getLogger(__name__)
@ -428,7 +426,3 @@ class StateStore(SQLBaseStore):
}
defer.returnValue(results)
def _make_group_id(clock):
return str(int(clock.time_msec())) + random_string(5)

View file

@ -159,9 +159,7 @@ class StreamStore(SQLBaseStore):
@log_function
def get_room_events_stream(self, user_id, from_key, to_key, room_id,
limit=0, with_feedback=False):
# TODO (erikj): Handle compressed feedback
limit=0):
current_room_membership_sql = (
"SELECT m.room_id FROM room_memberships as m "
" INNER JOIN current_state_events as c"
@ -227,10 +225,7 @@ class StreamStore(SQLBaseStore):
@defer.inlineCallbacks
def paginate_room_events(self, room_id, from_key, to_key=None,
direction='b', limit=-1,
with_feedback=False):
# TODO (erikj): Handle compressed feedback
direction='b', limit=-1):
# Tokens really represent positions between elements, but we use
# the convention of pointing to the event before the gap. Hence
# we have a bit of asymmetry when it comes to equalities.
@ -302,7 +297,6 @@ class StreamStore(SQLBaseStore):
@cachedInlineCallbacks(num_args=4)
def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None):
# TODO (erikj): Handle compressed feedback
end_token = RoomStreamToken.parse_stream_token(end_token)
@ -379,6 +373,38 @@ class StreamStore(SQLBaseStore):
)
defer.returnValue("t%d-%d" % (topo, token))
def get_stream_token_for_event(self, event_id):
"""The stream token for an event
Args:
event_id(str): The id of the event to look up a stream token for.
Raises:
StoreError if the event wasn't in the database.
Returns:
A deferred "s%d" stream token.
"""
return self._simple_select_one_onecol(
table="events",
keyvalues={"event_id": event_id},
retcol="stream_ordering",
).addCallback(lambda row: "s%d" % (row,))
def get_topological_token_for_event(self, event_id):
"""The stream token for an event
Args:
event_id(str): The id of the event to look up a stream token for.
Raises:
StoreError if the event wasn't in the database.
Returns:
A deferred "t%d-%d" topological token.
"""
return self._simple_select_one(
table="events",
keyvalues={"event_id": event_id},
retcols=("stream_ordering", "topological_ordering"),
).addCallback(lambda row: "t%d-%d" % (
row["topological_ordering"], row["stream_ordering"],)
)
def _get_max_topological_txn(self, txn):
txn.execute(
"SELECT MAX(topological_ordering) FROM events"

View file

@ -34,6 +34,11 @@ class SourcePaginationConfig(object):
self.direction = 'f' if direction == 'f' else 'b'
self.limit = int(limit) if limit is not None else None
def __repr__(self):
return (
"StreamConfig(from_key=%r, to_key=%r, direction=%r, limit=%r)"
) % (self.from_key, self.to_key, self.direction, self.limit)
class PaginationConfig(object):
@ -94,10 +99,10 @@ class PaginationConfig(object):
logger.exception("Failed to create pagination config")
raise SynapseError(400, "Invalid request.")
def __str__(self):
def __repr__(self):
return (
"<PaginationConfig from_tok=%s, to_tok=%s, "
"direction=%s, limit=%s>"
"PaginationConfig(from_tok=%r, to_tok=%r,"
" direction=%r, limit=%r)"
) % (self.from_token, self.to_token, self.direction, self.limit)
def get_source_config(self, source_name):

View file

@ -23,22 +23,6 @@ from synapse.handlers.typing import TypingNotificationEventSource
from synapse.handlers.receipts import ReceiptEventSource
class NullSource(object):
"""This event source never yields any events and its token remains at
zero. It may be useful for unit-testing."""
def __init__(self, hs):
pass
def get_new_events_for_user(self, user, from_key, limit):
return defer.succeed(([], from_key))
def get_current_key(self, direction='f'):
return defer.succeed(0)
def get_pagination_rows(self, user, pagination_config, key):
return defer.succeed(([], pagination_config.from_key))
class EventSources(object):
SOURCE_TYPES = {
"room": RoomEventSource,
@ -70,15 +54,3 @@ class EventSources(object):
),
)
defer.returnValue(token)
class StreamSource(object):
def get_new_events_for_user(self, user, from_key, limit):
"""from_key is the key within this event source."""
raise NotImplementedError("get_new_events_for_user")
def get_current_key(self):
raise NotImplementedError("get_current_key")
def get_pagination_rows(self, user, pagination_config, key):
raise NotImplementedError("get_rows")

View file

@ -29,34 +29,6 @@ def unwrapFirstError(failure):
return failure.value.subFailure
def unwrap_deferred(d):
"""Given a deferred that we know has completed, return its value or raise
the failure as an exception
"""
if not d.called:
raise RuntimeError("deferred has not finished")
res = []
def f(r):
res.append(r)
return r
d.addCallback(f)
if res:
return res[0]
def f(r):
res.append(r)
return r
d.addErrback(f)
if res:
res[0].raiseException()
else:
raise RuntimeError("deferred did not call callbacks")
class Clock(object):
"""A small utility that obtains current time-of-day so that time may be
mocked during unit-tests.

View file

@ -19,17 +19,21 @@ from mock import Mock
from synapse.api.auth import Auth
from synapse.api.errors import AuthError
from synapse.types import UserID
from tests.utils import setup_test_homeserver
import pymacaroons
class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.state_handler = Mock()
self.store = Mock()
self.hs = Mock()
self.hs = yield setup_test_homeserver(handlers=None)
self.hs.get_datastore = Mock(return_value=self.store)
self.hs.get_state_handler = Mock(return_value=self.state_handler)
self.auth = Auth(self.hs)
self.test_user = "@foo:bar"
@ -133,3 +137,140 @@ class AuthTestCase(unittest.TestCase):
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError)
@defer.inlineCallbacks
def test_get_user_from_macaroon(self):
# TODO(danielwh): Remove this mock when we remove the
# get_user_by_access_token fallback.
self.store.get_user_by_access_token = Mock(
return_value={"name": "@baldrick:matrix.org"}
)
user_id = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
key=self.hs.config.macaroon_secret_key)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
user_info = yield self.auth._get_user_from_macaroon(macaroon.serialize())
user = user_info["user"]
self.assertEqual(UserID.from_string(user_id), user)
@defer.inlineCallbacks
def test_get_user_from_macaroon_user_db_mismatch(self):
self.store.get_user_by_access_token = Mock(
return_value={"name": "@percy:matrix.org"}
)
user = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
key=self.hs.config.macaroon_secret_key)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user,))
with self.assertRaises(AuthError) as cm:
yield self.auth._get_user_from_macaroon(macaroon.serialize())
self.assertEqual(401, cm.exception.code)
self.assertIn("User mismatch", cm.exception.msg)
@defer.inlineCallbacks
def test_get_user_from_macaroon_missing_caveat(self):
# TODO(danielwh): Remove this mock when we remove the
# get_user_by_access_token fallback.
self.store.get_user_by_access_token = Mock(
return_value={"name": "@baldrick:matrix.org"}
)
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
key=self.hs.config.macaroon_secret_key)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
with self.assertRaises(AuthError) as cm:
yield self.auth._get_user_from_macaroon(macaroon.serialize())
self.assertEqual(401, cm.exception.code)
self.assertIn("No user caveat", cm.exception.msg)
@defer.inlineCallbacks
def test_get_user_from_macaroon_wrong_key(self):
# TODO(danielwh): Remove this mock when we remove the
# get_user_by_access_token fallback.
self.store.get_user_by_access_token = Mock(
return_value={"name": "@baldrick:matrix.org"}
)
user = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
key=self.hs.config.macaroon_secret_key + "wrong")
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user,))
with self.assertRaises(AuthError) as cm:
yield self.auth._get_user_from_macaroon(macaroon.serialize())
self.assertEqual(401, cm.exception.code)
self.assertIn("Invalid macaroon", cm.exception.msg)
@defer.inlineCallbacks
def test_get_user_from_macaroon_unknown_caveat(self):
# TODO(danielwh): Remove this mock when we remove the
# get_user_by_access_token fallback.
self.store.get_user_by_access_token = Mock(
return_value={"name": "@baldrick:matrix.org"}
)
user = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
key=self.hs.config.macaroon_secret_key)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user,))
macaroon.add_first_party_caveat("cunning > fox")
with self.assertRaises(AuthError) as cm:
yield self.auth._get_user_from_macaroon(macaroon.serialize())
self.assertEqual(401, cm.exception.code)
self.assertIn("Invalid macaroon", cm.exception.msg)
@defer.inlineCallbacks
def test_get_user_from_macaroon_expired(self):
# TODO(danielwh): Remove this mock when we remove the
# get_user_by_access_token fallback.
self.store.get_user_by_access_token = Mock(
return_value={"name": "@baldrick:matrix.org"}
)
self.store.get_user_by_access_token = Mock(
return_value={"name": "@baldrick:matrix.org"}
)
user = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
key=self.hs.config.macaroon_secret_key)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user,))
macaroon.add_first_party_caveat("time < 1") # ms
self.hs.clock.now = 5000 # seconds
yield self.auth._get_user_from_macaroon(macaroon.serialize())
# TODO(daniel): Turn on the check that we validate expiration, when we
# validate expiration (and remove the above line, which will start
# throwing).
# with self.assertRaises(AuthError) as cm:
# yield self.auth._get_user_from_macaroon(macaroon.serialize())
# self.assertEqual(401, cm.exception.code)
# self.assertIn("Invalid macaroon", cm.exception.msg)

View file

@ -41,6 +41,22 @@ myid = "@apple:test"
PATH_PREFIX = "/_matrix/client/api/v1"
class NullSource(object):
"""This event source never yields any events and its token remains at
zero. It may be useful for unit-testing."""
def __init__(self, hs):
pass
def get_new_events_for_user(self, user, from_key, limit):
return defer.succeed(([], from_key))
def get_current_key(self, direction='f'):
return defer.succeed(0)
def get_pagination_rows(self, user, pagination_config, key):
return defer.succeed(([], pagination_config.from_key))
class JustPresenceHandlers(object):
def __init__(self, hs):
self.presence_handler = PresenceHandler(hs)
@ -76,7 +92,7 @@ class PresenceStateTestCase(unittest.TestCase):
"token_id": 1,
}
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
room_member_handler = hs.handlers.room_member_handler = Mock(
spec=[
@ -169,7 +185,7 @@ class PresenceListTestCase(unittest.TestCase):
]
)
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
presence.register_servlets(hs, self.mock_resource)
@ -243,7 +259,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
# HIDEOUS HACKERY
# TODO(paul): This should be injected in via the HomeServer DI system
from synapse.streams.events import (
PresenceEventSource, NullSource, EventSources
PresenceEventSource, EventSources
)
old_SOURCE_TYPES = EventSources.SOURCE_TYPES

View file

@ -59,7 +59,7 @@ class RoomPermissionsTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id),
"token_id": 1,
}
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
@ -239,7 +239,7 @@ class RoomPermissionsTestCase(RestTestCase):
"PUT", topic_path, topic_content)
self.assertEquals(403, code, msg=str(response))
(code, response) = yield self.mock_resource.trigger_get(topic_path)
self.assertEquals(403, code, msg=str(response))
self.assertEquals(200, code, msg=str(response))
# get topic in PUBLIC room, not joined, expect 403
(code, response) = yield self.mock_resource.trigger_get(
@ -301,11 +301,11 @@ class RoomPermissionsTestCase(RestTestCase):
room=room, expect_code=200)
# get membership of self, get membership of other, private room + left
# expect all 403s
# expect all 200s
yield self.leave(room=room, user=self.user_id)
yield self._test_get_membership(
members=[self.user_id, self.rmcreator_id],
room=room, expect_code=403)
room=room, expect_code=200)
@defer.inlineCallbacks
def test_membership_public_room_perms(self):
@ -326,11 +326,11 @@ class RoomPermissionsTestCase(RestTestCase):
room=room, expect_code=200)
# get membership of self, get membership of other, public room + left
# expect 403.
# expect 200.
yield self.leave(room=room, user=self.user_id)
yield self._test_get_membership(
members=[self.user_id, self.rmcreator_id],
room=room, expect_code=403)
room=room, expect_code=200)
@defer.inlineCallbacks
def test_invited_permissions(self):
@ -444,7 +444,7 @@ class RoomsMemberListTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id),
"token_id": 1,
}
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
@ -492,9 +492,9 @@ class RoomsMemberListTestCase(RestTestCase):
self.assertEquals(200, code, msg=str(response))
yield self.leave(room=room_id, user=self.user_id)
# can no longer see list, you've left.
# can see old list once left
(code, response) = yield self.mock_resource.trigger_get(room_path)
self.assertEquals(403, code, msg=str(response))
self.assertEquals(200, code, msg=str(response))
class RoomsCreateTestCase(RestTestCase):
@ -522,7 +522,7 @@ class RoomsCreateTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id),
"token_id": 1,
}
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
@ -614,7 +614,7 @@ class RoomTopicTestCase(RestTestCase):
"token_id": 1,
}
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
@ -718,7 +718,7 @@ class RoomMemberStateTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id),
"token_id": 1,
}
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
@ -843,7 +843,7 @@ class RoomMessagesTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id),
"token_id": 1,
}
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
@ -938,7 +938,7 @@ class RoomInitialSyncTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id),
"token_id": 1,
}
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)

View file

@ -67,7 +67,7 @@ class RoomTypingTestCase(RestTestCase):
"token_id": 1,
}
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)

View file

@ -37,9 +37,6 @@ class RestTestCase(unittest.TestCase):
self.mock_resource = None
self.auth_user_id = None
def mock_get_user_by_access_token(self, token=None):
return self.auth_user_id
@defer.inlineCallbacks
def create_room_as(self, room_creator, is_public=True, tok=None):
temp_id = self.auth_user_id

View file

@ -48,7 +48,7 @@ class V2AlphaRestTestCase(unittest.TestCase):
"user": UserID.from_string(self.USER_ID),
"token_id": 1,
}
hs.get_auth().get_user_by_access_token = _get_user_by_access_token
hs.get_auth()._get_user_by_access_token = _get_user_by_access_token
for r in self.TO_REGISTER:
r.register_servlets(hs, self.mock_resource)

View file

@ -0,0 +1,81 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests import unittest
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.types import UserID, RoomID
from tests.utils import setup_test_homeserver
from mock import Mock
class EventInjector:
def __init__(self, hs):
self.hs = hs
self.store = hs.get_datastore()
self.message_handler = hs.get_handlers().message_handler
self.event_builder_factory = hs.get_event_builder_factory()
@defer.inlineCallbacks
def create_room(self, room):
builder = self.event_builder_factory.new({
"type": EventTypes.Create,
"room_id": room.to_string(),
"content": {},
})
event, context = yield self.message_handler._create_new_client_event(
builder
)
yield self.store.persist_event(event, context)
@defer.inlineCallbacks
def inject_room_member(self, room, user, membership):
builder = self.event_builder_factory.new({
"type": EventTypes.Member,
"sender": user.to_string(),
"state_key": user.to_string(),
"room_id": room.to_string(),
"content": {"membership": membership},
})
event, context = yield self.message_handler._create_new_client_event(
builder
)
yield self.store.persist_event(event, context)
defer.returnValue(event)
@defer.inlineCallbacks
def inject_message(self, room, user, body):
builder = self.event_builder_factory.new({
"type": EventTypes.Message,
"sender": user.to_string(),
"state_key": user.to_string(),
"room_id": room.to_string(),
"content": {"body": body, "msgtype": u"message"},
})
event, context = yield self.message_handler._create_new_client_event(
builder
)
yield self.store.persist_event(event, context)

View file

@ -185,26 +185,6 @@ class SQLBaseStoreTestCase(unittest.TestCase):
[3, 4, 1, 2]
)
@defer.inlineCallbacks
def test_update_one_with_return(self):
self.mock_txn.rowcount = 1
self.mock_txn.fetchone.return_value = ("Old Value",)
ret = yield self.datastore._simple_selectupdate_one(
table="tablename",
keyvalues={"keycol": "TheKey"},
updatevalues={"columname": "New Value"},
retcols=["columname"]
)
self.assertEquals({"columname": "Old Value"}, ret)
self.mock_txn.execute.assert_has_calls([
call('SELECT columname FROM tablename WHERE keycol = ?',
['TheKey']),
call("UPDATE tablename SET columname = ? WHERE keycol = ?",
["New Value", "TheKey"])
])
@defer.inlineCallbacks
def test_delete_one(self):
self.mock_txn.rowcount = 1

View file

@ -0,0 +1,116 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import uuid
from mock.mock import Mock
from synapse.types import RoomID, UserID
from tests import unittest
from twisted.internet import defer
from tests.storage.event_injector import EventInjector
from tests.utils import setup_test_homeserver
class EventsStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.hs = yield setup_test_homeserver(
resource_for_federation=Mock(),
http_client=None,
)
self.store = self.hs.get_datastore()
self.db_pool = self.hs.get_db_pool()
self.message_handler = self.hs.get_handlers().message_handler
self.event_injector = EventInjector(self.hs)
@defer.inlineCallbacks
def test_count_daily_messages(self):
self.db_pool.runQuery("DELETE FROM stats_reporting")
self.hs.clock.now = 100
# Never reported before, and nothing which could be reported
count = yield self.store.count_daily_messages()
self.assertIsNone(count)
count = yield self.db_pool.runQuery("SELECT COUNT(*) FROM stats_reporting")
self.assertEqual([(0,)], count)
# Create something to report
room = RoomID.from_string("!abc123:test")
user = UserID.from_string("@raccoonlover:test")
yield self.event_injector.create_room(room)
self.base_event = yield self._get_last_stream_token()
yield self.event_injector.inject_message(room, user, "Raccoons are really cute")
# Never reported before, something could be reported, but isn't because
# it isn't old enough.
count = yield self.store.count_daily_messages()
self.assertIsNone(count)
self._assert_stats_reporting(1, self.hs.clock.now)
# Already reported yesterday, two new events from today.
yield self.event_injector.inject_message(room, user, "Yeah they are!")
yield self.event_injector.inject_message(room, user, "Incredibly!")
self.hs.clock.now += 60 * 60 * 24
count = yield self.store.count_daily_messages()
self.assertEqual(2, count) # 2 since yesterday
self._assert_stats_reporting(3, self.hs.clock.now) # 3 ever
# Last reported too recently.
yield self.event_injector.inject_message(room, user, "Who could disagree?")
self.hs.clock.now += 60 * 60 * 22
count = yield self.store.count_daily_messages()
self.assertIsNone(count)
self._assert_stats_reporting(4, self.hs.clock.now)
# Last reported too long ago
yield self.event_injector.inject_message(room, user, "No one.")
self.hs.clock.now += 60 * 60 * 26
count = yield self.store.count_daily_messages()
self.assertIsNone(count)
self._assert_stats_reporting(5, self.hs.clock.now)
# And now let's actually report something
yield self.event_injector.inject_message(room, user, "Indeed.")
yield self.event_injector.inject_message(room, user, "Indeed.")
yield self.event_injector.inject_message(room, user, "Indeed.")
# A little over 24 hours is fine :)
self.hs.clock.now += (60 * 60 * 24) + 50
count = yield self.store.count_daily_messages()
self.assertEqual(3, count)
self._assert_stats_reporting(8, self.hs.clock.now)
@defer.inlineCallbacks
def _get_last_stream_token(self):
rows = yield self.db_pool.runQuery(
"SELECT stream_ordering"
" FROM events"
" ORDER BY stream_ordering DESC"
" LIMIT 1"
)
if not rows:
defer.returnValue(0)
else:
defer.returnValue(rows[0][0])
@defer.inlineCallbacks
def _assert_stats_reporting(self, messages, time):
rows = yield self.db_pool.runQuery(
"SELECT reported_stream_token, reported_time FROM stats_reporting"
)
self.assertEqual([(self.base_event + messages, time,)], rows)

View file

@ -85,7 +85,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
# Room events need the full datastore, for persist_event() and
# get_room_state()
self.store = hs.get_datastore()
self.event_factory = hs.get_event_factory();
self.event_factory = hs.get_event_factory()
self.room = RoomID.from_string("!abcde:test")

View file

@ -19,6 +19,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.types import UserID, RoomID
from tests.storage.event_injector import EventInjector
from tests.utils import setup_test_homeserver
@ -36,6 +37,7 @@ class StreamStoreTestCase(unittest.TestCase):
self.store = hs.get_datastore()
self.event_builder_factory = hs.get_event_builder_factory()
self.event_injector = EventInjector(hs)
self.handlers = hs.get_handlers()
self.message_handler = self.handlers.message_handler
@ -45,60 +47,20 @@ class StreamStoreTestCase(unittest.TestCase):
self.room1 = RoomID.from_string("!abc123:test")
self.room2 = RoomID.from_string("!xyx987:test")
self.depth = 1
@defer.inlineCallbacks
def inject_room_member(self, room, user, membership):
self.depth += 1
builder = self.event_builder_factory.new({
"type": EventTypes.Member,
"sender": user.to_string(),
"state_key": user.to_string(),
"room_id": room.to_string(),
"content": {"membership": membership},
})
event, context = yield self.message_handler._create_new_client_event(
builder
)
yield self.store.persist_event(event, context)
defer.returnValue(event)
@defer.inlineCallbacks
def inject_message(self, room, user, body):
self.depth += 1
builder = self.event_builder_factory.new({
"type": EventTypes.Message,
"sender": user.to_string(),
"state_key": user.to_string(),
"room_id": room.to_string(),
"content": {"body": body, "msgtype": u"message"},
})
event, context = yield self.message_handler._create_new_client_event(
builder
)
yield self.store.persist_event(event, context)
@defer.inlineCallbacks
def test_event_stream_get_other(self):
# Both bob and alice joins the room
yield self.inject_room_member(
yield self.event_injector.inject_room_member(
self.room1, self.u_alice, Membership.JOIN
)
yield self.inject_room_member(
yield self.event_injector.inject_room_member(
self.room1, self.u_bob, Membership.JOIN
)
# Initial stream key:
start = yield self.store.get_room_events_max_id()
yield self.inject_message(self.room1, self.u_alice, u"test")
yield self.event_injector.inject_message(self.room1, self.u_alice, u"test")
end = yield self.store.get_room_events_max_id()
@ -125,17 +87,17 @@ class StreamStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_event_stream_get_own(self):
# Both bob and alice joins the room
yield self.inject_room_member(
yield self.event_injector.inject_room_member(
self.room1, self.u_alice, Membership.JOIN
)
yield self.inject_room_member(
yield self.event_injector.inject_room_member(
self.room1, self.u_bob, Membership.JOIN
)
# Initial stream key:
start = yield self.store.get_room_events_max_id()
yield self.inject_message(self.room1, self.u_alice, u"test")
yield self.event_injector.inject_message(self.room1, self.u_alice, u"test")
end = yield self.store.get_room_events_max_id()
@ -162,22 +124,22 @@ class StreamStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_event_stream_join_leave(self):
# Both bob and alice joins the room
yield self.inject_room_member(
yield self.event_injector.inject_room_member(
self.room1, self.u_alice, Membership.JOIN
)
yield self.inject_room_member(
yield self.event_injector.inject_room_member(
self.room1, self.u_bob, Membership.JOIN
)
# Then bob leaves again.
yield self.inject_room_member(
yield self.event_injector.inject_room_member(
self.room1, self.u_bob, Membership.LEAVE
)
# Initial stream key:
start = yield self.store.get_room_events_max_id()
yield self.inject_message(self.room1, self.u_alice, u"test")
yield self.event_injector.inject_message(self.room1, self.u_alice, u"test")
end = yield self.store.get_room_events_max_id()
@ -193,17 +155,17 @@ class StreamStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_event_stream_prev_content(self):
yield self.inject_room_member(
yield self.event_injector.inject_room_member(
self.room1, self.u_bob, Membership.JOIN
)
event1 = yield self.inject_room_member(
event1 = yield self.event_injector.inject_room_member(
self.room1, self.u_alice, Membership.JOIN
)
start = yield self.store.get_room_events_max_id()
event2 = yield self.inject_room_member(
event2 = yield self.event_injector.inject_room_member(
self.room1, self.u_alice, Membership.JOIN,
)