Merge branch 'develop' into application-services-txn-reliability

Conflicts:
	synapse/storage/__init__.py
This commit is contained in:
Kegan Dougal 2015-03-26 10:07:59 +00:00
commit 4edcbcee3b
45 changed files with 1251 additions and 731 deletions

View file

@ -1,3 +1,12 @@
Changes in synapse v0.8.1 (2015-03-18)
======================================
* Disable registration by default. New users can be added using the command
``register_new_matrix_user`` or by enabling registration in the config.
* Add metrics to synapse. To enable metrics use config options
``enable_metrics`` and ``metrics_port``.
* Fix bug where banning only kicked the user.
Changes in synapse v0.8.0 (2015-03-06)
======================================

View file

@ -128,6 +128,17 @@ To set up your homeserver, run (in your virtualenv, as before)::
Substituting your host and domain name as appropriate.
By default, registration of new users is disabled. You can either enable
registration in the config (it is then recommended to also set up CAPTCHA), or
you can use the command line to register new users::
$ source ~/.synapse/bin/activate
$ register_new_matrix_user -c homeserver.yaml https://localhost:8448
New user localpart: erikj
Password:
Confirm password:
Success!
For reliable VoIP calls to be routed via this homeserver, you MUST configure
a TURN server. See docs/turn-howto.rst for details.

149
register_new_matrix_user Executable file
View file

@ -0,0 +1,149 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import getpass
import hashlib
import hmac
import json
import sys
import urllib2
import yaml
def request_registration(user, password, server_location, shared_secret):
mac = hmac.new(
key=shared_secret,
msg=user,
digestmod=hashlib.sha1,
).hexdigest()
data = {
"user": user,
"password": password,
"mac": mac,
"type": "org.matrix.login.shared_secret",
}
server_location = server_location.rstrip("/")
print "Sending registration request..."
req = urllib2.Request(
"%s/_matrix/client/api/v1/register" % (server_location,),
data=json.dumps(data),
headers={'Content-Type': 'application/json'}
)
try:
f = urllib2.urlopen(req)
f.read()
f.close()
print "Success."
except urllib2.HTTPError as e:
print "ERROR! Received %d %s" % (e.code, e.reason,)
if 400 <= e.code < 500:
if e.info().type == "application/json":
resp = json.load(e)
if "error" in resp:
print resp["error"]
sys.exit(1)
def register_new_user(user, password, server_location, shared_secret):
if not user:
try:
default_user = getpass.getuser()
except:
default_user = None
if default_user:
user = raw_input("New user localpart [%s]: " % (default_user,))
if not user:
user = default_user
else:
user = raw_input("New user localpart: ")
if not user:
print "Invalid user name"
sys.exit(1)
if not password:
password = getpass.getpass("Password: ")
if not password:
print "Password cannot be blank."
sys.exit(1)
confirm_password = getpass.getpass("Confirm password: ")
if password != confirm_password:
print "Passwords do not match"
sys.exit(1)
request_registration(user, password, server_location, shared_secret)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Used to register new users with a given home server when"
" registration has been disabled. The home server must be"
" configured with the 'registration_shared_secret' option"
" set.",
)
parser.add_argument(
"-u", "--user",
default=None,
help="Local part of the new user. Will prompt if omitted.",
)
parser.add_argument(
"-p", "--password",
default=None,
help="New password for user. Will prompt if omitted.",
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
"-c", "--config",
type=argparse.FileType('r'),
help="Path to server config file. Used to read in shared secret.",
)
group.add_argument(
"-k", "--shared-secret",
help="Shared secret as defined in server config file.",
)
parser.add_argument(
"server_url",
default="https://localhost:8448",
nargs='?',
help="URL to use to talk to the home server. Defaults to "
" 'https://localhost:8448'.",
)
args = parser.parse_args()
if "config" in args and args.config:
config = yaml.safe_load(args.config)
secret = config.get("registration_shared_secret", None)
if not secret:
print "No 'registration_shared_secret' defined in config."
sys.exit(1)
else:
secret = args.shared_secret
register_new_user(args.user, args.password, args.server_url, secret)

View file

