forked from MirrorHub/synapse
Merge branch 'develop' into application-services-txn-reliability
Conflicts: synapse/storage/__init__.py
This commit is contained in:
commit
4edcbcee3b
45 changed files with 1251 additions and 731 deletions
|
@ -1,3 +1,12 @@
|
|||
Changes in synapse v0.8.1 (2015-03-18)
|
||||
======================================
|
||||
|
||||
* Disable registration by default. New users can be added using the command
|
||||
``register_new_matrix_user`` or by enabling registration in the config.
|
||||
* Add metrics to synapse. To enable metrics use config options
|
||||
``enable_metrics`` and ``metrics_port``.
|
||||
* Fix bug where banning only kicked the user.
|
||||
|
||||
Changes in synapse v0.8.0 (2015-03-06)
|
||||
======================================
|
||||
|
||||
|
|
11
README.rst
11
README.rst
|
@ -128,6 +128,17 @@ To set up your homeserver, run (in your virtualenv, as before)::
|
|||
|
||||
Substituting your host and domain name as appropriate.
|
||||
|
||||
By default, registration of new users is disabled. You can either enable
|
||||
registration in the config (it is then recommended to also set up CAPTCHA), or
|
||||
you can use the command line to register new users::
|
||||
|
||||
$ source ~/.synapse/bin/activate
|
||||
$ register_new_matrix_user -c homeserver.yaml https://localhost:8448
|
||||
New user localpart: erikj
|
||||
Password:
|
||||
Confirm password:
|
||||
Success!
|
||||
|
||||
For reliable VoIP calls to be routed via this homeserver, you MUST configure
|
||||
a TURN server. See docs/turn-howto.rst for details.
|
||||
|
||||
|
|
149
register_new_matrix_user
Executable file
149
register_new_matrix_user
Executable file
|
@ -0,0 +1,149 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import getpass
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import sys
|
||||
import urllib2
|
||||
import yaml
|
||||
|
||||
|
||||
def request_registration(user, password, server_location, shared_secret):
|
||||
mac = hmac.new(
|
||||
key=shared_secret,
|
||||
msg=user,
|
||||
digestmod=hashlib.sha1,
|
||||
).hexdigest()
|
||||
|
||||
data = {
|
||||
"user": user,
|
||||
"password": password,
|
||||
"mac": mac,
|
||||
"type": "org.matrix.login.shared_secret",
|
||||
}
|
||||
|
||||
server_location = server_location.rstrip("/")
|
||||
|
||||
print "Sending registration request..."
|
||||
|
||||
req = urllib2.Request(
|
||||
"%s/_matrix/client/api/v1/register" % (server_location,),
|
||||
data=json.dumps(data),
|
||||
headers={'Content-Type': 'application/json'}
|
||||
)
|
||||
try:
|
||||
f = urllib2.urlopen(req)
|
||||
f.read()
|
||||
f.close()
|
||||
print "Success."
|
||||
except urllib2.HTTPError as e:
|
||||
print "ERROR! Received %d %s" % (e.code, e.reason,)
|
||||
if 400 <= e.code < 500:
|
||||
if e.info().type == "application/json":
|
||||
resp = json.load(e)
|
||||
if "error" in resp:
|
||||
print resp["error"]
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def register_new_user(user, password, server_location, shared_secret):
|
||||
if not user:
|
||||
try:
|
||||
default_user = getpass.getuser()
|
||||
except:
|
||||
default_user = None
|
||||
|
||||
if default_user:
|
||||
user = raw_input("New user localpart [%s]: " % (default_user,))
|
||||
if not user:
|
||||
user = default_user
|
||||
else:
|
||||
user = raw_input("New user localpart: ")
|
||||
|
||||
if not user:
|
||||
print "Invalid user name"
|
||||
sys.exit(1)
|
||||
|
||||
if not password:
|
||||
password = getpass.getpass("Password: ")
|
||||
|
||||
if not password:
|
||||
print "Password cannot be blank."
|
||||
sys.exit(1)
|
||||
|
||||
confirm_password = getpass.getpass("Confirm password: ")
|
||||
|
||||
if password != confirm_password:
|
||||
print "Passwords do not match"
|
||||
sys.exit(1)
|
||||
|
||||
request_registration(user, password, server_location, shared_secret)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Used to register new users with a given home server when"
|
||||
" registration has been disabled. The home server must be"
|
||||
" configured with the 'registration_shared_secret' option"
|
||||
" set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-u", "--user",
|
||||
default=None,
|
||||
help="Local part of the new user. Will prompt if omitted.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p", "--password",
|
||||
default=None,
|
||||
help="New password for user. Will prompt if omitted.",
|
||||
)
|
||||
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
group.add_argument(
|
||||
"-c", "--config",
|
||||
type=argparse.FileType('r'),
|
||||
help="Path to server config file. Used to read in shared secret.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"-k", "--shared-secret",
|
||||
help="Shared secret as defined in server config file.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"server_url",
|
||||
default="https://localhost:8448",
|
||||
nargs='?',
|
||||
help="URL to use to talk to the home server. Defaults to "
|
||||
" 'https://localhost:8448'.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if "config" in args and args.config:
|
||||
config = yaml.safe_load(args.config)
|
||||
secret = config.get("registration_shared_secret", None)
|
||||
if not secret:
|
||||
print "No 'registration_shared_secret' defined in config."
|
||||
sys.exit(1)
|
||||
else:
|
||||
secret = args.shared_secret
|
||||
|
||||
register_new_user(args.user, args.password, args.server_url, secret)
|
4
setup.py
4
setup.py
|
@ -45,7 +45,7 @@ setup(
|
|||
version=version,
|
||||
packages=find_packages(exclude=["tests", "tests.*"]),
|
||||
description="Reference Synapse Home Server",
|
||||
install_requires=dependencies["REQUIREMENTS"].keys(),
|
||||
install_requires=dependencies['requirements'](include_conditional=True).keys(),
|
||||
setup_requires=[
|
||||
"Twisted==14.0.2", # Here to override setuptools_trial's dependency on Twisted>=2.4.0
|
||||
"setuptools_trial",
|
||||
|
@ -55,5 +55,5 @@ setup(
|
|||
include_package_data=True,
|
||||
zip_safe=False,
|
||||
long_description=long_description,
|
||||
scripts=["synctl"],
|
||||
scripts=["synctl", "register_new_matrix_user"],
|
||||
)
|
||||
|
|
|
@ -16,4 +16,4 @@
|
|||
""" This is a reference implementation of a Matrix home server.
|
||||
"""
|
||||
|
||||
__version__ = "0.8.0"
|
||||
__version__ = "0.8.1-r2"
|
||||
|
|
|
@ -28,6 +28,12 @@ import logging
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
AuthEventTypes = (
|
||||
EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
|
||||
EventTypes.JoinRules,
|
||||
)
|
||||
|
||||
|
||||
class Auth(object):
|
||||
|
||||
def __init__(self, hs):
|
||||
|
@ -166,6 +172,7 @@ class Auth(object):
|
|||
target = auth_events.get(key)
|
||||
|
||||
target_in_room = target and target.membership == Membership.JOIN
|
||||
target_banned = target and target.membership == Membership.BAN
|
||||
|
||||
key = (EventTypes.JoinRules, "", )
|
||||
join_rule_event = auth_events.get(key)
|
||||
|
@ -194,6 +201,7 @@ class Auth(object):
|
|||
{
|
||||
"caller_in_room": caller_in_room,
|
||||
"caller_invited": caller_invited,
|
||||
"target_banned": target_banned,
|
||||
"target_in_room": target_in_room,
|
||||
"membership": membership,
|
||||
"join_rule": join_rule,
|
||||
|
@ -202,6 +210,11 @@ class Auth(object):
|
|||
}
|
||||
)
|
||||
|
||||
if ban_level:
|
||||
ban_level = int(ban_level)
|
||||
else:
|
||||
ban_level = 50 # FIXME (erikj): What should we do here?
|
||||
|
||||
if Membership.INVITE == membership:
|
||||
# TODO (erikj): We should probably handle this more intelligently
|
||||
# PRIVATE join rules.
|
||||
|
@ -212,6 +225,10 @@ class Auth(object):
|
|||
403,
|
||||
"%s not in room %s." % (event.user_id, event.room_id,)
|
||||
)
|
||||
elif target_banned:
|
||||
raise AuthError(
|
||||
403, "%s is banned from the room" % (target_user_id,)
|
||||
)
|
||||
elif target_in_room: # the target is already in the room.
|
||||
raise AuthError(403, "%s is already in the room." %
|
||||
target_user_id)
|
||||
|
@ -221,6 +238,8 @@ class Auth(object):
|
|||
# joined: It's a NOOP
|
||||
if event.user_id != target_user_id:
|
||||
raise AuthError(403, "Cannot force another user to join.")
|
||||
elif target_banned:
|
||||
raise AuthError(403, "You are banned from this room")
|
||||
elif join_rule == JoinRules.PUBLIC:
|
||||
pass
|
||||
elif join_rule == JoinRules.INVITE:
|
||||
|
@ -238,6 +257,10 @@ class Auth(object):
|
|||
403,
|
||||
"%s not in room %s." % (target_user_id, event.room_id,)
|
||||
)
|
||||
elif target_banned and user_level < ban_level:
|
||||
raise AuthError(
|
||||
403, "You cannot unban user &s." % (target_user_id,)
|
||||
)
|
||||
elif target_user_id != event.user_id:
|
||||
if kick_level:
|
||||
kick_level = int(kick_level)
|
||||
|
@ -249,11 +272,6 @@ class Auth(object):
|
|||
403, "You cannot kick user %s." % target_user_id
|
||||
)
|
||||
elif Membership.BAN == membership:
|
||||
if ban_level:
|
||||
ban_level = int(ban_level)
|
||||
else:
|
||||
ban_level = 50 # FIXME (erikj): What should we do here?
|
||||
|
||||
if user_level < ban_level:
|
||||
raise AuthError(403, "You don't have permission to ban")
|
||||
else:
|
||||
|
@ -370,7 +388,7 @@ class Auth(object):
|
|||
AuthError if no user by that token exists or the token is invalid.
|
||||
"""
|
||||
try:
|
||||
ret = yield self.store.get_user_by_token(token=token)
|
||||
ret = yield self.store.get_user_by_token(token)
|
||||
if not ret:
|
||||
raise StoreError(400, "Unknown token")
|
||||
user_info = {
|
||||
|
@ -412,12 +430,6 @@ class Auth(object):
|
|||
|
||||
builder.auth_events = auth_events_entries
|
||||
|
||||
context.auth_events = {
|
||||
k: v
|
||||
for k, v in context.current_state.items()
|
||||
if v.event_id in auth_ids
|
||||
}
|
||||
|
||||
def compute_auth_events(self, event, current_state):
|
||||
if event.type == EventTypes.Create:
|
||||
return []
|
||||
|
|
|
@ -60,6 +60,7 @@ class LoginType(object):
|
|||
EMAIL_IDENTITY = u"m.login.email.identity"
|
||||
RECAPTCHA = u"m.login.recaptcha"
|
||||
APPLICATION_SERVICE = u"m.login.application_service"
|
||||
SHARED_SECRET = u"org.matrix.login.shared_secret"
|
||||
|
||||
|
||||
class EventTypes(object):
|
||||
|
|
|
@ -60,7 +60,6 @@ import re
|
|||
import resource
|
||||
import subprocess
|
||||
import sqlite3
|
||||
import syweb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -83,6 +82,7 @@ class SynapseHomeServer(HomeServer):
|
|||
return AppServiceRestResource(self)
|
||||
|
||||
def build_resource_for_web_client(self):
|
||||
import syweb
|
||||
syweb_path = os.path.dirname(syweb.__file__)
|
||||
webclient_path = os.path.join(syweb_path, "webclient")
|
||||
return File(webclient_path) # TODO configurable?
|
||||
|
@ -130,7 +130,7 @@ class SynapseHomeServer(HomeServer):
|
|||
True.
|
||||
"""
|
||||
config = self.get_config()
|
||||
web_client = config.webclient
|
||||
web_client = config.web_client
|
||||
|
||||
# list containing (path_str, Resource) e.g:
|
||||
# [ ("/aaa/bbb/cc", Resource1), ("/aaa/dummy", Resource2) ]
|
||||
|
@ -343,7 +343,8 @@ def setup(config_options):
|
|||
|
||||
config.setup_logging()
|
||||
|
||||
check_requirements()
|
||||
# check any extra requirements we have now we have a config
|
||||
check_requirements(config)
|
||||
|
||||
version_string = get_version_string()
|
||||
|
||||
|
@ -450,6 +451,7 @@ def run(hs):
|
|||
|
||||
def main():
|
||||
with LoggingContext("main"):
|
||||
# check base requirements
|
||||
check_requirements()
|
||||
hs = setup(sys.argv[1:])
|
||||
run(hs)
|
||||
|
|
|
@ -15,19 +15,46 @@
|
|||
|
||||
from ._base import Config
|
||||
|
||||
from synapse.util.stringutils import random_string_with_symbols
|
||||
|
||||
import distutils.util
|
||||
|
||||
|
||||
class RegistrationConfig(Config):
|
||||
|
||||
def __init__(self, args):
|
||||
super(RegistrationConfig, self).__init__(args)
|
||||
self.disable_registration = args.disable_registration
|
||||
|
||||
# `args.disable_registration` may either be a bool or a string depending
|
||||
# on if the option was given a value (e.g. --disable-registration=false
|
||||
# would set `args.disable_registration` to "false" not False.)
|
||||
self.disable_registration = bool(
|
||||
distutils.util.strtobool(str(args.disable_registration))
|
||||
)
|
||||
self.registration_shared_secret = args.registration_shared_secret
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser):
|
||||
super(RegistrationConfig, cls).add_arguments(parser)
|
||||
reg_group = parser.add_argument_group("registration")
|
||||
|
||||
reg_group.add_argument(
|
||||
"--disable-registration",
|
||||
action='store_true',
|
||||
help="Disable registration of new users."
|
||||
const=True,
|
||||
default=True,
|
||||
nargs='?',
|
||||
help="Disable registration of new users.",
|
||||
)
|
||||
reg_group.add_argument(
|
||||
"--registration-shared-secret", type=str,
|
||||
help="If set, allows registration by anyone who also has the shared"
|
||||
" secret, even if registration is otherwise disabled.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate_config(cls, args, config_dir_path):
|
||||
if args.disable_registration is None:
|
||||
args.disable_registration = True
|
||||
|
||||
if args.registration_shared_secret is None:
|
||||
args.registration_shared_secret = random_string_with_symbols(50)
|
||||
|
|
|
@ -28,7 +28,7 @@ class ServerConfig(Config):
|
|||
self.unsecure_port = args.unsecure_port
|
||||
self.daemonize = args.daemonize
|
||||
self.pid_file = self.abspath(args.pid_file)
|
||||
self.webclient = True
|
||||
self.web_client = args.web_client
|
||||
self.manhole = args.manhole
|
||||
self.soft_file_limit = args.soft_file_limit
|
||||
|
||||
|
@ -68,6 +68,8 @@ class ServerConfig(Config):
|
|||
server_group.add_argument('--pid-file', default="homeserver.pid",
|
||||
help="When running as a daemon, the file to"
|
||||
" store the pid in")
|
||||
server_group.add_argument('--web_client', default=True, type=bool,
|
||||
help="Whether or not to serve a web client")
|
||||
server_group.add_argument("--manhole", metavar="PORT", dest="manhole",
|
||||
type=int,
|
||||
help="Turn on the twisted telnet manhole"
|
||||
|
|
|
@ -16,8 +16,7 @@
|
|||
|
||||
class EventContext(object):
|
||||
|
||||
def __init__(self, current_state=None, auth_events=None):
|
||||
def __init__(self, current_state=None):
|
||||
self.current_state = current_state
|
||||
self.auth_events = auth_events
|
||||
self.state_group = None
|
||||
self.rejected = False
|
||||
|
|
|
@ -361,4 +361,5 @@ SERVLET_CLASSES = (
|
|||
FederationInviteServlet,
|
||||
FederationQueryAuthServlet,
|
||||
FederationGetMissingEventsServlet,
|
||||
FederationEventAuthServlet,
|
||||
)
|
||||
|
|
|
@ -90,8 +90,8 @@ class BaseHandler(object):
|
|||
event = builder.build()
|
||||
|
||||
logger.debug(
|
||||
"Created event %s with auth_events: %s, current state: %s",
|
||||
event.event_id, context.auth_events, context.current_state,
|
||||
"Created event %s with current state: %s",
|
||||
event.event_id, context.current_state,
|
||||
)
|
||||
|
||||
defer.returnValue(
|
||||
|
@ -106,7 +106,7 @@ class BaseHandler(object):
|
|||
# We now need to go and hit out to wherever we need to hit out to.
|
||||
|
||||
if not suppress_auth:
|
||||
self.auth.check(event, auth_events=context.auth_events)
|
||||
self.auth.check(event, auth_events=context.current_state)
|
||||
|
||||
yield self.store.persist_event(event, context=context)
|
||||
|
||||
|
@ -142,7 +142,16 @@ class BaseHandler(object):
|
|||
"Failed to get destination from event %s", s.event_id
|
||||
)
|
||||
|
||||
yield self.notifier.on_new_room_event(event, extra_users=extra_users)
|
||||
# Don't block waiting on waking up all the listeners.
|
||||
d = self.notifier.on_new_room_event(event, extra_users=extra_users)
|
||||
|
||||
def log_failure(f):
|
||||
logger.warn(
|
||||
"Failed to notify about %s: %s",
|
||||
event.event_id, f.value
|
||||
)
|
||||
|
||||
d.addErrback(log_failure)
|
||||
|
||||
yield federation_handler.handle_new_event(
|
||||
event, destinations=destinations,
|
||||
|
|
|
@ -290,6 +290,8 @@ class FederationHandler(BaseHandler):
|
|||
"""
|
||||
logger.debug("Joining %s to %s", joinee, room_id)
|
||||
|
||||
yield self.store.clean_room_for_join(room_id)
|
||||
|
||||
origin, pdu = yield self.replication_layer.make_join(
|
||||
target_hosts,
|
||||
room_id,
|
||||
|
@ -464,11 +466,9 @@ class FederationHandler(BaseHandler):
|
|||
builder=builder,
|
||||
)
|
||||
|
||||
self.auth.check(event, auth_events=context.auth_events)
|
||||
self.auth.check(event, auth_events=context.current_state)
|
||||
|
||||
pdu = event
|
||||
|
||||
defer.returnValue(pdu)
|
||||
defer.returnValue(event)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
|
@ -705,7 +705,7 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
|
||||
if not auth_events:
|
||||
auth_events = context.auth_events
|
||||
auth_events = context.current_state
|
||||
|
||||
logger.debug(
|
||||
"_handle_new_event: %s, auth_events: %s",
|
||||
|
|
|
@ -33,6 +33,10 @@ logger = logging.getLogger(__name__)
|
|||
metrics = synapse.metrics.get_metrics_for(__name__)
|
||||
|
||||
|
||||
# Don't bother bumping "last active" time if it differs by less than 60 seconds
|
||||
LAST_ACTIVE_GRANULARITY = 60*1000
|
||||
|
||||
|
||||
# TODO(paul): Maybe there's one of these I can steal from somewhere
|
||||
def partition(l, func):
|
||||
"""Partition the list by the result of func applied to each element."""
|
||||
|
@ -282,6 +286,10 @@ class PresenceHandler(BaseHandler):
|
|||
if now is None:
|
||||
now = self.clock.time_msec()
|
||||
|
||||
prev_state = self._get_or_make_usercache(user)
|
||||
if now - prev_state.state.get("last_active", 0) < LAST_ACTIVE_GRANULARITY:
|
||||
return
|
||||
|
||||
self.changed_presencelike_data(user, {"last_active": now})
|
||||
|
||||
def changed_presencelike_data(self, user, state):
|
||||
|
|
|
@ -31,6 +31,7 @@ import base64
|
|||
import bcrypt
|
||||
import json
|
||||
import logging
|
||||
import urllib
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -63,6 +64,13 @@ class RegistrationHandler(BaseHandler):
|
|||
password_hash = bcrypt.hashpw(password, bcrypt.gensalt())
|
||||
|
||||
if localpart:
|
||||
if localpart and urllib.quote(localpart) != localpart:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"User ID must only contain characters which do not"
|
||||
" require URL encoding."
|
||||
)
|
||||
|
||||
user = UserID(localpart, self.hs.hostname)
|
||||
user_id = user.to_string()
|
||||
|
||||
|
|
|
@ -46,6 +46,11 @@ outgoing_responses_counter = metrics.register_counter(
|
|||
labels=["method", "code"],
|
||||
)
|
||||
|
||||
response_timer = metrics.register_distribution(
|
||||
"response_time",
|
||||
labels=["method", "servlet"]
|
||||
)
|
||||
|
||||
|
||||
class HttpServer(object):
|
||||
""" Interface for registering callbacks on a HTTP server
|
||||
|
@ -169,6 +174,10 @@ class JsonResource(HttpServer, resource.Resource):
|
|||
code, response = yield callback(request, *args)
|
||||
|
||||
self._send_response(request, code, response)
|
||||
response_timer.inc_by(
|
||||
self.clock.time_msec() - start, request.method, servlet_classname
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
|
||||
|
|
|
@ -51,8 +51,8 @@ class RestServlet(object):
|
|||
pattern = self.PATTERN
|
||||
|
||||
for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"):
|
||||
if hasattr(self, "on_%s" % (method)):
|
||||
method_handler = getattr(self, "on_%s" % (method))
|
||||
if hasattr(self, "on_%s" % (method,)):
|
||||
method_handler = getattr(self, "on_%s" % (method,))
|
||||
http_server.register_path(method, pattern, method_handler)
|
||||
else:
|
||||
raise NotImplementedError("RestServlet must register something.")
|
||||
|
|
|
@ -5,7 +5,6 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
REQUIREMENTS = {
|
||||
"syutil>=0.0.3": ["syutil"],
|
||||
"matrix_angular_sdk>=0.6.5": ["syweb>=0.6.5"],
|
||||
"Twisted==14.0.2": ["twisted==14.0.2"],
|
||||
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
|
||||
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
|
||||
|
@ -18,6 +17,19 @@ REQUIREMENTS = {
|
|||
"pillow": ["PIL"],
|
||||
"pydenticon": ["pydenticon"],
|
||||
}
|
||||
CONDITIONAL_REQUIREMENTS = {
|
||||
"web_client": {
|
||||
"matrix_angular_sdk>=0.6.5": ["syweb>=0.6.5"],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def requirements(config=None, include_conditional=False):
|
||||
reqs = REQUIREMENTS.copy()
|
||||
for key, req in CONDITIONAL_REQUIREMENTS.items():
|
||||
if (config and getattr(config, key)) or include_conditional:
|
||||
reqs.update(req)
|
||||
return reqs
|
||||
|
||||
|
||||
def github_link(project, version, egg):
|
||||
|
@ -46,10 +58,11 @@ class MissingRequirementError(Exception):
|
|||
pass
|
||||
|
||||
|
||||
def check_requirements():
|
||||
def check_requirements(config=None):
|
||||
"""Checks that all the modules needed by synapse have been correctly
|
||||
installed and are at the correct version"""
|
||||
for dependency, module_requirements in REQUIREMENTS.items():
|
||||
for dependency, module_requirements in (
|
||||
requirements(config, include_conditional=False).items()):
|
||||
for module_requirement in module_requirements:
|
||||
if ">=" in module_requirement:
|
||||
module_name, required_version = module_requirement.split(">=")
|
||||
|
@ -110,7 +123,7 @@ def list_requirements():
|
|||
egg = link.split("#egg=")[1]
|
||||
linked.append(egg.split('-')[0])
|
||||
result.append(link)
|
||||
for requirement in REQUIREMENTS:
|
||||
for requirement in requirements(include_conditional=True):
|
||||
is_linked = False
|
||||
for link in linked:
|
||||
if requirement.replace('-', '_').startswith(link):
|
||||
|
|
|
@ -27,7 +27,6 @@ from hashlib import sha1
|
|||
import hmac
|
||||
import simplejson as json
|
||||
import logging
|
||||
import urllib
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -110,14 +109,22 @@ class RegisterRestServlet(ClientV1RestServlet):
|
|||
login_type = register_json["type"]
|
||||
|
||||
is_application_server = login_type == LoginType.APPLICATION_SERVICE
|
||||
if self.disable_registration and not is_application_server:
|
||||
is_using_shared_secret = login_type == LoginType.SHARED_SECRET
|
||||
|
||||
can_register = (
|
||||
not self.disable_registration
|
||||
or is_application_server
|
||||
or is_using_shared_secret
|
||||
)
|
||||
if not can_register:
|
||||
raise SynapseError(403, "Registration has been disabled")
|
||||
|
||||
stages = {
|
||||
LoginType.RECAPTCHA: self._do_recaptcha,
|
||||
LoginType.PASSWORD: self._do_password,
|
||||
LoginType.EMAIL_IDENTITY: self._do_email_identity,
|
||||
LoginType.APPLICATION_SERVICE: self._do_app_service
|
||||
LoginType.APPLICATION_SERVICE: self._do_app_service,
|
||||
LoginType.SHARED_SECRET: self._do_shared_secret,
|
||||
}
|
||||
|
||||
session_info = self._get_session_info(request, session)
|
||||
|
@ -255,14 +262,11 @@ class RegisterRestServlet(ClientV1RestServlet):
|
|||
)
|
||||
|
||||
password = register_json["password"].encode("utf-8")
|
||||
desired_user_id = (register_json["user"].encode("utf-8")
|
||||
if "user" in register_json else None)
|
||||
if (desired_user_id
|
||||
and urllib.quote(desired_user_id) != desired_user_id):
|
||||
raise SynapseError(
|
||||
400,
|
||||
"User ID must only contain characters which do not " +
|
||||
"require URL encoding.")
|
||||
desired_user_id = (
|
||||
register_json["user"].encode("utf-8")
|
||||
if "user" in register_json else None
|
||||
)
|
||||
|
||||
handler = self.handlers.registration_handler
|
||||
(user_id, token) = yield handler.register(
|
||||
localpart=desired_user_id,
|
||||
|
@ -304,6 +308,51 @@ class RegisterRestServlet(ClientV1RestServlet):
|
|||
"home_server": self.hs.hostname,
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_shared_secret(self, request, register_json, session):
|
||||
yield run_on_reactor()
|
||||
|
||||
if not isinstance(register_json.get("mac", None), basestring):
|
||||
raise SynapseError(400, "Expected mac.")
|
||||
if not isinstance(register_json.get("user", None), basestring):
|
||||
raise SynapseError(400, "Expected 'user' key.")
|
||||
if not isinstance(register_json.get("password", None), basestring):
|
||||
raise SynapseError(400, "Expected 'password' key.")
|
||||
|
||||
if not self.hs.config.registration_shared_secret:
|
||||
raise SynapseError(400, "Shared secret registration is not enabled")
|
||||
|
||||
user = register_json["user"].encode("utf-8")
|
||||
|
||||
# str() because otherwise hmac complains that 'unicode' does not
|
||||
# have the buffer interface
|
||||
got_mac = str(register_json["mac"])
|
||||
|
||||
want_mac = hmac.new(
|
||||
key=self.hs.config.registration_shared_secret,
|
||||
msg=user,
|
||||
digestmod=sha1,
|
||||
).hexdigest()
|
||||
|
||||
password = register_json["password"].encode("utf-8")
|
||||
|
||||
if compare_digest(want_mac, got_mac):
|
||||
handler = self.handlers.registration_handler
|
||||
user_id, token = yield handler.register(
|
||||
localpart=user,
|
||||
password=password,
|
||||
)
|
||||
self._remove_session(session)
|
||||
defer.returnValue({
|
||||
"user_id": user_id,
|
||||
"access_token": token,
|
||||
"home_server": self.hs.hostname,
|
||||
})
|
||||
else:
|
||||
raise SynapseError(
|
||||
403, "HMAC incorrect",
|
||||
)
|
||||
|
||||
|
||||
def _parse_json(request):
|
||||
try:
|
||||
|
|
|
@ -21,6 +21,7 @@ from synapse.util.async import run_on_reactor
|
|||
from synapse.util.expiringcache import ExpiringCache
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import AuthError
|
||||
from synapse.api.auth import AuthEventTypes
|
||||
from synapse.events.snapshot import EventContext
|
||||
|
||||
from collections import namedtuple
|
||||
|
@ -38,12 +39,6 @@ def _get_state_key_from_event(event):
|
|||
KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
|
||||
|
||||
|
||||
AuthEventTypes = (
|
||||
EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
|
||||
EventTypes.JoinRules,
|
||||
)
|
||||
|
||||
|
||||
SIZE_OF_CACHE = 1000
|
||||
EVICTION_TIMEOUT_SECONDS = 20
|
||||
|
||||
|
@ -139,18 +134,6 @@ class StateHandler(object):
|
|||
}
|
||||
context.state_group = None
|
||||
|
||||
if hasattr(event, "auth_events") and event.auth_events:
|
||||
auth_ids = self.hs.get_auth().compute_auth_events(
|
||||
event, context.current_state
|
||||
)
|
||||
context.auth_events = {
|
||||
k: v
|
||||
for k, v in context.current_state.items()
|
||||
if v.event_id in auth_ids
|
||||
}
|
||||
else:
|
||||
context.auth_events = {}
|
||||
|
||||
if event.is_state():
|
||||
key = (event.type, event.state_key)
|
||||
if key in context.current_state:
|
||||
|
@ -187,18 +170,6 @@ class StateHandler(object):
|
|||
replaces = context.current_state[key]
|
||||
event.unsigned["replaces_state"] = replaces.event_id
|
||||
|
||||
if hasattr(event, "auth_events") and event.auth_events:
|
||||
auth_ids = self.hs.get_auth().compute_auth_events(
|
||||
event, context.current_state
|
||||
)
|
||||
context.auth_events = {
|
||||
k: v
|
||||
for k, v in context.current_state.items()
|
||||
if v.event_id in auth_ids
|
||||
}
|
||||
else:
|
||||
context.auth_events = {}
|
||||
|
||||
context.prev_state_events = prev_state
|
||||
defer.returnValue(context)
|
||||
|
||||
|
|
|
@ -14,15 +14,12 @@
|
|||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.api.constants import EventTypes
|
||||
|
||||
from .appservice import (
|
||||
ApplicationServiceStore, ApplicationServiceTransactionStore
|
||||
)
|
||||
from ._base import Cache
|
||||
from .directory import DirectoryStore
|
||||
from .feedback import FeedbackStore
|
||||
from .events import EventsStore
|
||||
from .presence import PresenceStore
|
||||
from .profile import ProfileStore
|
||||
from .registration import RegistrationStore
|
||||
|
@ -41,11 +38,6 @@ from .state import StateStore
|
|||
from .signatures import SignatureStore
|
||||
from .filtering import FilteringStore
|
||||
|
||||
from syutil.base64util import decode_base64
|
||||
from syutil.jsonutil import encode_canonical_json
|
||||
|
||||
from synapse.crypto.event_signing import compute_event_reference_hash
|
||||
|
||||
|
||||
import fnmatch
|
||||
import imp
|
||||
|
@ -63,16 +55,14 @@ SCHEMA_VERSION = 15
|
|||
|
||||
dir_path = os.path.abspath(os.path.dirname(__file__))
|
||||
|
||||
|
||||
class _RollbackButIsFineException(Exception):
|
||||
""" This exception is used to rollback a transaction without implying
|
||||
something went wrong.
|
||||
"""
|
||||
pass
|
||||
# Number of msec of granularity to store the user IP 'last seen' time. Smaller
|
||||
# times give more inserts into the database even for readonly API hits
|
||||
# 120 seconds == 2 minutes
|
||||
LAST_SEEN_GRANULARITY = 120*1000
|
||||
|
||||
|
||||
class DataStore(RoomMemberStore, RoomStore,
|
||||
RegistrationStore, StreamStore, ProfileStore, FeedbackStore,
|
||||
RegistrationStore, StreamStore, ProfileStore,
|
||||
PresenceStore, TransactionStore,
|
||||
DirectoryStore, KeyStore, StateStore, SignatureStore,
|
||||
ApplicationServiceStore,
|
||||
|
@ -83,6 +73,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||
PusherStore,
|
||||
PushRuleStore,
|
||||
ApplicationServiceTransactionStore,
|
||||
EventsStore,
|
||||
):
|
||||
|
||||
def __init__(self, hs):
|
||||
|
@ -92,424 +83,28 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||
self.min_token_deferred = self._get_min_token()
|
||||
self.min_token = None
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def persist_event(self, event, context, backfilled=False,
|
||||
is_new_state=True, current_state=None):
|
||||
stream_ordering = None
|
||||
if backfilled:
|
||||
if not self.min_token_deferred.called:
|
||||
yield self.min_token_deferred
|
||||
self.min_token -= 1
|
||||
stream_ordering = self.min_token
|
||||
|
||||
try:
|
||||
yield self.runInteraction(
|
||||
"persist_event",
|
||||
self._persist_event_txn,
|
||||
event=event,
|
||||
context=context,
|
||||
backfilled=backfilled,
|
||||
stream_ordering=stream_ordering,
|
||||
is_new_state=is_new_state,
|
||||
current_state=current_state,
|
||||
)
|
||||
except _RollbackButIsFineException:
|
||||
pass
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_event(self, event_id, check_redacted=True,
|
||||
get_prev_content=False, allow_rejected=False,
|
||||
allow_none=False):
|
||||
"""Get an event from the database by event_id.
|
||||
|
||||
Args:
|
||||
event_id (str): The event_id of the event to fetch
|
||||
check_redacted (bool): If True, check if event has been redacted
|
||||
and redact it.
|
||||
get_prev_content (bool): If True and event is a state event,
|
||||
include the previous states content in the unsigned field.
|
||||
allow_rejected (bool): If True return rejected events.
|
||||
allow_none (bool): If True, return None if no event found, if
|
||||
False throw an exception.
|
||||
|
||||
Returns:
|
||||
Deferred : A FrozenEvent.
|
||||
"""
|
||||
event = yield self.runInteraction(
|
||||
"get_event", self._get_event_txn,
|
||||
event_id,
|
||||
check_redacted=check_redacted,
|
||||
get_prev_content=get_prev_content,
|
||||
allow_rejected=allow_rejected,
|
||||
)
|
||||
|
||||
if not event and not allow_none:
|
||||
raise RuntimeError("Could not find event %s" % (event_id,))
|
||||
|
||||
defer.returnValue(event)
|
||||
|
||||
@log_function
|
||||
def _persist_event_txn(self, txn, event, context, backfilled,
|
||||
stream_ordering=None, is_new_state=True,
|
||||
current_state=None):
|
||||
|
||||
# Remove the any existing cache entries for the event_id
|
||||
self._get_event_cache.pop(event.event_id)
|
||||
|
||||
# We purposefully do this first since if we include a `current_state`
|
||||
# key, we *want* to update the `current_state_events` table
|
||||
if current_state:
|
||||
txn.execute(
|
||||
"DELETE FROM current_state_events WHERE room_id = ?",
|
||||
(event.room_id,)
|
||||
)
|
||||
|
||||
for s in current_state:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"current_state_events",
|
||||
{
|
||||
"event_id": s.event_id,
|
||||
"room_id": s.room_id,
|
||||
"type": s.type,
|
||||
"state_key": s.state_key,
|
||||
},
|
||||
or_replace=True,
|
||||
)
|
||||
|
||||
if event.is_state() and is_new_state:
|
||||
if not backfilled and not context.rejected:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="state_forward_extremities",
|
||||
values={
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"type": event.type,
|
||||
"state_key": event.state_key,
|
||||
},
|
||||
or_replace=True,
|
||||
)
|
||||
|
||||
for prev_state_id, _ in event.prev_state:
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="state_forward_extremities",
|
||||
keyvalues={
|
||||
"event_id": prev_state_id,
|
||||
}
|
||||
)
|
||||
|
||||
outlier = event.internal_metadata.is_outlier()
|
||||
|
||||
if not outlier:
|
||||
self._store_state_groups_txn(txn, event, context)
|
||||
|
||||
self._update_min_depth_for_room_txn(
|
||||
txn,
|
||||
event.room_id,
|
||||
event.depth
|
||||
)
|
||||
|
||||
self._handle_prev_events(
|
||||
txn,
|
||||
outlier=outlier,
|
||||
event_id=event.event_id,
|
||||
prev_events=event.prev_events,
|
||||
room_id=event.room_id,
|
||||
)
|
||||
|
||||
have_persisted = self._simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="event_json",
|
||||
keyvalues={"event_id": event.event_id},
|
||||
retcol="event_id",
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
metadata_json = encode_canonical_json(
|
||||
event.internal_metadata.get_dict()
|
||||
)
|
||||
|
||||
# If we have already persisted this event, we don't need to do any
|
||||
# more processing.
|
||||
# The processing above must be done on every call to persist event,
|
||||
# since they might not have happened on previous calls. For example,
|
||||
# if we are persisting an event that we had persisted as an outlier,
|
||||
# but is no longer one.
|
||||
if have_persisted:
|
||||
if not outlier:
|
||||
sql = (
|
||||
"UPDATE event_json SET internal_metadata = ?"
|
||||
" WHERE event_id = ?"
|
||||
)
|
||||
txn.execute(
|
||||
sql,
|
||||
(metadata_json.decode("UTF-8"), event.event_id,)
|
||||
)
|
||||
|
||||
sql = (
|
||||
"UPDATE events SET outlier = 0"
|
||||
" WHERE event_id = ?"
|
||||
)
|
||||
txn.execute(
|
||||
sql,
|
||||
(event.event_id,)
|
||||
)
|
||||
return
|
||||
|
||||
if event.type == EventTypes.Member:
|
||||
self._store_room_member_txn(txn, event)
|
||||
elif event.type == EventTypes.Feedback:
|
||||
self._store_feedback_txn(txn, event)
|
||||
elif event.type == EventTypes.Name:
|
||||
self._store_room_name_txn(txn, event)
|
||||
elif event.type == EventTypes.Topic:
|
||||
self._store_room_topic_txn(txn, event)
|
||||
elif event.type == EventTypes.Redaction:
|
||||
self._store_redaction(txn, event)
|
||||
|
||||
event_dict = {
|
||||
k: v
|
||||
for k, v in event.get_dict().items()
|
||||
if k not in [
|
||||
"redacted",
|
||||
"redacted_because",
|
||||
]
|
||||
}
|
||||
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="event_json",
|
||||
values={
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"internal_metadata": metadata_json.decode("UTF-8"),
|
||||
"json": encode_canonical_json(event_dict).decode("UTF-8"),
|
||||
},
|
||||
or_replace=True,
|
||||
)
|
||||
|
||||
content = encode_canonical_json(
|
||||
event.content
|
||||
).decode("UTF-8")
|
||||
|
||||
vals = {
|
||||
"topological_ordering": event.depth,
|
||||
"event_id": event.event_id,
|
||||
"type": event.type,
|
||||
"room_id": event.room_id,
|
||||
"content": content,
|
||||
"processed": True,
|
||||
"outlier": outlier,
|
||||
"depth": event.depth,
|
||||
}
|
||||
|
||||
if stream_ordering is not None:
|
||||
vals["stream_ordering"] = stream_ordering
|
||||
|
||||
unrec = {
|
||||
k: v
|
||||
for k, v in event.get_dict().items()
|
||||
if k not in vals.keys() and k not in [
|
||||
"redacted",
|
||||
"redacted_because",
|
||||
"signatures",
|
||||
"hashes",
|
||||
"prev_events",
|
||||
]
|
||||
}
|
||||
|
||||
vals["unrecognized_keys"] = encode_canonical_json(
|
||||
unrec
|
||||
).decode("UTF-8")
|
||||
|
||||
try:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"events",
|
||||
vals,
|
||||
or_replace=(not outlier),
|
||||
or_ignore=bool(outlier),
|
||||
)
|
||||
except:
|
||||
logger.warn(
|
||||
"Failed to persist, probably duplicate: %s",
|
||||
event.event_id,
|
||||
exc_info=True,
|
||||
)
|
||||
raise _RollbackButIsFineException("_persist_event")
|
||||
|
||||
if context.rejected:
|
||||
self._store_rejections_txn(txn, event.event_id, context.rejected)
|
||||
|
||||
if event.is_state():
|
||||
vals = {
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"type": event.type,
|
||||
"state_key": event.state_key,
|
||||
}
|
||||
|
||||
# TODO: How does this work with backfilling?
|
||||
if hasattr(event, "replaces_state"):
|
||||
vals["prev_state"] = event.replaces_state
|
||||
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"state_events",
|
||||
vals,
|
||||
or_replace=True,
|
||||
)
|
||||
|
||||
if is_new_state and not context.rejected:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"current_state_events",
|
||||
{
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"type": event.type,
|
||||
"state_key": event.state_key,
|
||||
},
|
||||
or_replace=True,
|
||||
)
|
||||
|
||||
for e_id, h in event.prev_state:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="event_edges",
|
||||
values={
|
||||
"event_id": event.event_id,
|
||||
"prev_event_id": e_id,
|
||||
"room_id": event.room_id,
|
||||
"is_state": 1,
|
||||
},
|
||||
or_ignore=True,
|
||||
)
|
||||
|
||||
for hash_alg, hash_base64 in event.hashes.items():
|
||||
hash_bytes = decode_base64(hash_base64)
|
||||
self._store_event_content_hash_txn(
|
||||
txn, event.event_id, hash_alg, hash_bytes,
|
||||
)
|
||||
|
||||
for prev_event_id, prev_hashes in event.prev_events:
|
||||
for alg, hash_base64 in prev_hashes.items():
|
||||
hash_bytes = decode_base64(hash_base64)
|
||||
self._store_prev_event_hash_txn(
|
||||
txn, event.event_id, prev_event_id, alg, hash_bytes
|
||||
)
|
||||
|
||||
for auth_id, _ in event.auth_events:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="event_auth",
|
||||
values={
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"auth_id": auth_id,
|
||||
},
|
||||
or_ignore=True,
|
||||
)
|
||||
|
||||
(ref_alg, ref_hash_bytes) = compute_event_reference_hash(event)
|
||||
self._store_event_reference_hash_txn(
|
||||
txn, event.event_id, ref_alg, ref_hash_bytes
|
||||
)
|
||||
|
||||
def _store_redaction(self, txn, event):
|
||||
# invalidate the cache for the redacted event
|
||||
self._get_event_cache.pop(event.redacts)
|
||||
txn.execute(
|
||||
"INSERT OR IGNORE INTO redactions "
|
||||
"(event_id, redacts) VALUES (?,?)",
|
||||
(event.event_id, event.redacts)
|
||||
self.client_ip_last_seen = Cache(
|
||||
name="client_ip_last_seen",
|
||||
keylen=4,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_current_state(self, room_id, event_type=None, state_key=""):
|
||||
del_sql = (
|
||||
"SELECT event_id FROM redactions WHERE redacts = e.event_id "
|
||||
"LIMIT 1"
|
||||
)
|
||||
|
||||
sql = (
|
||||
"SELECT e.*, (%(redacted)s) AS redacted FROM events as e "
|
||||
"INNER JOIN current_state_events as c ON e.event_id = c.event_id "
|
||||
"INNER JOIN state_events as s ON e.event_id = s.event_id "
|
||||
"WHERE c.room_id = ? "
|
||||
) % {
|
||||
"redacted": del_sql,
|
||||
}
|
||||
|
||||
if event_type and state_key is not None:
|
||||
sql += " AND s.type = ? AND s.state_key = ? "
|
||||
args = (room_id, event_type, state_key)
|
||||
elif event_type:
|
||||
sql += " AND s.type = ?"
|
||||
args = (room_id, event_type)
|
||||
else:
|
||||
args = (room_id, )
|
||||
|
||||
results = yield self._execute_and_decode("get_current_state", sql, *args)
|
||||
|
||||
events = yield self._parse_events(results)
|
||||
defer.returnValue(events)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_name_and_aliases(self, room_id):
|
||||
del_sql = (
|
||||
"SELECT event_id FROM redactions WHERE redacts = e.event_id "
|
||||
"LIMIT 1"
|
||||
)
|
||||
|
||||
sql = (
|
||||
"SELECT e.*, (%(redacted)s) AS redacted FROM events as e "
|
||||
"INNER JOIN current_state_events as c ON e.event_id = c.event_id "
|
||||
"INNER JOIN state_events as s ON e.event_id = s.event_id "
|
||||
"WHERE c.room_id = ? "
|
||||
) % {
|
||||
"redacted": del_sql,
|
||||
}
|
||||
|
||||
sql += " AND ((s.type = 'm.room.name' AND s.state_key = '')"
|
||||
sql += " OR s.type = 'm.room.aliases')"
|
||||
args = (room_id,)
|
||||
|
||||
results = yield self._execute_and_decode("get_current_state", sql, *args)
|
||||
|
||||
events = yield self._parse_events(results)
|
||||
|
||||
name = None
|
||||
aliases = []
|
||||
|
||||
for e in events:
|
||||
if e.type == 'm.room.name':
|
||||
if 'name' in e.content:
|
||||
name = e.content['name']
|
||||
elif e.type == 'm.room.aliases':
|
||||
if 'aliases' in e.content:
|
||||
aliases.extend(e.content['aliases'])
|
||||
|
||||
defer.returnValue((name, aliases))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_min_token(self):
|
||||
row = yield self._execute(
|
||||
"_get_min_token", None, "SELECT MIN(stream_ordering) FROM events"
|
||||
)
|
||||
|
||||
self.min_token = row[0][0] if row and row[0] and row[0][0] else -1
|
||||
self.min_token = min(self.min_token, -1)
|
||||
|
||||
logger.debug("min_token is: %s", self.min_token)
|
||||
|
||||
defer.returnValue(self.min_token)
|
||||
|
||||
def insert_client_ip(self, user, access_token, device_id, ip, user_agent):
|
||||
return self._simple_insert(
|
||||
now = int(self._clock.time_msec())
|
||||
key = (user.to_string(), access_token, device_id, ip)
|
||||
|
||||
try:
|
||||
last_seen = self.client_ip_last_seen.get(*key)
|
||||
except KeyError:
|
||||
last_seen = None
|
||||
|
||||
# Rate-limited inserts
|
||||
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
|
||||
defer.returnValue(None)
|
||||
|
||||
self.client_ip_last_seen.prefill(*key + (now,))
|
||||
|
||||
yield self._simple_insert(
|
||||
"user_ips",
|
||||
{
|
||||
"user": user.to_string(),
|
||||
|
@ -517,8 +112,9 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||
"device_id": device_id,
|
||||
"ip": ip,
|
||||
"user_agent": user_agent,
|
||||
"last_seen": int(self._clock.time_msec()),
|
||||
}
|
||||
"last_seen": now,
|
||||
},
|
||||
desc="insert_client_ip",
|
||||
)
|
||||
|
||||
def get_user_ip_and_agents(self, user):
|
||||
|
@ -528,38 +124,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||
retcols=[
|
||||
"device_id", "access_token", "ip", "user_agent", "last_seen"
|
||||
],
|
||||
)
|
||||
|
||||
def have_events(self, event_ids):
|
||||
"""Given a list of event ids, check if we have already processed them.
|
||||
|
||||
Returns:
|
||||
dict: Has an entry for each event id we already have seen. Maps to
|
||||
the rejected reason string if we rejected the event, else maps to
|
||||
None.
|
||||
"""
|
||||
if not event_ids:
|
||||
return defer.succeed({})
|
||||
|
||||
def f(txn):
|
||||
sql = (
|
||||
"SELECT e.event_id, reason FROM events as e "
|
||||
"LEFT JOIN rejections as r ON e.event_id = r.event_id "
|
||||
"WHERE e.event_id = ?"
|
||||
)
|
||||
|
||||
res = {}
|
||||
for event_id in event_ids:
|
||||
txn.execute(sql, (event_id,))
|
||||
row = txn.fetchone()
|
||||
if row:
|
||||
_, rejected = row
|
||||
res[event_id] = rejected
|
||||
|
||||
return res
|
||||
|
||||
return self.runInteraction(
|
||||
"have_events", f,
|
||||
desc="get_user_ip_and_agents",
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@ import synapse.metrics
|
|||
from twisted.internet import defer
|
||||
|
||||
from collections import namedtuple, OrderedDict
|
||||
import functools
|
||||
import simplejson as json
|
||||
import sys
|
||||
import time
|
||||
|
@ -38,6 +39,8 @@ transaction_logger = logging.getLogger("synapse.storage.txn")
|
|||
|
||||
metrics = synapse.metrics.get_metrics_for("synapse.storage")
|
||||
|
||||
sql_scheduling_timer = metrics.register_distribution("schedule_time")
|
||||
|
||||
sql_query_timer = metrics.register_distribution("query_time", labels=["verb"])
|
||||
sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"])
|
||||
sql_getevents_timer = metrics.register_distribution("getEvents_time", labels=["desc"])
|
||||
|
@ -50,14 +53,57 @@ cache_counter = metrics.register_cache(
|
|||
)
|
||||
|
||||
|
||||
# TODO(paul):
|
||||
# * more generic key management
|
||||
# * consider other eviction strategies - LRU?
|
||||
def cached(max_entries=1000):
|
||||
class Cache(object):
|
||||
|
||||
def __init__(self, name, max_entries=1000, keylen=1, lru=False):
|
||||
if lru:
|
||||
self.cache = LruCache(max_size=max_entries)
|
||||
self.max_entries = None
|
||||
else:
|
||||
self.cache = OrderedDict()
|
||||
self.max_entries = max_entries
|
||||
|
||||
self.name = name
|
||||
self.keylen = keylen
|
||||
|
||||
caches_by_name[name] = self.cache
|
||||
|
||||
def get(self, *keyargs):
|
||||
if len(keyargs) != self.keylen:
|
||||
raise ValueError("Expected a key to have %d items", self.keylen)
|
||||
|
||||
if keyargs in self.cache:
|
||||
cache_counter.inc_hits(self.name)
|
||||
return self.cache[keyargs]
|
||||
|
||||
cache_counter.inc_misses(self.name)
|
||||
raise KeyError()
|
||||
|
||||
def prefill(self, *args): # because I can't *keyargs, value
|
||||
keyargs = args[:-1]
|
||||
value = args[-1]
|
||||
|
||||
if len(keyargs) != self.keylen:
|
||||
raise ValueError("Expected a key to have %d items", self.keylen)
|
||||
|
||||
if self.max_entries is not None:
|
||||
while len(self.cache) >= self.max_entries:
|
||||
self.cache.popitem(last=False)
|
||||
|
||||
self.cache[keyargs] = value
|
||||
|
||||
def invalidate(self, *keyargs):
|
||||
if len(keyargs) != self.keylen:
|
||||
raise ValueError("Expected a key to have %d items", self.keylen)
|
||||
|
||||
self.cache.pop(keyargs, None)
|
||||
|
||||
|
||||
def cached(max_entries=1000, num_args=1, lru=False):
|
||||
""" A method decorator that applies a memoizing cache around the function.
|
||||
|
||||
The function is presumed to take one additional argument, which is used as
|
||||
the key for the cache. Cache hits are served directly from the cache;
|
||||
The function is presumed to take zero or more arguments, which are used in
|
||||
a tuple as the key for the cache. Hits are served directly from the cache;
|
||||
misses use the function body to generate the value.
|
||||
|
||||
The wrapped function has an additional member, a callable called
|
||||
|
@ -68,33 +114,27 @@ def cached(max_entries=1000):
|
|||
calling the calculation function.
|
||||
"""
|
||||
def wrap(orig):
|
||||
cache = OrderedDict()
|
||||
name = orig.__name__
|
||||
|
||||
caches_by_name[name] = cache
|
||||
|
||||
def prefill(key, value):
|
||||
while len(cache) > max_entries:
|
||||
cache.popitem(last=False)
|
||||
|
||||
cache[key] = value
|
||||
cache = Cache(
|
||||
name=orig.__name__,
|
||||
max_entries=max_entries,
|
||||
keylen=num_args,
|
||||
lru=lru,
|
||||
)
|
||||
|
||||
@functools.wraps(orig)
|
||||
@defer.inlineCallbacks
|
||||
def wrapped(self, key):
|
||||
if key in cache:
|
||||
cache_counter.inc_hits(name)
|
||||
defer.returnValue(cache[key])
|
||||
def wrapped(self, *keyargs):
|
||||
try:
|
||||
defer.returnValue(cache.get(*keyargs))
|
||||
except KeyError:
|
||||
ret = yield orig(self, *keyargs)
|
||||
|
||||
cache_counter.inc_misses(name)
|
||||
ret = yield orig(self, key)
|
||||
prefill(key, ret)
|
||||
defer.returnValue(ret)
|
||||
cache.prefill(*keyargs + (ret,))
|
||||
|
||||
def invalidate(key):
|
||||
cache.pop(key, None)
|
||||
defer.returnValue(ret)
|
||||
|
||||
wrapped.invalidate = invalidate
|
||||
wrapped.prefill = prefill
|
||||
wrapped.invalidate = cache.invalidate
|
||||
wrapped.prefill = cache.prefill
|
||||
return wrapped
|
||||
|
||||
return wrap
|
||||
|
@ -240,6 +280,8 @@ class SQLBaseStore(object):
|
|||
"""Wraps the .runInteraction() method on the underlying db_pool."""
|
||||
current_context = LoggingContext.current_context()
|
||||
|
||||
start_time = time.time() * 1000
|
||||
|
||||
def inner_func(txn, *args, **kwargs):
|
||||
with LoggingContext("runInteraction") as context:
|
||||
current_context.copy_to(context)
|
||||
|
@ -252,6 +294,7 @@ class SQLBaseStore(object):
|
|||
|
||||
name = "%s-%x" % (desc, txn_id, )
|
||||
|
||||
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
|
||||
transaction_logger.debug("[TXN START] {%s}", name)
|
||||
try:
|
||||
return func(LoggingTransaction(txn, name), *args, **kwargs)
|
||||
|
@ -314,7 +357,8 @@ class SQLBaseStore(object):
|
|||
# "Simple" SQL API methods that operate on a single table with no JOINs,
|
||||
# no complex WHERE clauses, just a dict of values for columns.
|
||||
|
||||
def _simple_insert(self, table, values, or_replace=False, or_ignore=False):
|
||||
def _simple_insert(self, table, values, or_replace=False, or_ignore=False,
|
||||
desc="_simple_insert"):
|
||||
"""Executes an INSERT query on the named table.
|
||||
|
||||
Args:
|
||||
|
@ -323,7 +367,7 @@ class SQLBaseStore(object):
|
|||
or_replace : bool; if True performs an INSERT OR REPLACE
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"_simple_insert",
|
||||
desc,
|
||||
self._simple_insert_txn, table, values, or_replace=or_replace,
|
||||
or_ignore=or_ignore,
|
||||
)
|
||||
|
@ -347,7 +391,7 @@ class SQLBaseStore(object):
|
|||
txn.execute(sql, values.values())
|
||||
return txn.lastrowid
|
||||
|
||||
def _simple_upsert(self, table, keyvalues, values):
|
||||
def _simple_upsert(self, table, keyvalues, values, desc="_simple_upsert"):
|
||||
"""
|
||||
Args:
|
||||
table (str): The table to upsert into
|
||||
|
@ -356,7 +400,7 @@ class SQLBaseStore(object):
|
|||
Returns: A deferred
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"_simple_upsert",
|
||||
desc,
|
||||
self._simple_upsert_txn, table, keyvalues, values
|
||||
)
|
||||
|
||||
|
@ -392,7 +436,7 @@ class SQLBaseStore(object):
|
|||
txn.execute(sql, allvalues.values())
|
||||
|
||||
def _simple_select_one(self, table, keyvalues, retcols,
|
||||
allow_none=False):
|
||||
allow_none=False, desc="_simple_select_one"):
|
||||
"""Executes a SELECT query on the named table, which is expected to
|
||||
return a single row, returning a single column from it.
|
||||
|
||||
|
@ -404,12 +448,15 @@ class SQLBaseStore(object):
|
|||
allow_none : If true, return None instead of failing if the SELECT
|
||||
statement returns no rows
|
||||
"""
|
||||
return self._simple_selectupdate_one(
|
||||
table, keyvalues, retcols=retcols, allow_none=allow_none
|
||||
return self.runInteraction(
|
||||
desc,
|
||||
self._simple_select_one_txn,
|
||||
table, keyvalues, retcols, allow_none,
|
||||
)
|
||||
|
||||
def _simple_select_one_onecol(self, table, keyvalues, retcol,
|
||||
allow_none=False):
|
||||
allow_none=False,
|
||||
desc="_simple_select_one_onecol"):
|
||||
"""Executes a SELECT query on the named table, which is expected to
|
||||
return a single row, returning a single column from it."
|
||||
|
||||
|
@ -419,7 +466,7 @@ class SQLBaseStore(object):
|
|||
retcol : string giving the name of the column to return
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"_simple_select_one_onecol",
|
||||
desc,
|
||||
self._simple_select_one_onecol_txn,
|
||||
table, keyvalues, retcol, allow_none=allow_none,
|
||||
)
|
||||
|
@ -455,7 +502,8 @@ class SQLBaseStore(object):
|
|||
|
||||
return [r[0] for r in txn.fetchall()]
|
||||
|
||||
def _simple_select_onecol(self, table, keyvalues, retcol):
|
||||
def _simple_select_onecol(self, table, keyvalues, retcol,
|
||||
desc="_simple_select_onecol"):
|
||||
"""Executes a SELECT query on the named table, which returns a list
|
||||
comprising of the values of the named column from the selected rows.
|
||||
|
||||
|
@ -468,12 +516,13 @@ class SQLBaseStore(object):
|
|||
Deferred: Results in a list
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"_simple_select_onecol",
|
||||
desc,
|
||||
self._simple_select_onecol_txn,
|
||||
table, keyvalues, retcol
|
||||
)
|
||||
|
||||
def _simple_select_list(self, table, keyvalues, retcols):
|
||||
def _simple_select_list(self, table, keyvalues, retcols,
|
||||
desc="_simple_select_list"):
|
||||
"""Executes a SELECT query on the named table, which may return zero or
|
||||
more rows, returning the result as a list of dicts.
|
||||
|
||||
|
@ -484,7 +533,7 @@ class SQLBaseStore(object):
|
|||
retcols : list of strings giving the names of the columns to return
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"_simple_select_list",
|
||||
desc,
|
||||
self._simple_select_list_txn,
|
||||
table, keyvalues, retcols
|
||||
)
|
||||
|
@ -516,7 +565,7 @@ class SQLBaseStore(object):
|
|||
return self.cursor_to_dict(txn)
|
||||
|
||||
def _simple_update_one(self, table, keyvalues, updatevalues,
|
||||
retcols=None):
|
||||
desc="_simple_update_one"):
|
||||
"""Executes an UPDATE query on the named table, setting new values for
|
||||
columns in a row matching the key values.
|
||||
|
||||
|
@ -534,56 +583,76 @@ class SQLBaseStore(object):
|
|||
get-and-set. This can be used to implement compare-and-set by putting
|
||||
the update column in the 'keyvalues' dict as well.
|
||||
"""
|
||||
return self._simple_selectupdate_one(table, keyvalues, updatevalues,
|
||||
retcols=retcols)
|
||||
return self.runInteraction(
|
||||
desc,
|
||||
self._simple_update_one_txn,
|
||||
table, keyvalues, updatevalues,
|
||||
)
|
||||
|
||||
def _simple_update_one_txn(self, txn, table, keyvalues, updatevalues):
|
||||
update_sql = "UPDATE %s SET %s WHERE %s" % (
|
||||
table,
|
||||
", ".join("%s = ?" % (k,) for k in updatevalues),
|
||||
" AND ".join("%s = ?" % (k,) for k in keyvalues)
|
||||
)
|
||||
|
||||
txn.execute(
|
||||
update_sql,
|
||||
updatevalues.values() + keyvalues.values()
|
||||
)
|
||||
|
||||
if txn.rowcount == 0:
|
||||
raise StoreError(404, "No row found")
|
||||
if txn.rowcount > 1:
|
||||
raise StoreError(500, "More than one row matched")
|
||||
|
||||
def _simple_select_one_txn(self, txn, table, keyvalues, retcols,
|
||||
allow_none=False):
|
||||
select_sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
|
||||
", ".join(retcols),
|
||||
table,
|
||||
" AND ".join("%s = ?" % (k) for k in keyvalues)
|
||||
)
|
||||
|
||||
txn.execute(select_sql, keyvalues.values())
|
||||
|
||||
row = txn.fetchone()
|
||||
if not row:
|
||||
if allow_none:
|
||||
return None
|
||||
raise StoreError(404, "No row found")
|
||||
if txn.rowcount > 1:
|
||||
raise StoreError(500, "More than one row matched")
|
||||
|
||||
return dict(zip(retcols, row))
|
||||
|
||||
def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None,
|
||||
retcols=None, allow_none=False):
|
||||
retcols=None, allow_none=False,
|
||||
desc="_simple_selectupdate_one"):
|
||||
""" Combined SELECT then UPDATE."""
|
||||
if retcols:
|
||||
select_sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
|
||||
", ".join(retcols),
|
||||
table,
|
||||
" AND ".join("%s = ?" % (k) for k in keyvalues)
|
||||
)
|
||||
|
||||
if updatevalues:
|
||||
update_sql = "UPDATE %s SET %s WHERE %s" % (
|
||||
table,
|
||||
", ".join("%s = ?" % (k,) for k in updatevalues),
|
||||
" AND ".join("%s = ?" % (k,) for k in keyvalues)
|
||||
)
|
||||
|
||||
def func(txn):
|
||||
ret = None
|
||||
if retcols:
|
||||
txn.execute(select_sql, keyvalues.values())
|
||||
|
||||
row = txn.fetchone()
|
||||
if not row:
|
||||
if allow_none:
|
||||
return None
|
||||
raise StoreError(404, "No row found")
|
||||
if txn.rowcount > 1:
|
||||
raise StoreError(500, "More than one row matched")
|
||||
|
||||
ret = dict(zip(retcols, row))
|
||||
|
||||
if updatevalues:
|
||||
txn.execute(
|
||||
update_sql,
|
||||
updatevalues.values() + keyvalues.values()
|
||||
ret = self._simple_select_one_txn(
|
||||
txn,
|
||||
table=table,
|
||||
keyvalues=keyvalues,
|
||||
retcols=retcols,
|
||||
allow_none=allow_none,
|
||||
)
|
||||
|
||||
if txn.rowcount == 0:
|
||||
raise StoreError(404, "No row found")
|
||||
if txn.rowcount > 1:
|
||||
raise StoreError(500, "More than one row matched")
|
||||
if updatevalues:
|
||||
self._simple_update_one_txn(
|
||||
txn,
|
||||
table=table,
|
||||
keyvalues=keyvalues,
|
||||
updatevalues=updatevalues,
|
||||
)
|
||||
|
||||
return ret
|
||||
return self.runInteraction("_simple_selectupdate_one", func)
|
||||
return self.runInteraction(desc, func)
|
||||
|
||||
def _simple_delete_one(self, table, keyvalues):
|
||||
def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"):
|
||||
"""Executes a DELETE query on the named table, expecting to delete a
|
||||
single row.
|
||||
|
||||
|
@ -602,9 +671,9 @@ class SQLBaseStore(object):
|
|||
raise StoreError(404, "No row found")
|
||||
if txn.rowcount > 1:
|
||||
raise StoreError(500, "more than one row matched")
|
||||
return self.runInteraction("_simple_delete_one", func)
|
||||
return self.runInteraction(desc, func)
|
||||
|
||||
def _simple_delete(self, table, keyvalues):
|
||||
def _simple_delete(self, table, keyvalues, desc="_simple_delete"):
|
||||
"""Executes a DELETE query on the named table.
|
||||
|
||||
Args:
|
||||
|
@ -612,7 +681,7 @@ class SQLBaseStore(object):
|
|||
keyvalues : dict of column names and values to select the row with
|
||||
"""
|
||||
|
||||
return self.runInteraction("_simple_delete", self._simple_delete_txn)
|
||||
return self.runInteraction(desc, self._simple_delete_txn)
|
||||
|
||||
def _simple_delete_txn(self, txn, table, keyvalues):
|
||||
sql = "DELETE FROM %s WHERE %s" % (
|
||||
|
@ -782,6 +851,13 @@ class SQLBaseStore(object):
|
|||
return result[0] if result else None
|
||||
|
||||
|
||||
class _RollbackButIsFineException(Exception):
|
||||
""" This exception is used to rollback a transaction without implying
|
||||
something went wrong.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Table(object):
|
||||
""" A base class used to store information about a particular table.
|
||||
"""
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
from ._base import SQLBaseStore, cached
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
|
||||
|
@ -48,6 +48,7 @@ class DirectoryStore(SQLBaseStore):
|
|||
{"room_alias": room_alias.to_string()},
|
||||
"room_id",
|
||||
allow_none=True,
|
||||
desc="get_association_from_room_alias",
|
||||
)
|
||||
|
||||
if not room_id:
|
||||
|
@ -58,6 +59,7 @@ class DirectoryStore(SQLBaseStore):
|
|||
"room_alias_servers",
|
||||
{"room_alias": room_alias.to_string()},
|
||||
"server",
|
||||
desc="get_association_from_room_alias",
|
||||
)
|
||||
|
||||
if not servers:
|
||||
|
@ -87,6 +89,7 @@ class DirectoryStore(SQLBaseStore):
|
|||
"room_alias": room_alias.to_string(),
|
||||
"room_id": room_id,
|
||||
},
|
||||
desc="create_room_alias_association",
|
||||
)
|
||||
except sqlite3.IntegrityError:
|
||||
raise SynapseError(
|
||||
|
@ -100,16 +103,22 @@ class DirectoryStore(SQLBaseStore):
|
|||
{
|
||||
"room_alias": room_alias.to_string(),
|
||||
"server": server,
|
||||
}
|
||||
},
|
||||
desc="create_room_alias_association",
|
||||
)
|
||||
self.get_aliases_for_room.invalidate(room_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_room_alias(self, room_alias):
|
||||
return self.runInteraction(
|
||||
room_id = yield self.runInteraction(
|
||||
"delete_room_alias",
|
||||
self._delete_room_alias_txn,
|
||||
room_alias,
|
||||
)
|
||||
|
||||
self.get_aliases_for_room.invalidate(room_id)
|
||||
defer.returnValue(room_id)
|
||||
|
||||
def _delete_room_alias_txn(self, txn, room_alias):
|
||||
cursor = txn.execute(
|
||||
"SELECT room_id FROM room_aliases WHERE room_alias = ?",
|
||||
|
@ -134,9 +143,11 @@ class DirectoryStore(SQLBaseStore):
|
|||
|
||||
return room_id
|
||||
|
||||
@cached()
|
||||
def get_aliases_for_room(self, room_id):
|
||||
return self._simple_select_onecol(
|
||||
"room_aliases",
|
||||
{"room_id": room_id},
|
||||
"room_alias",
|
||||
desc="get_aliases_for_room",
|
||||
)
|
||||
|
|
|
@ -429,3 +429,15 @@ class EventFederationStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
return events[:limit]
|
||||
|
||||
def clean_room_for_join(self, room_id):
|
||||
return self.runInteraction(
|
||||
"clean_room_for_join",
|
||||
self._clean_room_for_join_txn,
|
||||
room_id,
|
||||
)
|
||||
|
||||
def _clean_room_for_join_txn(self, txn, room_id):
|
||||
query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
|
||||
|
||||
txn.execute(query, (room_id,))
|
||||
|
|
395
synapse/storage/events.py
Normal file
395
synapse/storage/events.py
Normal file
|
@ -0,0 +1,395 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014, 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from _base import SQLBaseStore, _RollbackButIsFineException
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.crypto.event_signing import compute_event_reference_hash
|
||||
|
||||
from syutil.base64util import decode_base64
|
||||
from syutil.jsonutil import encode_canonical_json
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EventsStore(SQLBaseStore):
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def persist_event(self, event, context, backfilled=False,
|
||||
is_new_state=True, current_state=None):
|
||||
stream_ordering = None
|
||||
if backfilled:
|
||||
if not self.min_token_deferred.called:
|
||||
yield self.min_token_deferred
|
||||
self.min_token -= 1
|
||||
stream_ordering = self.min_token
|
||||
|
||||
try:
|
||||
yield self.runInteraction(
|
||||
"persist_event",
|
||||
self._persist_event_txn,
|
||||
event=event,
|
||||
context=context,
|
||||
backfilled=backfilled,
|
||||
stream_ordering=stream_ordering,
|
||||
is_new_state=is_new_state,
|
||||
current_state=current_state,
|
||||
)
|
||||
self.get_room_events_max_id.invalidate()
|
||||
except _RollbackButIsFineException:
|
||||
pass
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_event(self, event_id, check_redacted=True,
|
||||
get_prev_content=False, allow_rejected=False,
|
||||
allow_none=False):
|
||||
"""Get an event from the database by event_id.
|
||||
|
||||
Args:
|
||||
event_id (str): The event_id of the event to fetch
|
||||
check_redacted (bool): If True, check if event has been redacted
|
||||
and redact it.
|
||||
get_prev_content (bool): If True and event is a state event,
|
||||
include the previous states content in the unsigned field.
|
||||
allow_rejected (bool): If True return rejected events.
|
||||
allow_none (bool): If True, return None if no event found, if
|
||||
False throw an exception.
|
||||
|
||||
Returns:
|
||||
Deferred : A FrozenEvent.
|
||||
"""
|
||||
event = yield self.runInteraction(
|
||||
"get_event", self._get_event_txn,
|
||||
event_id,
|
||||
check_redacted=check_redacted,
|
||||
get_prev_content=get_prev_content,
|
||||
allow_rejected=allow_rejected,
|
||||
)
|
||||
|
||||
if not event and not allow_none:
|
||||
raise RuntimeError("Could not find event %s" % (event_id,))
|
||||
|
||||
defer.returnValue(event)
|
||||
|
||||
@log_function
|
||||
def _persist_event_txn(self, txn, event, context, backfilled,
|
||||
stream_ordering=None, is_new_state=True,
|
||||
current_state=None):
|
||||
|
||||
# Remove the any existing cache entries for the event_id
|
||||
self._get_event_cache.pop(event.event_id)
|
||||
|
||||
# We purposefully do this first since if we include a `current_state`
|
||||
# key, we *want* to update the `current_state_events` table
|
||||
if current_state:
|
||||
txn.execute(
|
||||
"DELETE FROM current_state_events WHERE room_id = ?",
|
||||
(event.room_id,)
|
||||
)
|
||||
|
||||
for s in current_state:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"current_state_events",
|
||||
{
|
||||
"event_id": s.event_id,
|
||||
"room_id": s.room_id,
|
||||
"type": s.type,
|
||||
"state_key": s.state_key,
|
||||
},
|
||||
or_replace=True,
|
||||
)
|
||||
|
||||
if event.is_state() and is_new_state:
|
||||
if not backfilled and not context.rejected:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="state_forward_extremities",
|
||||
values={
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"type": event.type,
|
||||
"state_key": event.state_key,
|
||||
},
|
||||
or_replace=True,
|
||||
)
|
||||
|
||||
for prev_state_id, _ in event.prev_state:
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="state_forward_extremities",
|
||||
keyvalues={
|
||||
"event_id": prev_state_id,
|
||||
}
|
||||
)
|
||||
|
||||
outlier = event.internal_metadata.is_outlier()
|
||||
|
||||
if not outlier:
|
||||
self._store_state_groups_txn(txn, event, context)
|
||||
|
||||
self._update_min_depth_for_room_txn(
|
||||
txn,
|
||||
event.room_id,
|
||||
event.depth
|
||||
)
|
||||
|
||||
self._handle_prev_events(
|
||||
txn,
|
||||
outlier=outlier,
|
||||
event_id=event.event_id,
|
||||
prev_events=event.prev_events,
|
||||
room_id=event.room_id,
|
||||
)
|
||||
|
||||
have_persisted = self._simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="event_json",
|
||||
keyvalues={"event_id": event.event_id},
|
||||
retcol="event_id",
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
metadata_json = encode_canonical_json(
|
||||
event.internal_metadata.get_dict()
|
||||
)
|
||||
|
||||
# If we have already persisted this event, we don't need to do any
|
||||
# more processing.
|
||||
# The processing above must be done on every call to persist event,
|
||||
# since they might not have happened on previous calls. For example,
|
||||
# if we are persisting an event that we had persisted as an outlier,
|
||||
# but is no longer one.
|
||||
if have_persisted:
|
||||
if not outlier:
|
||||
sql = (
|
||||
"UPDATE event_json SET internal_metadata = ?"
|
||||
" WHERE event_id = ?"
|
||||
)
|
||||
txn.execute(
|
||||
sql,
|
||||
(metadata_json.decode("UTF-8"), event.event_id,)
|
||||
)
|
||||
|
||||
sql = (
|
||||
"UPDATE events SET outlier = 0"
|
||||
" WHERE event_id = ?"
|
||||
)
|
||||
txn.execute(
|
||||
sql,
|
||||
(event.event_id,)
|
||||
)
|
||||
return
|
||||
|
||||
if event.type == EventTypes.Member:
|
||||
self._store_room_member_txn(txn, event)
|
||||
elif event.type == EventTypes.Feedback:
|
||||
self._store_feedback_txn(txn, event)
|
||||
elif event.type == EventTypes.Name:
|
||||
self._store_room_name_txn(txn, event)
|
||||
elif event.type == EventTypes.Topic:
|
||||
self._store_room_topic_txn(txn, event)
|
||||
elif event.type == EventTypes.Redaction:
|
||||
self._store_redaction(txn, event)
|
||||
|
||||
event_dict = {
|
||||
k: v
|
||||
for k, v in event.get_dict().items()
|
||||
if k not in [
|
||||
"redacted",
|
||||
"redacted_because",
|
||||
]
|
||||
}
|
||||
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="event_json",
|
||||
values={
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"internal_metadata": metadata_json.decode("UTF-8"),
|
||||
"json": encode_canonical_json(event_dict).decode("UTF-8"),
|
||||
},
|
||||
or_replace=True,
|
||||
)
|
||||
|
||||
content = encode_canonical_json(
|
||||
event.content
|
||||
).decode("UTF-8")
|
||||
|
||||
vals = {
|
||||
"topological_ordering": event.depth,
|
||||
"event_id": event.event_id,
|
||||
"type": event.type,
|
||||
"room_id": event.room_id,
|
||||
"content": content,
|
||||
"processed": True,
|
||||
"outlier": outlier,
|
||||
"depth": event.depth,
|
||||
}
|
||||
|
||||
if stream_ordering is not None:
|
||||
vals["stream_ordering"] = stream_ordering
|
||||
|
||||
unrec = {
|
||||
k: v
|
||||
for k, v in event.get_dict().items()
|
||||
if k not in vals.keys() and k not in [
|
||||
"redacted",
|
||||
"redacted_because",
|
||||
"signatures",
|
||||
"hashes",
|
||||
"prev_events",
|
||||
]
|
||||
}
|
||||
|
||||
vals["unrecognized_keys"] = encode_canonical_json(
|
||||
unrec
|
||||
).decode("UTF-8")
|
||||
|
||||
try:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"events",
|
||||
vals,
|
||||
or_replace=(not outlier),
|
||||
or_ignore=bool(outlier),
|
||||
)
|
||||
except:
|
||||
logger.warn(
|
||||
"Failed to persist, probably duplicate: %s",
|
||||
event.event_id,
|
||||
exc_info=True,
|
||||
)
|
||||
raise _RollbackButIsFineException("_persist_event")
|
||||
|
||||
if context.rejected:
|
||||
self._store_rejections_txn(txn, event.event_id, context.rejected)
|
||||
|
||||
if event.is_state():
|
||||
vals = {
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"type": event.type,
|
||||
"state_key": event.state_key,
|
||||
}
|
||||
|
||||
# TODO: How does this work with backfilling?
|
||||
if hasattr(event, "replaces_state"):
|
||||
vals["prev_state"] = event.replaces_state
|
||||
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"state_events",
|
||||
vals,
|
||||
)
|
||||
|
||||
if is_new_state and not context.rejected:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"current_state_events",
|
||||
{
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"type": event.type,
|
||||
"state_key": event.state_key,
|
||||
},
|
||||
)
|
||||
|
||||
for e_id, h in event.prev_state:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="event_edges",
|
||||
values={
|
||||
"event_id": event.event_id,
|
||||
"prev_event_id": e_id,
|
||||
"room_id": event.room_id,
|
||||
"is_state": 1,
|
||||
},
|
||||
)
|
||||
|
||||
for hash_alg, hash_base64 in event.hashes.items():
|
||||
hash_bytes = decode_base64(hash_base64)
|
||||
self._store_event_content_hash_txn(
|
||||
txn, event.event_id, hash_alg, hash_bytes,
|
||||
)
|
||||
|
||||
for prev_event_id, prev_hashes in event.prev_events:
|
||||
for alg, hash_base64 in prev_hashes.items():
|
||||
hash_bytes = decode_base64(hash_base64)
|
||||
self._store_prev_event_hash_txn(
|
||||
txn, event.event_id, prev_event_id, alg, hash_bytes
|
||||
)
|
||||
|
||||
for auth_id, _ in event.auth_events:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="event_auth",
|
||||
values={
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"auth_id": auth_id,
|
||||
},
|
||||
)
|
||||
|
||||
(ref_alg, ref_hash_bytes) = compute_event_reference_hash(event)
|
||||
self._store_event_reference_hash_txn(
|
||||
txn, event.event_id, ref_alg, ref_hash_bytes
|
||||
)
|
||||
|
||||
def _store_redaction(self, txn, event):
|
||||
# invalidate the cache for the redacted event
|
||||
self._get_event_cache.pop(event.redacts)
|
||||
txn.execute(
|
||||
"INSERT INTO redactions (event_id, redacts) VALUES (?,?)",
|
||||
(event.event_id, event.redacts)
|
||||
)
|
||||
|
||||
def have_events(self, event_ids):
|
||||
"""Given a list of event ids, check if we have already processed them.
|
||||
|
||||
Returns:
|
||||
dict: Has an entry for each event id we already have seen. Maps to
|
||||
the rejected reason string if we rejected the event, else maps to
|
||||
None.
|
||||
"""
|
||||
if not event_ids:
|
||||
return defer.succeed({})
|
||||
|
||||
def f(txn):
|
||||
sql = (
|
||||
"SELECT e.event_id, reason FROM events as e "
|
||||
"LEFT JOIN rejections as r ON e.event_id = r.event_id "
|
||||
"WHERE e.event_id = ?"
|
||||
)
|
||||
|
||||
res = {}
|
||||
for event_id in event_ids:
|
||||
txn.execute(sql, (event_id,))
|
||||
row = txn.fetchone()
|
||||
if row:
|
||||
_, rejected = row
|
||||
res[event_id] = rejected
|
||||
|
||||
return res
|
||||
|
||||
return self.runInteraction(
|
||||
"have_events", f,
|
||||
)
|
|
@ -1,47 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014, 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
|
||||
|
||||
class FeedbackStore(SQLBaseStore):
|
||||
|
||||
def _store_feedback_txn(self, txn, event):
|
||||
self._simple_insert_txn(txn, "feedback", {
|
||||
"event_id": event.event_id,
|
||||
"feedback_type": event.content["type"],
|
||||
"room_id": event.room_id,
|
||||
"target_event_id": event.content["target_event_id"],
|
||||
"sender": event.user_id,
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_feedback_for_event(self, event_id):
|
||||
sql = (
|
||||
"SELECT events.* FROM events INNER JOIN feedback "
|
||||
"ON events.event_id = feedback.event_id "
|
||||
"WHERE feedback.target_event_id = ? "
|
||||
)
|
||||
|
||||
rows = yield self._execute_and_decode("get_feedback_for_event", sql, event_id)
|
||||
|
||||
defer.returnValue(
|
||||
[
|
||||
(yield self._parse_events(r))
|
||||
for r in rows
|
||||
]
|
||||
)
|
|
@ -31,6 +31,7 @@ class FilteringStore(SQLBaseStore):
|
|||
},
|
||||
retcol="filter_json",
|
||||
allow_none=False,
|
||||
desc="get_user_filter",
|
||||
)
|
||||
|
||||
defer.returnValue(json.loads(def_json))
|
||||
|
|
|
@ -32,6 +32,7 @@ class MediaRepositoryStore(SQLBaseStore):
|
|||
{"media_id": media_id},
|
||||
("media_type", "media_length", "upload_name", "created_ts"),
|
||||
allow_none=True,
|
||||
desc="get_local_media",
|
||||
)
|
||||
|
||||
def store_local_media(self, media_id, media_type, time_now_ms, upload_name,
|
||||
|
@ -45,7 +46,8 @@ class MediaRepositoryStore(SQLBaseStore):
|
|||
"upload_name": upload_name,
|
||||
"media_length": media_length,
|
||||
"user_id": user_id.to_string(),
|
||||
}
|
||||
},
|
||||
desc="store_local_media",
|
||||
)
|
||||
|
||||
def get_local_media_thumbnails(self, media_id):
|
||||
|
@ -55,7 +57,8 @@ class MediaRepositoryStore(SQLBaseStore):
|
|||
(
|
||||
"thumbnail_width", "thumbnail_height", "thumbnail_method",
|
||||
"thumbnail_type", "thumbnail_length",
|
||||
)
|
||||
),
|
||||
desc="get_local_media_thumbnails",
|
||||
)
|
||||
|
||||
def store_local_thumbnail(self, media_id, thumbnail_width,
|
||||
|
@ -70,7 +73,8 @@ class MediaRepositoryStore(SQLBaseStore):
|
|||
"thumbnail_method": thumbnail_method,
|
||||
"thumbnail_type": thumbnail_type,
|
||||
"thumbnail_length": thumbnail_length,
|
||||
}
|
||||
},
|
||||
desc="store_local_thumbnail",
|
||||
)
|
||||
|
||||
def get_cached_remote_media(self, origin, media_id):
|
||||
|
@ -82,6 +86,7 @@ class MediaRepositoryStore(SQLBaseStore):
|
|||
"filesystem_id",
|
||||
),
|
||||
allow_none=True,
|
||||
desc="get_cached_remote_media",
|
||||
)
|
||||
|
||||
def store_cached_remote_media(self, origin, media_id, media_type,
|
||||
|
@ -97,7 +102,8 @@ class MediaRepositoryStore(SQLBaseStore):
|
|||
"created_ts": time_now_ms,
|
||||
"upload_name": upload_name,
|
||||
"filesystem_id": filesystem_id,
|
||||
}
|
||||
},
|
||||
desc="store_cached_remote_media",
|
||||
)
|
||||
|
||||
def get_remote_media_thumbnails(self, origin, media_id):
|
||||
|
@ -107,7 +113,8 @@ class MediaRepositoryStore(SQLBaseStore):
|
|||
(
|
||||
"thumbnail_width", "thumbnail_height", "thumbnail_method",
|
||||
"thumbnail_type", "thumbnail_length", "filesystem_id",
|
||||
)
|
||||
),
|
||||
desc="get_remote_media_thumbnails",
|
||||
)
|
||||
|
||||
def store_remote_media_thumbnail(self, origin, media_id, filesystem_id,
|
||||
|
@ -125,5 +132,6 @@ class MediaRepositoryStore(SQLBaseStore):
|
|||
"thumbnail_type": thumbnail_type,
|
||||
"thumbnail_length": thumbnail_length,
|
||||
"filesystem_id": filesystem_id,
|
||||
}
|
||||
},
|
||||
desc="store_remote_media_thumbnail",
|
||||
)
|
||||
|
|
|
@ -21,6 +21,7 @@ class PresenceStore(SQLBaseStore):
|
|||
return self._simple_insert(
|
||||
table="presence",
|
||||
values={"user_id": user_localpart},
|
||||
desc="create_presence",
|
||||
)
|
||||
|
||||
def has_presence_state(self, user_localpart):
|
||||
|
@ -29,6 +30,7 @@ class PresenceStore(SQLBaseStore):
|
|||
keyvalues={"user_id": user_localpart},
|
||||
retcols=["user_id"],
|
||||
allow_none=True,
|
||||
desc="has_presence_state",
|
||||
)
|
||||
|
||||
def get_presence_state(self, user_localpart):
|
||||
|
@ -36,6 +38,7 @@ class PresenceStore(SQLBaseStore):
|
|||
table="presence",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
retcols=["state", "status_msg", "mtime"],
|
||||
desc="get_presence_state",
|
||||
)
|
||||
|
||||
def set_presence_state(self, user_localpart, new_state):
|
||||
|
@ -45,7 +48,7 @@ class PresenceStore(SQLBaseStore):
|
|||
updatevalues={"state": new_state["state"],
|
||||
"status_msg": new_state["status_msg"],
|
||||
"mtime": self._clock.time_msec()},
|
||||
retcols=["state"],
|
||||
desc="set_presence_state",
|
||||
)
|
||||
|
||||
def allow_presence_visible(self, observed_localpart, observer_userid):
|
||||
|
@ -53,6 +56,7 @@ class PresenceStore(SQLBaseStore):
|
|||
table="presence_allow_inbound",
|
||||
values={"observed_user_id": observed_localpart,
|
||||
"observer_user_id": observer_userid},
|
||||
desc="allow_presence_visible",
|
||||
)
|
||||
|
||||
def disallow_presence_visible(self, observed_localpart, observer_userid):
|
||||
|
@ -60,6 +64,7 @@ class PresenceStore(SQLBaseStore):
|
|||
table="presence_allow_inbound",
|
||||
keyvalues={"observed_user_id": observed_localpart,
|
||||
"observer_user_id": observer_userid},
|
||||
desc="disallow_presence_visible",
|
||||
)
|
||||
|
||||
def is_presence_visible(self, observed_localpart, observer_userid):
|
||||
|
@ -69,6 +74,7 @@ class PresenceStore(SQLBaseStore):
|
|||
"observer_user_id": observer_userid},
|
||||
retcols=["observed_user_id"],
|
||||
allow_none=True,
|
||||
desc="is_presence_visible",
|
||||
)
|
||||
|
||||
def add_presence_list_pending(self, observer_localpart, observed_userid):
|
||||
|
@ -77,6 +83,7 @@ class PresenceStore(SQLBaseStore):
|
|||
values={"user_id": observer_localpart,
|
||||
"observed_user_id": observed_userid,
|
||||
"accepted": False},
|
||||
desc="add_presence_list_pending",
|
||||
)
|
||||
|
||||
def set_presence_list_accepted(self, observer_localpart, observed_userid):
|
||||
|
@ -85,6 +92,7 @@ class PresenceStore(SQLBaseStore):
|
|||
keyvalues={"user_id": observer_localpart,
|
||||
"observed_user_id": observed_userid},
|
||||
updatevalues={"accepted": True},
|
||||
desc="set_presence_list_accepted",
|
||||
)
|
||||
|
||||
def get_presence_list(self, observer_localpart, accepted=None):
|
||||
|
@ -96,6 +104,7 @@ class PresenceStore(SQLBaseStore):
|
|||
table="presence_list",
|
||||
keyvalues=keyvalues,
|
||||
retcols=["observed_user_id", "accepted"],
|
||||
desc="get_presence_list",
|
||||
)
|
||||
|
||||
def del_presence_list(self, observer_localpart, observed_userid):
|
||||
|
@ -103,4 +112,5 @@ class PresenceStore(SQLBaseStore):
|
|||
table="presence_list",
|
||||
keyvalues={"user_id": observer_localpart,
|
||||
"observed_user_id": observed_userid},
|
||||
desc="del_presence_list",
|
||||
)
|
||||
|
|
|
@ -21,6 +21,7 @@ class ProfileStore(SQLBaseStore):
|
|||
return self._simple_insert(
|
||||
table="profiles",
|
||||
values={"user_id": user_localpart},
|
||||
desc="create_profile",
|
||||
)
|
||||
|
||||
def get_profile_displayname(self, user_localpart):
|
||||
|
@ -28,6 +29,7 @@ class ProfileStore(SQLBaseStore):
|
|||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
retcol="displayname",
|
||||
desc="get_profile_displayname",
|
||||
)
|
||||
|
||||
def set_profile_displayname(self, user_localpart, new_displayname):
|
||||
|
@ -35,6 +37,7 @@ class ProfileStore(SQLBaseStore):
|
|||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
updatevalues={"displayname": new_displayname},
|
||||
desc="set_profile_displayname",
|
||||
)
|
||||
|
||||
def get_profile_avatar_url(self, user_localpart):
|
||||
|
@ -42,6 +45,7 @@ class ProfileStore(SQLBaseStore):
|
|||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
retcol="avatar_url",
|
||||
desc="get_profile_avatar_url",
|
||||
)
|
||||
|
||||
def set_profile_avatar_url(self, user_localpart, new_avatar_url):
|
||||
|
@ -49,4 +53,5 @@ class ProfileStore(SQLBaseStore):
|
|||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
updatevalues={"avatar_url": new_avatar_url},
|
||||
desc="set_profile_avatar_url",
|
||||
)
|
||||
|
|
|
@ -50,7 +50,8 @@ class PushRuleStore(SQLBaseStore):
|
|||
results = yield self._simple_select_list(
|
||||
PushRuleEnableTable.table_name,
|
||||
{'user_name': user_name},
|
||||
PushRuleEnableTable.fields
|
||||
PushRuleEnableTable.fields,
|
||||
desc="get_push_rules_enabled_for_user",
|
||||
)
|
||||
defer.returnValue(
|
||||
{r['rule_id']: False if r['enabled'] == 0 else True for r in results}
|
||||
|
@ -201,7 +202,8 @@ class PushRuleStore(SQLBaseStore):
|
|||
"""
|
||||
yield self._simple_delete_one(
|
||||
PushRuleTable.table_name,
|
||||
{'user_name': user_name, 'rule_id': rule_id}
|
||||
{'user_name': user_name, 'rule_id': rule_id},
|
||||
desc="delete_push_rule",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -209,7 +211,8 @@ class PushRuleStore(SQLBaseStore):
|
|||
yield self._simple_upsert(
|
||||
PushRuleEnableTable.table_name,
|
||||
{'user_name': user_name, 'rule_id': rule_id},
|
||||
{'enabled': enabled}
|
||||
{'enabled': enabled},
|
||||
desc="set_push_rule_enabled",
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -114,7 +114,9 @@ class PusherStore(SQLBaseStore):
|
|||
ts=pushkey_ts,
|
||||
lang=lang,
|
||||
data=data
|
||||
))
|
||||
),
|
||||
desc="add_pusher",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("create_pusher with failed: %s", e)
|
||||
raise StoreError(500, "Problem creating pusher.")
|
||||
|
@ -123,7 +125,8 @@ class PusherStore(SQLBaseStore):
|
|||
def delete_pusher_by_app_id_pushkey(self, app_id, pushkey):
|
||||
yield self._simple_delete_one(
|
||||
PushersTable.table_name,
|
||||
dict(app_id=app_id, pushkey=pushkey)
|
||||
{"app_id": app_id, "pushkey": pushkey},
|
||||
desc="delete_pusher_by_app_id_pushkey",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -131,7 +134,8 @@ class PusherStore(SQLBaseStore):
|
|||
yield self._simple_update_one(
|
||||
PushersTable.table_name,
|
||||
{'app_id': app_id, 'pushkey': pushkey},
|
||||
{'last_token': last_token}
|
||||
{'last_token': last_token},
|
||||
desc="update_pusher_last_token",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -140,7 +144,8 @@ class PusherStore(SQLBaseStore):
|
|||
yield self._simple_update_one(
|
||||
PushersTable.table_name,
|
||||
{'app_id': app_id, 'pushkey': pushkey},
|
||||
{'last_token': last_token, 'last_success': last_success}
|
||||
{'last_token': last_token, 'last_success': last_success},
|
||||
desc="update_pusher_last_token_and_success",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -148,7 +153,8 @@ class PusherStore(SQLBaseStore):
|
|||
yield self._simple_update_one(
|
||||
PushersTable.table_name,
|
||||
{'app_id': app_id, 'pushkey': pushkey},
|
||||
{'failing_since': failing_since}
|
||||
{'failing_since': failing_since},
|
||||
desc="update_pusher_failing_since",
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ from sqlite3 import IntegrityError
|
|||
|
||||
from synapse.api.errors import StoreError, Codes
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
from ._base import SQLBaseStore, cached
|
||||
|
||||
|
||||
class RegistrationStore(SQLBaseStore):
|
||||
|
@ -39,7 +39,10 @@ class RegistrationStore(SQLBaseStore):
|
|||
Raises:
|
||||
StoreError if there was a problem adding this.
|
||||
"""
|
||||
row = yield self._simple_select_one("users", {"name": user_id}, ["id"])
|
||||
row = yield self._simple_select_one(
|
||||
"users", {"name": user_id}, ["id"],
|
||||
desc="add_access_token_to_user",
|
||||
)
|
||||
if not row:
|
||||
raise StoreError(400, "Bad user ID supplied.")
|
||||
row_id = row["id"]
|
||||
|
@ -48,7 +51,8 @@ class RegistrationStore(SQLBaseStore):
|
|||
{
|
||||
"user_id": row_id,
|
||||
"token": token
|
||||
}
|
||||
},
|
||||
desc="add_access_token_to_user",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -91,6 +95,11 @@ class RegistrationStore(SQLBaseStore):
|
|||
"get_user_by_id", self.cursor_to_dict, query, user_id
|
||||
)
|
||||
|
||||
@cached()
|
||||
# TODO(paul): Currently there's no code to invalidate this cache. That
|
||||
# means if/when we ever add internal ways to invalidate access tokens or
|
||||
# change whether a user is a server admin, those will need to invoke
|
||||
# store.get_user_by_token.invalidate(token)
|
||||
def get_user_by_token(self, token):
|
||||
"""Get a user from the given access token.
|
||||
|
||||
|
@ -115,6 +124,7 @@ class RegistrationStore(SQLBaseStore):
|
|||
keyvalues={"name": user.to_string()},
|
||||
retcol="admin",
|
||||
allow_none=True,
|
||||
desc="is_server_admin",
|
||||
)
|
||||
|
||||
defer.returnValue(res if res else False)
|
||||
|
|
|
@ -29,7 +29,7 @@ class RejectionsStore(SQLBaseStore):
|
|||
"event_id": event_id,
|
||||
"reason": reason,
|
||||
"last_check": self._clock.time_msec(),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def get_rejection_reason(self, event_id):
|
||||
|
@ -40,4 +40,5 @@ class RejectionsStore(SQLBaseStore):
|
|||
"event_id": event_id,
|
||||
},
|
||||
allow_none=True,
|
||||
desc="get_rejection_reason",
|
||||
)
|
||||
|
|
|
@ -15,11 +15,9 @@
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
from sqlite3 import IntegrityError
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
|
||||
from ._base import SQLBaseStore, Table
|
||||
from ._base import SQLBaseStore
|
||||
|
||||
import collections
|
||||
import logging
|
||||
|
@ -27,8 +25,9 @@ import logging
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
OpsLevel = collections.namedtuple("OpsLevel", (
|
||||
"ban_level", "kick_level", "redact_level")
|
||||
OpsLevel = collections.namedtuple(
|
||||
"OpsLevel",
|
||||
("ban_level", "kick_level", "redact_level",)
|
||||
)
|
||||
|
||||
|
||||
|
@ -47,13 +46,15 @@ class RoomStore(SQLBaseStore):
|
|||
StoreError if the room could not be stored.
|
||||
"""
|
||||
try:
|
||||
yield self._simple_insert(RoomsTable.table_name, dict(
|
||||
room_id=room_id,
|
||||
creator=room_creator_user_id,
|
||||
is_public=is_public
|
||||
))
|
||||
except IntegrityError:
|
||||
raise StoreError(409, "Room ID in use.")
|
||||
yield self._simple_insert(
|
||||
RoomsTable.table_name,
|
||||
{
|
||||
"room_id": room_id,
|
||||
"creator": room_creator_user_id,
|
||||
"is_public": is_public,
|
||||
},
|
||||
desc="store_room",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("store_room with room_id=%s failed: %s", room_id, e)
|
||||
raise StoreError(500, "Problem creating room.")
|
||||
|
@ -66,9 +67,11 @@ class RoomStore(SQLBaseStore):
|
|||
Returns:
|
||||
A namedtuple containing the room information, or an empty list.
|
||||
"""
|
||||
query = RoomsTable.select_statement("room_id=?")
|
||||
return self._execute(
|
||||
"get_room", RoomsTable.decode_single_result, query, room_id,
|
||||
return self._simple_select_one(
|
||||
table=RoomsTable.table_name,
|
||||
keyvalues={"room_id": room_id},
|
||||
retcols=RoomsTable.fields,
|
||||
desc="get_room",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -143,7 +146,7 @@ class RoomStore(SQLBaseStore):
|
|||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"topic": event.content["topic"],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def _store_room_name_txn(self, txn, event):
|
||||
|
@ -158,8 +161,45 @@ class RoomStore(SQLBaseStore):
|
|||
}
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_name_and_aliases(self, room_id):
|
||||
del_sql = (
|
||||
"SELECT event_id FROM redactions WHERE redacts = e.event_id "
|
||||
"LIMIT 1"
|
||||
)
|
||||
|
||||
class RoomsTable(Table):
|
||||
sql = (
|
||||
"SELECT e.*, (%(redacted)s) AS redacted FROM events as e "
|
||||
"INNER JOIN current_state_events as c ON e.event_id = c.event_id "
|
||||
"INNER JOIN state_events as s ON e.event_id = s.event_id "
|
||||
"WHERE c.room_id = ? "
|
||||
) % {
|
||||
"redacted": del_sql,
|
||||
}
|
||||
|
||||
sql += " AND ((s.type = 'm.room.name' AND s.state_key = '')"
|
||||
sql += " OR s.type = 'm.room.aliases')"
|
||||
args = (room_id,)
|
||||
|
||||
results = yield self._execute_and_decode("get_current_state", sql, *args)
|
||||
|
||||
events = yield self._parse_events(results)
|
||||
|
||||
name = None
|
||||
aliases = []
|
||||
|
||||
for e in events:
|
||||
if e.type == 'm.room.name':
|
||||
if 'name' in e.content:
|
||||
name = e.content['name']
|
||||
elif e.type == 'm.room.aliases':
|
||||
if 'aliases' in e.content:
|
||||
aliases.extend(e.content['aliases'])
|
||||
|
||||
defer.returnValue((name, aliases))
|
||||
|
||||
|
||||
class RoomsTable(object):
|
||||
table_name = "rooms"
|
||||
|
||||
fields = [
|
||||
|
@ -167,5 +207,3 @@ class RoomsTable(Table):
|
|||
"is_public",
|
||||
"creator"
|
||||
]
|
||||
|
||||
EntryType = collections.namedtuple("RoomEntry", fields)
|
||||
|
|
|
@ -212,7 +212,8 @@ class RoomMemberStore(SQLBaseStore):
|
|||
return self._simple_select_onecol(
|
||||
"room_hosts",
|
||||
{"room_id": room_id},
|
||||
"host"
|
||||
"host",
|
||||
desc="get_joined_hosts_for_room",
|
||||
)
|
||||
|
||||
def _get_members_by_dict(self, where_dict):
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
|
||||
from ._base import SQLBaseStore
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -82,7 +84,7 @@ class StateStore(SQLBaseStore):
|
|||
if context.current_state is None:
|
||||
return
|
||||
|
||||
state_events = context.current_state
|
||||
state_events = dict(context.current_state)
|
||||
|
||||
if event.is_state():
|
||||
state_events[(event.type, event.state_key)] = event
|
||||
|
@ -122,3 +124,33 @@ class StateStore(SQLBaseStore):
|
|||
},
|
||||
or_replace=True,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_current_state(self, room_id, event_type=None, state_key=""):
|
||||
del_sql = (
|
||||
"SELECT event_id FROM redactions WHERE redacts = e.event_id "
|
||||
"LIMIT 1"
|
||||
)
|
||||
|
||||
sql = (
|
||||
"SELECT e.*, (%(redacted)s) AS redacted FROM events as e "
|
||||
"INNER JOIN current_state_events as c ON e.event_id = c.event_id "
|
||||
"INNER JOIN state_events as s ON e.event_id = s.event_id "
|
||||
"WHERE c.room_id = ? "
|
||||
) % {
|
||||
"redacted": del_sql,
|
||||
}
|
||||
|
||||
if event_type and state_key is not None:
|
||||
sql += " AND s.type = ? AND s.state_key = ? "
|
||||
args = (room_id, event_type, state_key)
|
||||
elif event_type:
|
||||
sql += " AND s.type = ?"
|
||||
args = (room_id, event_type)
|
||||
else:
|
||||
args = (room_id, )
|
||||
|
||||
results = yield self._execute_and_decode("get_current_state", sql, *args)
|
||||
|
||||
events = yield self._parse_events(results)
|
||||
defer.returnValue(events)
|
||||
|
|
|
@ -35,7 +35,7 @@ what sort order was used:
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
from ._base import SQLBaseStore, cached
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.util.logutils import log_function
|
||||
|
@ -413,12 +413,32 @@ class StreamStore(SQLBaseStore):
|
|||
"get_recent_events_for_room", get_recent_events_for_room_txn
|
||||
)
|
||||
|
||||
@cached(num_args=0)
|
||||
def get_room_events_max_id(self):
|
||||
return self.runInteraction(
|
||||
"get_room_events_max_id",
|
||||
self._get_room_events_max_id_txn
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_min_token(self):
|
||||
row = yield self._execute(
|
||||
"_get_min_token", None, "SELECT MIN(stream_ordering) FROM events"
|
||||
)
|
||||
|
||||
self.min_token = row[0][0] if row and row[0] and row[0][0] else -1
|
||||
self.min_token = min(self.min_token, -1)
|
||||
|
||||
logger.debug("min_token is: %s", self.min_token)
|
||||
|
||||
defer.returnValue(self.min_token)
|
||||
|
||||
def get_next_stream_id(self):
|
||||
with self._next_stream_id_lock:
|
||||
i = self._next_stream_id
|
||||
self._next_stream_id += 1
|
||||
return i
|
||||
|
||||
def _get_room_events_max_id_txn(self, txn):
|
||||
txn.execute(
|
||||
"SELECT MAX(stream_ordering) as m FROM events"
|
||||
|
|
|
@ -46,15 +46,19 @@ class TransactionStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
def _get_received_txn_response(self, txn, transaction_id, origin):
|
||||
where_clause = "transaction_id = ? AND origin = ?"
|
||||
query = ReceivedTransactionsTable.select_statement(where_clause)
|
||||
result = self._simple_select_one_txn(
|
||||
txn,
|
||||
table=ReceivedTransactionsTable.table_name,
|
||||
keyvalues={
|
||||
"transaction_id": transaction_id,
|
||||
"origin": origin,
|
||||
},
|
||||
retcols=ReceivedTransactionsTable.fields,
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
txn.execute(query, (transaction_id, origin))
|
||||
|
||||
results = ReceivedTransactionsTable.decode_results(txn.fetchall())
|
||||
|
||||
if results and results[0].response_code:
|
||||
return (results[0].response_code, results[0].response_json)
|
||||
if result and result.response_code:
|
||||
return result["response_code"], result["response_json"]
|
||||
else:
|
||||
return None
|
||||
|
||||
|
|
|
@ -90,12 +90,16 @@ class LruCache(object):
|
|||
def cache_len():
|
||||
return len(cache)
|
||||
|
||||
def cache_contains(key):
|
||||
return key in cache
|
||||
|
||||
self.sentinel = object()
|
||||
self.get = cache_get
|
||||
self.set = cache_set
|
||||
self.setdefault = cache_set_default
|
||||
self.pop = cache_pop
|
||||
self.len = cache_len
|
||||
self.contains = cache_contains
|
||||
|
||||
def __getitem__(self, key):
|
||||
result = self.get(key, self.sentinel)
|
||||
|
@ -114,3 +118,6 @@ class LruCache(object):
|
|||
|
||||
def __len__(self):
|
||||
return self.len()
|
||||
|
||||
def __contains__(self, key):
|
||||
return self.contains(key)
|
||||
|
|
|
@ -16,6 +16,10 @@
|
|||
import random
|
||||
import string
|
||||
|
||||
_string_with_symbols = (
|
||||
string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
|
||||
)
|
||||
|
||||
|
||||
def origin_from_ucid(ucid):
|
||||
return ucid.split("@", 1)[1]
|
||||
|
@ -23,3 +27,9 @@ def origin_from_ucid(ucid):
|
|||
|
||||
def random_string(length):
|
||||
return ''.join(random.choice(string.ascii_letters) for _ in xrange(length))
|
||||
|
||||
|
||||
def random_string_with_symbols(length):
|
||||
return ''.join(
|
||||
random.choice(_string_with_symbols) for _ in xrange(length)
|
||||
)
|
||||
|
|
|
@ -17,7 +17,79 @@
|
|||
from tests import unittest
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.storage._base import cached
|
||||
from synapse.storage._base import Cache, cached
|
||||
|
||||
|
||||
class CacheTestCase(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.cache = Cache("test")
|
||||
|
||||
def test_empty(self):
|
||||
failed = False
|
||||
try:
|
||||
self.cache.get("foo")
|
||||
except KeyError:
|
||||
failed = True
|
||||
|
||||
self.assertTrue(failed)
|
||||
|
||||
def test_hit(self):
|
||||
self.cache.prefill("foo", 123)
|
||||
|
||||
self.assertEquals(self.cache.get("foo"), 123)
|
||||
|
||||
def test_invalidate(self):
|
||||
self.cache.prefill("foo", 123)
|
||||
self.cache.invalidate("foo")
|
||||
|
||||
failed = False
|
||||
try:
|
||||
self.cache.get("foo")
|
||||
except KeyError:
|
||||
failed = True
|
||||
|
||||
self.assertTrue(failed)
|
||||
|
||||
def test_eviction(self):
|
||||
cache = Cache("test", max_entries=2)
|
||||
|
||||
cache.prefill(1, "one")
|
||||
cache.prefill(2, "two")
|
||||
cache.prefill(3, "three") # 1 will be evicted
|
||||
|
||||
failed = False
|
||||
try:
|
||||
cache.get(1)
|
||||
except KeyError:
|
||||
failed = True
|
||||
|
||||
self.assertTrue(failed)
|
||||
|
||||
cache.get(2)
|
||||
cache.get(3)
|
||||
|
||||
def test_eviction_lru(self):
|
||||
cache = Cache("test", max_entries=2, lru=True)
|
||||
|
||||
cache.prefill(1, "one")
|
||||
cache.prefill(2, "two")
|
||||
|
||||
# Now access 1 again, thus causing 2 to be least-recently used
|
||||
cache.get(1)
|
||||
|
||||
cache.prefill(3, "three")
|
||||
|
||||
failed = False
|
||||
try:
|
||||
cache.get(2)
|
||||
except KeyError:
|
||||
failed = True
|
||||
|
||||
self.assertTrue(failed)
|
||||
|
||||
cache.get(1)
|
||||
cache.get(3)
|
||||
|
||||
|
||||
class CacheDecoratorTestCase(unittest.TestCase):
|
||||
|
|
|
@ -180,7 +180,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
|||
self.mock_txn.rowcount = 1
|
||||
self.mock_txn.fetchone.return_value = ("Old Value",)
|
||||
|
||||
ret = yield self.datastore._simple_update_one(
|
||||
ret = yield self.datastore._simple_selectupdate_one(
|
||||
table="tablename",
|
||||
keyvalues={"keycol": "TheKey"},
|
||||
updatevalues={"columname": "New Value"},
|
||||
|
|
|
@ -44,7 +44,7 @@ class RoomStoreTestCase(unittest.TestCase):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_room(self):
|
||||
self.assertObjectHasAttributes(
|
||||
self.assertDictContainsSubset(
|
||||
{"room_id": self.room.to_string(),
|
||||
"creator": self.u_creator.to_string(),
|
||||
"is_public": True},
|
||||
|
|
Loading…
Reference in a new issue