@ -45,7 +45,7 @@ setup(
version=version,
packages=find_packages(exclude=["tests", "tests.*"]),
description="Reference Synapse Home Server",
install_requires=dependencies["REQUIREMENTS"].keys(),
install_requires=dependencies['requirements'](include_conditional=True).keys(),
setup_requires=[
"Twisted==14.0.2", # Here to override setuptools_trial's dependency on Twisted>=2.4.0
"setuptools_trial",
@ -55,5 +55,5 @@ setup(
include_package_data=True,
zip_safe=False,
long_description=long_description,
scripts=["synctl"],
scripts=["synctl", "register_new_matrix_user"],
)

View file

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server.
"""
__version__ = "0.8.0"
__version__ = "0.8.1-r2"

View file

@ -28,6 +28,12 @@ import logging
logger = logging.getLogger(__name__)
AuthEventTypes = (
EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
EventTypes.JoinRules,
)
class Auth(object):
def __init__(self, hs):
@ -166,6 +172,7 @@ class Auth(object):
target = auth_events.get(key)
target_in_room = target and target.membership == Membership.JOIN
target_banned = target and target.membership == Membership.BAN
key = (EventTypes.JoinRules, "", )
join_rule_event = auth_events.get(key)
@ -194,6 +201,7 @@ class Auth(object):
{
"caller_in_room": caller_in_room,
"caller_invited": caller_invited,
"target_banned": target_banned,
"target_in_room": target_in_room,
"membership": membership,
"join_rule": join_rule,
@ -202,6 +210,11 @@ class Auth(object):
}
)
if ban_level:
ban_level = int(ban_level)
else:
ban_level = 50 # FIXME (erikj): What should we do here?
if Membership.INVITE == membership:
# TODO (erikj): We should probably handle this more intelligently
# PRIVATE join rules.
@ -212,6 +225,10 @@ class Auth(object):
403,
"%s not in room %s." % (event.user_id, event.room_id,)
)
elif target_banned:
raise AuthError(
403, "%s is banned from the room" % (target_user_id,)
)
elif target_in_room: # the target is already in the room.
raise AuthError(403, "%s is already in the room." %
target_user_id)
@ -221,6 +238,8 @@ class Auth(object):
# joined: It's a NOOP
if event.user_id != target_user_id:
raise AuthError(403, "Cannot force another user to join.")
elif target_banned:
raise AuthError(403, "You are banned from this room")
elif join_rule == JoinRules.PUBLIC:
pass
elif join_rule == JoinRules.INVITE:
@ -238,6 +257,10 @@ class Auth(object):
403,
"%s not in room %s." % (target_user_id, event.room_id,)
)
elif target_banned and user_level < ban_level:
raise AuthError(
403, "You cannot unban user &s." % (target_user_id,)
)
elif target_user_id != event.user_id:
if kick_level:
kick_level = int(kick_level)
@ -249,11 +272,6 @@ class Auth(object):
403, "You cannot kick user %s." % target_user_id
)
elif Membership.BAN == membership:
if ban_level:
ban_level = int(ban_level)
else:
ban_level = 50 # FIXME (erikj): What should we do here?
if user_level < ban_level:
raise AuthError(403, "You don't have permission to ban")
else:
@ -370,7 +388,7 @@ class Auth(object):
AuthError if no user by that token exists or the token is invalid.
"""
try:
ret = yield self.store.get_user_by_token(token=token)
ret = yield self.store.get_user_by_token(token)
if not ret:
raise StoreError(400, "Unknown token")
user_info = {
@ -412,12 +430,6 @@ class Auth(object):
builder.auth_events = auth_events_entries
context.auth_events = {
k: v
for k, v in context.current_state.items()
if v.event_id in auth_ids
}
def compute_auth_events(self, event, current_state):
if event.type == EventTypes.Create:
return []

View file

@ -60,6 +60,7 @@ class LoginType(object):
EMAIL_IDENTITY = u"m.login.email.identity"
RECAPTCHA = u"m.login.recaptcha"
APPLICATION_SERVICE = u"m.login.application_service"
SHARED_SECRET = u"org.matrix.login.shared_secret"
class EventTypes(object):

View file

@ -60,7 +60,6 @@ import re
import resource
import subprocess
import sqlite3
import syweb
logger = logging.getLogger(__name__)
@ -83,6 +82,7 @@ class SynapseHomeServer(HomeServer):
return AppServiceRestResource(self)
def build_resource_for_web_client(self):
import syweb
syweb_path = os.path.dirname(syweb.__file__)
webclient_path = os.path.join(syweb_path, "webclient")
return File(webclient_path) # TODO configurable?
@ -130,7 +130,7 @@ class SynapseHomeServer(HomeServer):
True.
"""
config = self.get_config()
web_client = config.webclient
web_client = config.web_client
# list containing (path_str, Resource) e.g:
# [ ("/aaa/bbb/cc", Resource1), ("/aaa/dummy", Resource2) ]
@ -343,7 +343,8 @@ def setup(config_options):
config.setup_logging()
check_requirements()
# check any extra requirements we have now we have a config
check_requirements(config)
version_string = get_version_string()
@ -450,6 +451,7 @@ def run(hs):
def main():
with LoggingContext("main"):
# check base requirements
check_requirements()
hs = setup(sys.argv[1:])
run(hs)

View file

@ -15,19 +15,46 @@
from ._base import Config
from synapse.util.stringutils import random_string_with_symbols
import distutils.util
class RegistrationConfig(Config):
def __init__(self, args):
super(RegistrationConfig, self).__init__(args)
self.disable_registration = args.disable_registration
# `args.disable_registration` may either be a bool or a string depending
# on if the option was given a value (e.g. --disable-registration=false
# would set `args.disable_registration` to "false" not False.)
self.disable_registration = bool(
distutils.util.strtobool(str(args.disable_registration))
)
self.registration_shared_secret = args.registration_shared_secret
@classmethod
def add_arguments(cls, parser):
super(RegistrationConfig, cls).add_arguments(parser)
reg_group = parser.add_argument_group("registration")
reg_group.add_argument(
"--disable-registration",
action='store_true',
help="Disable registration of new users."
const=True,
default=True,
nargs='?',
help="Disable registration of new users.",
)
reg_group.add_argument(
"--registration-shared-secret", type=str,
help="If set, allows registration by anyone who also has the shared"
" secret, even if registration is otherwise disabled.",
)
@classmethod
def generate_config(cls, args, config_dir_path):
if args.disable_registration is None:
args.disable_registration = True
if args.registration_shared_secret is None:
args.registration_shared_secret = random_string_with_symbols(50)

View file

@ -28,7 +28,7 @@ class ServerConfig(Config):
self.unsecure_port = args.unsecure_port
self.daemonize = args.daemonize
self.pid_file = self.abspath(args.pid_file)
self.webclient = True
self.web_client = args.web_client
self.manhole = args.manhole
self.soft_file_limit = args.soft_file_limit
@ -68,6 +68,8 @@ class ServerConfig(Config):
server_group.add_argument('--pid-file', default="homeserver.pid",
help="When running as a daemon, the file to"
" store the pid in")
server_group.add_argument('--web_client', default=True, type=bool,
help="Whether or not to serve a web client")
server_group.add_argument("--manhole", metavar="PORT", dest="manhole",
type=int,
help="Turn on the twisted telnet manhole"

View file

@ -16,8 +16,7 @@
class EventContext(object):
def __init__(self, current_state=None, auth_events=None):
def __init__(self, current_state=None):
self.current_state = current_state
self.auth_events = auth_events
self.state_group = None
self.rejected = False

View file

@ -361,4 +361,5 @@ SERVLET_CLASSES = (
FederationInviteServlet,
FederationQueryAuthServlet,
FederationGetMissingEventsServlet,
FederationEventAuthServlet,
)

View file

@ -90,8 +90,8 @@ class BaseHandler(object):
event = builder.build()
logger.debug(
"Created event %s with auth_events: %s, current state: %s",
event.event_id, context.auth_events, context.current_state,
"Created event %s with current state: %s",
event.event_id, context.current_state,
)
defer.returnValue(
@ -106,7 +106,7 @@ class BaseHandler(object):
# We now need to go and hit out to wherever we need to hit out to.
if not suppress_auth:
self.auth.check(event, auth_events=context.auth_events)
self.auth.check(event, auth_events=context.current_state)
yield self.store.persist_event(event, context=context)
@ -142,7 +142,16 @@ class BaseHandler(object):
"Failed to get destination from event %s", s.event_id
)
yield self.notifier.on_new_room_event(event, extra_users=extra_users)
# Don't block waiting on waking up all the listeners.
d = self.notifier.on_new_room_event(event, extra_users=extra_users)
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
yield federation_handler.handle_new_event(
event, destinations=destinations,

View file

@ -290,6 +290,8 @@ class FederationHandler(BaseHandler):
"""
logger.debug("Joining %s to %s", joinee, room_id)
yield self.store.clean_room_for_join(room_id)
origin, pdu = yield self.replication_layer.make_join(
target_hosts,
room_id,
@ -464,11 +466,9 @@ class FederationHandler(BaseHandler):
builder=builder,
)
self.auth.check(event, auth_events=context.auth_events)
self.auth.check(event, auth_events=context.current_state)
pdu = event
defer.returnValue(pdu)
defer.returnValue(event)
@defer.inlineCallbacks
@log_function
@ -705,7 +705,7 @@ class FederationHandler(BaseHandler):
)
if not auth_events:
auth_events = context.auth_events
auth_events = context.current_state
logger.debug(
"_handle_new_event: %s, auth_events: %s",

View file

@ -33,6 +33,10 @@ logger = logging.getLogger(__name__)
metrics = synapse.metrics.get_metrics_for(__name__)
# Don't bother bumping "last active" time if it differs by less than 60 seconds
LAST_ACTIVE_GRANULARITY = 60*1000
# TODO(paul): Maybe there's one of these I can steal from somewhere
def partition(l, func):
"""Partition the list by the result of func applied to each element."""
@ -282,6 +286,10 @@ class PresenceHandler(BaseHandler):
if now is None:
now = self.clock.time_msec()
prev_state = self._get_or_make_usercache(user)
if now - prev_state.state.get("last_active", 0) < LAST_ACTIVE_GRANULARITY:
return
self.changed_presencelike_data(user, {"last_active": now})
def changed_presencelike_data(self, user, state):

View file

@ -31,6 +31,7 @@ import base64
import bcrypt
import json
import logging
import urllib
logger = logging.getLogger(__name__)
@ -63,6 +64,13 @@ class RegistrationHandler(BaseHandler):
password_hash = bcrypt.hashpw(password, bcrypt.gensalt())
if localpart:
if localpart and urllib.quote(localpart) != localpart:
raise SynapseError(
400,
"User ID must only contain characters which do not"
" require URL encoding."
)
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()

View file

@ -46,6 +46,11 @@ outgoing_responses_counter = metrics.register_counter(
labels=["method", "code"],
)
response_timer = metrics.register_distribution(
"response_time",
labels=["method", "servlet"]
)
class HttpServer(object):
""" Interface for registering callbacks on a HTTP server
@ -169,6 +174,10 @@ class JsonResource(HttpServer, resource.Resource):
code, response = yield callback(request, *args)
self._send_response(request, code, response)
response_timer.inc_by(
self.clock.time_msec() - start, request.method, servlet_classname
)
return
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.

View file

@ -51,8 +51,8 @@ class RestServlet(object):
pattern = self.PATTERN
for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"):
if hasattr(self, "on_%s" % (method)):
method_handler = getattr(self, "on_%s" % (method))
if hasattr(self, "on_%s" % (method,)):
method_handler = getattr(self, "on_%s" % (method,))
http_server.register_path(method, pattern, method_handler)
else:
raise NotImplementedError("RestServlet must register something.")

View file

@ -5,7 +5,6 @@ logger = logging.getLogger(__name__)
REQUIREMENTS = {
"syutil>=0.0.3": ["syutil"],
"matrix_angular_sdk>=0.6.5": ["syweb>=0.6.5"],
"Twisted==14.0.2": ["twisted==14.0.2"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
@ -18,6 +17,19 @@ REQUIREMENTS = {
"pillow": ["PIL"],
"pydenticon": ["pydenticon"],
}
CONDITIONAL_REQUIREMENTS = {
"web_client": {
"matrix_angular_sdk>=0.6.5": ["syweb>=0.6.5"],
}
}
def requirements(config=None, include_conditional=False):
reqs = REQUIREMENTS.copy()
for key, req in CONDITIONAL_REQUIREMENTS.items():
if (config and getattr(config, key)) or include_conditional:
reqs.update(req)
return reqs
def github_link(project, version, egg):
@ -46,10 +58,11 @@ class MissingRequirementError(Exception):
pass
def check_requirements():
def check_requirements(config=None):
"""Checks that all the modules needed by synapse have been correctly
installed and are at the correct version"""
for dependency, module_requirements in REQUIREMENTS.items():
for dependency, module_requirements in (
requirements(config, include_conditional=False).items()):
for module_requirement in module_requirements:
if ">=" in module_requirement:
module_name, required_version = module_requirement.split(">=")
@ -110,7 +123,7 @@ def list_requirements():
egg = link.split("#egg=")[1]
linked.append(egg.split('-')[0])
result.append(link)
for requirement in REQUIREMENTS:
for requirement in requirements(include_conditional=True):
is_linked = False
for link in linked:
if requirement.replace('-', '_').startswith(link):

View file

@ -27,7 +27,6 @@ from hashlib import sha1
import hmac
import simplejson as json
import logging
import urllib
logger = logging.getLogger(__name__)
@ -110,14 +109,22 @@ class RegisterRestServlet(ClientV1RestServlet):
login_type = register_json["type"]
is_application_server = login_type == LoginType.APPLICATION_SERVICE
if self.disable_registration and not is_application_server:
is_using_shared_secret = login_type == LoginType.SHARED_SECRET
can_register = (
not self.disable_registration
or is_application_server
or is_using_shared_secret
)
if not can_register:
raise SynapseError(403, "Registration has been disabled")
stages = {
LoginType.RECAPTCHA: self._do_recaptcha,
LoginType.PASSWORD: self._do_password,
LoginType.EMAIL_IDENTITY: self._do_email_identity,
LoginType.APPLICATION_SERVICE: self._do_app_service
LoginType.APPLICATION_SERVICE: self._do_app_service,
LoginType.SHARED_SECRET: self._do_shared_secret,
}
session_info = self._get_session_info(request, session)
@ -255,14 +262,11 @@ class RegisterRestServlet(ClientV1RestServlet):
)
password = register_json["password"].encode("utf-8")
desired_user_id = (register_json["user"].encode("utf-8")
if "user" in register_json else None)
if (desired_user_id
and urllib.quote(desired_user_id) != desired_user_id):
raise SynapseError(
400,
"User ID must only contain characters which do not " +
"require URL encoding.")
desired_user_id = (
register_json["user"].encode("utf-8")
if "user" in register_json else None
)
handler = self.handlers.registration_handler
(user_id, token) = yield handler.register(
localpart=desired_user_id,
@ -304,6 +308,51 @@ class RegisterRestServlet(ClientV1RestServlet):
"home_server": self.hs.hostname,
})
@defer.inlineCallbacks
def _do_shared_secret(self, request, register_json, session):
yield run_on_reactor()
if not isinstance(register_json.get("mac", None), basestring):
raise SynapseError(400, "Expected mac.")
if not isinstance(register_json.get("user", None), basestring):
raise SynapseError(400, "Expected 'user' key.")
if not isinstance(register_json.get("password", None), basestring):
raise SynapseError(400, "Expected 'password' key.")
if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled")
user = register_json["user"].encode("utf-8")
# str() because otherwise hmac complains that 'unicode' does not
# have the buffer interface
got_mac = str(register_json["mac"])
want_mac = hmac.new(
key=self.hs.config.registration_shared_secret,
msg=user,
digestmod=sha1,
).hexdigest()
password = register_json["password"].encode("utf-8")
if compare_digest(want_mac, got_mac):
handler = self.handlers.registration_handler
user_id, token = yield handler.register(
localpart=user,
password=password,
)
self._remove_session(session)
defer.returnValue({
"user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname,
})
else:
raise SynapseError(
403, "HMAC incorrect",
)
def _parse_json(request):
try:

View file

@ -21,6 +21,7 @@ from synapse.util.async import run_on_reactor
from synapse.util.expiringcache import ExpiringCache
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.api.auth import AuthEventTypes
from synapse.events.snapshot import EventContext
from collections import namedtuple
@ -38,12 +39,6 @@ def _get_state_key_from_event(event):
KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
AuthEventTypes = (
EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
EventTypes.JoinRules,
)
SIZE_OF_CACHE = 1000
EVICTION_TIMEOUT_SECONDS = 20
@ -139,18 +134,6 @@ class StateHandler(object):
}
context.state_group = None
if hasattr(event, "auth_events") and event.auth_events:
auth_ids = self.hs.get_auth().compute_auth_events(
event, context.current_state
)
context.auth_events = {
k: v
for k, v in context.current_state.items()
if v.event_id in auth_ids
}
else:
context.auth_events = {}
if event.is_state():
key = (event.type, event.state_key)
if key in context.current_state:
@ -187,18 +170,6 @@ class StateHandler(object):
replaces = context.current_state[key]
event.unsigned["replaces_state"] = replaces.event_id
if hasattr(event, "auth_events") and event.auth_events:
auth_ids = self.hs.get_auth().compute_auth_events(
event, context.current_state
)
context.auth_events = {
k: v
for k, v in context.current_state.items()
if v.event_id in auth_ids
}
else:
context.auth_events = {}
context.prev_state_events = prev_state
defer.returnValue(context)

View file

@ -14,15 +14,12 @@
# limitations under the License.
from twisted.internet import defer
from synapse.util.logutils import log_function
from synapse.api.constants import EventTypes
from .appservice import (
ApplicationServiceStore, ApplicationServiceTransactionStore
)
from ._base import Cache
from .directory import DirectoryStore
from .feedback import FeedbackStore
from .events import EventsStore
from .presence import PresenceStore
from .profile import ProfileStore
from .registration import RegistrationStore
@ -41,11 +38,6 @@ from .state import StateStore
from .signatures import SignatureStore
from .filtering import FilteringStore
from syutil.base64util import decode_base64
from syutil.jsonutil import encode_canonical_json
from synapse.crypto.event_signing import compute_event_reference_hash
import fnmatch
import imp
@ -63,16 +55,14 @@ SCHEMA_VERSION = 15
dir_path = os.path.abspath(os.path.dirname(__file__))
class _RollbackButIsFineException(Exception):
""" This exception is used to rollback a transaction without implying
something went wrong.
"""
pass
# Number of msec of granularity to store the user IP 'last seen' time. Smaller
# times give more inserts into the database even for readonly API hits
# 120 seconds == 2 minutes
LAST_SEEN_GRANULARITY = 120*1000
class DataStore(RoomMemberStore, RoomStore,
RegistrationStore, StreamStore, ProfileStore, FeedbackStore,
RegistrationStore, StreamStore, ProfileStore,
PresenceStore, TransactionStore,
DirectoryStore, KeyStore, StateStore, SignatureStore,
ApplicationServiceStore,
@ -83,6 +73,7 @@ class DataStore(RoomMemberStore, RoomStore,
PusherStore,
PushRuleStore,
ApplicationServiceTransactionStore,
EventsStore,
):
def __init__(self, hs):
@ -92,424 +83,28 @@ class DataStore(RoomMemberStore, RoomStore,
self.min_token_deferred = self._get_min_token()
self.min_token = None
@defer.inlineCallbacks
@log_function
def persist_event(self, event, context, backfilled=False,
is_new_state=True, current_state=None):
stream_ordering = None
if backfilled:
if not self.min_token_deferred.called:
yield self.min_token_deferred
self.min_token -= 1
stream_ordering = self.min_token
try:
yield self.runInteraction(
"persist_event",
self._persist_event_txn,
event=event,
context=context,
backfilled=backfilled,
stream_ordering=stream_ordering,
is_new_state=is_new_state,
current_state=current_state,
)
except _RollbackButIsFineException:
pass
@defer.inlineCallbacks
def get_event(self, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False,
allow_none=False):
"""Get an event from the database by event_id.
Args:
event_id (str): The event_id of the event to fetch
check_redacted (bool): If True, check if event has been redacted
and redact it.
get_prev_content (bool): If True and event is a state event,
include the previous states content in the unsigned field.
allow_rejected (bool): If True return rejected events.
allow_none (bool): If True, return None if no event found, if
False throw an exception.
Returns:
Deferred : A FrozenEvent.
"""
event = yield self.runInteraction(
"get_event", self._get_event_txn,
event_id,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
if not event and not allow_none:
raise RuntimeError("Could not find event %s" % (event_id,))
defer.returnValue(event)
@log_function
def _persist_event_txn(self, txn, event, context, backfilled,
stream_ordering=None, is_new_state=True,
current_state=None):
# Remove the any existing cache entries for the event_id
self._get_event_cache.pop(event.event_id)
# We purposefully do this first since if we include a `current_state`
# key, we *want* to update the `current_state_events` table
if current_state:
txn.execute(
"DELETE FROM current_state_events WHERE room_id = ?",
(event.room_id,)
)
for s in current_state:
self._simple_insert_txn(
txn,
"current_state_events",
{
"event_id": s.event_id,
"room_id": s.room_id,
"type": s.type,
"state_key": s.state_key,
},
or_replace=True,
)
if event.is_state() and is_new_state:
if not backfilled and not context.rejected:
self._simple_insert_txn(
txn,
table="state_forward_extremities",
values={
"event_id": event.event_id,
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
},
or_replace=True,
)
for prev_state_id, _ in event.prev_state:
self._simple_delete_txn(
txn,
table="state_forward_extremities",
keyvalues={
"event_id": prev_state_id,
}
)
outlier = event.internal_metadata.is_outlier()
if not outlier:
self._store_state_groups_txn(txn, event, context)
self._update_min_depth_for_room_txn(
txn,
event.room_id,
event.depth
)
self._handle_prev_events(
txn,
outlier=outlier,
event_id=event.event_id,
prev_events=event.prev_events,
room_id=event.room_id,
)
have_persisted = self._simple_select_one_onecol_txn(
txn,
table="event_json",
keyvalues={"event_id": event.event_id},
retcol="event_id",
allow_none=True,
)
metadata_json = encode_canonical_json(
event.internal_metadata.get_dict()
)
# If we have already persisted this event, we don't need to do any
# more processing.
# The processing above must be done on every call to persist event,
# since they might not have happened on previous calls. For example,
# if we are persisting an event that we had persisted as an outlier,
# but is no longer one.
if have_persisted:
if not outlier:
sql = (
"UPDATE event_json SET internal_metadata = ?"
" WHERE event_id = ?"
)
txn.execute(
sql,
(metadata_json.decode("UTF-8"), event.event_id,)
)
sql = (
"UPDATE events SET outlier = 0"
" WHERE event_id = ?"
)
txn.execute(
sql,
(event.event_id,)
)
return
if event.type == EventTypes.Member:
self._store_room_member_txn(txn, event)
elif event.type == EventTypes.Feedback:
self._store_feedback_txn(txn, event)
elif event.type == EventTypes.Name:
self._store_room_name_txn(txn, event)
elif event.type == EventTypes.Topic:
self._store_room_topic_txn(txn, event)
elif event.type == EventTypes.Redaction:
self._store_redaction(txn, event)
event_dict = {
k: v
for k, v in event.get_dict().items()
if k not in [
"redacted",
"redacted_because",
]
}
self._simple_insert_txn(
txn,
table="event_json",
values={
"event_id": event.event_id,
"room_id": event.room_id,
"internal_metadata": metadata_json.decode("UTF-8"),
"json": encode_canonical_json(event_dict).decode("UTF-8"),
},
or_replace=True,
)
content = encode_canonical_json(
event.content
).decode("UTF-8")
vals = {
"topological_ordering": event.depth,
"event_id": event.event_id,
"type": event.type,
"room_id": event.room_id,
"content": content,
"processed": True,
"outlier": outlier,
"depth": event.depth,
}
if stream_ordering is not None:
vals["stream_ordering"] = stream_ordering
unrec = {
k: v
for k, v in event.get_dict().items()
if k not in vals.keys() and k not in [
"redacted",
"redacted_because",
"signatures",
"hashes",
"prev_events",
]
}
vals["unrecognized_keys"] = encode_canonical_json(
unrec
).decode("UTF-8")
try:
self._simple_insert_txn(
txn,
"events",
vals,
or_replace=(not outlier),
or_ignore=bool(outlier),
)
except:
logger.warn(
"Failed to persist, probably duplicate: %s",
event.event_id,
exc_info=True,
)
raise _RollbackButIsFineException("_persist_event")
if context.rejected:
self._store_rejections_txn(txn, event.event_id, context.rejected)
if event.is_state():
vals = {
"event_id": event.event_id,
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
}
# TODO: How does this work with backfilling?
if hasattr(event, "replaces_state"):
vals["prev_state"] = event.replaces_state
self._simple_insert_txn(
txn,
"state_events",
vals,
or_replace=True,
)
if is_new_state and not context.rejected:
self._simple_insert_txn(
txn,
"current_state_events",
{
"event_id": event.event_id,
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
},
or_replace=True,
)
for e_id, h in event.prev_state:
self._simple_insert_txn(
txn,
table="event_edges",
values={
"event_id": event.event_id,
"prev_event_id": e_id,
"room_id": event.room_id,
"is_state": 1,
},
or_ignore=True,
)
for hash_alg, hash_base64 in event.hashes.items():
hash_bytes = decode_base64(hash_base64)
self._store_event_content_hash_txn(
txn, event.event_id, hash_alg, hash_bytes,
)
for prev_event_id, prev_hashes in event.prev_events:
for alg, hash_base64 in prev_hashes.items():
hash_bytes = decode_base64(hash_base64)
self._store_prev_event_hash_txn(
txn, event.event_id, prev_event_id, alg, hash_bytes
)
for auth_id, _ in event.auth_events:
self._simple_insert_txn(
txn,
table="event_auth",
values={
"event_id": event.event_id,
"room_id": event.room_id,
"auth_id": auth_id,
},
or_ignore=True,
)
(ref_alg, ref_hash_bytes) = compute_event_reference_hash(event)
self._store_event_reference_hash_txn(
txn, event.event_id, ref_alg, ref_hash_bytes
)
def _store_redaction(self, txn, event):
# invalidate the cache for the redacted event
self._get_event_cache.pop(event.redacts)
txn.execute(
"INSERT OR IGNORE INTO redactions "
"(event_id, redacts) VALUES (?,?)",
(event.event_id, event.redacts)
self.client_ip_last_seen = Cache(
name="client_ip_last_seen",
keylen=4,
)
@defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key=""):
del_sql = (
"SELECT event_id FROM redactions WHERE redacts = e.event_id "
"LIMIT 1"
)
sql = (
"SELECT e.*, (%(redacted)s) AS redacted FROM events as e "
"INNER JOIN current_state_events as c ON e.event_id = c.event_id "
"INNER JOIN state_events as s ON e.event_id = s.event_id "
"WHERE c.room_id = ? "
) % {
"redacted": del_sql,
}
if event_type and state_key is not None:
sql += " AND s.type = ? AND s.state_key = ? "
args = (room_id, event_type, state_key)
elif event_type:
sql += " AND s.type = ?"
args = (room_id, event_type)
else:
args = (room_id, )
results = yield self._execute_and_decode("get_current_state", sql, *args)
events = yield self._parse_events(results)
defer.returnValue(events)
@defer.inlineCallbacks
def get_room_name_and_aliases(self, room_id):
del_sql = (
"SELECT event_id FROM redactions WHERE redacts = e.event_id "
"LIMIT 1"
)
sql = (
"SELECT e.*, (%(redacted)s) AS redacted FROM events as e "
"INNER JOIN current_state_events as c ON e.event_id = c.event_id "
"INNER JOIN state_events as s ON e.event_id = s.event_id "
"WHERE c.room_id = ? "
) % {
"redacted": del_sql,
}
sql += " AND ((s.type = 'm.room.name' AND s.state_key = '')"
sql += " OR s.type = 'm.room.aliases')"
args = (room_id,)
results = yield self._execute_and_decode("get_current_state", sql, *args)
events = yield self._parse_events(results)
name = None
aliases = []
for e in events:
if e.type == 'm.room.name':
if 'name' in e.content:
name = e.content['name']
elif e.type == 'm.room.aliases':
if 'aliases' in e.content:
aliases.extend(e.content['aliases'])
defer.returnValue((name, aliases))
@defer.inlineCallbacks
def _get_min_token(self):
row = yield self._execute(
"_get_min_token", None, "SELECT MIN(stream_ordering) FROM events"
)
self.min_token = row[0][0] if row and row[0] and row[0][0] else -1
self.min_token = min(self.min_token, -1)
logger.debug("min_token is: %s", self.min_token)
defer.returnValue(self.min_token)
def insert_client_ip(self, user, access_token, device_id, ip, user_agent):
return self._simple_insert(
now = int(self._clock.time_msec())
key = (user.to_string(), access_token, device_id, ip)
try:
last_seen = self.client_ip_last_seen.get(*key)
except KeyError:
last_seen = None
# Rate-limited inserts
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
defer.returnValue(None)
self.client_ip_last_seen.prefill(*key + (now,))
yield self._simple_insert(
"user_ips",
{
"user": user.to_string(),
@ -517,8 +112,9 @@ class DataStore(RoomMemberStore, RoomStore,
"device_id": device_id,
"ip": ip,
"user_agent": user_agent,
"last_seen": int(self._clock.time_msec()),
}
"last_seen": now,
},
desc="insert_client_ip",
)
def get_user_ip_and_agents(self, user):
@ -528,38 +124,7 @@ class DataStore(RoomMemberStore, RoomStore,
retcols=[
"device_id", "access_token", "ip", "user_agent", "last_seen"
],
)
def have_events(self, event_ids):
"""Given a list of event ids, check if we have already processed them.
Returns:
dict: Has an entry for each event id we already have seen. Maps to
the rejected reason string if we rejected the event, else maps to
None.
"""
if not event_ids:
return defer.succeed({})
def f(txn):
sql = (
"SELECT e.event_id, reason FROM events as e "
"LEFT JOIN rejections as r ON e.event_id = r.event_id "
"WHERE e.event_id = ?"
)
res = {}
for event_id in event_ids:
txn.execute(sql, (event_id,))
row = txn.fetchone()
if row:
_, rejected = row
res[event_id] = rejected
return res
return self.runInteraction(
"have_events", f,
desc="get_user_ip_and_agents",
)

View file

@ -25,6 +25,7 @@ import synapse.metrics
from twisted.internet import defer
from collections import namedtuple, OrderedDict
import functools
import simplejson as json
import sys
import time
@ -38,6 +39,8 @@ transaction_logger = logging.getLogger("synapse.storage.txn")
metrics = synapse.metrics.get_metrics_for("synapse.storage")
sql_scheduling_timer = metrics.register_distribution("schedule_time")
sql_query_timer = metrics.register_distribution("query_time", labels=["verb"])
sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"])
sql_getevents_timer = metrics.register_distribution("getEvents_time", labels=["desc"])
@ -50,14 +53,57 @@ cache_counter = metrics.register_cache(
)
# TODO(paul):
# * more generic key management
# * consider other eviction strategies - LRU?
def cached(max_entries=1000):
class Cache(object):
def __init__(self, name, max_entries=1000, keylen=1, lru=False):
if lru:
self.cache = LruCache(max_size=max_entries)
self.max_entries = None
else:
self.cache = OrderedDict()
self.max_entries = max_entries
self.name = name
self.keylen = keylen
caches_by_name[name] = self.cache
def get(self, *keyargs):
if len(keyargs) != self.keylen:
raise ValueError("Expected a key to have %d items", self.keylen)
if keyargs in self.cache:
cache_counter.inc_hits(self.name)
return self.cache[keyargs]
cache_counter.inc_misses(self.name)
raise KeyError()
def prefill(self, *args): # because I can't *keyargs, value
keyargs = args[:-1]
value = args[-1]
if len(keyargs) != self.keylen:
raise ValueError("Expected a key to have %d items", self.keylen)
if self.max_entries is not None:
while len(self.cache) >= self.max_entries:
self.cache.popitem(last=False)
self.cache[keyargs] = value
def invalidate(self, *keyargs):
if len(keyargs) != self.keylen:
raise ValueError("Expected a key to have %d items", self.keylen)
self.cache.pop(keyargs, None)
def cached(max_entries=1000, num_args=1, lru=False):
""" A method decorator that applies a memoizing cache around the function.
The function is presumed to take one additional argument, which is used as
the key for the cache. Cache hits are served directly from the cache;
The function is presumed to take zero or more arguments, which are used in
a tuple as the key for the cache. Hits are served directly from the cache;
misses use the function body to generate the value.
The wrapped function has an additional member, a callable called
@ -68,33 +114,27 @@ def cached(max_entries=1000):
calling the calculation function.
"""
def wrap(orig):
cache = OrderedDict()
name = orig.__name__
caches_by_name[name] = cache
def prefill(key, value):
while len(cache) > max_entries:
cache.popitem(last=False)
cache[key] = value
cache = Cache(
name=orig.__name__,
max_entries=max_entries,
keylen=num_args,
lru=lru,
)
@functools.wraps(orig)
@defer.inlineCallbacks
def wrapped(self, key):
if key in cache:
cache_counter.inc_hits(name)
defer.returnValue(cache[key])
def wrapped(self, *keyargs):
try:
defer.returnValue(cache.get(*keyargs))
except KeyError:
ret = yield orig(self, *keyargs)
cache.prefill(*keyargs + (ret,))
cache_counter.inc_misses(name)
ret = yield orig(self, key)
prefill(key, ret)
defer.returnValue(ret)
def invalidate(key):
cache.pop(key, None)
wrapped.invalidate = invalidate
wrapped.prefill = prefill
wrapped.invalidate = cache.invalidate
wrapped.prefill = cache.prefill
return wrapped
return wrap
@ -240,6 +280,8 @@ class SQLBaseStore(object):
"""Wraps the .runInteraction() method on the underlying db_pool."""
current_context = LoggingContext.current_context()
start_time = time.time() * 1000
def inner_func(txn, *args, **kwargs):
with LoggingContext("runInteraction") as context:
current_context.copy_to(context)
@ -252,6 +294,7 @@ class SQLBaseStore(object):
name = "%s-%x" % (desc, txn_id, )
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
transaction_logger.debug("[TXN START] {%s}", name)
try:
return func(LoggingTransaction(txn, name), *args, **kwargs)
@ -314,7 +357,8 @@ class SQLBaseStore(object):
# "Simple" SQL API methods that operate on a single table with no JOINs,
# no complex WHERE clauses, just a dict of values for columns.
def _simple_insert(self, table, values, or_replace=False, or_ignore=False):
def _simple_insert(self, table, values, or_replace=False, or_ignore=False,
desc="_simple_insert"):
"""Executes an INSERT query on the named table.
Args:
@ -323,7 +367,7 @@ class SQLBaseStore(object):
or_replace : bool; if True performs an INSERT OR REPLACE
"""
return self.runInteraction(
"_simple_insert",
desc,
self._simple_insert_txn, table, values, or_replace=or_replace,
or_ignore=or_ignore,
)
@ -347,7 +391,7 @@ class SQLBaseStore(object):
txn.execute(sql, values.values())
return txn.lastrowid
def _simple_upsert(self, table, keyvalues, values):
def _simple_upsert(self, table, keyvalues, values, desc="_simple_upsert"):
"""
Args:
table (str): The table to upsert into
@ -356,7 +400,7 @@ class SQLBaseStore(object):
Returns: A deferred
"""
return self.runInteraction(
"_simple_upsert",
desc,
self._simple_upsert_txn, table, keyvalues, values
)
@ -392,7 +436,7 @@ class SQLBaseStore(object):
txn.execute(sql, allvalues.values())
def _simple_select_one(self, table, keyvalues, retcols,
allow_none=False):
allow_none=False, desc="_simple_select_one"):
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it.
@ -404,12 +448,15 @@ class SQLBaseStore(object):
allow_none : If true, return None instead of failing if the SELECT
statement returns no rows
"""
return self._simple_selectupdate_one(
table, keyvalues, retcols=retcols, allow_none=allow_none
return self.runInteraction(
desc,
self._simple_select_one_txn,
table, keyvalues, retcols, allow_none,
)
def _simple_select_one_onecol(self, table, keyvalues, retcol,
allow_none=False):
allow_none=False,
desc="_simple_select_one_onecol"):
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it."
@ -419,7 +466,7 @@ class SQLBaseStore(object):
retcol : string giving the name of the column to return
"""
return self.runInteraction(
"_simple_select_one_onecol",
desc,
self._simple_select_one_onecol_txn,
table, keyvalues, retcol, allow_none=allow_none,
)
@ -455,7 +502,8 @@ class SQLBaseStore(object):
return [r[0] for r in txn.fetchall()]
def _simple_select_onecol(self, table, keyvalues, retcol):
def _simple_select_onecol(self, table, keyvalues, retcol,
desc="_simple_select_onecol"):
"""Executes a SELECT query on the named table, which returns a list
comprising of the values of the named column from the selected rows.
@ -468,12 +516,13 @@ class SQLBaseStore(object):
Deferred: Results in a list
"""
return self.runInteraction(
"_simple_select_onecol",
desc,
self._simple_select_onecol_txn,
table, keyvalues, retcol
)
def _simple_select_list(self, table, keyvalues, retcols):
def _simple_select_list(self, table, keyvalues, retcols,
desc="_simple_select_list"):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@ -484,7 +533,7 @@ class SQLBaseStore(object):
retcols : list of strings giving the names of the columns to return
"""
return self.runInteraction(
"_simple_select_list",
desc,
self._simple_select_list_txn,
table, keyvalues, retcols
)
@ -516,7 +565,7 @@ class SQLBaseStore(object):
return self.cursor_to_dict(txn)
def _simple_update_one(self, table, keyvalues, updatevalues,
retcols=None):
desc="_simple_update_one"):
"""Executes an UPDATE query on the named table, setting new values for
columns in a row matching the key values.
@ -534,42 +583,19 @@ class SQLBaseStore(object):
get-and-set. This can be used to implement compare-and-set by putting
the update column in the 'keyvalues' dict as well.
"""
return self._simple_selectupdate_one(table, keyvalues, updatevalues,
retcols=retcols)
def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None,
retcols=None, allow_none=False):
""" Combined SELECT then UPDATE."""
if retcols:
select_sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
", ".join(retcols),
table,
" AND ".join("%s = ?" % (k) for k in keyvalues)
return self.runInteraction(
desc,
self._simple_update_one_txn,
table, keyvalues, updatevalues,
)
if updatevalues:
def _simple_update_one_txn(self, txn, table, keyvalues, updatevalues):
update_sql = "UPDATE %s SET %s WHERE %s" % (
table,
", ".join("%s = ?" % (k,) for k in updatevalues),
" AND ".join("%s = ?" % (k,) for k in keyvalues)
)
def func(txn):
ret = None
if retcols:
txn.execute(select_sql, keyvalues.values())
row = txn.fetchone()
if not row:
if allow_none:
return None
raise StoreError(404, "No row found")
if txn.rowcount > 1:
raise StoreError(500, "More than one row matched")
ret = dict(zip(retcols, row))
if updatevalues:
txn.execute(
update_sql,
updatevalues.values() + keyvalues.values()
@ -580,10 +606,53 @@ class SQLBaseStore(object):
if txn.rowcount > 1:
raise StoreError(500, "More than one row matched")
return ret
return self.runInteraction("_simple_selectupdate_one", func)
def _simple_select_one_txn(self, txn, table, keyvalues, retcols,
allow_none=False):
select_sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
", ".join(retcols),
table,
" AND ".join("%s = ?" % (k) for k in keyvalues)
)
def _simple_delete_one(self, table, keyvalues):
txn.execute(select_sql, keyvalues.values())
row = txn.fetchone()
if not row:
if allow_none:
return None
raise StoreError(404, "No row found")
if txn.rowcount > 1:
raise StoreError(500, "More than one row matched")
return dict(zip(retcols, row))
def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None,
retcols=None, allow_none=False,
desc="_simple_selectupdate_one"):
""" Combined SELECT then UPDATE."""
def func(txn):
ret = None
if retcols:
ret = self._simple_select_one_txn(
txn,
table=table,
keyvalues=keyvalues,
retcols=retcols,
allow_none=allow_none,
)
if updatevalues:
self._simple_update_one_txn(
txn,
table=table,
keyvalues=keyvalues,
updatevalues=updatevalues,
)
return ret
return self.runInteraction(desc, func)
def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"):
"""Executes a DELETE query on the named table, expecting to delete a
single row.
@ -602,9 +671,9 @@ class SQLBaseStore(object):
raise StoreError(404, "No row found")
if txn.rowcount > 1:
raise StoreError(500, "more than one row matched")
return self.runInteraction("_simple_delete_one", func)
return self.runInteraction(desc, func)
def _simple_delete(self, table, keyvalues):
def _simple_delete(self, table, keyvalues, desc="_simple_delete"):
"""Executes a DELETE query on the named table.
Args:
@ -612,7 +681,7 @@ class SQLBaseStore(object):
keyvalues : dict of column names and values to select the row with
"""
return self.runInteraction("_simple_delete", self._simple_delete_txn)
return self.runInteraction(desc, self._simple_delete_txn)
def _simple_delete_txn(self, txn, table, keyvalues):
sql = "DELETE FROM %s WHERE %s" % (
@ -782,6 +851,13 @@ class SQLBaseStore(object):
return result[0] if result else None
class _RollbackButIsFineException(Exception):
""" This exception is used to rollback a transaction without implying
something went wrong.
"""
pass
class Table(object):
""" A base class used to store information about a particular table.
"""

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore
from ._base import SQLBaseStore, cached
from synapse.api.errors import SynapseError
@ -48,6 +48,7 @@ class DirectoryStore(SQLBaseStore):
{"room_alias": room_alias.to_string()},
"room_id",
allow_none=True,
desc="get_association_from_room_alias",
)
if not room_id:
@ -58,6 +59,7 @@ class DirectoryStore(SQLBaseStore):
"room_alias_servers",
{"room_alias": room_alias.to_string()},
"server",
desc="get_association_from_room_alias",
)
if not servers:
@ -87,6 +89,7 @@ class DirectoryStore(SQLBaseStore):
"room_alias": room_alias.to_string(),
"room_id": room_id,
},
desc="create_room_alias_association",
)
except sqlite3.IntegrityError:
raise SynapseError(
@ -100,16 +103,22 @@ class DirectoryStore(SQLBaseStore):
{
"room_alias": room_alias.to_string(),
"server": server,
}
},
desc="create_room_alias_association",
)
self.get_aliases_for_room.invalidate(room_id)
@defer.inlineCallbacks
def delete_room_alias(self, room_alias):
return self.runInteraction(
room_id = yield self.runInteraction(
"delete_room_alias",
self._delete_room_alias_txn,
room_alias,
)
self.get_aliases_for_room.invalidate(room_id)
defer.returnValue(room_id)
def _delete_room_alias_txn(self, txn, room_alias):
cursor = txn.execute(
"SELECT room_id FROM room_aliases WHERE room_alias = ?",
@ -134,9 +143,11 @@ class DirectoryStore(SQLBaseStore):
return room_id
@cached()
def get_aliases_for_room(self, room_id):
return self._simple_select_onecol(
"room_aliases",
{"room_id": room_id},
"room_alias",
desc="get_aliases_for_room",
)

View file

@ -429,3 +429,15 @@ class EventFederationStore(SQLBaseStore):
)
return events[:limit]
def clean_room_for_join(self, room_id):
return self.runInteraction(
"clean_room_for_join",
self._clean_room_for_join_txn,
room_id,
)
def _clean_room_for_join_txn(self, txn, room_id):
query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
txn.execute(query, (room_id,))

395
synapse/storage/events.py Normal file
View file

@ -0,0 +1,395 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from _base import SQLBaseStore, _RollbackButIsFineException
from twisted.internet import defer
from synapse.util.logutils import log_function
from synapse.api.constants import EventTypes
from synapse.crypto.event_signing import compute_event_reference_hash
from syutil.base64util import decode_base64
from syutil.jsonutil import encode_canonical_json
import logging
logger = logging.getLogger(__name__)
class EventsStore(SQLBaseStore):
@defer.inlineCallbacks
@log_function
def persist_event(self, event, context, backfilled=False,
is_new_state=True, current_state=None):
stream_ordering = None
if backfilled:
if not self.min_token_deferred.called:
yield self.min_token_deferred
self.min_token -= 1
stream_ordering = self.min_token
try:
yield self.runInteraction(
"persist_event",
self._persist_event_txn,
event=event,
context=context,
backfilled=backfilled,
stream_ordering=stream_ordering,
is_new_state=is_new_state,
current_state=current_state,
)
self.get_room_events_max_id.invalidate()
except _RollbackButIsFineException:
pass
@defer.inlineCallbacks
def get_event(self, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False,
allow_none=False):
"""Get an event from the database by event_id.
Args:
event_id (str): The event_id of the event to fetch
check_redacted (bool): If True, check if event has been redacted
and redact it.
get_prev_content (bool): If True and event is a state event,
include the previous states content in the unsigned field.
allow_rejected (bool): If True return rejected events.
allow_none (bool): If True, return None if no event found, if
False throw an exception.
Returns:
Deferred : A FrozenEvent.
"""
event = yield self.runInteraction(
"get_event", self._get_event_txn,
event_id,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
if not event and not allow_none:
raise RuntimeError("Could not find event %s" % (event_id,))
defer.returnValue(event)
@log_function
def _persist_event_txn(self, txn, event, context, backfilled,
stream_ordering=None, is_new_state=True,
current_state=None):
# Remove the any existing cache entries for the event_id
self._get_event_cache.pop(event.event_id)
# We purposefully do this first since if we include a `current_state`
# key, we *want* to update the `current_state_events` table
if current_state:
txn.execute(
"DELETE FROM current_state_events WHERE room_id = ?",
(event.room_id,)
)
for s in current_state:
self._simple_insert_txn(
txn,
"current_state_events",
{
"event_id": s.event_id,
"room_id": s.room_id,
"type": s.type,
"state_key": s.state_key,
},
or_replace=True,
)
if event.is_state() and is_new_state:
if not backfilled and not context.rejected:
self._simple_insert_txn(
txn,
table="state_forward_extremities",
values={
"event_id": event.event_id,
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
},
or_replace=True,
)
for prev_state_id, _ in event.prev_state:
self._simple_delete_txn(
txn,
table="state_forward_extremities",
keyvalues={
"event_id": prev_state_id,
}
)
outlier = event.internal_metadata.is_outlier()
if not outlier:
self._store_state_groups_txn(txn, event, context)
self._update_min_depth_for_room_txn(
txn,
event.room_id,
event.depth
)
self._handle_prev_events(
txn,
outlier=outlier,
event_id=event.event_id,
prev_events=event.prev_events,
room_id=event.room_id,
)
have_persisted = self._simple_select_one_onecol_txn(
txn,
table="event_json",
keyvalues={"event_id": event.event_id},
retcol="event_id",
allow_none=True,
)
metadata_json = encode_canonical_json(
event.internal_metadata.get_dict()
)
# If we have already persisted this event, we don't need to do any
# more processing.
# The processing above must be done on every call to persist event,
# since they might not have happened on previous calls. For example,
# if we are persisting an event that we had persisted as an outlier,
# but is no longer one.
if have_persisted:
if not outlier:
sql = (
"UPDATE event_json SET internal_metadata = ?"
" WHERE event_id = ?"
)
txn.execute(
sql,
(metadata_json.decode("UTF-8"), event.event_id,)
)
sql = (
"UPDATE events SET outlier = 0"
" WHERE event_id = ?"
)
txn.execute(
sql,
(event.event_id,)
)
return
if event.type == EventTypes.Member:
self._store_room_member_txn(txn, event)
elif event.type == EventTypes.Feedback:
self._store_feedback_txn(txn, event)
elif event.type == EventTypes.Name:
self._store_room_name_txn(txn, event)
elif event.type == EventTypes.Topic:
self._store_room_topic_txn(txn, event)
elif event.type == EventTypes.Redaction:
self._store_redaction(txn, event)
event_dict = {
k: v
for k, v in event.get_dict().items()
if k not in [
"redacted",
"redacted_because",
]
}
self._simple_insert_txn(
txn,
table="event_json",
values={
"event_id": event.event_id,
"room_id": event.room_id,
"internal_metadata": metadata_json.decode("UTF-8"),
"json": encode_canonical_json(event_dict).decode("UTF-8"),
},
or_replace=True,
)
content = encode_canonical_json(
event.content
).decode("UTF-8")
vals = {
"topological_ordering": event.depth,
"event_id": event.event_id,
"type": event.type,
"room_id": event.room_id,
"content": content,
"processed": True,
"outlier": outlier,
"depth": event.depth,
}
if stream_ordering is not None:
vals["stream_ordering"] = stream_ordering
unrec = {
k: v
for k, v in event.get_dict().items()
if k not in vals.keys() and k not in [
"redacted",
"redacted_because",
"signatures",
"hashes",
"prev_events",
]
}
vals["unrecognized_keys"] = encode_canonical_json(
unrec
).decode("UTF-8")
try:
self._simple_insert_txn(
txn,
"events",
vals,
or_replace=(not outlier),
or_ignore=bool(outlier),
)
except:
logger.warn(
"Failed to persist, probably duplicate: %s",
event.event_id,
exc_info=True,
)
raise _RollbackButIsFineException("_persist_event")
if context.rejected:
self._store_rejections_txn(txn, event.event_id, context.rejected)
if event.is_state():
vals = {
"event_id": event.event_id,
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
}
# TODO: How does this work with backfilling?
if hasattr(event, "replaces_state"):
vals["prev_state"] = event.replaces_state
self._simple_insert_txn(
txn,
"state_events",
vals,
)
if is_new_state and not context.rejected:
self._simple_insert_txn(
txn,
"current_state_events",
{
"event_id": event.event_id,
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
},
)
for e_id, h in event.prev_state:
self._simple_insert_txn(
txn,
table="event_edges",
values={
"event_id": event.event_id,
"prev_event_id": e_id,
"room_id": event.room_id,
"is_state": 1,
},
)
for hash_alg, hash_base64 in event.hashes.items():
hash_bytes = decode_base64(hash_base64)
self._store_event_content_hash_txn(
txn, event.event_id, hash_alg, hash_bytes,
)
for prev_event_id, prev_hashes in event.prev_events:
for alg, hash_base64 in prev_hashes.items():
hash_bytes = decode_base64(hash_base64)
self._store_prev_event_hash_txn(
txn, event.event_id, prev_event_id, alg, hash_bytes
)
for auth_id, _ in event.auth_events:
self._simple_insert_txn(
txn,
table="event_auth",
values={
"event_id": event.event_id,
"room_id": event.room_id,
"auth_id": auth_id,
},
)
(ref_alg, ref_hash_bytes) = compute_event_reference_hash(event)
self._store_event_reference_hash_txn(
txn, event.event_id, ref_alg, ref_hash_bytes
)
def _store_redaction(self, txn, event):
# invalidate the cache for the redacted event
self._get_event_cache.pop(event.redacts)
txn.execute(
"INSERT INTO redactions (event_id, redacts) VALUES (?,?)",
(event.event_id, event.redacts)
)
def have_events(self, event_ids):
"""Given a list of event ids, check if we have already processed them.
Returns:
dict: Has an entry for each event id we already have seen. Maps to
the rejected reason string if we rejected the event, else maps to
None.
"""
if not event_ids:
return defer.succeed({})
def f(txn):
sql = (
"SELECT e.event_id, reason FROM events as e "
"LEFT JOIN rejections as r ON e.event_id = r.event_id "
"WHERE e.event_id = ?"
)
res = {}
for event_id in event_ids:
txn.execute(sql, (event_id,))
row = txn.fetchone()
if row:
_, rejected = row
res[event_id] = rejected
return res
return self.runInteraction(
"have_events", f,
)

View file

@ -1,47 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import SQLBaseStore
class FeedbackStore(SQLBaseStore):
def _store_feedback_txn(self, txn, event):
self._simple_insert_txn(txn, "feedback", {
"event_id": event.event_id,
"feedback_type": event.content["type"],
"room_id": event.room_id,
"target_event_id": event.content["target_event_id"],
"sender": event.user_id,
})
@defer.inlineCallbacks
def get_feedback_for_event(self, event_id):
sql = (
"SELECT events.* FROM events INNER JOIN feedback "
"ON events.event_id = feedback.event_id "
"WHERE feedback.target_event_id = ? "
)
rows = yield self._execute_and_decode("get_feedback_for_event", sql, event_id)
defer.returnValue(
[
(yield self._parse_events(r))
for r in rows
]
)

View file

@ -31,6 +31,7 @@ class FilteringStore(SQLBaseStore):
},
retcol="filter_json",
allow_none=False,
desc="get_user_filter",
)
defer.returnValue(json.loads(def_json))

View file

@ -32,6 +32,7 @@ class MediaRepositoryStore(SQLBaseStore):
{"media_id": media_id},
("media_type", "media_length", "upload_name", "created_ts"),
allow_none=True,
desc="get_local_media",
)
def store_local_media(self, media_id, media_type, time_now_ms, upload_name,
@ -45,7 +46,8 @@ class MediaRepositoryStore(SQLBaseStore):
"upload_name": upload_name,
"media_length": media_length,
"user_id": user_id.to_string(),
}
},
desc="store_local_media",
)
def get_local_media_thumbnails(self, media_id):
@ -55,7 +57,8 @@ class MediaRepositoryStore(SQLBaseStore):
(
"thumbnail_width", "thumbnail_height", "thumbnail_method",
"thumbnail_type", "thumbnail_length",
)
),
desc="get_local_media_thumbnails",
)
def store_local_thumbnail(self, media_id, thumbnail_width,
@ -70,7 +73,8 @@ class MediaRepositoryStore(SQLBaseStore):
"thumbnail_method": thumbnail_method,
"thumbnail_type": thumbnail_type,
"thumbnail_length": thumbnail_length,
}
},
desc="store_local_thumbnail",
)
def get_cached_remote_media(self, origin, media_id):
@ -82,6 +86,7 @@ class MediaRepositoryStore(SQLBaseStore):
"filesystem_id",
),
allow_none=True,
desc="get_cached_remote_media",
)
def store_cached_remote_media(self, origin, media_id, media_type,
@ -97,7 +102,8 @@ class MediaRepositoryStore(SQLBaseStore):
"created_ts": time_now_ms,
"upload_name": upload_name,
"filesystem_id": filesystem_id,
}
},
desc="store_cached_remote_media",
)
def get_remote_media_thumbnails(self, origin, media_id):
@ -107,7 +113,8 @@ class MediaRepositoryStore(SQLBaseStore):
(
"thumbnail_width", "thumbnail_height", "thumbnail_method",
"thumbnail_type", "thumbnail_length", "filesystem_id",
)
),
desc="get_remote_media_thumbnails",
)
def store_remote_media_thumbnail(self, origin, media_id, filesystem_id,
@ -125,5 +132,6 @@ class MediaRepositoryStore(SQLBaseStore):
"thumbnail_type": thumbnail_type,
"thumbnail_length": thumbnail_length,
"filesystem_id": filesystem_id,
}
},
desc="store_remote_media_thumbnail",
)

View file

@ -21,6 +21,7 @@ class PresenceStore(SQLBaseStore):
return self._simple_insert(
table="presence",
values={"user_id": user_localpart},
desc="create_presence",
)
def has_presence_state(self, user_localpart):
@ -29,6 +30,7 @@ class PresenceStore(SQLBaseStore):
keyvalues={"user_id": user_localpart},
retcols=["user_id"],
allow_none=True,
desc="has_presence_state",
)
def get_presence_state(self, user_localpart):
@ -36,6 +38,7 @@ class PresenceStore(SQLBaseStore):
table="presence",
keyvalues={"user_id": user_localpart},
retcols=["state", "status_msg", "mtime"],
desc="get_presence_state",
)
def set_presence_state(self, user_localpart, new_state):
@ -45,7 +48,7 @@ class PresenceStore(SQLBaseStore):
updatevalues={"state": new_state["state"],
"status_msg": new_state["status_msg"],
"mtime": self._clock.time_msec()},
retcols=["state"],
desc="set_presence_state",
)
def allow_presence_visible(self, observed_localpart, observer_userid):
@ -53,6 +56,7 @@ class PresenceStore(SQLBaseStore):
table="presence_allow_inbound",
values={"observed_user_id": observed_localpart,
"observer_user_id": observer_userid},
desc="allow_presence_visible",
)
def disallow_presence_visible(self, observed_localpart, observer_userid):
@ -60,6 +64,7 @@ class PresenceStore(SQLBaseStore):
table="presence_allow_inbound",
keyvalues={"observed_user_id": observed_localpart,
"observer_user_id": observer_userid},
desc="disallow_presence_visible",
)
def is_presence_visible(self, observed_localpart, observer_userid):
@ -69,6 +74,7 @@ class PresenceStore(SQLBaseStore):
"observer_user_id": observer_userid},
retcols=["observed_user_id"],
allow_none=True,
desc="is_presence_visible",
)
def add_presence_list_pending(self, observer_localpart, observed_userid):
@ -77,6 +83,7 @@ class PresenceStore(SQLBaseStore):
values={"user_id": observer_localpart,
"observed_user_id": observed_userid,
"accepted": False},
desc="add_presence_list_pending",
)
def set_presence_list_accepted(self, observer_localpart, observed_userid):
@ -85,6 +92,7 @@ class PresenceStore(SQLBaseStore):
keyvalues={"user_id": observer_localpart,
"observed_user_id": observed_userid},
updatevalues={"accepted": True},
desc="set_presence_list_accepted",
)
def get_presence_list(self, observer_localpart, accepted=None):
@ -96,6 +104,7 @@ class PresenceStore(SQLBaseStore):
table="presence_list",
keyvalues=keyvalues,
retcols=["observed_user_id", "accepted"],
desc="get_presence_list",
)
def del_presence_list(self, observer_localpart, observed_userid):
@ -103,4 +112,5 @@ class PresenceStore(SQLBaseStore):
table="presence_list",
keyvalues={"user_id": observer_localpart,
"observed_user_id": observed_userid},
desc="del_presence_list",
)

View file

@ -21,6 +21,7 @@ class ProfileStore(SQLBaseStore):
return self._simple_insert(
table="profiles",
values={"user_id": user_localpart},
desc="create_profile",
)
def get_profile_displayname(self, user_localpart):
@ -28,6 +29,7 @@ class ProfileStore(SQLBaseStore):
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="displayname",
desc="get_profile_displayname",
)
def set_profile_displayname(self, user_localpart, new_displayname):
@ -35,6 +37,7 @@ class ProfileStore(SQLBaseStore):
table="profiles",
keyvalues={"user_id": user_localpart},
updatevalues={"displayname": new_displayname},
desc="set_profile_displayname",
)
def get_profile_avatar_url(self, user_localpart):
@ -42,6 +45,7 @@ class ProfileStore(SQLBaseStore):
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="avatar_url",
desc="get_profile_avatar_url",
)
def set_profile_avatar_url(self, user_localpart, new_avatar_url):
@ -49,4 +53,5 @@ class ProfileStore(SQLBaseStore):
table="profiles",
keyvalues={"user_id": user_localpart},
updatevalues={"avatar_url": new_avatar_url},
desc="set_profile_avatar_url",
)

View file

@ -50,7 +50,8 @@ class PushRuleStore(SQLBaseStore):
results = yield self._simple_select_list(
PushRuleEnableTable.table_name,
{'user_name': user_name},
PushRuleEnableTable.fields
PushRuleEnableTable.fields,
desc="get_push_rules_enabled_for_user",
)
defer.returnValue(
{r['rule_id']: False if r['enabled'] == 0 else True for r in results}
@ -201,7 +202,8 @@ class PushRuleStore(SQLBaseStore):
"""
yield self._simple_delete_one(
PushRuleTable.table_name,
{'user_name': user_name, 'rule_id': rule_id}
{'user_name': user_name, 'rule_id': rule_id},
desc="delete_push_rule",
)
@defer.inlineCallbacks
@ -209,7 +211,8 @@ class PushRuleStore(SQLBaseStore):
yield self._simple_upsert(
PushRuleEnableTable.table_name,
{'user_name': user_name, 'rule_id': rule_id},
{'enabled': enabled}
{'enabled': enabled},
desc="set_push_rule_enabled",
)

View file

@ -114,7 +114,9 @@ class PusherStore(SQLBaseStore):
ts=pushkey_ts,
lang=lang,
data=data
))
),
desc="add_pusher",
)
except Exception as e:
logger.error("create_pusher with failed: %s", e)
raise StoreError(500, "Problem creating pusher.")
@ -123,7 +125,8 @@ class PusherStore(SQLBaseStore):
def delete_pusher_by_app_id_pushkey(self, app_id, pushkey):
yield self._simple_delete_one(
PushersTable.table_name,
dict(app_id=app_id, pushkey=pushkey)
{"app_id": app_id, "pushkey": pushkey},
desc="delete_pusher_by_app_id_pushkey",
)
@defer.inlineCallbacks
@ -131,7 +134,8 @@ class PusherStore(SQLBaseStore):
yield self._simple_update_one(
PushersTable.table_name,
{'app_id': app_id, 'pushkey': pushkey},
{'last_token': last_token}
{'last_token': last_token},
desc="update_pusher_last_token",
)
@defer.inlineCallbacks
@ -140,7 +144,8 @@ class PusherStore(SQLBaseStore):
yield self._simple_update_one(
PushersTable.table_name,
{'app_id': app_id, 'pushkey': pushkey},
{'last_token': last_token, 'last_success': last_success}
{'last_token': last_token, 'last_success': last_success},
desc="update_pusher_last_token_and_success",
)
@defer.inlineCallbacks
@ -148,7 +153,8 @@ class PusherStore(SQLBaseStore):
yield self._simple_update_one(
PushersTable.table_name,
{'app_id': app_id, 'pushkey': pushkey},
{'failing_since': failing_since}
{'failing_since': failing_since},
desc="update_pusher_failing_since",
)

View file

@ -19,7 +19,7 @@ from sqlite3 import IntegrityError
from synapse.api.errors import StoreError, Codes
from ._base import SQLBaseStore
from ._base import SQLBaseStore, cached
class RegistrationStore(SQLBaseStore):
@ -39,7 +39,10 @@ class RegistrationStore(SQLBaseStore):
Raises:
StoreError if there was a problem adding this.
"""
row = yield self._simple_select_one("users", {"name": user_id}, ["id"])
row = yield self._simple_select_one(
"users", {"name": user_id}, ["id"],
desc="add_access_token_to_user",
)
if not row:
raise StoreError(400, "Bad user ID supplied.")
row_id = row["id"]
@ -48,7 +51,8 @@ class RegistrationStore(SQLBaseStore):
{
"user_id": row_id,
"token": token
}
},
desc="add_access_token_to_user",
)
@defer.inlineCallbacks
@ -91,6 +95,11 @@ class RegistrationStore(SQLBaseStore):
"get_user_by_id", self.cursor_to_dict, query, user_id
)
@cached()
# TODO(paul): Currently there's no code to invalidate this cache. That
# means if/when we ever add internal ways to invalidate access tokens or
# change whether a user is a server admin, those will need to invoke
# store.get_user_by_token.invalidate(token)
def get_user_by_token(self, token):
"""Get a user from the given access token.
@ -115,6 +124,7 @@ class RegistrationStore(SQLBaseStore):
keyvalues={"name": user.to_string()},
retcol="admin",
allow_none=True,
desc="is_server_admin",
)
defer.returnValue(res if res else False)

View file

@ -29,7 +29,7 @@ class RejectionsStore(SQLBaseStore):
"event_id": event_id,
"reason": reason,
"last_check": self._clock.time_msec(),
}
},
)
def get_rejection_reason(self, event_id):
@ -40,4 +40,5 @@ class RejectionsStore(SQLBaseStore):
"event_id": event_id,
},
allow_none=True,
desc="get_rejection_reason",
)

View file

@ -15,11 +15,9 @@
from twisted.internet import defer
from sqlite3 import IntegrityError
from synapse.api.errors import StoreError
from ._base import SQLBaseStore, Table
from ._base import SQLBaseStore
import collections
import logging
@ -27,8 +25,9 @@ import logging
logger = logging.getLogger(__name__)
OpsLevel = collections.namedtuple("OpsLevel", (
"ban_level", "kick_level", "redact_level")
OpsLevel = collections.namedtuple(
"OpsLevel",
("ban_level", "kick_level", "redact_level",)
)
@ -47,13 +46,15 @@ class RoomStore(SQLBaseStore):
StoreError if the room could not be stored.
"""
try:
yield self._simple_insert(RoomsTable.table_name, dict(
room_id=room_id,
creator=room_creator_user_id,
is_public=is_public
))
except IntegrityError:
raise StoreError(409, "Room ID in use.")
yield self._simple_insert(
RoomsTable.table_name,
{
"room_id": room_id,
"creator": room_creator_user_id,
"is_public": is_public,
},
desc="store_room",
)
except Exception as e:
logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.")
@ -66,9 +67,11 @@ class RoomStore(SQLBaseStore):
Returns:
A namedtuple containing the room information, or an empty list.
"""
query = RoomsTable.select_statement("room_id=?")
return self._execute(
"get_room", RoomsTable.decode_single_result, query, room_id,
return self._simple_select_one(
table=RoomsTable.table_name,
keyvalues={"room_id": room_id},
retcols=RoomsTable.fields,
desc="get_room",
)
@defer.inlineCallbacks
@ -143,7 +146,7 @@ class RoomStore(SQLBaseStore):
"event_id": event.event_id,
"room_id": event.room_id,
"topic": event.content["topic"],
}
},
)
def _store_room_name_txn(self, txn, event):
@ -158,8 +161,45 @@ class RoomStore(SQLBaseStore):
}
)
@defer.inlineCallbacks
def get_room_name_and_aliases(self, room_id):
del_sql = (
"SELECT event_id FROM redactions WHERE redacts = e.event_id "
"LIMIT 1"
)
class RoomsTable(Table):
sql = (
"SELECT e.*, (%(redacted)s) AS redacted FROM events as e "
"INNER JOIN current_state_events as c ON e.event_id = c.event_id "
"INNER JOIN state_events as s ON e.event_id = s.event_id "
"WHERE c.room_id = ? "
) % {
"redacted": del_sql,
}
sql += " AND ((s.type = 'm.room.name' AND s.state_key = '')"
sql += " OR s.type = 'm.room.aliases')"
args = (room_id,)
results = yield self._execute_and_decode("get_current_state", sql, *args)
events = yield self._parse_events(results)
name = None
aliases = []
for e in events:
if e.type == 'm.room.name':
if 'name' in e.content:
name = e.content['name']
elif e.type == 'm.room.aliases':
if 'aliases' in e.content:
aliases.extend(e.content['aliases'])
defer.returnValue((name, aliases))
class RoomsTable(object):
table_name = "rooms"
fields = [
@ -167,5 +207,3 @@ class RoomsTable(Table):
"is_public",
"creator"
]
EntryType = collections.namedtuple("RoomEntry", fields)

View file

@ -212,7 +212,8 @@ class RoomMemberStore(SQLBaseStore):
return self._simple_select_onecol(
"room_hosts",
{"room_id": room_id},
"host"
"host",
desc="get_joined_hosts_for_room",
)
def _get_members_by_dict(self, where_dict):

View file

@ -15,6 +15,8 @@
from ._base import SQLBaseStore
from twisted.internet import defer
import logging
logger = logging.getLogger(__name__)
@ -82,7 +84,7 @@ class StateStore(SQLBaseStore):
if context.current_state is None:
return
state_events = context.current_state
state_events = dict(context.current_state)
if event.is_state():
state_events[(event.type, event.state_key)] = event
@ -122,3 +124,33 @@ class StateStore(SQLBaseStore):
},
or_replace=True,
)
@defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key=""):
del_sql = (
"SELECT event_id FROM redactions WHERE redacts = e.event_id "
"LIMIT 1"
)
sql = (
"SELECT e.*, (%(redacted)s) AS redacted FROM events as e "
"INNER JOIN current_state_events as c ON e.event_id = c.event_id "
"INNER JOIN state_events as s ON e.event_id = s.event_id "
"WHERE c.room_id = ? "
) % {
"redacted": del_sql,
}
if event_type and state_key is not None:
sql += " AND s.type = ? AND s.state_key = ? "
args = (room_id, event_type, state_key)
elif event_type:
sql += " AND s.type = ?"
args = (room_id, event_type)
else:
args = (room_id, )
results = yield self._execute_and_decode("get_current_state", sql, *args)
events = yield self._parse_events(results)
defer.returnValue(events)

View file

@ -35,7 +35,7 @@ what sort order was used:
from twisted.internet import defer
from ._base import SQLBaseStore
from ._base import SQLBaseStore, cached
from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
from synapse.util.logutils import log_function
@ -413,12 +413,32 @@ class StreamStore(SQLBaseStore):
"get_recent_events_for_room", get_recent_events_for_room_txn
)
@cached(num_args=0)
def get_room_events_max_id(self):
return self.runInteraction(
"get_room_events_max_id",
self._get_room_events_max_id_txn
)
@defer.inlineCallbacks
def _get_min_token(self):
row = yield self._execute(
"_get_min_token", None, "SELECT MIN(stream_ordering) FROM events"
)
self.min_token = row[0][0] if row and row[0] and row[0][0] else -1
self.min_token = min(self.min_token, -1)
logger.debug("min_token is: %s", self.min_token)
defer.returnValue(self.min_token)
def get_next_stream_id(self):
with self._next_stream_id_lock:
i = self._next_stream_id
self._next_stream_id += 1
return i
def _get_room_events_max_id_txn(self, txn):
txn.execute(
"SELECT MAX(stream_ordering) as m FROM events"

View file

@ -46,15 +46,19 @@ class TransactionStore(SQLBaseStore):
)
def _get_received_txn_response(self, txn, transaction_id, origin):
where_clause = "transaction_id = ? AND origin = ?"
query = ReceivedTransactionsTable.select_statement(where_clause)
result = self._simple_select_one_txn(
txn,
table=ReceivedTransactionsTable.table_name,
keyvalues={
"transaction_id": transaction_id,
"origin": origin,
},
retcols=ReceivedTransactionsTable.fields,
allow_none=True,
)
txn.execute(query, (transaction_id, origin))
results = ReceivedTransactionsTable.decode_results(txn.fetchall())
if results and results[0].response_code:
return (results[0].response_code, results[0].response_json)
if result and result.response_code:
return result["response_code"], result["response_json"]
else:
return None

View file

@ -90,12 +90,16 @@ class LruCache(object):
def cache_len():
return len(cache)
def cache_contains(key):
return key in cache
self.sentinel = object()
self.get = cache_get
self.set = cache_set
self.setdefault = cache_set_default
self.pop = cache_pop
self.len = cache_len
self.contains = cache_contains
def __getitem__(self, key):
result = self.get(key, self.sentinel)
@ -114,3 +118,6 @@ class LruCache(object):
def __len__(self):
return self.len()
def __contains__(self, key):
return self.contains(key)

View file

@ -16,6 +16,10 @@
import random
import string
_string_with_symbols = (
string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
)
def origin_from_ucid(ucid):
return ucid.split("@", 1)[1]
@ -23,3 +27,9 @@ def origin_from_ucid(ucid):
def random_string(length):
return ''.join(random.choice(string.ascii_letters) for _ in xrange(length))
def random_string_with_symbols(length):
return ''.join(
random.choice(_string_with_symbols) for _ in xrange(length)
)

View file

@ -17,7 +17,79 @@
from tests import unittest
from twisted.internet import defer
from synapse.storage._base import cached
from synapse.storage._base import Cache, cached
class CacheTestCase(unittest.TestCase):
def setUp(self):
self.cache = Cache("test")
def test_empty(self):
failed = False
try:
self.cache.get("foo")
except KeyError:
failed = True
self.assertTrue(failed)
def test_hit(self):
self.cache.prefill("foo", 123)
self.assertEquals(self.cache.get("foo"), 123)
def test_invalidate(self):
self.cache.prefill("foo", 123)
self.cache.invalidate("foo")
failed = False
try:
self.cache.get("foo")
except KeyError:
failed = True
self.assertTrue(failed)
def test_eviction(self):
cache = Cache("test", max_entries=2)
cache.prefill(1, "one")
cache.prefill(2, "two")
cache.prefill(3, "three") # 1 will be evicted
failed = False
try:
cache.get(1)
except KeyError:
failed = True
self.assertTrue(failed)
cache.get(2)
cache.get(3)
def test_eviction_lru(self):
cache = Cache("test", max_entries=2, lru=True)
cache.prefill(1, "one")
cache.prefill(2, "two")
# Now access 1 again, thus causing 2 to be least-recently used
cache.get(1)
cache.prefill(3, "three")
failed = False
try:
cache.get(2)
except KeyError:
failed = True
self.assertTrue(failed)
cache.get(1)
cache.get(3)
class CacheDecoratorTestCase(unittest.TestCase):

View file

@ -180,7 +180,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1
self.mock_txn.fetchone.return_value = ("Old Value",)
ret = yield self.datastore._simple_update_one(
ret = yield self.datastore._simple_selectupdate_one(
table="tablename",
keyvalues={"keycol": "TheKey"},
updatevalues={"columname": "New Value"},

View file

@ -44,7 +44,7 @@ class RoomStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_room(self):
self.assertObjectHasAttributes(
self.assertDictContainsSubset(
{"room_id": self.room.to_string(),
"creator": self.u_creator.to_string(),
"is_public": True},